/*
 * rsa_demo.c : exemple d'implémentation de RSA, pour illustrer un TPE
 * Nicolas George -- février 2008
 * Domaine public
 *
 * Usage:
 *
 * ./rsa_demo genrsa 4 10
 * p: 0x830EDDC2D122D523
 * q: 0xC62291B2DF3319B1
 * m: 0x656F32151FD65DAD7883610EF2DEC833
 * e: 0x10001
 * d: 0x409722091990E36D3564B4B7A70A8581
 *
 * ./rsa_demo rsa 0x656F32151FD65DAD7883610EF2DEC833 0x10001 0x123456789
 * 0x5A35C02DF45FE05562B2113536BA991E
 *
 * ./rsa_demo rsa 0x656F32151FD65DAD7883610EF2DEC833 0x409722091990E36D3564B4B7A70A8581 0x5A35C02DF45FE05562B2113536BA991E
 * 0x123456789
 */

#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <time.h>

/*
   findprime 256 10 -> 2613.88s
   findprime 64 10 -> 24.10s
   findprime 64 100 -> 173.14s
   findprime 32 100  33.62s
   findprime 16 100  4.08s
 */

#define BBASE 16
#define BASE (1 << BBASE)

#define MAX(a, b) ((a) > (b) ? (a) : (b))

#if RAND_MAX < BASE
# error I need a bigger random number generator.
#endif

typedef struct {
    unsigned size;
    unsigned short digits[0];
} num_t;

num_t *
num_allocate(unsigned size)
{
    num_t *r;
    unsigned i;

    r = malloc(sizeof(num_t) + size * sizeof(unsigned short));
    if(r == NULL) {
	perror("Out of memory");
	exit(1);
    }
    for(i = 0; i < size; i++)
	r->digits[i] = 0;
    r->size = size;
    return(r);
}

void
num_free(num_t *n)
{
    free(n);
}

num_t *
num_of_int(unsigned v)
{
    num_t *n;

    n = num_allocate(v == 0 ? 0 : 1);
    if(v != 0)
	n->digits[0] = v;
    return(n);
}

/*
 * Enlève les 0 à la fin du nombre.
 */
num_t *
num_reduce(num_t *n)
{
    unsigned s;

    s = n->size;
    while(s > 0 && n->digits[s - 1] == 0)
	s--;
    if(s != n->size) {
	n->size = s;
	n = realloc(n, sizeof(num_t) + s * sizeof(unsigned short));
	if(n == NULL) {
	    perror("Out of memory");
	    exit(1);
	}
    }
    return(n);
}

/*
 * Calcule le nombre de bits du nombre.
 */
unsigned
num_bits(num_t *n)
{
    unsigned b, v;

    if(n->size == 0)
	return(0);
    b = (n->size - 1) * BBASE;
    for(v = n->digits[n->size - 1]; v != 0; v >>= 1)
	b++;
    return(b);
}

num_t *
num_copy(num_t *n)
{
    num_t *r;
    unsigned i;

    r = num_allocate(n->size);
    for(i = 0; i < n->size; i++)
	r->digits[i] = n->digits[i];
    return(r);
}

num_t *
num_add(num_t *a, num_t *b)
{
    unsigned ss, i;
    num_t *s;
    unsigned ret, v;

    ss = MAX(a->size, b->size);
    s = num_allocate(ss + 1);
    ret = 0;
    for(i = 0; i < ss; i++) {
	v = ret;
	if(i < a->size)
	    v += a->digits[i];
	if(i < b->size)
	    v += b->digits[i];
	s->digits[i] = v % BASE;
	ret = v / BASE;
    }
    s->digits[ss] = ret;
    s = num_reduce(s);
    return(s);
}

/*
 * Returns NULL if result is negative.
 */
