InCTF 2020 - Bakflipandsons Writeup

By 4cad@TheAdditionalPayphones, 2020-07-31

This was an interesting elliptic curve based challenge. The server will generate a 101 bit secret key for NIST192p, allows you to sign messages (other than the protected plaintext b'please_give_me_the_flag') and if you can provide the server with a signature for the protected plaintext b'please_give_me_the_flag' then it will give you the flag.

The twists are as follows:

  1. You are not given the public key, all you get to see is the resulting signatures.
  2. The signature oracle allows you to pass in number that will be XORed with the private multiplier before signing. My code and writeup refers to this number as the "jitter" of the signature.
  3. You can only use the signature oracle at most 72 times

Here is the full source of the challenge server:

In [ ]:
import random, sys
from binascii import hexlify, unhexlify
from ecdsa import SigningKey, VerifyingKey, Signature, NIST192p
from Crypto.Util.number import bytes_to_long, long_to_bytes
from flag import flag

secret_multiplier = random.getrandbits(101)

def menu():
    menu = [exit, signMessage, verifyMessage, getFlag, sys.exit]

    print("""
    bakflip&sons Signature Scheme

        1) Sign Message
        2) Verify Signature
        3) Get Flag
        4) Exit
    [ecdsa@cryptolab]# """, end = "")

    choice = int(input())
    menu[choice]()

def signMessage():
    print("""
    Sign Message Service - courtsy of bakflip&sons
    """)

    message = input("Enter a message to sign: ").encode()
    if message == b'please_give_me_the_flag':
        print("\n\t:Coughs: This ain't that easy as Verifier1")
        sys.exit()
    secret_mask = int(input("Now insert a really stupid value here: "))

    secret = secret_multiplier ^ secret_mask

    signingKey = SigningKey.from_secret_exponent(secret)
    signature = signingKey.sign(message)
    print("Signature: ", hexlify(signature).decode())


def verifyMessage():
    raise(
        NotImplementedError(
            "Geez! We are working round the clock to get this Beetle fixed."
        )
    )

def getFlag():
    print("""
    BeetleBountyProgram - by bakflip&sons

        Wanted! Patched or Alive- $200,000
        Submit a valid signature for 'please_give_me_the_flag' and claim the flag
    """)
    signingKey = SigningKey.from_secret_exponent(secret_multiplier)
    verifyingKey = signingKey.verifying_key
    try:
        signature = unhexlify(input("Forged Signature: "))
        if verifyingKey.verify(signature, b'please_give_me_the_flag'):
            print(flag)
    except:
        print("Phew! that was close")

The public key is easy to recover given the signature and the input (see the python module ecdsa implementation of VerifyingKey.from_public_key_recovery for an explanation), the tricky and interesting part is the extracting of the private key.

Our ECDSA system has the following components:

  1. The curve, which has a base point $G$ along with the curve parameters including the curve's order $n$. In this case it is the well known NIST192p curve.
  2. The private multiplier $d$, which in this case is a random 101-bit integer.
  3. The public key $Q = dG$.

To execute our attack you don't even need to know the mathematics of ECDSA signatures, all you need to know is that given a signed message we can obtain the public key. Let a be the private key with jitter applied - then from the signature we will derive a different public key $A = aG$. If we were to apply jitter by flipping the $k$th bit we end up with

\begin{align} a &= d + b2^{k}, b \in \{-1, 1\}\\ A &= (d + b2^{k})G\\ A &= dG + b2^{k}G\\ A &= Q + b2^{k}G \end{align}

So we can easily derive this bit of the secret key by brute forcing $b$ since it only has two possibilities. \begin{equation} b = \begin{cases} 1, & A = Q + 2^{k}G \\ -1, & A = Q - 2^{k}G \end{cases} \end{equation}

This is the core of the attack. However since we can only call the oracle 72 times we adjust the attack to brute force 8 bits at a time instead if 1 bit, which can easily give us all 101 bits of the private key. See the full source of the attack for details:

In [ ]:
import random, sys
from binascii import hexlify, unhexlify
from ecdsa import SigningKey, VerifyingKey, NIST192p
from Crypto.Util.number import bytes_to_long, long_to_bytes
import itertools
from pwn import *

conn = remote('34.74.30.191', 9999)

# Proof-of-work code is omitted for conciseness

def pp(x) :
    print(str(x, 'ASCII'))

def sign_with_jitter(message, jitter) :
    conn.send(b'1\n')
    pp(conn.recvuntil(':'))
    conn.send(b'test\n')
    pp(conn.recvuntil(':'))
    conn.send(bytes('%d\n'%jitter, 'ASCII'))
    pp(conn.recvuntil('Signature:'))
    signatureRaw = conn.recvuntil('\n\n')
    print('signatureRaw: ', signatureRaw)
    signature = unhexlify(str(signatureRaw, 'ASCII').strip())
    print('signature: ', signature)
    pp(conn.recvuntil('#'))
    return signature

G = NIST192p.generator
n = NIST192p.order
secret_multiplier = random.getrandbits(101)

def determine_public_key(signatures) :
    result = None
    for signature in signatures : 
        keys = VerifyingKey.from_public_key_recovery(signature, b'test', NIST192p)
        if result is None :
            result = list(keys)
        else :
            new_result = list()
            for k in keys :
                if k in result :
                    new_result.append(k)
            result = new_result

    assert len(result) == 1
    return result[0]

def sign_with_jitter_test(message, jitter) :
    ## TODO replace with network calls
    signingKey = SigningKey.from_secret_exponent(secret_multiplier ^ jitter)
    return signingKey.sign(b'test')

def encode_byte(b, byteIndex) :
    result = 0
    for index in range(8) :
        powerOfTwo = 1 << index
        bit = b & powerOfTwo
        if bit == 0 :
            result += powerOfTwo << byteIndex*8 # If the bit was 0 and was flipped to 1, then jitterPoint = pubkey + powerOfTwo
        else :
            result -= powerOfTwo << byteIndex*8
    if result < 0 :
        return n + result
    else :
        return result

def extract_private_key_byte(byteIndex, Q) :
    jitterPoint =  determine_public_key([sign_with_jitter(b'test', 255 << 8*byteIndex) for _ in range(3)]).pubkey.point
    for b in range(256) :
        if jitterPoint.x() == (Q+encode_byte(b, byteIndex)*G).x() :
            return b
    return None
        

def forge_signature(message) :
    verifyingKey = determine_public_key([sign_with_jitter(b'test', 0) for _ in range(3)])
     
    result = 0
    for byteIndex in range(2 + (101//8)) :
        result += extract_private_key_byte(byteIndex, verifyingKey.pubkey.point) << 8*byteIndex
        print('>>>>>>>>>>>>> [%d/%d] attack result     ='%(byteIndex,2 + (101//8)), hex(result))
        
    print('secret_multiplier =', hex(secret_multiplier))
    print('attack result     =', hex(result))
    forgingKey = SigningKey.from_secret_exponent(result)
    forgedSignature = forgingKey.sign(message)
    valid = verifyingKey.verify(forgedSignature, message)
    assert valid
    return forgedSignature

forgedSignature = forge_signature(b'please_give_me_the_flag')
print(forgedSignature)

conn.send(b'3\n')
pp(conn.recvuntil(':'))

finalMessage = str(hexlify(forgedSignature), 'ASCII') + '\n'
print('Sending: %s'%finalMessage)
conn.send(bytes(finalMessage, 'ASCII'))
pp(conn.recvuntil('#'))
In [ ]: