Zh3r0 CTF V2 - twist_and_shout

  • Author: qopruzjf
  • Date:
  • Solves: 29

Challenge

Wise men once said, “Well, shake it up, baby, now Twist and shout come on and work it on out” I obliged, not the flag is as twisted as my sense of humour

nc crypto.zh3r0.cf 5555

Source

 1from secret import flag
 2import os
 3import random
 4
 5state_len = 624*4
 6right_pad = random.randint(0,state_len-len(flag))
 7left_pad = state_len-len(flag)-right_pad
 8state_bytes = os.urandom(left_pad)+flag+os.urandom(right_pad)
 9state = tuple( int.from_bytes(state_bytes[i:i+4],'big') for i in range(0,state_len,4) )
10random.setstate((3,state+(624,),None))
11outputs = [random.getrandbits(32) for i in range(624)]
12print(*outputs,sep='\n')

Solution

The challenge is about finding the original state of the random number generator, which will contain the flag. Python’s random module uses a MT19337 Mersenne Twister. Given 624 outputs from the generator, we can find the state of the RNG at the start of the generation. The challenge comes from the fact that before any generation, the state of the RNG is “twisted”, modifying the state, meaning that the state that we recover will be the state post-twist, not pre-twist. The goal seems to be to untwist the recovered state so that we can get the flag.

The final script used to solve the challenge was written by my teammate willwam845, and this writeup is mainly to explain the ideas involved in tackling this challenge.

Recovering the RNG’s state

With a bit of searching, it’s not hard to find that you can reverse a Mersenne twister’s state from 624 outputs and its index, which is 0 in this case, as the generation happens right after a twist. This is because when the index hits 624(which it is initialized to), the RNG first performs a twist, then sets its index to 0 before outputting anything.

When a Mersenne twister outputs values(specifically 32 bit outputs), it takes a look at the value at the current index in its state, then tempers it. The result of the temper is the output of the RNG. This means that if we can untemper the outputs, we can trivially recover the state of the RNG post-twist. This is what willwam ended up doing; I wasn’t aware of how to reverse the temper function, so I simply used the same Mersenne twister solver that we used in the Real Mersenne challenge. A link to the writeup for that is here, by Quintec. It showcases the usefulness of the solver more clearly. When comparing the two solves scripts at the end, just keep in mind that my use of the solver effectively has the same result as untempering all 624 values given by the server.

The source for the solver(it is not ours and is open source): link

Undoing the twist

Here is the main substance of the challenge. We want to figure out how to untwist the recovered state of the RNG so that we can get the state with the flag. First, we looked around for the source of the twist function. Here is the one I ended up using:

 1(w, n, m, r) = (32, 624, 397, 31)
 2a = 0x9908B0DF
 3(u, d) = (11, 0xFFFFFFFF)
 4(s, b) = (7, 0x9D2C5680)
 5(t, c) = (15, 0xEFC60000)
 6l = 18
 7f = 1812433253
 8lower_mask = 0x7FFFFFFF
 9upper_mask = 0x80000000
10
11def twist(MT):
12    for i in range(0, n):
13        x = (MT[i] & upper_mask) + (MT[(i+1) % n] & lower_mask)
14        xA = x >> 1
15        if (x % 2) != 0:
16            xA = xA ^ a
17        MT[i] = MT[(i + m) % n] ^ xA

The original code is taken from here. I just made some changes to the parameters as they are off from the Mersenne twister python’s random module uses. Other twist functions we found were pretty much the same. (I later noticed pretty much the same twist code is also in the aforementioned mersenne cracker source.)

So, what the twist function does is for each value in the state, it takes its current value’s top bit(MT[i] & upper_mask) and the bottom 31 bits of the next value in the state(MT[(i+1) % n] & lower_mask). It then cuts off the last bit of this (x » 1), and based on that last bit, XORs the result with the parameter a. Finally, the value is replaced with all this XORed with the value in the state at an index m away.

Now, let’s try reversing. We want to go backwards starting from the last value modified, being the one at index 623, and undo the XOR operations. Here is some psuedocode for the process(for reversing a single state at index i):

xA = MT[(i + m) % n] ^ MT[i] // undo MT[i] = MT[(i + m) % n] ^ xA
if pretwist_MT[(i+1) % n] was odd:
	xA ^= a // undo xA = xA ^ a
x = (xA << 1) + (last bit of pretwist_MT[(i+1) % n])

A few glaring issues immediately become apparent:

  1. We don’t know MT[(i + m) % n] for all values of i. To be precise, consider i=0. Then no twisting has occurred, and MT[(i + m) % n] is from the pre-twist state. In contrast, for index 623, it is the last value being twisted, so MT[(i + m) % n] uses an already twisted value, which we do have access to(from the recovered state).

  2. We don’t know pretwist_MT[(i+1) % n], so we don’t know whether or not to do xA ^= a for sure. In exchange, we can assume the last bit, and check 2 possibilities each time. However, if we have to do this for every state, this becomes unfeasible, since with 624 states we’d have to check 2^624 possibilities to get the full state correct.

  3. Based on our assumption for x, we actually get more information about pretwist_MT[(i + 1) % n] rather than pretwist_MT[i].