num_t *
num_sub(num_t *a, num_t *b)
{
    unsigned ss, i;
    num_t *s;
    unsigned ret, va, vb;

    if(a->size < b->size)
	return(NULL);
    ss = a->size;
    s = num_allocate(ss);
    ret = 0;
    for(i = 0; i < ss; i++) {
	va = a->digits[i];
	vb = ret;
	if(i < b->size)
	    vb += b->digits[i];
	ret = vb > va;
	if(ret)
	    va += BASE;
	s->digits[i] = va - vb;
    }
    if(ret) {
	num_free(s);
	return(NULL);
    } else {
	s = num_reduce(s);
	return(s);
    }
}

num_t *
num_mul(num_t *a, num_t *b)
{
    unsigned sp, i, j;
    num_t *p;
    unsigned ret;

    sp = a->size + b->size;
    p = num_allocate(sp);
    for(i = 0; i < a->size; i++) {
	ret = 0;
	for(j = 0; j < b->size; j++) {
	    ret += p->digits[i + j] + (unsigned)a->digits[i] * b->digits[j];
	    p->digits[i + j] = ret % BASE;
	    ret /= BASE;
	}
	p->digits[i + j] = ret;
    }
    p = num_reduce(p);
    return(p);
}

num_t *
num_shift(num_t *n, unsigned b)
{
    num_t *r;
    unsigned sr;
    unsigned i;
    unsigned bi, bb;

    sr = n->size + (b + BBASE - 1) / BBASE;
    r = num_allocate(sr);
    bi = b / BBASE;
    bb = b % BBASE;
    for(i = 0; i < n->size; i++) {
	r->digits[i + bi] |= (n->digits[i] << bb) % BASE;
	r->digits[i + bi + 1] |= n->digits[i] >> (BBASE - bb);
    }
    r = num_reduce(r);
    return(r);
}

num_t *
num_div(num_t *a, num_t *b, num_t **rrest)
{
    unsigned ba, bb;
    unsigned i;
    num_t *q, *r, *nr, *sb;
    unsigned sa;

    ba = num_bits(a);
    bb = num_bits(b);
    if(ba < bb) {
	if(rrest != NULL)
	    *rrest = num_copy(a);
	return(num_allocate(0));
    }
    sa = a->size;
    q = num_allocate(a->size - b->size + 1);
    r = num_copy(a);
    for(i = ba - bb; i != (unsigned)-1; i--) {
	sb = num_shift(b, i);
	nr = num_sub(r, sb);
	num_free(sb);
	if(nr != NULL) {
	    q->digits[i / BBASE] |= 1 << (i % BBASE);
	    num_free(r);
	    r = nr;
	}
    }
    if(rrest != NULL) {
	r = num_reduce(r);
	*rrest = r;
    } else {
	num_free(r);
    }
    q = num_reduce(q);
    return(q);
}

unsigned
num_mod_int(num_t *n, unsigned m)
{
    unsigned r, k, i;

    r = 0;
    k = 1;
    for(i = 0; i < n->size; i++) {
	r = (r + k * n->digits[i]) % m;
	k = (k * BASE) % m;
    }
    return(r);
}

num_t *
num_reduce_mod(num_t *n, num_t *m)
{
    num_t *r, *q;

    if(m == NULL)
	return(n);
    r = num_div(n, m, &q);
    num_free(r);
    num_free(n);
    return(q);
}

num_t *
num_expmod_shift(num_t *n, num_t *p, num_t *m, unsigned shift)
{
    num_t *r, *k, *t;
    unsigned i, b;

    r = num_of_int(1);
    k = num_copy(n);
    b = num_bits(p);
    for(i = shift; i < b; i++) {
	if(i > shift) {
	    t = num_mul(k, k);
	    num_free(k);
	    k = t;
	    k = num_reduce_mod(k, m);
	}
	if(((p->digits[i / BBASE] >> (i % BBASE)) & 1) != 0) {
	    t = num_mul(r, k);
	    num_free(r);
	    r = t;
	    r = num_reduce_mod(r, m);
	}
    }
    num_free(k);
    return(r);
}

num_t *
num_expmod(num_t *n, num_t *p, num_t *m)
{
    return(num_expmod_shift(n, p, m, 0));
}

