Greetings! After the sphinx challenge from Google CTF 2025, I took interest in differential cryptanalysis, so I was happy to try and solve this one. Here I’ll describe my solution, which seems to be one of the most efficient ways to solve this challenge, so I hope it will be interesting to you. Let’s get straight into it.

The challenge

The challenge includes a couple of files. This is due to the distribution including both the optimized C code as well as a SageMath file. Having a fast implementation allows the input limit on the remote to be a whopping 60 MB, from which (spoiler alert) I used a suprisingly little amount of 64 KB.

I’ll be using Python with SageMath. My code, edited and expanded from chall.sage, can be found here.

I will index everything from 0. Let’s also define by .

The cipher uses four round keys. The first one is the provided key, others are derived from it with SHA-256. Let’s denote those byte vectors .

The other components are:

  • a binary invertible 16x16 matrix . Although , this did not prove to be helpful in my attack.
  • an . Seemingly, it does not have useful features.
  • a collection of 16 linear functions .
  • I’ll denote for numbers and when for vectors.

All the arithmetic with bytes is done modulo .

The encryption consists of three rounds. If denotes the state after round , then and . Finally, .

The idea

First of all, notice how and so we may as well look for differentials from that point in the cipher. The multiplication by is the only source of diffusion, so is then similar to (the zeros are the same).

What if ? Then is similar to .

This fact turns out to be very helpful. Since is binary, both and are binary as well. So consists only of components equal to or . So there are two possible patterns of zeros here, depending on if or not. If we can distinguish those, we get an important piece of information about and .

Luckily, is sparse enough that, in the last multiplication, it maps many of those patterns to distinct patterns in the output. For example, here is one of the applicable characteristics:

Values marked with an asterisk are zero when and are (usually) nonzero when . There are three such places in the ciphertext, so it will be a good guess that the probability of all three of those being zero when is , which suggests that over all of the pairs this is a good way to identify the “right” ones. While we’re at it, the probability of hitting is , so we can expect around of those over all of the pairs.

So now we can get a bunch of pairs with . But what are and ? Remember, , so (call , )

Similarly, .

So , aka really means . is close to being injective in terms of , so we can construct a relatively small collection of possible pairs . I built a certain LUT to speed it up.

This is only for a single pair, but clearly we can get more of those! As you might guess, only the real pair of appears for every one of those collections. In reality, due to false right pairs and other things, it’s not in 100% of them, but is still easily identifiable.

By analyzing a couple of right pairs for each index, we can recover and so , which is the master key, so we’re done here.

The implementation

To make the attack more efficient and easier to code, considering we must encrypt everything we want at once, I used this structure:

def diff_at(i, x):
    diff = [0] * 16
    diff[i] = x
    return vector(diff)
 
def gen_pts(base, i):
    pts = []
    for x in range(256):
        diff = inv_M * diff_at(i, x)
        pts.append(base + diff)
    return pts
 
pts_collections = [gen_pts(base, i) for i in range(16)]

This allows to get a pair for any 2-place difference in we might want. This is just plaintexts, or 64 KB.

Here are the functions I used to bruteforce the characteristics:

Mz = M.change_ring(ZZ)
 
def get_ct_diff(v1_diff):
    v1 = M * v1_diff
    v1 = vector(ZZ, [x != 0 for x in v1])
    v2 = Mz * v1
    return tuple(x != 0 for x in v2)
 
def get_char(i, j):
    v1 = get_ct_diff(diff_at(i, 1) + diff_at(j, 1))
    v2 = get_ct_diff(diff_at(i, 1) + diff_at(j, 255))
    diff = tuple(x - y for x, y in zip(v1, v2))
    assert all(x == 0 or x == 1 for x in diff)
    return tuple(i for i in range(16) if diff[i])
 
k_recovered = [None] * 16
 
def find_char(i):
    js = []
    for jj in range(16):
        if i == jj:
            continue
        char = get_char(i, jj)
        if len(char) >= 2:
            js.append(jj)
    best_js = [j for j in js if k_recovered[j] is None]
    if best_js:
        return best_js[0]
    elif js:
        return js[0]
    else:
        return None