With a bit of thinking, we can actually get around most of these issues, to some degree. For the first issue, we can simply check values of i for which MT[(i + m) % n] is an already twisted value. With m = 397, this would be from i=227 to i=623. We don’t have to necessarily worry about the other values, as we only have the untwist the parts of the state that include the flag. In other words, if we can succesfully untwist the aforementioned portion and the flag is in it, we’re still good. Naturally, we can keep reconnecting until we get a result which has the flag in this portion.

For the second issue, we can actually “verify” if a single untwisted value is correct because we know the flag is readable, while the other values in the state are randomly generated and will not necessarily be readable(in fact, it is unlikely for all the bytes in a 32-bit value to be readable this way). Being able to verify state values independently also drastically reduces the possibilities we have to check from 2^(624 - 227) to 2*(624 - 277), which is easily feasible because we do that part locally. This is the idea that willwam came up with to complete the final stretch of the challenge.

For the third issue, since we get 31 bits of information about pretwist_MT[(i + 1) % n] and only one bit about pretwist_MT[i], we can simply use each x from examining index i to determine the bottom 31 bits of pretwist_MT[(i + 1) % n]. Using this, we still don’t know the top bit of pretwist_MT[(i + 1) % n]. However, just like before, we can simply test the two possibilities for it, increasing the possibilites to 4*(624 - 277). Naturally, this is still very feasible.

One last thing to note: we don’t actually need to check index 623 anymore, since MT[(i + 1) % n] will be MT[0], which is already twisted. So, it won’t tell us anything new.

Here is the script that I used:

  1from Crypto.Util.number import long_to_bytes
  2from pwn import *
  3from z3 import *
  4from random import Random
  5from itertools import count
  6import time
  7import logging
  8
  9logging.basicConfig(format='STT> %(message)s')
 10logger = logging.getLogger()
 11logger.setLevel(logging.DEBUG)
 12
 13SYMBOLIC_COUNTER = count()
 14
 15class Untwister:
 16    def __init__(self):
 17        name = next(SYMBOLIC_COUNTER)
 18        self.MT = [BitVec(f'MT_{i}_{name}', 32) for i in range(624)]
 19        self.index = 0
 20        self.solver = Solver()
 21
 22    #This particular method was adapted from https://www.schutzwerk.com/en/43/posts/attacking_a_random_number_generator/
 23    def symbolic_untamper(self, solver, y):
 24        name = next(SYMBOLIC_COUNTER)
 25
 26        y1 = BitVec(f'y1_{name}', 32)
 27        y2 = BitVec(f'y2_{name}' , 32)
 28        y3 = BitVec(f'y3_{name}', 32)
 29        y4 = BitVec(f'y4_{name}', 32)
 30
 31        equations = [
 32            y2 == y1 ^ (LShR(y1, 11)),
 33            y3 == y2 ^ ((y2 << 7) & 0x9D2C5680),
 34            y4 == y3 ^ ((y3 << 15) & 0xEFC60000),
 35            y == y4 ^ (LShR(y4, 18))
 36        ]
 37
 38        solver.add(equations)
 39        return y1
 40
 41    def symbolic_twist(self, MT, n=624, upper_mask=0x80000000, lower_mask=0x7FFFFFFF, a=0x9908B0DF, m=397):
 42        '''
 43            This method models MT19937 function as a Z3 program
 44        '''
 45        MT = [i for i in MT] #Just a shallow copy of the state
 46
 47        for i in range(n):
 48            x = (MT[i] & upper_mask) + (MT[(i+1) % n] & lower_mask)
 49            xA = LShR(x, 1)
 50            xB = If(x & 1 == 0, xA, xA ^ a) #Possible Z3 optimization here by declaring auxiliary symbolic variables
 51            MT[i] = MT[(i + m) % n] ^ xB
 52
 53        return MT
 54
 55    def get_symbolic(self, guess):
 56        name = next(SYMBOLIC_COUNTER)
 57        ERROR = 'Must pass a string like "?1100???1001000??0?100?10??10010" where ? represents an unknown bit'
 58
 59        assert type(guess) == str, ERROR
 60        assert all(map(lambda x: x in '01?', guess)), ERROR
 61        assert len(guess) <= 32, "One 32-bit number at a time please"
 62        guess = guess.zfill(32)
 63
 64        self.symbolic_guess = BitVec(f'symbolic_guess_{name}', 32)
 65        guess = guess[::-1]
 66
 67        for i, bit in enumerate(guess):
 68            if bit != '?':
 69                self.solver.add(Extract(i, i, self.symbolic_guess) == bit)
 70
 71        return self.symbolic_guess
 72
 73
 74    def submit(self, guess):
 75        '''
 76            You need 624 numbers to completely clone the state.
 77                You can input less than that though and this will give you the best guess for the state
 78        '''
 79        if self.index >= 624:
 80            name = next(SYMBOLIC_COUNTER)
 81            next_mt = self.symbolic_twist(self.MT)
 82            self.MT = [BitVec(f'MT_{i}_{name}', 32) for i in range(624)]
 83            for i in range(624):
 84                self.solver.add(self.MT[i] == next_mt[i])
 85            self.index = 0
 86
 87        symbolic_guess = self.get_symbolic(guess)
 88        symbolic_guess = self.symbolic_untamper(self.solver, symbolic_guess)
 89        self.solver.add(self.MT[self.index] == symbolic_guess)
 90        self.index += 1
 91
 92    def get_random(self):
 93        '''
 94            This will give you a random.Random() instance with the cloned state.
 95        '''
 96        logger.debug('Solving...')
 97        start = time.time()
 98        self.solver.check()
 99        model = self.solver.model()
