AES Decryption not working correctly but all its functions seem to work correctly

This is the encryption function :

def AES_Encryption (text) :
    round_keys = generate_round_keys()
    #print("Round keys = " , round_keys)
    aes_texts = break_text(text)
    #print(aes_texts)
    print(aes_texts)
    aes_texts = text_to_hex(aes_texts)
    print(aes_texts)

    for j in range(0,15) :
        #print("j=",j)
        for i in range(len(aes_texts)) :
            if j == 0 :
                aes_texts[i] = Add_Round_key(aes_texts[i] , round_keys[j])
            else :
                aes_texts[i] = sbox_conversion(aes_texts[i] , False)
                if j != 14 :
                    aes_texts[i] = MixColumns (aes_texts[i])
                aes_texts[i] = Add_Round_key(aes_texts[i] , round_keys[j])
        print(aes_texts)
    #print(aes_texts)
    return aes_texts , round_keys
 

This is my string = “What are you gonna do? Gonna cry?”
This will create 3 2d lists in the 3d list called aes_texts.
Remember aes_texts is a 3d list. Where 16 hexas are stored in a 4X4 2d list.

This is the decryption function :

def AES_decryption(texts, round_keys):
    aes_texts = texts
    for j in range(14, -1, -1):
        for i in range(len(aes_texts)):
            aes_texts[i] = Add_Round_key(aes_texts[i], round_keys[j])
            if j != 0:
                if j != 1 :
                    aes_texts[i] = InvMixColumns(aes_texts[i])
                aes_texts[i] = inverse_shift_rows(aes_texts[i])
                aes_texts[i] = sbox_conversion(aes_texts[i], True)
        print(aes_texts)
    return aes_texts

Now here’s what I tried I did this aes_texts[i] = Add_Round_key(aes_texts[i] , round_keys[j]) and aes_texts[i] = Add_Round_key(aes_texts[i] , round_keys[j]) and I am getting the hex text before as expected.
Similarly I did aes_texts[i] = sbox_conversion(aes_texts[i] , False) and aes_texts[i] = sbox_conversion(aes_texts[i] , True) and I got the hexa before the operations as expected.

    Similarly I did aes_texts\[i\] = shift_rows(aes_texts\[i\]) and aes_texts\[i\] = inverse_shift_rows(aes_texts\[i\]) and i got the hex text before as expected.                                               
     I also tried aes_texts\[i\] = MixColumns (aes_texts\[i\]) and aes_texts\[i\] = InverseMixColumns (aes_texts\[i\]) and I got the same hex as before as expected.                                          

This means that all the functions are working correctly. But after 14 round encryption when I do this ,

print("Texts before 1 decryption" , aes_texts)
for i in range(len(aes_texts)):
    aes_texts[i] = Add_Round_key(aes_texts[i], round_keys[14])
    aes_texts[i] = InvMixColumns(aes_texts[i])
    aes_texts[i] = inverse_shift_rows(aes_texts[i])
    aes_texts[i] = sbox_conversion(aes_texts[i], True)
    #aes_texts[i] = Add_Round_key(aes_texts[i], round_keys[13])
    #print(aes_texts)
print("Texts after 1 decryption" , aes_texts)
 

I don’t get the previous hexes. What seems to be the issue here. ## Code for Add_round_keys ##

def Add_Round_key(hex_converted_text , round_key) :
    addition=[]
    #print("Before rounf key : " , hex_converted_text)
    for i in range (0 , 4) :
        row=[]
        for j in range (0 , 4) :
            row.append ('{:02x}'.format((int(hex_converted_text[i][j], 16) ^ int(round_key[j][i], 16))))
        addition.append(row)
    #print("After rounf key : " , addition)
    return addition

Code for sbox and invsbox