Characteristics with only 1 distinguishing byte are expected to provide more false right pairs then real ones, so they are discarded. There are fitting characteristics for all bytes, except for the 15-th. I just bruteforced that byte directly, as you will see.

Constructing the LUTs:

def diffs_lut(i):
    diffs = {}
    for dx in range(256):
        if dx not in diffs:
            diffs[dx] = {}
        for x in range(256):
            diffs[dx][(f(i, (x + dx) % 256) - f(i, x)) % 256] = x
    return diffs
 
diffs_luts = [diffs_lut(i) for i in range(16)]

It would be a bit better to consider all possible values here, but I just went with the simple and fast solution.

Finding right pairs:

def gather_pairs(i, j):
    res = set()
 
    char = get_char(i, j)
    prs1 = pairs_collections[i]
    prs2 = pairs_collections[j]
 
    for pt1, ct1 in prs1:
        for pt2, ct2 in prs2:
            ctdiff = [x != y for x, y in zip(ct1, ct2)]
            if all(not ctdiff[i] for i in char):
                m_pt1 = M * vector(pt1)
                x1i = m_pt1[i]
                x1j = m_pt1[j]
                m_pt2 = M * vector(pt2)
                m_v0 = m_pt2 - m_pt1
                dxi, dxj = m_v0[i], m_v0[j]
 
                res.add((dxi, dxj, x1i, x1j))
 
    return res

This could be a bit faster without directly trying all the pairs, but isn’t really worth it.

Finding a single byte.

def find_k_byte(i):
    j = find_char(i)
    if j is None:
        return None
 
    lut = luts[j]
    res = gather_pairs(i, j)
    sol = {}
 
    for dxi, dxj, x1i, x1j in res:
        pairs = set()
        for k0 in range(256):
            x = (f(i, x1i + k0) - f(i, x1i + k0 + dxi)) % 256
            if x in lut[dxj]:
                k1 = lut[dxj][x] - x1j
                pairs.add((k0, k1))
        for pair in pairs:
            sol[pair] = sol.get(pair, 0) + 1
 
    right_pairs = [pair for pair, v in sol.items() if v > 50]
    assert len(right_pairs) == 1, right_pairs
    right_pair = right_pairs[0]
  
    if k_recovered[i] is not None:
        assert k_recovered[i] == right_pair[0]
    else:
        k_recovered[i] = right_pair[0]
    if k_recovered[j] is not None:
        assert k_recovered[j] == right_pair[1]
    else:
        k_recovered[j] = right_pair[1]

Pretty neat and fast. 50 matches for a right pair is an arbitrary number, but it works fine.

The final steps:

def ensure_k_recovered():
    for i in range(16):
        if k_recovered[i] is not None:
            continue
        find_k_byte(i)
        print(k_recovered)
 
def get_key_candidates():
    keys = []
    idx = [i for i in range(16) if k_recovered[i] is None]
    assert len(idx) == 1
    i = idx[0]
    for x in range(256):
        k_candidate = k_recovered.copy()
        k_candidate[i] = x
        key = inv_M * vector(k_candidate)
        key = list(key)
        keys.append(key)
 
    return keys
 
def get_key():
    ensure_k_recovered()
    pt, ct = pairs_collections[0][0]
    valid = []
    for key in get_key_candidates():
        if encrypt_block(pt, key=key) == ct:
            valid.append(key)
    assert len(valid) == 1
    return valid[0]

To actually get the ciphertexts inside the program, I just made it to write a file with the plaintexts and ask me to encrypt it.

Afterthoughts

In reality, a couple of right pairs is enough, so you can just cut on the encryptions and it will still work. The byte which has to be bruteforced anyway can be skipped, too. I was able to easily reduce it to just encryptions, or 15 KB.

Theoretically, just two right pairs is expected to be enough, and we can skip another byte since is not a lot. So, the expected theoretical amount of ciphertext is even lower: 322 blocks, or a little over 5 KB, while retaining about the same compute.

I want to note that, besides being efficient, this is one of the most complicated ways to solve this task! I’m not (yet) experienced in differential cryptanalysis, so I’m kind of rolling my own. Make sure to check out the other solutions for more interesting crypto!