pqc/external/flint-2.4.3/nmod_matxx.h

556 lines
18 KiB
C++

/*=============================================================================
This file is part of FLINT.
FLINT is free software; you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation; either version 2 of the License, or
(at your option) any later version.
FLINT is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with FLINT; if not, write to the Free Software
Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
=============================================================================*/
/******************************************************************************
Copyright (C) 2013 Tom Bachmann
******************************************************************************/
#ifndef NMOD_MATXX_H
#define NMOD_MATXX_H
#include <string>
#include <vector>
#include "nmod_mat.h"
#include "nmod_vecxx.h"
#include "fmpz_matxx.h" // for modular reduction
#include "permxx.h"
#include "flintxx/flint_exception.h"
#include "flintxx/ltuple.h"
#include "flintxx/matrix.h"
// TODO addmul
// TODO default argument for mat_solve_triu etc?
// TODO nullspace member
// TODO unnecessary perm copies in set_lu*
namespace flint {
FLINT_DEFINE_BINOP(solve_vec)
FLINT_DEFINE_BINOP(mul_strassen)
FLINT_DEFINE_THREEARY(solve_tril)
FLINT_DEFINE_THREEARY(solve_tril_classical)
FLINT_DEFINE_THREEARY(solve_tril_recursive)
FLINT_DEFINE_THREEARY(solve_triu)
FLINT_DEFINE_THREEARY(solve_triu_classical)
FLINT_DEFINE_THREEARY(solve_triu_recursive)
FLINT_DEFINE_THREEARY(multi_CRT_precomp)
namespace detail {
template<class Mat>
struct nmod_matxx_traits : matrices::generic_traits<Mat> { };
} // detail
template<class Operation, class Data>
class nmod_matxx_expression
: public expression<derived_wrapper<nmod_matxx_expression>, Operation, Data>
{
public:
typedef expression<derived_wrapper< ::flint::nmod_matxx_expression>,
Operation, Data> base_t;
typedef detail::nmod_matxx_traits<nmod_matxx_expression> traits_t;
FLINTXX_DEFINE_BASICS(nmod_matxx_expression)
FLINTXX_DEFINE_CTORS(nmod_matxx_expression)
FLINTXX_DEFINE_C_REF(nmod_matxx_expression, nmod_mat_struct, _mat)
// These only make sense with immediates
nmodxx_ctx_srcref _ctx() const
{return nmodxx_ctx_srcref::make(_mat()->mod);}
// These work on any expression without evaluation
nmodxx_ctx_srcref estimate_ctx() const
{
return tools::find_nmodxx_ctx(*this);
}
mp_limb_t modulus() const {return estimate_ctx().n();}
template<class Expr>
static evaluated_t create_temporary_rowscols(
const Expr& e, slong rows, slong cols)
{
return evaluated_t(rows, cols, tools::find_nmodxx_ctx(e).n());
}
FLINTXX_DEFINE_MATRIX_METHODS(traits_t)
template<class Fmpz_mat>
static nmod_matxx_expression reduce(const Fmpz_mat& mat,
mp_limb_t modulus,
typename mp::enable_if<traits::is_fmpz_matxx<Fmpz_mat> >::type* = 0)
{
nmod_matxx_expression res(mat.rows(), mat.cols(), modulus);
fmpz_mat_get_nmod_mat(res._mat(), mat.evaluate()._mat());
return res;
}
static nmod_matxx_expression randtest(slong rows, slong cols, mp_limb_t n,
frandxx& state)
{
nmod_matxx_expression res(rows, cols, n);
res.set_randtest(state);
return res;
}
static nmod_matxx_expression randfull(slong rows, slong cols, mp_limb_t n,
frandxx& state)
{
nmod_matxx_expression res(rows, cols, n);
res.set_randfull(state);
return res;
}
static nmod_matxx_expression randrank(slong rows, slong cols, mp_limb_t n,
frandxx& state, slong rank)
{
nmod_matxx_expression res(rows, cols, n);
res.set_randrank(state, rank);
return res;
}
static nmod_matxx_expression randtril(slong rows, slong cols, mp_limb_t n,
frandxx& state, bool unit)
{
nmod_matxx_expression res(rows, cols, n);
res.set_randtril(state, unit);
return res;
}
static nmod_matxx_expression randtriu(slong rows, slong cols, mp_limb_t n,
frandxx& state,
bool unit)
{
nmod_matxx_expression res(rows, cols, n);
res.set_randtriu(state, unit);
return res;
}
template<class Vec>
static nmod_matxx_expression randpermdiag(slong rows, slong cols, mp_limb_t n,
frandxx& state, const Vec& v)
{
nmod_matxx_expression res(rows, cols, n);
res.set_randpermdiag(state, v);
return res;
}
static nmod_matxx_expression zero(slong rows, slong cols, mp_limb_t n)
{return nmod_matxx_expression(rows, cols, n);}
// these only make sense with targets
void set_randtest(frandxx& state)
{nmod_mat_randtest(_mat(), state._data());}
void set_randfull(frandxx& state)
{nmod_mat_randfull(_mat(), state._data());}
void set_randrank(frandxx& state, slong rank)
{nmod_mat_randrank(_mat(), state._data(), rank);}
void set_randtril(frandxx& state, bool unit)
{nmod_mat_randtril(_mat(), state._data(), unit);}
void set_randtriu(frandxx& state, bool unit)
{nmod_mat_randtriu(_mat(), state._data(), unit);}
template<class Vec>
int set_randpermdiag(frandxx& state, const Vec& v)
{
return nmod_mat_randpermdiag(_mat(), state._data(), v._array(), v.size());
}
void apply_randops(frandxx& state, slong count)
{nmod_mat_randops(_mat(), count, state._data());}
slong set_rref() {return nmod_mat_rref(_mat());}
void set_zero() {nmod_mat_zero(_mat());}
typedef mp::make_tuple<slong, permxx>::type lu_rt;
lu_rt set_lu(bool rank_check = false)
{
lu_rt res = mp::make_tuple<slong, permxx>::make(0, permxx(rows()));
res.first() = nmod_mat_lu(res.second()._data(), _mat(), rank_check);
return res;
}
lu_rt set_lu_classical(bool rank_check = false)
{
lu_rt res = mp::make_tuple<slong, permxx>::make(0, permxx(rows()));
res.first() = nmod_mat_lu_classical(
res.second()._data(), _mat(), rank_check);
return res;
}
lu_rt set_lu_recursive(bool rank_check = false)
{
lu_rt res = mp::make_tuple<slong, permxx>::make(0, permxx(rows()));
res.first() = nmod_mat_lu_recursive(
res.second()._data(), _mat(), rank_check);
return res;
}
// these cause evaluation
slong rank() const {return nmod_mat_rank(this->evaluate()._mat());}
bool is_zero() const {return nmod_mat_is_zero(this->evaluate()._mat());}
bool is_empty() const {return nmod_mat_is_empty(this->evaluate()._mat());}
bool is_square() const {return nmod_mat_is_square(this->evaluate()._mat());}
// lazy members
FLINTXX_DEFINE_MEMBER_BINOP(solve)
FLINTXX_DEFINE_MEMBER_BINOP(mul_classical)
FLINTXX_DEFINE_MEMBER_BINOP(mul_strassen)
FLINTXX_DEFINE_MEMBER_UNOP(inv)
FLINTXX_DEFINE_MEMBER_UNOP(transpose)
FLINTXX_DEFINE_MEMBER_UNOP_RTYPE(nmodxx, trace)
FLINTXX_DEFINE_MEMBER_UNOP_RTYPE(nmodxx, det)
//FLINTXX_DEFINE_MEMBER_UNOP_RTYPE(???, nullspace) // TODO
FLINTXX_DEFINE_MEMBER_3OP(solve_tril)
FLINTXX_DEFINE_MEMBER_3OP(solve_tril_recursive)
FLINTXX_DEFINE_MEMBER_3OP(solve_tril_classical)
FLINTXX_DEFINE_MEMBER_3OP(solve_triu)
FLINTXX_DEFINE_MEMBER_3OP(solve_triu_recursive)
FLINTXX_DEFINE_MEMBER_3OP(solve_triu_classical)
};
namespace detail {
struct nmod_mat_data;
} // detail
typedef nmod_matxx_expression<operations::immediate, detail::nmod_mat_data> nmod_matxx;
typedef nmod_matxx_expression<operations::immediate,
flint_classes::ref_data<nmod_matxx, nmod_mat_struct> > nmod_matxx_ref;
typedef nmod_matxx_expression<operations::immediate, flint_classes::srcref_data<
nmod_matxx, nmod_matxx_ref, nmod_mat_struct> > nmod_matxx_srcref;
template<>
struct matrix_traits<nmod_matxx>
{
template<class M> static slong rows(const M& m)
{
return nmod_mat_nrows(m._mat());
}
template<class M> static slong cols(const M& m)
{
return nmod_mat_ncols(m._mat());
}
template<class M> static nmodxx_srcref at(const M& m, slong i, slong j)
{
return nmodxx_srcref::make(nmod_mat_entry(m._mat(), i, j),
m.estimate_ctx());
}
template<class M> static nmodxx_ref at(M& m, slong i, slong j)
{
return nmodxx_ref::make(nmod_mat_entry(m._mat(), i, j),
m.estimate_ctx());
}
};
namespace traits {
template<> struct has_nmodxx_ctx<nmod_matxx> : mp::true_ { };
template<> struct has_nmodxx_ctx<nmod_matxx_ref> : mp::true_ { };
template<> struct has_nmodxx_ctx<nmod_matxx_srcref> : mp::true_ { };
} // traits
namespace detail {
template<>
struct nmod_matxx_traits<nmod_matxx_srcref>
: matrices::generic_traits_srcref<nmodxx_srcref> { };
template<>
struct nmod_matxx_traits<nmod_matxx_ref>
: matrices::generic_traits_ref<nmodxx_ref> { };
template<> struct nmod_matxx_traits<nmod_matxx>
: matrices::generic_traits_nonref<nmodxx_ref, nmodxx_srcref> { };
struct nmod_mat_data
{
typedef nmod_mat_t& data_ref_t;
typedef const nmod_mat_t& data_srcref_t;
nmod_mat_t inner;
nmod_mat_data(slong m, slong n, mp_limb_t modulus)
{
nmod_mat_init(inner, m, n, modulus);
}
nmod_mat_data(const nmod_mat_data& o)
{
nmod_mat_init_set(inner, o.inner);
}
nmod_mat_data(nmod_matxx_srcref o)
{
nmod_mat_init_set(inner, o._data().inner);
}
~nmod_mat_data() {nmod_mat_clear(inner);}
};
} // detail
namespace matrices {
template<>
struct outsize<operations::mul_strassen_op>
: outsize<operations::times> { };
template<> struct outsize<operations::solve_tril_op>
: outsize<operations::solve_op> { };
template<> struct outsize<operations::solve_tril_classical_op>
: outsize<operations::solve_op> { };
template<> struct outsize<operations::solve_tril_recursive_op>
: outsize<operations::solve_op> { };
template<> struct outsize<operations::solve_triu_op>
: outsize<operations::solve_op> { };
template<> struct outsize<operations::solve_triu_classical_op>
: outsize<operations::solve_op> { };
template<> struct outsize<operations::solve_triu_recursive_op>
: outsize<operations::solve_op> { };
}
// temporary instantiation stuff
FLINTXX_DEFINE_TEMPORARY_RULES(nmod_matxx)
#define NMOD_MATXX_COND_S FLINTXX_COND_S(nmod_matxx)
#define NMOD_MATXX_COND_T FLINTXX_COND_T(nmod_matxx)
namespace traits {
template<class T> struct is_nmod_matxx
: flint_classes::is_Base<nmod_matxx, T> { };
} // traits
namespace rules {
FLINT_DEFINE_DOIT_COND2(assignment, NMOD_MATXX_COND_T, NMOD_MATXX_COND_S,
nmod_mat_set(to._mat(), from._mat()))
FLINTXX_DEFINE_SWAP(nmod_matxx, nmod_mat_swap(e1._mat(), e2._mat()))
FLINTXX_DEFINE_EQUALS(nmod_matxx, nmod_mat_equal(e1._mat(), e2._mat()))
FLINT_DEFINE_PRINT_PRETTY_COND(NMOD_MATXX_COND_S,
(nmod_mat_print_pretty(from._mat()), 1))
FLINT_DEFINE_THREEARY_EXPR_COND3(mat_at_op, nmodxx,
NMOD_MATXX_COND_S, traits::fits_into_slong, traits::fits_into_slong,
to.set_nored(nmod_mat_entry(e1._mat(), e2, e3)))
FLINT_DEFINE_BINARY_EXPR_COND2(times, nmod_matxx,
NMOD_MATXX_COND_S, NMOD_MATXX_COND_S,
nmod_mat_mul(to._mat(), e1._mat(), e2._mat()))
FLINT_DEFINE_CBINARY_EXPR_COND2(times, nmod_matxx,
NMOD_MATXX_COND_S, NMODXX_COND_S,
nmod_mat_scalar_mul(to._mat(), e1._mat(), e2._limb()))
FLINT_DEFINE_BINARY_EXPR_COND2(plus, nmod_matxx,
NMOD_MATXX_COND_S, NMOD_MATXX_COND_S,
nmod_mat_add(to._mat(), e1._mat(), e2._mat()))
FLINT_DEFINE_BINARY_EXPR_COND2(minus, nmod_matxx,
NMOD_MATXX_COND_S, NMOD_MATXX_COND_S,
nmod_mat_sub(to._mat(), e1._mat(), e2._mat()))
FLINT_DEFINE_UNARY_EXPR_COND(negate, nmod_matxx, NMOD_MATXX_COND_S,
nmod_mat_neg(to._mat(), from._mat()))
FLINT_DEFINE_UNARY_EXPR_COND(transpose_op, nmod_matxx, NMOD_MATXX_COND_S,
nmod_mat_transpose(to._mat(), from._mat()))
FLINT_DEFINE_UNARY_EXPR_COND(trace_op, nmodxx, NMOD_MATXX_COND_S,
to.set_nored(nmod_mat_trace(from._mat())))
FLINT_DEFINE_BINARY_EXPR_COND2(mul_classical_op, nmod_matxx,
NMOD_MATXX_COND_S, NMOD_MATXX_COND_S,
nmod_mat_mul(to._mat(), e1._mat(), e2._mat()))
FLINT_DEFINE_BINARY_EXPR_COND2(mul_strassen_op, nmod_matxx,
NMOD_MATXX_COND_S, NMOD_MATXX_COND_S,
nmod_mat_mul(to._mat(), e1._mat(), e2._mat()))
FLINT_DEFINE_UNARY_EXPR_COND(det_op, nmodxx, NMOD_MATXX_COND_S,
to.set_nored(nmod_mat_det(from._mat())))
FLINT_DEFINE_UNARY_EXPR_COND(inv_op, nmod_matxx, NMOD_MATXX_COND_S,
execution_check(nmod_mat_inv(to._mat(), from._mat()),
"inv", "nmod_mat"))
#define NMOD_MATXX_DEFINE_SOLVE_TRI(name) \
FLINT_DEFINE_THREEARY_EXPR_COND3(name##_op, nmod_matxx, \
NMOD_MATXX_COND_S, NMOD_MATXX_COND_S, tools::is_bool, \
nmod_mat_##name(to._mat(), e1._mat(), e2._mat(), e3))
NMOD_MATXX_DEFINE_SOLVE_TRI(solve_tril)
NMOD_MATXX_DEFINE_SOLVE_TRI(solve_tril_classical)
NMOD_MATXX_DEFINE_SOLVE_TRI(solve_tril_recursive)
NMOD_MATXX_DEFINE_SOLVE_TRI(solve_triu)
NMOD_MATXX_DEFINE_SOLVE_TRI(solve_triu_classical)
NMOD_MATXX_DEFINE_SOLVE_TRI(solve_triu_recursive)
FLINT_DEFINE_BINARY_EXPR_COND2(solve_op, nmod_matxx,
NMOD_MATXX_COND_S, NMOD_MATXX_COND_S,
execution_check(nmod_mat_solve(to._mat(), e1._mat(), e2._mat()),
"solve", "nmod_mat"))
FLINT_DEFINE_BINARY_EXPR_COND2(solve_op, nmod_vecxx,
NMOD_MATXX_COND_S, NMOD_VECXX_COND_S,
execution_check(nmod_mat_solve_vec(to._array(), e1._mat(), e2._array()),
"solve_vec", "nmod_mat"))
namespace rdetail {
typedef make_ltuple<mp::make_tuple<slong, nmod_matxx>::type >::type
nmod_mat_nullspace_rt;
} // rdetail
FLINT_DEFINE_UNARY_EXPR_COND(nullspace_op, rdetail::nmod_mat_nullspace_rt,
NMOD_MATXX_COND_S, to.template get<0>() = nmod_mat_nullspace(
to.template get<1>()._mat(), from._mat()))
} // rules
//////////////////////////////////////////////////////////////////////////////
// nmod_mat_vector class
//////////////////////////////////////////////////////////////////////////////
// This class stores a vector of nmod_matxx with differing moduli. It is *not*
// an expression template class!
class nmod_mat_vector
{
private:
nmod_mat_t* data;
std::size_t size_;
void init(const nmod_mat_vector& o)
{
size_ = o.size_;
data = new nmod_mat_t[size_];
for(std::size_t i = 0;i < size_;++i)
nmod_mat_init_set(data[i], o.data[i]);
}
public:
~nmod_mat_vector() {delete[] data;}
nmod_mat_vector(slong rows, slong cols, const std::vector<mp_limb_t>& primes)
{
size_ = primes.size();
data = new nmod_mat_t[primes.size()];
for(std::size_t i = 0;i < primes.size();++i)
nmod_mat_init(data[i], rows, cols, primes[i]);
}
nmod_mat_vector(const nmod_mat_vector& o)
{
init(o);
}
nmod_mat_vector& operator=(const nmod_mat_vector& o)
{
delete[] data;
init(o);
return *this;
}
nmod_matxx_ref operator[](std::size_t idx)
{return nmod_matxx_ref::make(data[idx]);}
nmod_matxx_srcref operator[](std::size_t idx) const
{return nmod_matxx_srcref::make(data[idx]);}
std::size_t size() const {return size_;}
const nmod_mat_t* _data() const {return data;}
nmod_mat_t* _data() {return data;}
bool operator==(const nmod_mat_vector& o)
{
if(size() != o.size())
return false;
for(std::size_t i = 0;i < size();++i)
if((*this)[i] != o[i])
return false;
return true;
}
bool operator!=(const nmod_mat_vector& o)
{
return !(*this == o);
}
template<class Fmpz_mat>
void set_multi_mod(const Fmpz_mat& m,
typename mp::enable_if<traits::is_fmpz_matxx<Fmpz_mat> >::type* = 0)
{
fmpz_mat_multi_mod_ui(data, size(), m.evaluate()._mat());
}
template<class Fmpz_mat>
void set_multi_mod_precomp(const Fmpz_mat& m,
const fmpz_combxx& comb,
typename mp::enable_if<traits::is_fmpz_matxx<Fmpz_mat> >::type* = 0)
{
fmpz_mat_multi_mod_ui_precomp(data, size(), m.evaluate()._mat(),
comb._comb(), comb._temp());
}
};
/////////////////////////////////////////////////////////////////////////////
// chinese remaindering
/////////////////////////////////////////////////////////////////////////////
// Note this operates on fmpz_matxx and fmpz_combxx (as well as nmod_matxx).
// We define it here to deal with the circular dependencies. fmpz_matxx.h
// includes nmod_matxx.h at the bottom.
template<class Fmpz_mat>
inline nmod_mat_vector multi_mod(const Fmpz_mat& m,
const std::vector<mp_limb_t>& primes,
typename mp::enable_if<traits::is_fmpz_matxx<Fmpz_mat> >::type* = 0)
{
nmod_mat_vector res(m.rows(), m.cols(), primes);
res.set_multi_mod(m);
return res;
}
template<class Fmpz_mat>
inline nmod_mat_vector multi_mod_precomp(const Fmpz_mat& m,
const std::vector<mp_limb_t>& primes,
const fmpz_combxx& comb,
typename mp::enable_if<traits::is_fmpz_matxx<Fmpz_mat> >::type* = 0)
{
nmod_mat_vector res(m.rows(), m.cols(), primes);
res.set_multi_mod_precomp(m, comb);
return res;
}
namespace matrices {
// outsize computation for multi-CRT
struct outsize_CRT
{
template<class Mat>
static slong rows(const Mat& m)
{
return m._data().first()[0].rows();
}
template<class Mat>
static slong cols(const Mat& m)
{
return m._data().first()[0].cols();
}
};
template<> struct outsize<operations::multi_CRT_op> : outsize_CRT { };
template<> struct outsize<operations::multi_CRT_precomp_op> : outsize_CRT { };
}
namespace rules {
FLINT_DEFINE_FOURARY_EXPR_COND4(CRT_op, fmpz_matxx,
FMPZ_MATXX_COND_T, FMPZXX_COND_S, NMOD_MATXX_COND_S, tools::is_bool,
fmpz_mat_CRT_ui(to._mat(), e1._mat(), e2._fmpz(), e3._mat(), e4))
FLINT_DEFINE_BINARY_EXPR2(multi_CRT_op, fmpz_matxx, nmod_mat_vector, bool,
fmpz_mat_multi_CRT_ui(to._mat(), (nmod_mat_t * const) e1._data(), e1.size(), e2))
FLINT_DEFINE_THREEARY_EXPR(multi_CRT_precomp_op, fmpz_matxx,
nmod_mat_vector, fmpz_combxx, bool,
fmpz_mat_multi_CRT_ui_precomp(to._mat(), (nmod_mat_t * const) e1._data(), e1.size(),
e2._comb(), e2._temp(), e3))
} // rules
} // flint
#endif