int
num_compare(num_t *a, num_t *b)
{
    unsigned s, i;

    if(a->size > b->size)
	return(1);
    if(a->size < b->size)
	return(-1);
    s = a->size;
    for(i = s - 1; i != (unsigned)-1; i--) {
	if(a->digits[i] > b->digits[i])
	    return(1);
	if(a->digits[i] < b->digits[i])
	    return(-1);
    }
    return(0);
}

num_t *
num_rand(num_t *n)
{
    unsigned s, b, i;
    num_t *r;

    s = n->size;
    r = num_allocate(s);
    b = num_bits(n) % BBASE;
    while(1) {
	for(i = 0; i < s; i++)
	    r->digits[i] = rand() % BASE;
	r->digits[s - 1] &= (1 << b) - 1;
	/* r n'est pas réduit, mais ce n'est pas grave pour num_compare */
	if(num_compare(n, r) > 0)
	    return(num_reduce(r));
    }
}

num_t *
num_parse(const char *t)
{
    unsigned l, i, v;
    char c;
    num_t *r;

    if(t[0] == '0' && (t[1] == 'x' || t[1] == 'X'))
	t += 2;
    l = strlen(t);
    r = num_allocate((l * 4 + BBASE - 1) / BBASE);
    for(i = l - 1; i != (unsigned)-1; i--) {
	c = *(t++);
	if(c >= '0' && c <= '9') {
	    v = c - '0';
	} else if(c >= 'a' && c <= 'f') {
	    v = c - 'a' + 10;
	} else if(c >= 'A' && c <= 'F') {
	    v = c - 'A' + 10;
	} else {
	    fprintf(stderr, "Invalid character: %c\n", c);
	    exit(1);
	}
	r->digits[i * 4 / BBASE] |= v << (i * 4) % BBASE;
    }
    r = num_reduce(r);
    return(r);
}

void
num_print(num_t *n)
{
    unsigned i;

    if(n->size == 0) {
	printf("0");
	return;
    }
    printf("0x%X", n->digits[n->size - 1]);
    for(i = n->size - 2; i != (unsigned)-1; i--)
	printf("%0*X", BBASE / 4, n->digits[i]);
}

int
is_prime_once(num_t *n)
{
    unsigned b, p;
    num_t *r, *s, *t, *nm1;
    int isprime = 0;

    t = num_of_int(1);
    nm1 = num_sub(n, t);
    num_free(t);

    while(1) {
	r = num_rand(n);
	if(r->size != 0)
	    break;
    }
    b = num_bits(nm1);
    for(p = 0; p < b; p++)
	if(((nm1->digits[p / BBASE] >> (p % BBASE)) & 1) != 0)
	    break;
    s = num_expmod_shift(r, nm1, n, p);
    if(s->size == 1 && s->digits[0] == 1) {
	isprime = 1;
    } else {
	while(p-- > 0) {
	    if(num_compare(s, nm1) == 0) {
		isprime = 1;
		break;
	    }
	    if(p > 0) {
		t = num_mul(s, s);
		num_free(s);
		s = t;
		s = num_reduce_mod(s, n);
	    }
	}
    }
    num_free(s);
    num_free(nm1);
    return(isprime);
}

unsigned *small_primes = NULL;
unsigned n_small_primes;

void
compute_small_primes(void)
{
    unsigned char *sieve;
    unsigned i, j;

    if(small_primes != NULL)
	return;
    sieve = malloc(BASE);
    if(sieve == NULL) {
	perror("Out of memory");
	exit(1);
    }
    for(i = 2; i < BASE; i++)
	sieve[i] = 1;
    for(i = 2; i * i < BASE; i++) {
	if(!sieve[i])
	    continue;
	for(j = i * 2; j < BASE; j += i)
	    sieve[j] = 0;
    }
    for(i = 2; i < BASE; i++)
	if(sieve[i])
	    n_small_primes++;
    small_primes = malloc(n_small_primes * sizeof(unsigned));
    if(small_primes == NULL) {
	perror("Out of memory");
	exit(1);
    }
    j = 0;
    for(i = 2; i < BASE; i++)
	if(sieve[i])
	    small_primes[j++] = i;
    free(sieve);
}