100        end = time.time()
101        logger.debug(f'Solved! (in {round(end-start,3)}s)')
102
103        #Compute best guess for state
104        state = list(map(lambda x: model[x].as_long(), self.MT))
105        result_state = (3, tuple(state+[self.index]), None)
106        r = Random()
107        r.setstate(result_state)
108        return r
109
110# parameters for python's random mersenne twister. Not all are used.
111(w, n, m, r) = (32, 624, 397, 31)
112a = 0x9908B0DF
113(u, d) = (11, 0xFFFFFFFF)
114(s, b) = (7, 0x9D2C5680)
115(t, c) = (15, 0xEFC60000)
116l = 18
117f = 1812433253
118lower_mask = 0x7FFFFFFF
119upper_mask = 0x80000000
120host, port = "crypto.zh3r0.cf", 5555
121
122def get(MT):
123	valids = ['']*624
124	for i in range(228, 624):
125		valids[i] = [] # for storing the bytes if they are all printable
126		xA = MT[i-1] ^ MT[(i-1 + m) % 624]
127		# possibilities..
128		xA1, xA2 = xA, xA ^ a # 2 possibilities based on last bit of original x
129		x1, x2 = xA1 << 1, (xA2 << 1) + 1
130		p11, p12 = (x1 & lower_mask), (x1 & lower_mask) | upper_mask # 1st bit 0, 1
131		p21, p22 = (x2 & lower_mask), (x2 & lower_mask) | upper_mask # 1st bit 0, 1
132		for p in (p11, p12, p21, p22):
133			if all([b in range(32, 128) for b in long_to_bytes(p)]): # are all the bytes in the untwisted value printable?
134				valids[i].append(long_to_bytes(p).decode())
135	return ''.join(sum(list(filter(lambda x: x, valids)), []))
136
137# recover the state
138ut = Untwister()
139r = remote(host, port)
140for i in range(624):
141	ut.submit(bin(int(r.recvline().decode().strip('\n')))[2:].zfill(32))
142prophet = ut.get_random()
143state = list(prophet.getstate()[1][:-1])
144
145# untwist and get the flag
146print(get(state))

And for reference, the script that willwam used and got us the solve(I added some comments for clarity):

 1from Crypto.Util.number import long_to_bytes
 2def untempering(y):
 3    y ^= (y >> 18)
 4    y ^= (y << 15) & 0xefc60000
 5    y ^= ((y <<  7) & 0x9d2c5680) ^ ((y << 14) & 0x94284000) ^ ((y << 21) & 0x14200000) ^ ((y << 28) & 0x10000000)
 6    y ^= (y >> 11) ^ (y >> 22)
 7    return y
 8
 9def get_seeds(m0, m227):
10    seeds = []
11    # 2 possibilities based on last bit of assumed x. a = 0x9908b0df
12    y_even = (m227 ^ m0) << 1
13    y_odd = (((m227 ^ m0 ^ 0x9908b0df) << 1) & 0xffffffff) | 1
14    for y in (y_even, y_odd):
15        for oldm227_upperbit in (0, 0x80000000):
16            for oldm228_upperbit in (0, 0x80000000):
17                n = (y ^ oldm227_upperbit ^ oldm228_upperbit) & 0xffffffff
18                seeds.append(n)
19    return list(set(seeds))
20
21data = [untempering(int(x)) for x in open("af.txt").read().split("\n")] # recovering the state post-twist
22o = b""
23for i in range(397):
24  seeds = get_seeds(data[i],data[(i+227)%624])
25  for s in seeds:
26    if len(str(long_to_bytes(s))) < 9: # checking if the untwisted value is valid
27      o += (long_to_bytes(s))        
28      
29print(o)

(Some code is taken from an old writeup here.) It seems that the len(str(long_to_bytes(s))) < 9 condition was another way to check the validity of untwisted value.

The flag: zh3r0{7h3_fu7ur3_m1gh7_b3_c4p71v471ng_bu7_n0w_y0u_kn0w_h0w_t0_l00k_a7_7h3_p457}

Thanks to Zh3r0 CTF for the challenge!