pqc/external/flint-2.4.3/ulong_extras/sqrtmod_primepow.c

346 lines
8.4 KiB
C

/*=============================================================================
This file is part of FLINT.
FLINT is free software; you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation; either version 2 of the License, or
(at your option) any later version.
FLINT is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with FLINT; if not, write to the Free Software
Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
=============================================================================*/
/******************************************************************************
Copyright (C) 2012 William Hart
******************************************************************************/
#define ulong ulongxx /* interferes with system includes */
#include <stdio.h>
#include <stdlib.h>
#undef ulong
#define ulong mp_limb_t
#include <gmp.h>
#include "flint.h"
#include "ulong_extras.h"
slong n_sqrtmod_2pow(mp_limb_t ** sqrt, mp_limb_t a, slong exp)
{
mp_limb_t r = (a & 1);
mp_limb_t * s;
if (exp == 0) /* special case for sqrt of 0 mod 1 */
{
*sqrt = flint_malloc(sizeof(mp_limb_t));
(*sqrt)[0] = 0;
return 1;
}
if (exp == 1) /* special case mod 2 */
{
*sqrt = flint_malloc(sizeof(mp_limb_t));
if (r) (*sqrt)[0] = 1;
else (*sqrt)[0] = 0;
return 1;
}
if (exp == 2) /* special case mod 4 */
{
r = (a & 3);
if (r < 2) /* 0, 1 mod 4 */
{
*sqrt = flint_malloc(sizeof(mp_limb_t)*2);
(*sqrt)[0] = r;
(*sqrt)[1] = r + 2;
return 2;
} else /* 2, 3 mod 4 */
{
*sqrt = NULL;
return 0;
}
}
if (r) /* a is odd */
{
mp_limb_t roots[2];
slong i, ex, pow;
if ((a & 7) != 1) /* check square root exists */
{
*sqrt = NULL;
return 0;
}
roots[0] = 1; /* one of each pair of square roots mod 8 */
roots[1] = 3;
pow = 8;
for (ex = 3; ex < exp; ex++, pow *= 2) /* lift roots */
{
i = 0;
r = roots[0];
if (((r*r) & (2*pow - 1)) == (a & (2*pow - 1)))
roots[i++] = r;
r = pow - r;
if (((r*r) & (2*pow - 1)) == (a & (2*pow - 1)))
{
roots[i++] = r;
if (i == 2) continue;
}
r = roots[1];
if (((r*r) & (2*pow - 1)) == (a & (2*pow - 1)))
{
roots[i++] = r;
if (i == 2) continue;
}
r = pow - r;
roots[i] = r;
}
*sqrt = flint_malloc(sizeof(mp_limb_t)*4);
(*sqrt)[0] = roots[0]; /* write out both pairs of roots */
(*sqrt)[1] = pow - roots[0];
(*sqrt)[2] = roots[1];
(*sqrt)[3] = pow - roots[1];
return 4;
} else /* a is even */
{
slong i, k, num, pow;
for (k = 2; k <= exp; k++) /* find highest power of 2 dividing a */
{
if (a & ((UWORD(1)<<k) - 1))
break;
}
k--;
if (a == 0)
{
a = (UWORD(1)<<(exp - k/2));
num = (UWORD(1)<<(k/2));
s = flint_malloc(num*sizeof(mp_limb_t));
for (i = 0; i < num; i++)
s[i] = i*a;
*sqrt = s;
return num;
}
if (k & 1) /* not a square */
{
*sqrt = NULL;
return 0;
}
pow = (UWORD(1)<<k);
exp -= k;
a /= pow;
num = n_sqrtmod_2pow(&s, a, exp); /* divide through by 2^k and recurse */
a = (UWORD(1)<<(k/2));
r = a*(UWORD(1)<<exp);
if (num == 0) /* check that roots were actually returned */
{
*sqrt = NULL;
return 0;
}
for (i = 0; i < num; i++) /* multiply roots by 2^(k/2) */
s[i] *= a;
if (num == 1) /* one root */
{
s = flint_realloc(s, a*sizeof(mp_limb_t));
for (i = 1; i < a; i++)
s[i] = s[i - 1] + r;
} else if (num == 2) /* two roots */
{
s = flint_realloc(s, 2*a*sizeof(mp_limb_t));
for (i = 1; i < a; i++)
{
s[2*i] = s[2*i - 2] + r;
s[2*i + 1] = s[2*i - 1] + r;
}
} else /* num == 4, i.e. four roots */
{
s = flint_realloc(s, 4*a*sizeof(mp_limb_t));
for (i = 1; i < a; i++)
{
s[4*i] = s[4*i - 4] + r;
s[4*i + 1] = s[4*i - 3] + r;
s[4*i + 2] = s[4*i - 2] + r;
s[4*i + 3] = s[4*i - 1] + r;
}
}
*sqrt = s;
return num*a;
}
}
slong n_sqrtmod_primepow(mp_limb_t ** sqrt, mp_limb_t a, mp_limb_t p, slong exp)
{
mp_limb_t r, ex, pow, k, a1, pinv, powinv;
mp_limb_t * s;
slong i, num;
if (exp < 0)
{
flint_printf("Exception (n_sqrtmod_primepow). exp must be non-negative.\n");
abort();
}
if (exp == 0) /* special case, sqrt of 0 mod 1 */
{
*sqrt = flint_malloc(sizeof(mp_limb_t));
(*sqrt)[0] = 0;
return 1;
}
if (p == 2) /* deal with p = 2 specially */
return n_sqrtmod_2pow(sqrt, a, exp);
if (exp == 1) /* special case, roots mod p */
{
r = n_sqrtmod(a, p);
if (r == 0 && a != 0)
{
*sqrt = NULL;
return 0;
}
*sqrt = flint_malloc(sizeof(mp_limb_t)*(1 + (r != 0)));
(*sqrt)[0] = r;
if (r) (*sqrt)[1] = p - r;
return 1 + (r != 0);
}
pinv = n_preinvert_limb(p);
a1 = n_mod2_preinv(a, p, pinv);
r = n_sqrtmod(a1, p);
if (r == 0 && a1 != 0)
{
*sqrt = NULL;
return 0;
}
if (r) /* gcd(a, p) = 1, p is odd, lift r and p - r */
{
for (ex = 1, pow = p; ex < exp; ex++, pow *= p) /* lift root */
{
/* set k = ((r^2 - a) mod (p^(ex + 1))) / p^ex */
powinv = n_preinvert_limb(pow*p);
a1 = n_mulmod2_preinv(r, r, pow*p, powinv);
k = (a < a1 ? a1 - a : a - a1);
k = n_mod2_preinv(k, pow*p, powinv);
k /= pow;
if (a < a1)
k = n_negmod(k, p);
/* set k = k / 2r mod p */
a1 = n_mulmod2_preinv(2, r, p, pinv);
k = n_mulmod2_preinv(n_invmod(a1, p), k, p, pinv);
/* set r = r + k*p^ex */
r += k*pow;
}
*sqrt = flint_malloc(sizeof(mp_limb_t)*2);
(*sqrt)[0] = r;
(*sqrt)[1] = pow - r;
return 2;
} else /* special case, one root lifts to p roots */
{
for (k = 1, pow = p; k < exp; k++) /* find highest power of p dividing a */
{
mp_limb_t pow2 = pow * p;
if (a % pow2 != 0)
break;
pow = pow2;
}
if (a == 0) /* special case, a == 0 */
{
a = n_pow(p, exp - k/2);
num = n_pow(p, k/2);
s = flint_malloc(num*sizeof(mp_limb_t));
for (i = 0; i < num; i++)
s[i] = i*a;
*sqrt = s;
return num;
}
if (k & 1) /* not a square */
{
*sqrt = NULL;
return 0;
}
exp -= k;
a /= pow;
num = n_sqrtmod_primepow(&s, a, p, exp); /* divide through by p^k and recurse */
if (num == 0)
{
*sqrt = NULL;
return 0;
}
a = n_pow(p, k/2);
r = a*n_pow(p, exp);
s[0] *= a; /* multiply roots by p^(k/2) */
s[1] *= a;
s = flint_realloc(s, 2*a*sizeof(mp_limb_t));
for (i = 1; i < a; i++)
{
s[2*i] = s[2*i - 2] + r;
s[2*i + 1] = s[2*i - 1] + r;
}
*sqrt = s;
return 2*a;
}
}