Zh3r0 CTF V2 - Real Mersenne

crypto rng

Do you believe in games of luck? I hope you make your guesses real or you'll be floating around.

nc crypto.zh3r0.cf 4444

Source Analysis

import random
from secret import flag
from fractions import Fraction

def score(a,b):
    if abs(a-b)<1/2**10:
        # capping score to 1024 so you dont get extra lucky
        return Fraction(2**10)
    return Fraction(2**53,int(2**53*a)-int(2**53*b))

total_score = 0
for _ in range(2000):
    try:
        x = random.random()
        y = float(input('enter your guess:\n'))
        round_score = score(x,y)
        total_score+=float(round_score)
        print('total score: {:0.2f}, round score: {}'.format(
            total_score,round_score))
        if total_score>10**6:
            print(flag)
            exit(0)
    except:
        print('Error, exiting')
        exit(1)
else:
    print('Maybe better luck next time')

We can see that we have to try to guess floats generated by random.random(), and we have 2000 attempts to get 1,000,000 score. The strategy is then to clone the state of the random with the first guesses, and then use at most 977 guesses in the end to get the required number of points.

Python's Random

Python's random module uses a Mersenne Twister inside, as hinted by the title of the challenge. So how does a Mersenne Twister work? The short answer is through a series of states, which each consist of 624 integers (in most cases, including this one, 32-bit). When you ask for random values, some of these 624 values are used up, and each value is used only once. For example, if you call random.getrandbits(32), exactly one 32-bit state value is used, after being tempered (slightly modified), to make it harder to reverse. Once all 624 values have been used, they are "twisted" to generate a set of 624 new values, and the cycle repeats. (The generator has a period of $2^{19937} - 1$, which is why we sometimes refer to it as MT19937.)

But how does that produce floats between 0 and 1? Well let's take a look at some source code.

/* random_random is the function named genrand_res53 in the original code;
 * generates a random number on [0,1) with 53-bit resolution; note that
 * 9007199254740992 == 2**53; I assume they're spelling "/2**53" as
 * multiply-by-reciprocal in the (likely vain) hope that the compiler will
 * optimize the division away at compile-time.  67108864 is 2**26.  In
 * effect, a contains 27 random bits shifted left 26, and b fills in the
 * lower 26 bits of the 53-bit numerator.
 * The original code credited Isaku Wada for this algorithm, 2002/01/09.
 */

/*[clinic input]
_random.Random.random
  self: self(type="RandomObject *")
random() -> x in the interval [0, 1).
[clinic start generated code]*/

static PyObject *
_random_Random_random_impl(RandomObject *self)
/*[clinic end generated code: output=117ff99ee53d755c input=afb2a59cbbb00349]*/
{
    uint32_t a=genrand_uint32(self)>>5, b=genrand_uint32(self)>>6;
    return PyFloat_FromDouble((a*67108864.0+b)*(1.0/9007199254740992.0));
}

So we see that calling random.random() uses up 2 state values, discards some trailing bits from them, and combines them in the numerator of some fraction to generate a random float between 0 and 1. More specifically, 27 bits and 26 bits respectively of two states are combined and divided by $2^{53}$.

The Plan

There are still a lot of issues we will need to deal with, but a hazy plan can be formed at this point. We can pass in an easy value, such as 0, to the guess function, and use the round score to solve for the numerator of the fraction in the code above (since the score is conveniently given as a Fraction with values multiplied by $2^{53}$!). Each float recovers 2 state values, so repeating this 312 times should give us 624 state values, enough to clone the random state and predict the next values perfectly, with ~800 guesses to spare at the end.

Now there are a few main issues with this plan:

Truncation

Even after solving for the numerator using the round score, and splitting that integer into the upper 27 and lower 26 bits, we're still missing a combined 11 bits of the 2 states! If we were to brute force those bits for all 312 guesses, that would take $2^{3432}$ attempts, clearly infeasible. The solution here is to solve for the state symbolically, since the 624 values in one state are related to each other mathematically, and this should be enough information to solve for the missing bits. To save time, a teammate remembered this repository, which was used to solve a related challenge from PlaidCTF. (Which I spent many hours of my life going down the wrong path on... but I digress.)

Precision

This one was a killer, and I honestly still have no idea why this happened. But even though int(2**53*a) should be precise due to how python stores floats, and that the default precision for floats is 53 bits, in my local testing the recovered range of values for b (the lower 26 bits in the numerator) were always a little off when compared to the actual value in the random state. However, this too had a relatively simple fix, had I not wasted many hours trying to fix the precision without knowing where the error was. I just had to take less bits of b for granted, and solve for all the rest. To mitigate the smaller amount of info, I used up 624 guesses (2 whole states) instead of 312, since we had 2000 total guesses.

After making these fixes, and waiting 30 minutes for the agonizingly slow server since I did not live in India, the flag finally popped out. The full lightly-commented source code of the solution is below.

#This class from https://github.com/icemonster/symbolic_mersenne_cracker
"""
MIT License

Copyright (c) 2021 icemonster

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
"""

from z3 import *
from random import Random
from itertools import count
import time
import logging

