Skip to content

Commit

Permalink
unique names for poly
Browse files Browse the repository at this point in the history
  • Loading branch information
jakemas committed Dec 24, 2024
1 parent ae01f07 commit f96096f
Show file tree
Hide file tree
Showing 5 changed files with 92 additions and 73 deletions.
77 changes: 44 additions & 33 deletions crypto/dilithium/pqcrystals_dilithium_ref_common/poly.c
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
*
* Arguments: - poly *a: pointer to input/output polynomial
**************************************************/
void ml_dsa_poly_reduce(poly *a) {
void ml_dsa_poly_reduce(ml_dsa_poly *a) {
unsigned int i;
for(i = 0; i < ML_DSA_N; ++i) {
a->coeffs[i] = ml_dsa_reduce32(a->coeffs[i]);
Expand All @@ -29,7 +29,7 @@ void ml_dsa_poly_reduce(poly *a) {
*
* Arguments: - poly *a: pointer to input/output polynomial
**************************************************/
void ml_dsa_poly_caddq(poly *a) {
void ml_dsa_poly_caddq(ml_dsa_poly *a) {
unsigned int i;
for(i = 0; i < ML_DSA_N; ++i) {
a->coeffs[i] = ml_dsa_caddq(a->coeffs[i]);
Expand All @@ -45,7 +45,7 @@ void ml_dsa_poly_caddq(poly *a) {
* - const poly *a: pointer to first summand
* - const poly *b: pointer to second summand
**************************************************/
void ml_dsa_poly_add(poly *c, const poly *a, const poly *b) {
void ml_dsa_poly_add(ml_dsa_poly *c, const ml_dsa_poly *a, const ml_dsa_poly *b) {
unsigned int i;
for(i = 0; i < ML_DSA_N; ++i) {
c->coeffs[i] = a->coeffs[i] + b->coeffs[i];
Expand All @@ -63,7 +63,7 @@ void ml_dsa_poly_add(poly *c, const poly *a, const poly *b) {
* - const poly *b: pointer to second input polynomial to be
* subtraced from first input polynomial
**************************************************/
void ml_dsa_poly_sub(poly *c, const poly *a, const poly *b) {
void ml_dsa_poly_sub(ml_dsa_poly *c, const ml_dsa_poly *a, const ml_dsa_poly *b) {
unsigned int i;
for(i = 0; i < ML_DSA_N; ++i) {
c->coeffs[i] = a->coeffs[i] - b->coeffs[i];
Expand All @@ -78,7 +78,7 @@ void ml_dsa_poly_sub(poly *c, const poly *a, const poly *b) {
*
* Arguments: - poly *a: pointer to input/output polynomial
**************************************************/
void ml_dsa_poly_shiftl(poly *a) {
void ml_dsa_poly_shiftl(ml_dsa_poly *a) {
unsigned int i;
for(i = 0; i < ML_DSA_N; ++i) {
a->coeffs[i] <<= ML_DSA_D;
Expand All @@ -93,7 +93,7 @@ void ml_dsa_poly_shiftl(poly *a) {
*
* Arguments: - poly *a: pointer to input/output polynomial
**************************************************/
void ml_dsa_poly_ntt(poly *a) {
void ml_dsa_poly_ntt(ml_dsa_poly *a) {
ml_dsa_ntt(a->coeffs);
}

Expand All @@ -106,7 +106,7 @@ void ml_dsa_poly_ntt(poly *a) {
*
* Arguments: - poly *a: pointer to input/output polynomial
**************************************************/
void ml_dsa_poly_invntt_tomont(poly *a) {
void ml_dsa_poly_invntt_tomont(ml_dsa_poly *a) {
ml_dsa_invntt_tomont(a->coeffs);
}

Expand All @@ -121,7 +121,9 @@ void ml_dsa_poly_invntt_tomont(poly *a) {
* - const poly *a: pointer to first input polynomial
* - const poly *b: pointer to second input polynomial
**************************************************/
void ml_dsa_poly_pointwise_montgomery(poly *c, const poly *a, const poly *b) {
void ml_dsa_poly_pointwise_montgomery(ml_dsa_poly *c,
const ml_dsa_poly *a,
const ml_dsa_poly *b) {
unsigned int i;
for(i = 0; i < ML_DSA_N; ++i) {
c->coeffs[i] = ml_dsa_fqmul(a->coeffs[i], b->coeffs[i]);
Expand All @@ -140,7 +142,7 @@ void ml_dsa_poly_pointwise_montgomery(poly *c, const poly *a, const poly *b) {
* - poly *a0: pointer to output polynomial with coefficients c0
* - const poly *a: pointer to input polynomial
**************************************************/
void ml_dsa_poly_power2round(poly *a1, poly *a0, const poly *a) {
void ml_dsa_poly_power2round(ml_dsa_poly *a1, ml_dsa_poly *a0, const ml_dsa_poly *a) {
unsigned int i;
for(i = 0; i < ML_DSA_N; ++i) {
a1->coeffs[i] = ml_dsa_power2round(&a0->coeffs[i], a->coeffs[i]);
Expand All @@ -161,7 +163,10 @@ void ml_dsa_poly_power2round(poly *a1, poly *a0, const poly *a) {
* - poly *a0: pointer to output polynomial with coefficients c0
* - const poly *a: pointer to input polynomial
**************************************************/
void ml_dsa_poly_decompose(ml_dsa_params *params, poly *a1, poly *a0, const poly *a) {
void ml_dsa_poly_decompose(ml_dsa_params *params,
ml_dsa_poly *a1,
ml_dsa_poly *a0,
const ml_dsa_poly *a) {
unsigned int i;
for(i = 0; i < ML_DSA_N; ++i) {
a1->coeffs[i] = ml_dsa_decompose(params, &a0->coeffs[i], a->coeffs[i]);
Expand All @@ -182,7 +187,10 @@ void ml_dsa_poly_decompose(ml_dsa_params *params, poly *a1, poly *a0, const poly
*
* Returns number of 1 bits.
**************************************************/
unsigned int ml_dsa_poly_make_hint(ml_dsa_params *params, poly *h, const poly *a0, const poly *a1) {
unsigned int ml_dsa_poly_make_hint(ml_dsa_params *params,
ml_dsa_poly *h,
const ml_dsa_poly *a0,
const ml_dsa_poly *a1) {
unsigned int i, s = 0;
for(i = 0; i < ML_DSA_N; ++i) {
h->coeffs[i] = ml_dsa_make_hint(params, a0->coeffs[i], a1->coeffs[i]);
Expand All @@ -201,7 +209,10 @@ unsigned int ml_dsa_poly_make_hint(ml_dsa_params *params, poly *h, const poly *a
* - const poly *a: pointer to input polynomial
* - const poly *h: pointer to input hint polynomial
**************************************************/
void ml_dsa_poly_use_hint(ml_dsa_params *params, poly *b, const poly *a, const poly *h) {
void ml_dsa_poly_use_hint(ml_dsa_params *params,
ml_dsa_poly *b,
const ml_dsa_poly *a,
const ml_dsa_poly *h) {
unsigned int i;
for(i = 0; i < ML_DSA_N; ++i) {
b->coeffs[i] = ml_dsa_use_hint(params, a->coeffs[i], h->coeffs[i]);
Expand All @@ -219,7 +230,7 @@ void ml_dsa_poly_use_hint(ml_dsa_params *params, poly *b, const poly *a, const p
*
* Returns 0 if norm is strictly smaller than B <= (Q-1)/8 and 1 otherwise.
**************************************************/
int ml_dsa_poly_chknorm(const poly *a, int32_t B) {
int ml_dsa_poly_chknorm(const ml_dsa_poly *a, int32_t B) {
unsigned int i;
int32_t t;

Expand All @@ -243,7 +254,7 @@ int ml_dsa_poly_chknorm(const poly *a, int32_t B) {
}

/*************************************************
* Name: rej_uniform
* Name: ml_dsa_rej_uniform
*
* Description: Sample uniformly random coefficients in [0, Q-1] by
* performing rejection sampling on array of random bytes.
Expand All @@ -256,10 +267,10 @@ int ml_dsa_poly_chknorm(const poly *a, int32_t B) {
* Returns number of sampled coefficients. Can be smaller than len if not enough
* random bytes were given.
**************************************************/
static unsigned int rej_uniform(int32_t *a,
unsigned int len,
const uint8_t *buf,
unsigned int buflen)
static unsigned int ml_dsa_rej_uniform(int32_t *a,
unsigned int len,
const uint8_t *buf,
unsigned int buflen)
{
unsigned int ctr, pos;
uint32_t t;
Expand Down Expand Up @@ -291,7 +302,7 @@ static unsigned int rej_uniform(int32_t *a,
* - uint16_t nonce: 2-byte nonce
**************************************************/
#define POLY_UNIFORM_NBLOCKS ((768 + SHAKE128_RATE - 1)/ SHAKE128_RATE)
void ml_dsa_poly_uniform(poly *a,
void ml_dsa_poly_uniform(ml_dsa_poly *a,
const uint8_t seed[ML_DSA_SEEDBYTES],
uint16_t nonce)
{
Expand All @@ -309,7 +320,7 @@ void ml_dsa_poly_uniform(poly *a,
SHA3_Update(&state, t, 2);
SHAKE_Final(buf, &state, POLY_UNIFORM_NBLOCKS * SHAKE128_BLOCKSIZE);

ctr = rej_uniform(a->coeffs, ML_DSA_N, buf, buflen);
ctr = ml_dsa_rej_uniform(a->coeffs, ML_DSA_N, buf, buflen);

while(ctr < ML_DSA_N) {
off = buflen % 3;
Expand All @@ -318,7 +329,7 @@ void ml_dsa_poly_uniform(poly *a,

SHAKE_Final(buf + off, &state, POLY_UNIFORM_NBLOCKS * SHAKE128_BLOCKSIZE);
buflen = SHAKE128_RATE + off;
ctr += rej_uniform(a->coeffs + ctr, ML_DSA_N - ctr, buf, buflen);
ctr += ml_dsa_rej_uniform(a->coeffs + ctr, ML_DSA_N - ctr, buf, buflen);
}
/* FIPS 204. Section 3.6.3 Destruction of intermediate values. */
OPENSSL_cleanse(buf, sizeof(buf));
Expand Down Expand Up @@ -393,7 +404,7 @@ static unsigned int rej_eta(ml_dsa_params *params,
* - uint16_t nonce: 2-byte nonce
**************************************************/
void ml_dsa_poly_uniform_eta(ml_dsa_params *params,
poly *a,
ml_dsa_poly *a,
const uint8_t seed[ML_DSA_CRHBYTES],
uint16_t nonce)
{
Expand Down Expand Up @@ -436,7 +447,7 @@ void ml_dsa_poly_uniform_eta(ml_dsa_params *params,
**************************************************/
#define POLY_UNIFORM_GAMMA1_NBLOCKS ((ML_DSA_POLYZ_PACKEDBYTES_MAX + SHAKE256_RATE - 1) / SHAKE256_RATE)
void ml_dsa_poly_uniform_gamma1(ml_dsa_params *params,
poly *a,
ml_dsa_poly *a,
const uint8_t seed[ML_DSA_CRHBYTES],
uint16_t nonce)
{
Expand Down Expand Up @@ -469,7 +480,7 @@ void ml_dsa_poly_uniform_gamma1(ml_dsa_params *params,
* - poly *c: pointer to output polynomial
* - const uint8_t mu[]: byte array containing seed of length CTILDEBYTES
**************************************************/
void ml_dsa_poly_challenge(ml_dsa_params *params, poly *c, const uint8_t *seed) {
void ml_dsa_poly_challenge(ml_dsa_params *params, ml_dsa_poly *c, const uint8_t *seed) {
unsigned int i, b, pos;
uint64_t signs;
uint8_t buf[SHAKE256_RATE];
Expand Down Expand Up @@ -518,7 +529,7 @@ void ml_dsa_poly_challenge(ml_dsa_params *params, poly *c, const uint8_t *seed)
* POLYETA_PACKEDBYTES bytes
* - const poly *a: pointer to input polynomial
**************************************************/
void ml_dsa_polyeta_pack(ml_dsa_params *params, uint8_t *r, const poly *a) {
void ml_dsa_polyeta_pack(ml_dsa_params *params, uint8_t *r, const ml_dsa_poly *a) {
unsigned int i;
uint8_t t[8];

Expand Down Expand Up @@ -559,7 +570,7 @@ void ml_dsa_polyeta_pack(ml_dsa_params *params, uint8_t *r, const poly *a) {
* - poly *r: pointer to output polynomial
* - const uint8_t *a: byte array with bit-packed polynomial
**************************************************/
void ml_dsa_polyeta_unpack(ml_dsa_params *params, poly *r, const uint8_t *a) {
void ml_dsa_polyeta_unpack(ml_dsa_params *params, ml_dsa_poly *r, const uint8_t *a) {
unsigned int i;
assert((params->eta == 2) ||
(params->eta == 4));
Expand Down Expand Up @@ -605,7 +616,7 @@ void ml_dsa_polyeta_unpack(ml_dsa_params *params, poly *r, const uint8_t *a) {
* POLYT1_PACKEDBYTES bytes
* - const poly *a: pointer to input polynomial
**************************************************/
void ml_dsa_polyt1_pack(uint8_t *r, const poly *a) {
void ml_dsa_polyt1_pack(uint8_t *r, const ml_dsa_poly *a) {
unsigned int i;

for(i = 0; i < ML_DSA_N/4; ++i) {
Expand All @@ -626,7 +637,7 @@ void ml_dsa_polyt1_pack(uint8_t *r, const poly *a) {
* Arguments: - poly *r: pointer to output polynomial
* - const uint8_t *a: byte array with bit-packed polynomial
**************************************************/
void ml_dsa_polyt1_unpack(poly *r, const uint8_t *a) {
void ml_dsa_polyt1_unpack(ml_dsa_poly *r, const uint8_t *a) {
unsigned int i;

for(i = 0; i < ML_DSA_N/4; ++i) {
Expand All @@ -646,7 +657,7 @@ void ml_dsa_polyt1_unpack(poly *r, const uint8_t *a) {
* POLYT0_PACKEDBYTES bytes
* - const poly *a: pointer to input polynomial
**************************************************/
void ml_dsa_polyt0_pack(uint8_t *r, const poly *a) {
void ml_dsa_polyt0_pack(uint8_t *r, const ml_dsa_poly *a) {
unsigned int i;
uint32_t t[8];

Expand Down Expand Up @@ -691,7 +702,7 @@ void ml_dsa_polyt0_pack(uint8_t *r, const poly *a) {
* Arguments: - poly *r: pointer to output polynomial
* - const uint8_t *a: byte array with bit-packed polynomial
**************************************************/
void ml_dsa_polyt0_unpack(poly *r, const uint8_t *a) {
void ml_dsa_polyt0_unpack(ml_dsa_poly *r, const uint8_t *a) {
unsigned int i;

for(i = 0; i < ML_DSA_N/8; ++i) {
Expand Down Expand Up @@ -753,7 +764,7 @@ void ml_dsa_polyt0_unpack(poly *r, const uint8_t *a) {
* POLYZ_PACKEDBYTES bytes
* - const poly *a: pointer to input polynomial
**************************************************/
void ml_dsa_polyz_pack(ml_dsa_params *params, uint8_t *r, const poly *a) {
void ml_dsa_polyz_pack(ml_dsa_params *params, uint8_t *r, const ml_dsa_poly *a) {
unsigned int i;
uint32_t t[4];

Expand Down Expand Up @@ -806,7 +817,7 @@ void ml_dsa_polyz_pack(ml_dsa_params *params, uint8_t *r, const poly *a) {
* - poly *r: pointer to output polynomial
* - const uint8_t *a: byte array with bit-packed polynomial
**************************************************/
void ml_dsa_polyz_unpack(ml_dsa_params *params, poly *r, const uint8_t *a) {
void ml_dsa_polyz_unpack(ml_dsa_params *params, ml_dsa_poly *r, const uint8_t *a) {
unsigned int i;

assert((params->gamma1 == (1 << 17)) ||
Expand Down Expand Up @@ -869,7 +880,7 @@ void ml_dsa_polyz_unpack(ml_dsa_params *params, poly *r, const uint8_t *a) {
* POLYW1_PACKEDBYTES bytes
* - const poly *a: pointer to input polynomial
**************************************************/
void ml_dsa_polyw1_pack(ml_dsa_params *params, uint8_t *r, const poly *a) {
void ml_dsa_polyw1_pack(ml_dsa_params *params, uint8_t *r, const ml_dsa_poly *a) {
unsigned int i;

if (params->gamma2 == (ML_DSA_Q-1)/88) {
Expand Down
66 changes: 37 additions & 29 deletions crypto/dilithium/pqcrystals_dilithium_ref_common/poly.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,69 +6,77 @@

typedef struct {
int32_t coeffs[ML_DSA_N];
} poly;
} ml_dsa_poly;

void ml_dsa_poly_reduce(poly *a);
void ml_dsa_poly_reduce(ml_dsa_poly *a);

void ml_dsa_poly_caddq(poly *a);
void ml_dsa_poly_caddq(ml_dsa_poly *a);

void ml_dsa_poly_add(poly *c, const poly *a, const poly *b);
void ml_dsa_poly_add(ml_dsa_poly *c, const ml_dsa_poly *a, const ml_dsa_poly *b);

void ml_dsa_poly_sub(poly *c, const poly *a, const poly *b);
void ml_dsa_poly_sub(ml_dsa_poly *c, const ml_dsa_poly *a, const ml_dsa_poly *b);

void ml_dsa_poly_shiftl(poly *a);
void ml_dsa_poly_shiftl(ml_dsa_poly *a);

void ml_dsa_poly_ntt(poly *a);
void ml_dsa_poly_ntt(ml_dsa_poly *a);

void ml_dsa_poly_invntt_tomont(poly *a);
void ml_dsa_poly_invntt_tomont(ml_dsa_poly *a);

void ml_dsa_poly_pointwise_montgomery(poly *c, const poly *a, const poly *b);
void ml_dsa_poly_pointwise_montgomery(ml_dsa_poly *c,
const ml_dsa_poly *a,
const ml_dsa_poly *b);

void ml_dsa_poly_power2round(poly *a1, poly *a0, const poly *a);
void ml_dsa_poly_power2round(ml_dsa_poly *a1, ml_dsa_poly *a0, const ml_dsa_poly *a);

void ml_dsa_poly_decompose(ml_dsa_params *params, poly *a1, poly *a0, const poly *a);
void ml_dsa_poly_decompose(ml_dsa_params *params,
ml_dsa_poly *a1,
ml_dsa_poly *a0,
const ml_dsa_poly *a);

unsigned int ml_dsa_poly_make_hint(ml_dsa_params *params,
poly *h,
const poly *a0,
const poly *a1);
ml_dsa_poly *h,
const ml_dsa_poly *a0,
const ml_dsa_poly *a1);

void ml_dsa_poly_use_hint(ml_dsa_params *params, poly *b, const poly *a, const poly *h);
void ml_dsa_poly_use_hint(ml_dsa_params *params,
ml_dsa_poly *b,
const ml_dsa_poly *a,
const ml_dsa_poly *h);

int ml_dsa_poly_chknorm(const poly *a, int32_t B);
int ml_dsa_poly_chknorm(const ml_dsa_poly *a, int32_t B);

void ml_dsa_poly_uniform(poly *a,
void ml_dsa_poly_uniform(ml_dsa_poly *a,
const uint8_t seed[ML_DSA_SEEDBYTES],
uint16_t nonce);

void ml_dsa_poly_uniform_eta(ml_dsa_params *params,
poly *a,
ml_dsa_poly *a,
const uint8_t seed[ML_DSA_CRHBYTES],
uint16_t nonce);

void ml_dsa_poly_uniform_gamma1(ml_dsa_params *params,
poly *a,
ml_dsa_poly *a,
const uint8_t seed[ML_DSA_CRHBYTES],
uint16_t nonce);

void ml_dsa_poly_challenge(ml_dsa_params *params, poly *c, const uint8_t *seed);
void ml_dsa_poly_challenge(ml_dsa_params *params, ml_dsa_poly *c, const uint8_t *seed);

void ml_dsa_polyeta_pack(ml_dsa_params *params, uint8_t *r, const poly *a);
void ml_dsa_polyeta_pack(ml_dsa_params *params, uint8_t *r, const ml_dsa_poly *a);

void ml_dsa_polyeta_unpack(ml_dsa_params *params, poly *r, const uint8_t *a);
void ml_dsa_polyeta_unpack(ml_dsa_params *params, ml_dsa_poly *r, const uint8_t *a);

void ml_dsa_polyt1_pack(uint8_t *r, const poly *a);
void ml_dsa_polyt1_pack(uint8_t *r, const ml_dsa_poly *a);

void ml_dsa_polyt1_unpack(poly *r, const uint8_t *a);
void ml_dsa_polyt1_unpack(ml_dsa_poly *r, const uint8_t *a);

void ml_dsa_polyt0_pack(uint8_t *r, const poly *a);
void ml_dsa_polyt0_pack(uint8_t *r, const ml_dsa_poly *a);

void ml_dsa_polyt0_unpack(poly *r, const uint8_t *a);
void ml_dsa_polyt0_unpack(ml_dsa_poly *r, const uint8_t *a);

void ml_dsa_polyz_pack(ml_dsa_params *params, uint8_t *r, const poly *a);
void ml_dsa_polyz_pack(ml_dsa_params *params, uint8_t *r, const ml_dsa_poly *a);

void ml_dsa_polyz_unpack(ml_dsa_params *params, poly *r, const uint8_t *a);
void ml_dsa_polyz_unpack(ml_dsa_params *params, ml_dsa_poly *r, const uint8_t *a);

void ml_dsa_polyw1_pack(ml_dsa_params *params, uint8_t *r, const poly *a);
void ml_dsa_polyw1_pack(ml_dsa_params *params, uint8_t *r, const ml_dsa_poly *a);

#endif
Loading

0 comments on commit f96096f

Please sign in to comment.