/* * $Id: rsa_math.c,v 1.1 2005/03/21 05:23:44 jerub Exp $ * * This is the server-side custom math code that is used by * rsa_util.c. No crypto code is contained in this file. * * Originally written by Ray Jones. Updated by Dave Ahn. * */ #include #include #include "config.h" #include "rsa_math.h" #undef MATH_DEBUG /* enable debugging */ #undef MATH_FAST /* use macros instead of compiler inlining */ /**************************************************************************** * to work correctly, SIZE must be 2 * the largest number you plan to deal * with, and also a power of 2. Unfortunate, but true. For the netrek-RSA * scheme, it has to be at least twice 80. */ /* this should already be defined by rsa_math.h */ #ifndef SIZE #define SIZE 64 #endif inline static void math_exit(const int i) { printf("math_exit: value = %d\n", i); exit(1); } /* debugging */ #ifdef MATH_DEBUG inline static void print_num(const int *n, const int digits) { int i; for (i=0; i>= 8; } if (temp) math_exit(temp); } /**************************************************************************** * Add and Subtract. Simple implementations, obviously, given that we don't * need to keep the number normalized. */ #ifdef RSA_FAST /* these macros really aren't necessary with modern day compilers that are smart about inlining */ #define add(out, a, b, digits) { \ int tempi, *tempout, *tempa, *tempb;\ for(tempi=0,tempa=(a),tempb=(b),tempout=(out);tempi<(digits);tempi++)\ *(tempout++)= *(tempa++)+ *(tempb++);\ } #define subtract(out, a, b, digits) {\ int tempi, *tempout, *tempa, *tempb;\ for(tempi=0,tempa=(a),tempb=(b),tempout=(out);tempi<(digits);tempi++)\ *(tempout++)= *(tempa++)- *(tempb++);\ } #else /* !RSA_FAST */ inline static void add(int *out, const int *a, const int *b, const int digits) { int i; for (i=0; i> 1) { multiply(out, a, b, new_digits); multiply(&(out[digits]), &(a[new_digits]), &(b[new_digits]), new_digits); add(mid1, a, &(a[new_digits]), new_digits); add(mid2, b, &(b[new_digits]), new_digits); multiply(temp, mid1, mid2, new_digits); subtract(temp, temp, out, digits); subtract(temp, temp, &(out[digits]), digits); add(&(out[new_digits]), &(out[new_digits]), temp, digits); return; } i = (*a) * (*b); *out = i & 0xFF; *(out + 1) = i >> 8; } /* a is digits long, out is 2 * digits */ inline static void square(int *out, int *a, const int digits) { int temp[SIZE]; int i, new_digits; if (new_digits = digits >> 1) { square(out, a, new_digits); square(&(out[digits]), &(a[new_digits]), new_digits); multiply(temp, a, &(a[new_digits]), new_digits); /* multiply by 2 */ for (i=0; i> 8; } /**************************************************************************** * bitwise shifts, copy, and compare. compare() only accepts normalized * numbers. */ inline static void shift_left1(int *n) { int i; for (i=0; i>= 1; if (n[i + 1] & 0x1) { n[i] += 0x80; } } n[i] >>= 1; } inline static void copy(int *out, const int *in) { int i; for (i=0; i b, -1 if b > a, or 0 if a == b */ inline static int compare(const int *a, const int *b) { int i; for (i=0; i b[SIZE - i - 1]) return 1; if (a[SIZE - i - 1] < b[SIZE - i - 1]) return -1; } return 0; } /**************************************************************************** * This method for doing mods is from the same book as the multiply function. * It's based on the fact that * (ax + b) % m = (a(x % m) + b) % m * Before calling modulus or expmod, you have to call setup_modulus. You don't * have to pass the modulus into modulus or expmod, since it's in the table. * cleanup_modulus will clean up the tables afterwards. */ static int modulus_size; static int modulus_in_table[SIZE]; static int modulus_table[SIZE][SIZE]; void setup_modulus(const int *modulus) { int i; int temp[SIZE]; copy(modulus_in_table, modulus); for (i=0; i> 3) < SIZE) */ while (i<(SIZE << 3)) { if (!(i & 0x7)) { copy(modulus_table[i >> 3], temp); } shift_left1(temp); renormalize(temp); if (compare(temp, modulus) > 0) { subtract(temp, temp, modulus, SIZE); renormalize(temp); } i++; } for (i=0; i 0) { for (i=0; i<(modulus_size - 1); i++) { from_table[i] = temp[i]; } for (; i