logging.basicConfig(format='STT> %(message)s')
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)

SYMBOLIC_COUNTER = count()

class Untwister:
    def __init__(self):
        name = next(SYMBOLIC_COUNTER)
        self.MT = [BitVec(f'MT_{i}_{name}', 32) for i in range(624)]
        self.index = 0
        self.solver = Solver()

    #This particular method was adapted from https://www.schutzwerk.com/en/43/posts/attacking_a_random_number_generator/
    def symbolic_untamper(self, solver, y):
        name = next(SYMBOLIC_COUNTER)

        y1 = BitVec(f'y1_{name}', 32)
        y2 = BitVec(f'y2_{name}' , 32)
        y3 = BitVec(f'y3_{name}', 32)
        y4 = BitVec(f'y4_{name}', 32)

        equations = [
            y2 == y1 ^ (LShR(y1, 11)),
            y3 == y2 ^ ((y2 << 7) & 0x9D2C5680),
            y4 == y3 ^ ((y3 << 15) & 0xEFC60000),
            y == y4 ^ (LShR(y4, 18))
        ]

        solver.add(equations)
        return y1

    def symbolic_twist(self, MT, n=624, upper_mask=0x80000000, lower_mask=0x7FFFFFFF, a=0x9908B0DF, m=397):
        '''
            This method models MT19937 function as a Z3 program
        '''
        MT = [i for i in MT] #Just a shallow copy of the state

        for i in range(n):
            x = (MT[i] & upper_mask) + (MT[(i+1) % n] & lower_mask)
            xA = LShR(x, 1)
            xB = If(x & 1 == 0, xA, xA ^ a) #Possible Z3 optimization here by declaring auxiliary symbolic variables
            MT[i] = MT[(i + m) % n] ^ xB

        return MT

    def get_symbolic(self, guess):
        name = next(SYMBOLIC_COUNTER)
        ERROR = 'Must pass a string like "?1100???1001000??0?100?10??10010" where ? represents an unknown bit'

        assert type(guess) == str, ERROR
        assert all(map(lambda x: x in '01?', guess)), ERROR
        assert len(guess) <= 32, "One 32-bit number at a time please"
        guess = guess.zfill(32)

        self.symbolic_guess = BitVec(f'symbolic_guess_{name}', 32)
        guess = guess[::-1]

        for i, bit in enumerate(guess):
            if bit != '?':
                self.solver.add(Extract(i, i, self.symbolic_guess) == bit)

        return self.symbolic_guess


    def submit(self, guess):
        '''
            You need 624 numbers to completely clone the state.
                You can input less than that though and this will give you the best guess for the state
        '''
        if self.index >= 624:
            name = next(SYMBOLIC_COUNTER)
            next_mt = self.symbolic_twist(self.MT)
            self.MT = [BitVec(f'MT_{i}_{name}', 32) for i in range(624)]
            for i in range(624):
                self.solver.add(self.MT[i] == next_mt[i])
            self.index = 0

        symbolic_guess = self.get_symbolic(guess)
        symbolic_guess = self.symbolic_untamper(self.solver, symbolic_guess)
        self.solver.add(self.MT[self.index] == symbolic_guess)
        self.index += 1

    def get_random(self):
        '''
            This will give you a random.Random() instance with the cloned state.
        '''
        logger.debug('Solving...')
        start = time.time()
        print(self.solver.check())
        model = self.solver.model()
        end = time.time()
        logger.debug(f'Solved! (in {round(end-start,3)}s)')

        #Compute best guess for state
        state = list(map(lambda x: model[x].as_long(), self.MT))
        result_state = (3, tuple(state+[self.index]), None)
        r = Random()
        r.setstate(result_state)
        return r

#below code, the actual exploit, is mine
import gmpy2
from pwn import *

ut = Untwister()
sh = remote('crypto.zh3r0.cf', 4444)
for ii in range(624):
    print(ii)
    sh.recvline()
    sh.sendline(b'0')
    lin = sh.recvline()
    flag = False
    sc = lin.decode().strip().split('round score: ')[1]
    if '/' in sc:
        score = sc.split('/')
    else:
        score = sc
        flag = True
    if not flag:
        numer = int(score[0])
        denom = int(score[1])
        state_comb = gmpy2.c_div(gmpy2.mpz(9007199254740992) * gmpy2.mpz(denom), numer) + 1 # solve for state
        bits = bin(state_comb)[2:]
        bits = '0' * (53 - len(bits)) + bits
        a = bits[:27] #upper
        b = bits[27:] #lower
        assert len(a) == 27
        assert len(b) == 26
        ut.submit(a + '?' * 5) #a is truncated 5 bits
        ut.submit(b[:13] + '?' * 19) # b bits are imprecise, take only 13
    else:
        ut.submit('0' * 10 + '?' * 22) #we only know that the first 10 bits are 0 (1/1024)
        ut.submit('?' * 16 + '?' * 16)

rec = ut.get_random()

#score 1024!
while True:
    print(sh.recvline())
    sh.sendline(str(rec.random()))
    print(sh.recvline())