/*=============================================================================

    This file is part of FLINT.

    FLINT is free software; you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation; either version 2 of the License, or
    (at your option) any later version.

    FLINT is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with FLINT; if not, write to the Free Software
    Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301 USA

=============================================================================*/
/******************************************************************************

    Copyright (C) 2012 William Hart

******************************************************************************/

#define ulong ulongxx /* interferes with system includes */
#include <stdio.h>
#include <stdlib.h>
#undef ulong
#define ulong mp_limb_t
#include <gmp.h>
#include "flint.h"
#include "ulong_extras.h"

slong n_sqrtmod_2pow(mp_limb_t ** sqrt, mp_limb_t a, slong exp) 
{
    mp_limb_t r = (a & 1);
    mp_limb_t * s;

    if (exp == 0) /* special case for sqrt of 0 mod 1 */
    {
        *sqrt = flint_malloc(sizeof(mp_limb_t));
        (*sqrt)[0] = 0;

        return 1;
    }

    if (exp == 1) /* special case mod 2 */
    {
        *sqrt = flint_malloc(sizeof(mp_limb_t));

        if (r) (*sqrt)[0] = 1;
        else (*sqrt)[0] = 0;

        return 1;
    }

    if (exp == 2) /* special case mod 4 */
    {
        r = (a & 3);
        
        if (r < 2) /* 0, 1 mod 4 */
        {
           *sqrt = flint_malloc(sizeof(mp_limb_t)*2);

           (*sqrt)[0] = r;
           (*sqrt)[1] = r + 2;
           
           return 2;
        } else /* 2, 3 mod 4 */
        {
           *sqrt = NULL;
           return 0;
        }
    }

    if (r) /* a is odd */
    {
        mp_limb_t roots[2];
        slong i, ex, pow;

        if ((a & 7) != 1) /* check square root exists */
        {
            *sqrt = NULL;
            return 0;
        }

        roots[0] = 1; /* one of each pair of square roots mod 8 */
        roots[1] = 3;

        pow = 8;
        for (ex = 3; ex < exp; ex++, pow *= 2) /* lift roots */
        {
            i = 0;
            
            r = roots[0];
            if (((r*r) & (2*pow - 1)) == (a & (2*pow - 1)))
               roots[i++] = r;
               
            r = pow - r;
            if (((r*r) & (2*pow - 1)) == (a & (2*pow - 1)))
            {
               roots[i++] = r;
               if (i == 2) continue;
            }
               
            r = roots[1];
            if (((r*r) & (2*pow - 1)) == (a & (2*pow - 1)))
            {
               roots[i++] = r;
               if (i == 2) continue;
            }
               
            r = pow - r;
            roots[i] = r;
        }

        *sqrt = flint_malloc(sizeof(mp_limb_t)*4);

        (*sqrt)[0] = roots[0]; /* write out both pairs of roots */
        (*sqrt)[1] = pow - roots[0];
        (*sqrt)[2] = roots[1];
        (*sqrt)[3] = pow - roots[1];
        
        return 4;        
    } else /* a is even */
    {
        slong i, k, num, pow;
        
        for (k = 2; k <= exp; k++) /* find highest power of 2 dividing a */
        {
            if (a & ((UWORD(1)<<k) - 1))
                break;
        }
        k--;
        
        if (a == 0)
        {
           a = (UWORD(1)<<(exp - k/2));
           num = (UWORD(1)<<(k/2));
           s = flint_malloc(num*sizeof(mp_limb_t));
           for (i = 0; i < num; i++)
               s[i] = i*a;

           *sqrt = s;
           return num;
        }

        if (k & 1) /* not a square */
        {
            *sqrt = NULL;
            return 0;
        }
        
        pow = (UWORD(1)<<k);

        exp -= k;
        a /= pow;

        num = n_sqrtmod_2pow(&s, a, exp); /* divide through by 2^k and recurse */
        
        a = (UWORD(1)<<(k/2));
        r = a*(UWORD(1)<<exp);
        
        if (num == 0) /* check that roots were actually returned */
        {
            *sqrt = NULL;
            return 0;
        }

        for (i = 0; i < num; i++) /* multiply roots by 2^(k/2) */
            s[i] *= a;

        if (num == 1) /* one root */
        {
           s = flint_realloc(s, a*sizeof(mp_limb_t));

           for (i = 1; i < a; i++)
              s[i] = s[i - 1] + r;

        } else if (num == 2) /* two roots */
        {
           s = flint_realloc(s, 2*a*sizeof(mp_limb_t));

           for (i = 1; i < a; i++)
           {
              s[2*i] = s[2*i - 2] + r;
              s[2*i + 1] = s[2*i - 1] + r;
           }
        } else /* num == 4, i.e. four roots */
        {
           s = flint_realloc(s, 4*a*sizeof(mp_limb_t));

           for (i = 1; i < a; i++)
           {
              s[4*i] = s[4*i - 4] + r;
              s[4*i + 1] = s[4*i - 3] + r;
              s[4*i + 2] = s[4*i - 2] + r;
              s[4*i + 3] = s[4*i - 1] + r;
           }
        }

        *sqrt = s;

        return num*a;
    }
}