int
is_prime_small(num_t *n)
{
    unsigned i;

    compute_small_primes();
    for(i = 0; i < n_small_primes; i++) {
	if(num_mod_int(n, small_primes[i]) == 0)
	    return(0);
    }
    return(1);
}

int
is_prime(num_t *n, unsigned f)
{
    if(!is_prime_small(n))
	return(0);
    while(f-- > 0)
	if(!is_prime_once(n))
	    return(0);
    return(1);
}

num_t *
find_prime(unsigned size, unsigned f)
{
    num_t *n;
    unsigned i;

    n = num_allocate(size);
    while(1) {
	for(i = 0; i < size; i++)
	    n->digits[i] = rand() % BASE;
	n->digits[size - 1] |= 1 << (BBASE - 1);
	n->digits[0] |= 1;
	if(is_prime(n, f))
	    break;
    }
    return(n);
}

void
bezout(num_t *a, num_t *b, num_t **rgcd, num_t **rca, num_t **rcb)
{
    num_t *wa, *wb, *q, *r;
    num_t *caa, *cab, *cba, *cbb, *cra, *crb;
    num_t *t;
    unsigned inv;

    /* inv = 0 => wa = + caa * a - cab b    wb = - cba * a + cbb * b
       inv = 1 => wa = - caa * a + cab b    wb = + cba * a - cbb * b */
    wa = num_copy(a);
    wb = num_copy(b);
    caa = num_of_int(1);
    cab = num_of_int(0);
    cba = num_of_int(0);
    cbb = num_of_int(1);
    inv = 0;
    while(wb->size != 0) {
	q = num_div(wa, wb, &r);
	/* crX = caX + q * cbX */
	t = num_mul(q, cba);
	cra = num_add(caa, t);
	num_free(t);
	t = num_mul(q, cbb);
	crb = num_add(cab, t);
	num_free(t);
	num_free(q);
	num_free(wa);
	num_free(caa);
	num_free(cab);
	wa = wb;
	wb = r;
	caa = cba;
	cab = cbb;
	cba = cra;
	cbb = crb;
	inv = !inv;
    }
    if(inv) {
	/* TODO: prouver que ça marche vraiment toujours */
	t = num_sub(cba, caa);
	num_free(caa);
	caa = t;
	t = num_sub(cbb, cab);
	num_free(cab);
	cab = t;
	inv = 0;
    }
    num_free(wb);
    num_free(cba);
    num_free(cbb);
    *rgcd = wa;
    *rca = caa;
    *rcb = cab;
}

void
show_num(const char *l, num_t *n)
{
    if(l != NULL)
	printf("%s: ", l);
    num_print(n);
    printf("\n");
}

void
genrsa(unsigned size, unsigned fiability)
{
    num_t *p, *q, *m, *phi;
    num_t *t1, *t2, *t3;
    num_t *e, *d, *gcd;
    unsigned ce, ok;
    unsigned j;

    p = find_prime(size, fiability);
    q = find_prime(size, fiability);
    m = num_mul(p, q);
    t1 = num_of_int(1);
    t2 = num_sub(p, t1);
    t3 = num_sub(q, t1);
    phi = num_mul(t2, t3);
    num_free(t1);
    num_free(t2);
    num_free(t3);
    for(ce = 65537; 1; ce += 2) {
	for(j = 2; j * j <= ce; j++)
	    if(ce % j == 0)
		break;
	if(ce % j == 0)
	    continue;
	e = num_allocate(2);
	e->digits[0] = ce % BASE;
	e->digits[1] = ce / BASE;
	bezout(e, phi, &gcd, &d, &t1);
	num_free(t1);
	ok = gcd->size == 1 && gcd->digits[0] == 1;
	num_free(gcd);
	if(ok)
	    break;
	num_free(d);
	num_free(e);
    }
    show_num("p", p);
    show_num("q", q);
    show_num("m", m);
    show_num("e", e);
    show_num("d", d);
}

