/*
* SPDX-License-Identifier: CC0-1.0
*
* Copyright (C) 2024, 2025 W. Kosior <koszko@koszko.org>
*/
#include <stdbool.h>
#include <stdlib.h>
#include <sys/types.h>
#include <flint/flint.h>
#include <flint/fmpz_mod.h>
#include "pqcrypto_poly.h"
#include "pqcrypto_prng_getrandom.h"
/***
*** Modulo operations.
***/
void mod_c0_ctx_init(mod_centered_0_ctx_t ctx, fmpz_t divisor) {
fmpz_init_set(ctx[0].divisor, divisor);
fmpz_init(ctx[0].range_max);
fmpz_sub_ui(ctx[0].range_max, ctx[0].divisor, 1);
fmpz_tdiv_q_2exp(ctx[0].range_max, ctx[0].range_max, 1);
}
void mod_c0_ctx_clear(mod_centered_0_ctx_t ctx) {
fmpz_clear(ctx[0].divisor);
fmpz_clear(ctx[0].range_max);
}
void mod_c0(fmpz_t value, mod_centered_0_ctx_t const ctx) {
fmpz_add(value, value, ctx[0].range_max);
fmpz_fdiv_r(value, value, ctx[0].divisor);
fmpz_sub(value, value, ctx[0].range_max);
}
void mod_c0_ctx_init_set(mod_centered_0_ctx_t dst_ctx,
mod_centered_0_ctx_t const src_ctx) {
fmpz_init_set(dst_ctx[0].divisor, src_ctx[0].divisor);
fmpz_init_set(dst_ctx[0].range_max, src_ctx[0].range_max);
}
void mod_c0_prng(fmpz_t value, mod_centered_0_ctx_t const ctx,
prng_t prng, void * prng_state) {
flint_bitcnt_t mod_bits_count = fmpz_bits(ctx[0].divisor);
flint_bitcnt_t high_limb_bits_count = mod_bits_count % FLINT_BITS;
ulong mod_limbs_count =
mod_bits_count / FLINT_BITS + !!high_limb_bits_count;
ulong high_limb_bit_mask =
(((unsigned long long) 1) << high_limb_bits_count) - 1;
size_t rand_bytes_count = mod_limbs_count * sizeof(ulong);
ulong * rand_limbs = malloc(rand_bytes_count);
if (!rand_limbs)
abort();
do {
prng(rand_limbs, rand_bytes_count, prng_state);
if (high_limb_bits_count)
rand_limbs[mod_limbs_count - 1] &= high_limb_bit_mask;
fmpz_set_ui_array(value, rand_limbs, mod_limbs_count);
} while (fmpz_cmp(value, ctx[0].divisor) >= 0);
fmpz_sub(value, value, ctx[0].range_max);
free(rand_limbs);
}
void mod_c0_rand(fmpz_t value, mod_centered_0_ctx_t const ctx) {
mod_c0_prng(value, ctx, prng_getrandom, NULL);
}
/***
*** Simple operations on polynomials.
***/
void poly_prng_mod_c0(fmpz_poly_t res, prng_t prng, void * prng_state,
ulong coef_count, mod_centered_0_ctx_t const ctx) {
fmpz_t new_coef_value;
if (coef_count > (ulong) WORD_MAX)
abort();
fmpz_init(new_coef_value);
fmpz_poly_realloc(res, coef_count);
for (ulong coef_idx = 0; coef_idx < coef_count; coef_idx++) {
mod_c0_prng(new_coef_value, ctx, prng, prng_state);
fmpz_poly_set_coeff_fmpz(res, coef_idx, new_coef_value);
}
fmpz_clear(new_coef_value);
}
void poly_rand_mod_c0(fmpz_poly_t res, ulong coef_count,
mod_centered_0_ctx_t const ctx) {
poly_prng_mod_c0(res, prng_getrandom, NULL, coef_count, ctx);
}
bool poly_all_abs_leq(fmpz_poly_struct const * poly, fmpz const * max_coef) {
ulong length = fmpz_poly_length(poly);
fmpz_t coef;
bool result = true;
fmpz_init(coef);
for (ulong i = 0; i < length; i++) {
fmpz_poly_get_coeff_fmpz(coef, poly, i);
fmpz_abs(coef, coef);
if (fmpz_cmp(coef, max_coef) > 0) {
result = false;
break;
}
}
fmpz_clear(coef);
return result;
}
bool poly_all_abs_leq_ui(fmpz_poly_struct const * poly, ulong max_coef) {
fmpz_t max_coef_fmpz;
bool result;
fmpz_init_set_ui(max_coef_fmpz, max_coef);
result = poly_all_abs_leq(poly, max_coef_fmpz);
fmpz_clear(max_coef_fmpz);
return result;
}
/***
*** Operations in polynomial rings.
***/
void poly_ring_ctx_init(poly_ring_ctx_t ctx, mod_centered_0_ctx_t mod_ctx,
ulong divisor_degree) {
if (divisor_degree > (ulong) WORD_MAX)
abort();
mod_c0_ctx_init_set(ctx[0].mod_ctx, mod_ctx);
ctx[0].divisor_degree = divisor_degree;
}
void poly_ring_ctx_clear(poly_ring_ctx_t ctx) {
mod_c0_ctx_clear(ctx[0].mod_ctx);
}
void poly_to_ring(fmpz_poly_t poly, poly_ring_ctx_t const ctx) {
ulong degree = fmpz_poly_degree(poly);
fmpz_t new_coef_value;
fmpz_init(new_coef_value);
for (ulong coef_idx = 0; coef_idx < ctx[0].divisor_degree; coef_idx++) {
int sign = 1;
ulong higher_coef_idx = coef_idx;
fmpz_poly_get_coeff_fmpz(new_coef_value, poly, coef_idx);
/*
* Polynomial division by X^m+1 can be achieved by substituting
* -1 for X^m.
*/
do {
fmpz const * higher_coef;
sign *= -1;
if (UWORD_MAX - ctx[0].divisor_degree < higher_coef_idx)
break;
higher_coef_idx += ctx[0].divisor_degree;
if (higher_coef_idx > degree)
break;
higher_coef =
fmpz_poly_get_coeff_ptr(poly, higher_coef_idx);
(sign == 1 ? &fmpz_add : &fmpz_sub)
(new_coef_value, new_coef_value, higher_coef);
} while (true);
mod_c0(new_coef_value, ctx[0].mod_ctx);
fmpz_poly_set_coeff_fmpz(poly, coef_idx, new_coef_value);
}
if (degree >= ctx[0].divisor_degree)
fmpz_poly_realloc(poly, ctx[0].divisor_degree);
fmpz_clear(new_coef_value);
}
void poly_mul_in_ring(fmpz_poly_t res, fmpz_poly_t const poly1,
fmpz_poly_t const poly2, poly_ring_ctx_t const ctx) {
fmpz_poly_mul(res, poly1, poly2);
poly_to_ring(res, ctx);
}
void poly_add_in_ring(fmpz_poly_t res, fmpz_poly_t const poly1,
fmpz_poly_t const poly2, poly_ring_ctx_t const ctx) {
fmpz_poly_add(res, poly1, poly2);
poly_to_ring(res, ctx);
}
void poly_sub_in_ring(fmpz_poly_t res, fmpz_poly_t const poly1,
fmpz_poly_t const poly2, poly_ring_ctx_t const ctx) {
fmpz_poly_sub(res, poly1, poly2);
poly_to_ring(res, ctx);
}
void poly_prng_in_ring(fmpz_poly_t res, prng_t prng, void * prng_state,
poly_ring_ctx_t const ctx) {
poly_prng_mod_c0(res, prng, prng_state, ctx[0].divisor_degree,
ctx[0].mod_ctx);
}
void poly_rand_in_ring(fmpz_poly_t res, poly_ring_ctx_t const ctx) {
poly_prng_in_ring(res, prng_getrandom, NULL, ctx);
}