abstract rainbow and Google CTF logo

Google CTF 2021 Pythia

Top News03/07/2022

By Florian Picca

We have access to a decryption oracle, but the algorithm used is AES-GCM. The oracle indicates if the message is valid or not. The queries are very limited so we can use a partitionning oracle to reduce the number of queries and recover the keys.

Details
 

  • Category : crypto
  • Points : 173
  • Solves : 65

Description

“Yet another oracle, but the queries are costly and limited so be frugal with them.”

nc pythia.2021.ctfcompetition.com 1337

Code source :

#!/usr/bin/python -u
import random
import string
import time

from base64 import b64encode, b64decode
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
from cryptography.hazmat.primitives.kdf.scrypt import Scrypt

max_queries = 150
query_delay = 10

passwords = [bytes(''.join(random.choice(string.ascii_lowercase) for _ in range(3)), 'UTF-8') for _ in range(3)]
flag = open("flag.txt", "rb").read()

def menu():
    print("What you wanna do?")
    print("1- Set key")
    print("2- Read flag")
    print("3- Decrypt text")
    print("4- Exit")
    try:
        return int(input(">>> "))
    except:
        return -1

print("Welcome!\n")

key_used = 0

for query in range(max_queries):
    option = menu()

    if option == 1:
        print("Which key you want to use [0-2]?")
        try:
            i = int(input(">>> "))
        except:
            i = -1
        if i >= 0 and i <= 2:
          key_used = i
        else:
          print("Please select a valid key.")
    elif option == 2:
        print("Password?")
        passwd = bytes(input(">>> "), 'UTF-8')

        print("Checking...")
        # Prevent bruteforce attacks...
        time.sleep(query_delay)
        if passwd == (passwords[0] + passwords[1] + passwords[2]):
            print("ACCESS GRANTED: " + flag.decode('UTF-8'))
        else:
            print("ACCESS DENIED!")
    elif option == 3:
        print("Send your ciphertext ")

        ct = input(">>> ")
        print("Decrypting...")
        # Prevent bruteforce attacks...
        time.sleep(query_delay)
        try:
            nonce, ciphertext = ct.split(",")
            nonce = b64decode(nonce)
            ciphertext = b64decode(ciphertext)
        except:
            print("ERROR: Ciphertext has invalid format. Must be of the form \"nonce,ciphertext\", where nonce and ciphertext are base64 strings.")
            continue

        kdf = Scrypt(salt=b'', length=16, n=2**4, r=8, p=1, backend=default_backend())
        key = kdf.derive(passwords[key_used])
        try:
            cipher = AESGCM(key)
            plaintext = cipher.decrypt(nonce, ciphertext, associated_data=None)
        except:
            print("ERROR: Decryption failed. Key was not correct.")
            continue

        print("Decryption successful")
    elif option == 4:
        print("Bye!")
        break
    else:
        print("Invalid option!")
    print("You have " + str(max_queries - query) + " trials left...\n")

Understanding the problem

The server generates 3 passwords composed of 3 lowercase letters and derives 3 keys from them.

We can submit AES-GCM encryped messages and the server will check if the decryption succeeded or not. We can also specify which of the 3 keys the server should use when decrypting.

To obtain the flag we have to recover all 3 passwords. To prevent brute force attacks, we only have a total of 150 queries and each of them is answered after a 10-second delay.

Solving the problem

The server is clearly acting as a decryption oracle, but unlike with CBC it does not check that the padding is valid (because there is none), but rather checks that the GCM tag is valid. This kind of oracles are called partitionning oracles. They are not limited to AES-GCM and can affect other AEAD encryption schemes. The attack is presented in this paper. In particular, chapter 3.1 describes a way to build a cipher text having a given tag and nonce, that will be valid for a set of different keys. They call it a key multi-collision attack and even provide an opensource implementation.

In our case, there are only 26^3 possible passwords, thus our key space is rather small. We could use the multi-collision attack to forge a valid cipher text for half of the keys and query the oracle. If the oracle says the decryption is valid, we can deduce that the real key must be one of those we used to forge our cipher text. We can than halve the search space once again and recover the real key using a simple binary search algorithm. This would have to be done 3 times, to recover the 3 keys and from them recover the passwords.

We could expect our binary search to completely recover a single key in about log2(26^3) = 14 steps, making the overall attack take less than 50 queries. However, we would need to compute a cipher text that is valid under 26^3/2 = 8788 keys, which would take way too long. The time complexity of finding such a cipher text is around O(k^2), with k being the size of the key space.

We can instead split the search in chunks of 500 keys, which will requires more queries but take less time to forge valid cipher texts.

We will have to search in at most 36 chunks. If we find that the key lies in a chunk, we can than use binary search to recover it, adding 9 additionnal queries. In this way, we can fully recover a single key in at most 45 queries, making the entire attack possible in less than 150 queries.

Implementing the solution - Building the key set

import pickle
import itertools
import string
from cryptography.hazmat.primitives.kdf.scrypt import Scrypt
from cryptography.hazmat.backends import default_backend