num_t *
rsa(num_t *m, num_t *e, num_t *v)
{
    return(num_expmod(v, e, m));
}

void
init_rand(void)
{
    char *e;

    e = getenv("RSA_DEMO_SEED");
    if(e == NULL)
	srand(time(NULL));
    else
	srand(strtol(e, NULL, 0));
}

int
main(int argc, char **argv)
{
    init_rand();
    argc--;
    argv++;
    if(argc == 2 && strcmp(argv[0], "print") == 0) {
	num_t *n;

	n = num_parse(argv[1]);
	num_print(n);
	printf("\n");
    } else if(argc == 3 && strcmp(argv[0], "add") == 0) {
	num_t *a, *b, *s;

	a = num_parse(argv[1]);
	b = num_parse(argv[2]);
	s = num_add(a, b);
	show_num(NULL, s);
    } else if(argc == 3 && strcmp(argv[0], "sub") == 0) {
	num_t *a, *b, *s;

	a = num_parse(argv[1]);
	b = num_parse(argv[2]);
	s = num_sub(a, b);
	if(s == NULL)
	    printf("-?");
	else
	    num_print(s);
	printf("\n");
    } else if(argc == 3 && strcmp(argv[0], "mul") == 0) {
	num_t *a, *b, *s;

	a = num_parse(argv[1]);
	b = num_parse(argv[2]);
	s = num_mul(a, b);
	show_num(NULL, s);
    } else if(argc == 3 && strcmp(argv[0], "div") == 0) {
	num_t *a, *b, *q, *r;

	a = num_parse(argv[1]);
	b = num_parse(argv[2]);
	q = num_div(a, b, &r);
	show_num("q", q);
	show_num("r", r);
    } else if((argc == 3 || argc == 4) && strcmp(argv[0], "pow") == 0) {
	num_t *n, *p, *m, *r;

	n = num_parse(argv[1]);
	p = num_parse(argv[2]);
	m = argc == 4 ? num_parse(argv[3]) : NULL;
	r = num_expmod(n, p, m);
	show_num(NULL, r);
    } else if(argc == 3 && strcmp(argv[0], "isprime") == 0) {
	num_t *n;
	unsigned f;

	n = num_parse(argv[1]);
	f = strtol(argv[2], NULL, 0);
	printf("%d\n", is_prime(n, f));
    } else if(argc == 3 && strcmp(argv[0], "findprime") == 0) {
	unsigned s, f;
	num_t *n;

	s = strtol(argv[1], NULL, 0);
	f = strtol(argv[2], NULL, 0);
	n = find_prime(s, f);
	show_num(NULL, n);
    } else if(argc == 3 && strcmp(argv[0], "bezout") == 0) {
	num_t *a, *b, *gcd, *ca, *cb;

	a = num_parse(argv[1]);
	b = num_parse(argv[2]);
	bezout(a, b, &gcd, &ca, &cb);
	show_num("gcd", gcd);
	show_num("ca", ca);
	show_num("cb", cb);
	printf("\n");
    } else if(argc == 3 && strcmp(argv[0], "genrsa") == 0) {
	unsigned s, f;

	s = strtol(argv[1], NULL, 0);
	f = strtol(argv[2], NULL, 0);
	genrsa(s, f);
    } else if(argc == 4 && strcmp(argv[0], "rsa") == 0) {
	num_t *m, *e, *v, *s;

	m = num_parse(argv[1]);
	e = num_parse(argv[2]);
	v = num_parse(argv[3]);
	s = rsa(m, e, v);
	show_num(NULL, s);
    } else {
	printf("Usage: rsa_demo action [arguments]\n"
	    "  print number\n"
	    "  add a b\n"
	    "  sub a b\n"
	    "  mul a b\n"
	    "  div a b\n"
	    "  pow n p [m]\n"
	    "  isprime n fiability\n"
	    "  findprime size fiability\n"
	    "  bezout a b\n"
	    "  genrsa size fiability\n"
	    "  rsa m e secret\n"
	    );
	exit(1);
    }
    return(0);
}