def sbox_conversion(text , sub_type = False) :
    #print("Sbox text before : " , text)
    for i in range(len(text)) :
        string=[]
        for k in text[i] :
            string.append(k)
   #     print("string = " , string)
        
        for j in range(len(string)) :
            if len(string[j]) < 2 :
                a = string[0]
                string[j][0] = '0'
                string[j].append(a)
            for k in range(0,16) :
                change_flag = False
                for l in range (0,16) :
                    x = '{:01x}'.format(k)
                    y = '{:01x}'.format(l)
                    a , b = string[j][0] , string[j][1]
                    #print("a = " , a , " x = " , x , " b = " , b , "y = " , y)
                    if (a == x and b == y) :
                        if sub_type == False :
                            text[i][j] = '{:02x}'.format(sbox[k][l])
                        else :
                            text[i][j] = '{:02x}'.format(sboxInv[k][l])
                        change_flag = True
                        break
                if change_flag == True :
                    break
    #print("Sbox text after : " , text)
    return text

Code for mix Columns and Inverse Mix columns

def multiply(a, b):
    p = 0
    hiBitSet = 0
    for i in range(8):
        if b & 1 == 1:
            p ^= a
        hiBitSet = a & 0x80
        a <<= 1
        if hiBitSet == 0x80:
            a ^= 0x1b
        b >>= 1
    return p % 256

# MixColumns operation
def MixColumns(state):
    for i in range(4):
        s0 = state[0][i]
        s1 = state[1][i]
        s2 = state[2][i]
        s3 = state[3][i]
        state[0][i] = '{:02x}'.format((multiply(int(s0,16) , 0x02) ^ multiply(int(s1,16) , 0x03) ^ int(s2,16) ^ int(s3,16)) % 256)
        state[1][i] = '{:02x}'.format((int(s0,16) ^ multiply(int(s1,16),0x02) ^ multiply(int(s2,16),0x03) ^ int(s3,16)) % 256)
        state[2][i] = '{:02x}'.format((int(s0,16) ^ int(s1,16) ^ multiply(int(s2,16),0x02) ^ multiply(int(s3,16),0x03)) % 256)
        state[3][i] = '{:02x}'.format((multiply(int(s0,16),0x03) ^ int(s1,16) ^ int(s2,16) ^ multiply(int(s3,16),0x02)) % 256)
    return state




def InvMixColumns(state):
    for i in range(4):
        s0 = state[0][i]
        s1 = state[1][i]
        s2 = state[2][i]
        s3 = state[3][i]
        state[0][i] = '{:02x}'.format((multiply(int(s0,16) , 0x0e) ^ multiply(int(s1,16) , 0x0b) ^ multiply(int(s2,16) , 0x0d) ^ multiply(int(s3,16) , 0x09)) % 256)
        state[1][i] = '{:02x}'.format((multiply(int(s0,16) , 0x09) ^ multiply(int(s1,16) , 0x0e) ^ multiply(int(s2,16) , 0x0b) ^ multiply(int(s3,16) , 0x0d)) % 256)
        state[2][i] = '{:02x}'.format((multiply(int(s0,16) , 0x0d) ^ multiply(int(s1,16) , 0x09) ^ multiply(int(s2,16) , 0x0e) ^ multiply(int(s3,16) , 0x0b)) % 256)
        state[3][i] = '{:02x}'.format((multiply(int(s0,16) , 0x0b) ^ multiply(int(s1,16) , 0x0d) ^ multiply(int(s2,16) , 0x09) ^ multiply(int(s3,16) , 0x0e)) % 256)
    return state

Code for Shift Rows and Inverse Shift rows

def shift_rows (text) :
    shift = 1
    #print("text before shift = " , text)
    for i in range (len (text)) :
        word = np.array(text[i])
        word = np.roll(word,-i)
        text[i]=word.tolist()
    ##print("text after shift = " , text)
    return text

def inverse_shift_rows (text) :
    shift = 1
    #print("text before shift = " , text)
    for i in range (len (text)) :
        word = np.array(text[i])
        word = np.roll(word,i)
        text[i]=word.tolist()
    ##print("text after shift = " , text)
    return text

  • Find (an implementation that creates) intermediate results for AES and try and find the error that way. This is not a debugging service.

    – 

  • 1

    nvlpubs.nist.gov/nistpubs/fips/nist.fips.197.pdf contains a complete calculation of AES, including intermediate values. You can use it to compare against your own intermediate values. Why are you writing your own AES. This is a very bad idea, unless you’re just doing it for homework.

    – 

Leave a Comment