# build all keys and pickle them for next time
try:
    keys = pickle.load(open("keys.pickle", "rb"))
    except FileNotFoundError:
        keys = {}
        for t in itertools.product(string.ascii_lowercase, repeat=3):
            pwd = "".join(t).encode()
            print(pwd)
            kdf = Scrypt(salt=b'', length=16, n=2 ** 4, r=8, p=1, backend=default_backend())
            keys[kdf.derive(pwd)] = pwd
            pickle.dump(keys, open("keys.pickle", "wb"))

We can make a dictionnary storing all the possible keys and their associated password. This will make the password recovery process easier, as we will only need to look for an entry in the dictionnary. Pickle is used to store this computation, so we do not have to recompute everything if we restart the script (which happens a lot during testing).

Implementing the solution - Implementing the multi-collision attack

from sage.all import *
from Crypto.Util.number import long_to_bytes
from Crypto.Cipher import AES
from bitstring import BitArray
import functools

P = PolynomialRing(GF(2), "x")
x = P.gen()
p = x ** 128 + x ** 7 + x ** 2 + x + 1
GH = GF(2 ** 128, "a", modulus=p)

def bytes_to_GH(data):
    """Simply convert bytes to field elements"""
    return GH([int(v) for v in BitArray(data).bin])

def GH_to_bytes(element):
    """Simply convert field elements to bytes"""
    return BitArray(element.polynomial().list()).tobytes().ljust(16, b'\x00')

def multi_collide_gcm(keyset, nonce, tag):
    R = PolynomialRing(GH, "r")
    L = bytes_to_GH(long_to_bytes(128 * len(keyset), 16))
    N = nonce + b'\x00\x00\x00\x01'
    T = bytes_to_GH(tag)
    interpolation_pts = []
    for key in keyset:
        H = bytes_to_GH(AES.new(key, AES.MODE_ECB).encrypt(b'\x00' * 16))
        B = ((L * H) + bytes_to_GH(AES.new(key, AES.MODE_ECB).encrypt(N)) + T) * H**-2
        interpolation_pts.append((H, B))
    sol = R.lagrange_polynomial(interpolation_pts)
    C_blocks = [GH_to_bytes(c) for c in sol.list()[::-1]]
    return b''.join(C_blocks) + tag

# cache the results for speedup, could have precomputed them but it's not that slow
@functools.lru_cache(maxsize=None)
def forge(start, end):
    keyset = list(keys.keys())[start:end]
    r = multi_collide_gcm(keyset, b'\x00'*12, b'\x01'*16)
    return r

This implementation is a rewrite of the opensource implementation provided in the paper. We use memoization to speed up the computation of forged cipher texts when one has already been forged for the same keys.

We can now split our key space in chunks and use binary search to recover the key afterwards.

Implementing the solution - Full script 

import pickle
import itertools
import string
import base64
from sage.all import *
from Crypto.Util.number import long_to_bytes
from Crypto.Cipher import AES
from pwn import *
import functools
from cryptography.hazmat.primitives.kdf.scrypt import Scrypt
from cryptography.hazmat.backends import default_backend
from bitstring import BitArray

def bytes_to_GH(data):
    """Simply convert bytes to field elements"""
    return GH([int(v) for v in BitArray(data).bin])

def GH_to_bytes(element):
    """Simply convert field elements to bytes"""
    return BitArray(element.polynomial().list()).tobytes().ljust(16, b'\x00')

def multi_collide_gcm(keyset, nonce, tag):
    R = PolynomialRing(GH, "r")
    L = bytes_to_GH(long_to_bytes(128 * len(keyset), 16))
    N = nonce + b'\x00\x00\x00\x01'
    T = bytes_to_GH(tag)
    interpolation_pts = []
    for key in keyset:
        H = bytes_to_GH(AES.new(key, AES.MODE_ECB).encrypt(b'\x00' * 16))
        B = ((L * H) + bytes_to_GH(AES.new(key, AES.MODE_ECB).encrypt(N)) + T) * H**-2
        interpolation_pts.append((H, B))
    sol = R.lagrange_polynomial(interpolation_pts)
    C_blocks = [GH_to_bytes(c) for c in sol.list()[::-1]]
    return b''.join(C_blocks) + tag

# cache the results for speedup, could have precomputed them but it's not that slow
@functools.lru_cache(maxsize=None)
def forge(start, end):
    keyset = list(keys.keys())[start:end]
    r = multi_collide_gcm(keyset, b'\x00'*12, b'\x01'*16)
    return r

def setKey(i):
    conn.sendline(b"1")
    conn.recvuntil(b">>> ")
    conn.sendline(f"{i}".encode())
    conn.recvuntil(b">>> ")

def decrypt(c):
    conn.sendline(b"3")
    conn.recvuntil(b">>> ")
    t = f"{nonce.decode()},{base64.b64encode(c).decode()}"
    conn.sendline(t.encode())
    conn.recvline()
    r = conn.recvline()
    conn.recvuntil(b">>> ")
    if b"Decryption failed." in r:
        return False
    return True