slong n_sqrtmod_primepow(mp_limb_t ** sqrt, mp_limb_t a, mp_limb_t p, slong exp) 
{
    mp_limb_t r, ex, pow, k, a1, pinv, powinv;
    mp_limb_t * s;
    slong i, num;

    if (exp < 0)
    {
        flint_printf("Exception (n_sqrtmod_primepow). exp must be non-negative.\n");
        abort();
    }

    if (exp == 0) /* special case, sqrt of 0 mod 1 */
    {
        *sqrt = flint_malloc(sizeof(mp_limb_t));
        (*sqrt)[0] = 0;

        return 1;
    }

    if (p == 2) /* deal with p = 2 specially */
       return n_sqrtmod_2pow(sqrt, a, exp);

    if (exp == 1) /* special case, roots mod p */
    {
        r = n_sqrtmod(a, p);
        
        if (r == 0 && a != 0)
        {
            *sqrt = NULL;
            return 0;
        }

        *sqrt = flint_malloc(sizeof(mp_limb_t)*(1 + (r != 0)));
        (*sqrt)[0] = r;
        if (r) (*sqrt)[1] = p - r;
       
        return 1 + (r != 0);
    }

    pinv = n_preinvert_limb(p);
    a1 = n_mod2_preinv(a, p, pinv);
    r = n_sqrtmod(a1, p);

    if (r == 0 && a1 != 0)
    {
        *sqrt = NULL;
        return 0;
    }
    
    if (r) /* gcd(a, p) = 1, p is odd, lift r and p - r */
    {
        for (ex = 1, pow = p; ex < exp; ex++, pow *= p) /* lift root */
        {
            /* set k = ((r^2 - a) mod (p^(ex + 1))) / p^ex */
            powinv = n_preinvert_limb(pow*p);
            a1 = n_mulmod2_preinv(r, r, pow*p, powinv);
            k = (a < a1 ? a1 - a : a - a1);
            k = n_mod2_preinv(k, pow*p, powinv);
            k /= pow;
            if (a < a1)
                k = n_negmod(k, p);

            /* set k = k / 2r mod p */
            a1 = n_mulmod2_preinv(2, r, p, pinv);
            k = n_mulmod2_preinv(n_invmod(a1, p), k, p, pinv);

            /* set r = r + k*p^ex */
            r += k*pow;
        }

        *sqrt = flint_malloc(sizeof(mp_limb_t)*2);
        (*sqrt)[0] = r;
        (*sqrt)[1] = pow - r;

        return 2;
    } else /* special case, one root lifts to p roots */
    {
        for (k = 1, pow = p; k < exp; k++) /* find highest power of p dividing a */
        {
            mp_limb_t pow2 = pow * p;

            if (a % pow2 != 0)
                break;
            
            pow = pow2;
        }

        if (a == 0) /* special case, a == 0 */
        {
           a = n_pow(p, exp - k/2);
           num = n_pow(p, k/2);
           s = flint_malloc(num*sizeof(mp_limb_t));
           for (i = 0; i < num; i++)
               s[i] = i*a;

           *sqrt = s;
           return num;
        }

        if (k & 1) /* not a square */
        {
            *sqrt = NULL;
            return 0;
        }

        exp -= k;
        a /= pow;

        num = n_sqrtmod_primepow(&s, a, p, exp); /* divide through by p^k and recurse */

        if (num == 0)
        {
            *sqrt = NULL;
            return 0;
        }

        a = n_pow(p, k/2);
        r = a*n_pow(p, exp);
        
        s[0] *= a; /* multiply roots by p^(k/2) */
        s[1] *= a;
        
        s = flint_realloc(s, 2*a*sizeof(mp_limb_t));

        for (i = 1; i < a; i++) 
        {
            s[2*i] = s[2*i - 2] + r;
            s[2*i + 1] = s[2*i - 1] + r;
        }

        *sqrt = s;

        return 2*a;
    }
}