diff --git a/src/poly.c b/src/poly.c index 7cb10f9..d813310 100644 --- a/src/poly.c +++ b/src/poly.c @@ -409,6 +409,152 @@ OUT_OF_LOOP: return true; } + +/** + * Invert the polynomial a modulo p. + * + * @param a polynomial to invert + * @param Fq polynomial [out] + * @param ctx NTRU context + */ +bool pb_inverse_poly_p(pb_poly *a, + pb_poly *Fp, + ntru_context *ctx) +{ + int k = 0, + j = 0; + pb_poly *a_tmp, *b, *c, *f, *g; + mp_int mp_modulus, mp_minus; + + /* general initialization of temp variables */ + init_integer(&mp_modulus); + init_integer(&mp_minus); + MP_SET_INT(&mp_modulus, (unsigned long)(ctx->p)); + MP_SET_INT(&mp_minus, 1); + mp_neg(&mp_minus, &mp_minus); + b = build_polynom(NULL, ctx->N + 1, ctx); + MP_SET(&(b->terms[0]), 1); + c = build_polynom(NULL, ctx->N + 1, ctx); + f = build_polynom(NULL, ctx->N + 1, ctx); + PB_COPY(a, f); + + /* set g(x) = x^N − 1 */ + g = build_polynom(NULL, ctx->N + 1, ctx); + MP_SET(&(g->terms[0]), 1); + mp_neg(&(g->terms[0]), &(g->terms[0])); + MP_SET(&(g->terms[ctx->N]), 1); + + /* avoid side effects */ + a_tmp = build_polynom(NULL, ctx->N, ctx); + PB_COPY(a, a_tmp); + erase_polynom(Fp, ctx->N); + + printf("f: "); draw_polynom(f); + printf("g: "); draw_polynom(g); + + while (1) { + while (mp_cmp_d(&(f->terms[0]), 0) == MP_EQ) { + printf("blah\n"); + for (unsigned int i = 1; i <= ctx->N; i++) { + /* f(x) = f(x) / x */ + MP_COPY(&(f->terms[i]), &(f->terms[i - 1])); + /* c(x) = c(x) * x */ + MP_COPY(&(c->terms[ctx->N - i]), &(c->terms[ctx->N + 1 - i])); + } + MP_SET(&(f->terms[ctx->N]), 0); + MP_SET(&(c->terms[0]), 0); + k++; + } + + if (get_degree(f) == 0) + goto OUT_OF_LOOP2; + + if (get_degree(f) < get_degree(g)) { + pb_exch(f, g); + pb_exch(b, c); + } + + { + pb_poly *u, *c_tmp, *g_tmp; + mp_int mp_tmp; + + init_integer(&mp_tmp); + u = build_polynom(NULL, ctx->N, ctx); + g_tmp = build_polynom(NULL, ctx->N + 1, ctx); + PB_COPY(g, g_tmp); + c_tmp = build_polynom(NULL, ctx->N + 1, ctx); + PB_COPY(c, c_tmp); + + /* u = ((f[0] mod p) * (g[0] inverse mod p) mod p) */ + printf("u before: "); draw_polynom(u); + MP_COPY(&(f->terms[0]), &mp_tmp); /* don't change f[0] */ + MP_INVMOD(&(g->terms[0]), &mp_modulus, &(u->terms[0])); + MP_MOD(&mp_tmp, &mp_modulus, &mp_tmp); + MP_MUL(&(u->terms[0]), &mp_tmp, &(u->terms[0])); + MP_MOD(&(u->terms[0]), &mp_modulus, &(u->terms[0])); + + /* f = f - u * g mod p */ + printf("f before: "); draw_polynom(f); + PB_MUL(g_tmp, u, g_tmp); + PB_SUB(f, g_tmp, f); + PB_MOD(f, &mp_modulus, f, ctx->N + 1); + + /* b = b - u * c mod p */ + printf("b before: "); draw_polynom(b); + PB_MUL(c_tmp, u, c_tmp); + PB_SUB(b, c_tmp, b); + PB_MOD(b, &mp_modulus, b, ctx->N + 1); + printf("u after: "); draw_polynom(u); + printf("f after: "); draw_polynom(f); + printf("g after: "); draw_polynom(g); + printf("b after: "); draw_polynom(b); + + mp_clear(&mp_tmp); + delete_polynom_multi(u, c_tmp, g_tmp, NULL); + } + } + +OUT_OF_LOOP2: + k = k % ctx->N; + + /* Fp(x) = x^(N-k) * b(x) */ + for (int i = ctx->N - 1; i >= 0; i--) { + + /* b(X) = f[0]^(-1) * b(X) (mod p) */ + { + pb_poly *poly_tmp; + + poly_tmp = build_polynom(NULL, ctx->N + 1, ctx); + + MP_INVMOD(&(f->terms[0]), &mp_modulus, &(poly_tmp->terms[0])); + MP_MOD(&(b->terms[i]), &mp_modulus, &(b->terms[i])); + MP_MUL(&(b->terms[i]), &(poly_tmp->terms[0]), &(b->terms[i])); + + delete_polynom(poly_tmp); + } + + j = i - k; + if (j < 0) + j = j + ctx->N; + MP_COPY(&(b->terms[i]), &(Fp->terms[j])); + + /* delete_polynom(f_tmp); */ + } + + /* pull into positive space */ + for (int i = ctx->N - 1; i >= 0; i--) + if (mp_cmp_d(&(Fp->terms[i]), 0) == MP_LT) + MP_ADD(&(Fp->terms[i]), &mp_modulus, &(Fp->terms[i])); + + mp_clear(&mp_modulus); + delete_polynom_multi(a_tmp, b, c, f, g, NULL); + + /* TODO: check if the f * Fq = 1 (mod p) condition holds true */ + + return true; +} + +/** * Print the polynomial in a human readable format to stdout. * * @param poly to draw diff --git a/src/poly.h b/src/poly.h index 77a9b54..4f3beea 100644 --- a/src/poly.h +++ b/src/poly.h @@ -105,6 +105,14 @@ mp_error_to_string(result)); \ } +#define MP_INVMOD(...) \ +{ \ + int result; \ + if ((result = mp_invmod(__VA_ARGS__)) != MP_OKAY) \ + NTRU_ABORT("Error computing modular inverse. %s", \ + mp_error_to_string(result)); \ +} + #define MP_EXPT_D(...) \ { \ int result; \ @@ -182,6 +190,10 @@ bool pb_inverse_poly_q(pb_poly *a, pb_poly *Fq, ntru_context *ctx); +bool pb_inverse_poly_p(pb_poly *a, + pb_poly *Fp, + ntru_context *ctx); + void draw_polynom(pb_poly * const poly); #endif /* NTRU_POLY_H */