aboutsummaryrefslogtreecommitdiff
/*
 * SPDX-License-Identifier: CC0-1.0
 *
 * Copyright (C) 2024 W. Kosior <koszko@koszko.org>
 */

#include <stdbool.h>
#include <stdio.h>
#include <stdlib.h>
#include <unistd.h>

#include <flint/flint.h>
#include <flint/fmpz.h>
#include <flint/fmpz_mod.h>
#include <flint/fmpz_poly.h>

/* Exponent for Mersenne prime, for testing. */
#define TEST_MERSENNE_EXPONENT 7 /* 89 */

void marsenne_prime_init(fmpz_t prime, const ulong exponent) {
	fmpz_init(prime);

	fmpz_ui_pow_ui(prime, 2, exponent);
	fmpz_sub_ui(prime, prime, 1);
}

void init_read_poly(fmpz_poly_t poly, FILE * file) {
	bool first = true;
	fmpz_t coef;

	fmpz_poly_init(poly);
	fmpz_init(coef);

	for (ulong exponent = 0;; exponent++) {
		int separator_char;

		if (first) {
			first = false;
		} else {
			separator_char = getc(file);

			if (separator_char == '\n')
				break;

			if (separator_char != ' ')
				goto error;
		}

		if (fmpz_fread(file, coef) < 0)
			goto error;

		fmpz_poly_set_coeff_fmpz(poly, exponent, coef);
	}

	fmpz_clear(coef);
	return;

error:
	fprintf(stderr, "Error reading polynomial.\n");
	abort();
}

/*
 * FLINT seems to assume all modulo operations are performed on integers in
 * range [0, n-1].  Here we provide a facility for performing modulo operations
 * on big integers in range [-(n-1)/2, (n-1)/2].
 */

struct mod_centered_0_ctx {
	fmpz_t mod;
	fmpz_t range_max;
};

typedef struct mod_centered_0_ctx mod_centered_0_ctx_t[1];

void mod_c0_ctx_init(mod_centered_0_ctx_t ctx, fmpz_t mod) {
	struct mod_centered_0_ctx *ctxp = ctx;

	fmpz_init_set(ctxp->mod, mod);

	fmpz_init(ctxp->range_max);
	fmpz_sub_ui(ctxp->range_max, ctxp->mod, 1);
	/* Bit-shifting is faster but FLINT lacks convenient API for it. */
	fmpz_divexact_ui(ctxp->range_max, ctxp->range_max, 2);
}

void mod_c0_ctx_clear(mod_centered_0_ctx_t ctx) {
	fmpz_clear(ctx[0].mod);
	fmpz_clear(ctx[0].range_max);
}

void mod_c0(fmpz_t value, mod_centered_0_ctx_t ctx) {
	fmpz_add(value, value, ctx[0].range_max);
	fmpz_fdiv_r(value, value, ctx[0].mod);
	fmpz_sub(value, value, ctx[0].range_max);
}

void mod_c0_ctx_init_set(mod_centered_0_ctx_t dst_ctx,
			 mod_centered_0_ctx_t src_ctx) {
	fmpz_init_set(dst_ctx[0].mod, src_ctx[0].mod);
	fmpz_init_set(dst_ctx[0].range_max, src_ctx[0].range_max);
}

/*
 * Here we provide a facility for performing operations in polynomial rings
 * modulo X^m+1 over fields of integers modulo n shifted to range [-(n-1)/2,
 * (n-1)/2].
 */

struct poly_ring_ctx {
	mod_centered_0_ctx_t mod_ctx;
	slong divisor_degree;
};

typedef struct poly_ring_ctx poly_ring_ctx_t[1];

void poly_ring_ctx_init(poly_ring_ctx_t ctx, mod_centered_0_ctx_t mod_ctx,
			slong divisor_degree) {
	if (divisor_degree < 0)
		abort();

	mod_c0_ctx_init_set(ctx[0].mod_ctx, mod_ctx);
	ctx[0].divisor_degree = divisor_degree;
}

void poly_ring_ctx_clear(poly_ring_ctx_t ctx) {
	mod_c0_ctx_clear(ctx[0].mod_ctx);
}

/*
 * Apply modulo operations to make poly a member of the ring designated by ctx.
 */
void poly_to_ring(fmpz_poly_t poly, poly_ring_ctx_t ctx) {
	slong degree = fmpz_poly_degree(poly);
	fmpz_t new_coef_value;

	fmpz_init(new_coef_value);

	for (slong coef_idx = 0;
	     coef_idx < ctx[0].divisor_degree;
	     coef_idx++) {
		int sign = 1;
		slong higher_coef_idx = coef_idx;

		fmpz_poly_get_coeff_fmpz(new_coef_value, poly, coef_idx);

		/*
		 * Polynomial division by X^m+1 can be achieved by substituting
		 * -1 for X^m.
		 */
		do {
			fmpz const * higher_coef;

			sign *= -1;
			higher_coef_idx += ctx[0].divisor_degree;

			if (higher_coef_idx > degree)
				break;

			higher_coef =
				fmpz_poly_get_coeff_ptr(poly, higher_coef_idx);

			(sign == 1 ? &fmpz_add : &fmpz_sub)
				(new_coef_value, new_coef_value, higher_coef);
		} while (true);

		mod_c0(new_coef_value, ctx[0].mod_ctx);
		fmpz_poly_set_coeff_fmpz(poly, coef_idx, new_coef_value);
	}

	if (degree >= ctx[0].divisor_degree)
		fmpz_poly_realloc(poly, ctx[0].divisor_degree);

	fmpz_clear(new_coef_value);
}