def bsearch(start, end):
    global tries
    mid = (end + start)//2
    if end - start == 1:
        return start
    tries -= 1
    print(f"tries left : {tries}")
    if decrypt(forge(start, mid)):
        return bsearch(start, mid)
    else:
        return bsearch(mid, end)

    
if __name__ == "__main__":
    # build all keys and pickle them for next time
    try:
        keys = pickle.load(open("keys.pickle", "rb"))
    except FileNotFoundError:
        keys = {}
        for t in itertools.product(string.ascii_lowercase, repeat=3):
            pwd = "".join(t).encode()
            print(pwd)
            kdf = Scrypt(salt=b'', length=16, n=2 ** 4, r=8, p=1, backend=default_backend())
            keys[kdf.derive(pwd)] = pwd
        pickle.dump(keys, open("keys.pickle", "wb"))

    # global variables
    P = PolynomialRing(GF(2), "x")
    x = P.gen()
    p = x ** 128 + x ** 7 + x ** 2 + x + 1
    GH = GF(2 ** 128, "a", modulus=p)

    tries = 150
    N = 26**3
    B = 500
    nonce = base64.b64encode(b'\x00' * 12)

    # recover the passwords and get the flag
    conn = remote("pythia.2021.ctfcompetition.com", 1337)
    # local testing
    # conn = process("./pythia.py")
    conn.recvuntil(b">>> ")

    password = b''
    # 3 passwords in total
    for j in range(3):
        # search in chunks
        for i in range(0, N, B):
            # if key is in this chunk
            if decrypt(forge(i, i + B)):
                print("Entering binary search...")
                index = bsearch(i, i+B)
                pwd = keys[list(keys.keys())[index]]
                password += pwd
                print(f"Found password : {pwd}")
                break
            tries -= 1
            print(f"tries left : {tries}")
        if j < 2:
            setKey(j+1)
            tries -= 1
            print(f"tries left : {tries}")
    print(f"full password = {password.decode()}")
    conn.sendline(b"2")
    conn.recvuntil(b">>> ")
    conn.sendline(password)
    conn.recvuntil(b"ACCESS GRANTED: ")
    print(f"Flag : {conn.recvline().decode()}")
    conn.close()

Running it gives :

[x] Opening connection to pythia.2021.ctfcompetition.com on port 1337
[x] Opening connection to pythia.2021.ctfcompetition.com on port 1337: Trying 34.77.25.116
[+] Opening connection to pythia.2021.ctfcompetition.com on port 1337: Done
tries left : 149
tries left : 148
tries left : 147
tries left : 146
tries left : 145
tries left : 144
tries left : 143
tries left : 142
tries left : 141
tries left : 140
tries left : 139
tries left : 138
tries left : 137
tries left : 136
tries left : 135
tries left : 134
tries left : 133
tries left : 132
tries left : 131
tries left : 130
tries left : 129
tries left : 128
tries left : 127
tries left : 126
tries left : 125
tries left : 124
tries left : 123
tries left : 122
tries left : 121
tries left : 120
tries left : 119
tries left : 118
Entering binary search...
tries left : 117
tries left : 116
tries left : 115
tries left : 114
tries left : 113
tries left : 112
tries left : 111
tries left : 110
tries left : 109
Found password : b'xvw'
tries left : 108
tries left : 107
tries left : 106
tries left : 105
tries left : 104
tries left : 103
tries left : 102
tries left : 101
tries left : 100
tries left : 99
tries left : 98
tries left : 97
tries left : 96
tries left : 95
tries left : 94
tries left : 93
tries left : 92
tries left : 91
tries left : 90
tries left : 89
tries left : 88
tries left : 87
tries left : 86
tries left : 85
tries left : 84
tries left : 83
tries left : 82
tries left : 81
tries left : 80
tries left : 79
tries left : 78
Entering binary search...
tries left : 77
tries left : 76
tries left : 75
tries left : 74
tries left : 73
tries left : 72
tries left : 71
tries left : 70
tries left : 69
Found password : b'woc'
tries left : 68
tries left : 67
tries left : 66
tries left : 65
tries left : 64
tries left : 63
tries left : 62
tries left : 61
tries left : 60
tries left : 59
Entering binary search...
tries left : 58
tries left : 57
tries left : 56
tries left : 55
tries left : 54
tries left : 53
tries left : 52
tries left : 51
tries left : 50
Found password : b'hcj'
full password = xvwwochcj
Flag : CTF{gCm_1s_n0t_v3ry_r0bust_4nd_1_sh0uld_us3_s0m3th1ng_els3_h3r3}

[*] Closed connection to pythia.2021.ctfcompetition.com port 1337

Flag : CTF{gCm_1s_n0t_v3ry_r0bust_4nd_1_sh0uld_us3_s0m3th1ng_els3_h3r3}

Our news