
// Brought to you by the sk3wl of r3v3rsing

#include <stdio.h>
#include <string.h>
#include <stdlib.h>

struct diffStruct {
   int idx1;
   int idx2;
   unsigned char diff;
};

//mask of all primes in range 0-255
unsigned char primes[] = {
   0,0,1,1,0,1,0,1,0,0,0,1,0,1,0,0,0,1,0,1,0,0,0,1,0,0,0,0,0,1,0,1,
   0,0,0,0,0,1,0,0,0,1,0,1,0,0,0,1,0,0,0,0,0,1,0,0,0,0,0,1,0,1,0,0,
   0,0,0,1,0,0,0,1,0,1,0,0,0,0,0,1,0,0,0,1,0,0,0,0,0,1,0,0,0,0,0,0,
   0,1,0,0,0,1,0,1,0,0,0,1,0,1,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,1,
   0,0,0,1,0,0,0,0,0,1,0,1,0,0,0,0,0,0,0,0,0,1,0,1,0,0,0,0,0,1,0,0,
   0,0,0,1,0,0,0,1,0,0,0,0,0,1,0,0,0,0,0,1,0,1,0,0,0,0,0,0,0,0,0,1,
   0,1,0,0,0,1,0,1,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,1,
   0,0,0,1,0,1,0,0,0,1,0,0,0,0,0,1,0,1,0,0,0,0,0,0,0,0,0,1,0,0,0,0
};

//this is not how kenshoto wrote it.  we rewrote it for speed
//returns true if input is prime, false otherwise
bool isPrime(unsigned char ch) {
   return primes[ch];
}

unsigned char f0[24] = {
   0xB2, 0xF3, 0x7E, 0x1A, 0x5B, 0x64,
   0x97, 0x94, 0xDB, 0xA8, 0xF8, 0xD9,
   0x83, 0xFE, 0x8E, 0x27, 0x43, 0x19,
   0x1E, 0x89, 0xA2, 0xCE, 0x9C, 0xED
};

//update bytes by adding corresponding f0 constant and the 
//adjacent (to the right) byte value
bool func_0(int len, unsigned char *s) {
   for (int i = 0; i < len; i++) {
      s[i] += s[(i + 1) % len] + f0[i];
   }
   return true;
}

//byte differences that must hold after stage_0
diffStruct f1[8] = {
   {18, 15, 0xE6}, {9, 3, 0xCE}, {21, 9, 0xE0}, {12, 0, 0x18}, 
   {15, 18, 0x1A}, {3, 12, 0x83}, {0, 6, 0x11}, {6, 21, 0xA6}
};

//verify that all difference relationship hold as define above
//This function operates only on bytes in positions (i % 3) == 0
bool func_1(int len, unsigned char *s) {
   for (int i = 0; i < 8; i++) {
      int idx1 = f1[i].idx1;
      int idx2 = f1[i].idx2;
      char diff = s[idx1] - s[idx2];
      if (diff != f1[i].diff) {
         return false;
      }
   }
   return true;
}

unsigned char f2[24] = {
   0xB2, 0xF3, 0x7F, 0x1E, 0x5B, 0x63,
   0x97, 0x93, 0xDB, 0xA8, 0xF8, 0xE0,
   0x83, 0x00, 0x8E, 0x27, 0x43, 0x18,
   0x1E, 0x8A, 0xA2, 0xCE, 0x9C, 0xE8
};

bool func_2(int len, unsigned char *s) {
   for (int i = 0; i < len; i++) {
      s[i] ^= s[(i + 1) % len];
      s[i] += f2[i];
   }
   return true;
}

//byte differences that must hold after stage_2
diffStruct f3[8] = {
   {13, 16, 0xA}, {4, 22, 0xD7}, {7, 13, 0x72},
   {1, 1, 0},     {22, 4, 0x29}, {19, 10, 0xDB},
   {10, 7, 0xF5}, {16, 19, 0xB4}
};

//verify that all difference relationship hold as define above
//This function operates only on bytes in positions (i % 3) == 1
bool func_3(int len, unsigned char *s) {
   //if the index is prime, then so must be the value
   for (int i = 0; i < len; i++) {
      if (isPrime(i) && !isPrime(s[i])) {
         return false;
      }
   }
   for (int i = 0; i < 8; i++) {
      int idx1 = f3[i].idx1;
      int idx2 = f3[i].idx2;
      char diff = s[idx1] - s[idx2];
      if (diff != f3[i].diff) {
         return false;
      }
   }
   return true;
}

//implemented in func4.asm
extern "C" {
bool func_4(int len, unsigned char *s);
}

//byte differences that must hold after stage_4
diffStruct f5[8] = {
   {8, 14, 0x49}, {14, 11, 0xB8}, {2, 17, 0x57},
   {20, 5, 0x38}, {17, 8, 0x99}, {11, 20, 0x41},
   {23, 2, 0xC5}, {5, 23, 0xD1}
};

//verify that all difference relationship hold as define above
//This function operates only on bytes in positions (i % 3) == 0
bool func_5(int len, unsigned char *s) {
   for (int i = 0; i < 8; i++) {
      int idx1 = f1[i].idx1;
      int idx2 = f1[i].idx2;
      char diff = s[idx1] - s[idx2];
      if (diff != f1[i].diff) {
         return false;
      }
   }
   return true;
}

unsigned int bad_hash(int len, unsigned char *s) {
   int hash = 0xFFFFFFFF;
   for (int i = 0; i < len; i++) {
      int temp = (hash >> 8) ^ (hash & 0xFF);
      hash = temp ^ *(int*)(s + (i % 6) * 4);
   }
   return (unsigned int)hash;
}

unsigned char key_bytes[24] = {
   0x72, 0x35, 0x75, 0xB0, 0xE0, 0x05,
   0xB7, 0x84, 0x76, 0x3E, 0x42, 0x89,
   0x64, 0x86, 0x31, 0x8F, 0x11, 0xCA,
   0x53, 0x64, 0x41, 0x75, 0xEB, 0xE7
};

void subtractKey(unsigned char k[]) {
   for (int i = 0; i < 24; i++) {
      k[i] = key_bytes[i] - k[i];
   }
}

typedef bool (*stage)(int, unsigned char*);

stage stage_funcs[] = {
   func_0, func_1, func_2, func_3, func_4, func_5
};

bool build_key(int len, char in[], char out[]) {
   if (bad_hash(len, (unsigned char*)in) != 0x0C668359B) return false;
   for (int i = 0; i < len; i++) {
      out[i] = key_bytes[i] - in[i];
   }
   return true;
}

//some liberties taken for the sake of brevity
int main(int argc, char **argv) {
   char buf[24];
   if (strlen(argv[1]) != 24) return 1;
   for (int i = 0; i < 6; i += 2) {
      if ((stage_funcs[i])(24, (unsigned char*)argv[1]) == 0) return 2;
      if ((stage_funcs[i + 1])(24, (unsigned char*)argv[1]) == 0) return 3;
   }
   if (build_key(24, argv[1], buf) == 0) return 4;
   printf("Congratulations. Your decrypted key is: %s\n", buf);
   return 0;
}