void poly_mul_in_ring(fmpz_poly_t res, fmpz_poly_t poly1, fmpz_poly_t poly2,
		      poly_ring_ctx_t ctx) {
	fmpz_poly_mul(res, poly1, poly2);
	poly_to_ring(res, ctx);
}

void poly_add_in_ring(fmpz_poly_t res, fmpz_poly_t poly1, fmpz_poly_t poly2,
		      poly_ring_ctx_t ctx) {
	fmpz_poly_add(res, poly1, poly2);
	poly_to_ring(res, ctx);
}

void poly_sub_in_ring(fmpz_poly_t res, fmpz_poly_t poly1, fmpz_poly_t poly2,
		      poly_ring_ctx_t ctx) {
	fmpz_poly_sub(res, poly1, poly2);
	poly_to_ring(res, ctx);
}

int main(const int argc, const char* const* const argv) {
	fmpz_t prime; /* integer for modulo operations */
	mod_centered_0_ctx_t mod_ctx;

	(void) argc;
	(void) argv;

	/*
	 * Marsenne primes are used just for testing.  Cryptographic algorithm
	 * will use different ones.
	 */
	marsenne_prime_init(prime, TEST_MERSENNE_EXPONENT);

	printf("Prime used for modulo operations: ");
	fmpz_fprint(stdout, prime);
	putchar('\n');

	mod_c0_ctx_init(mod_ctx, prime);

	{ /* Experiment 1 — modulo addition */
		fmpz_t num1, num2, num_sum;

		fmpz_init_set_ui(num1, 55);
		fmpz_init_set_ui(num2, 31);
		fmpz_init(num_sum);

		fmpz_fprint(stdout, num1);
		printf(" + ");
		fmpz_fprint(stdout, num2);
		printf(" mod [-");
		fmpz_fprint(stdout, mod_ctx[0].range_max);
		putchar(',');
		fmpz_fprint(stdout, mod_ctx[0].range_max);
		printf("] = ");

		fmpz_add(num_sum, num1, num2);
		mod_c0(num_sum, mod_ctx);
		fmpz_fprint(stdout, num_sum);
		putchar('\n');

		fmpz_clear(num1);
		fmpz_clear(num2);
		fmpz_clear(num_sum);
	} /* End of experiment 1 */

	{ /* Experiment 2 */
		fmpz_poly_t poly1, poly2, poly_computed;
		slong divisor_degree;
		poly_ring_ctx_t poly_ring_ctx;

		printf("Give first polynomial for the experiment:\n");
		init_read_poly(poly1, stdin);

		printf("Read polynomial: ");
		fmpz_poly_print_pretty(poly1, "x");
		putchar('\n');

		printf("Give second polynomial for the experiment:\n");
		init_read_poly(poly2, stdin);

		printf("Read polynomial: ");
		fmpz_poly_print_pretty(poly2, "x");
		putchar('\n');

		fmpz_poly_init(poly_computed);

		printf("Normal product of polynomials:\n");
		fmpz_poly_mul(poly_computed, poly1, poly2);
		fmpz_poly_print_pretty(poly_computed, "x");
		putchar('\n');

		printf("Normal sum of polynomials:\n");
		fmpz_poly_add(poly_computed, poly1, poly2);
		fmpz_poly_print_pretty(poly_computed, "x");
		putchar('\n');

		printf("Give the degree m of X^m+1 polynomial to be used as ");
		printf("divisor in the ring:\n");
		if (flint_scanf("%wd", &divisor_degree) < 1 ||
		    divisor_degree < 1) {
			fprintf(stderr, "Bad divisor.\n");
			abort();
		}
		poly_ring_ctx_init(poly_ring_ctx, mod_ctx, divisor_degree);

		printf("Product of polynomials in the ring:\n");
		poly_mul_in_ring(poly_computed, poly1, poly2, poly_ring_ctx);
		fmpz_poly_print_pretty(poly_computed, "x");
		putchar('\n');

		printf("Sum of polynomials in the ring:\n");
		poly_add_in_ring(poly_computed, poly1, poly2, poly_ring_ctx);
		fmpz_poly_print_pretty(poly_computed, "x");
		putchar('\n');

		fmpz_poly_clear(poly1);
		fmpz_poly_clear(poly2);
		fmpz_poly_clear(poly_computed);

		poly_ring_ctx_clear(poly_ring_ctx);
	} /* End of experiment 2 */

	mod_c0_ctx_clear(mod_ctx);
	fmpz_clear(prime);

	return EXIT_SUCCESS;
}