Implement ring multiplication method

This commit is contained in:
Andreas Steffen 2014-02-26 23:36:09 +01:00
parent bd1c9f1eac
commit bf24960cbe
4 changed files with 246 additions and 114 deletions

View File

@ -111,7 +111,6 @@ ntru_crypto_ntru_encrypt(
uint8_t *mask_trits;
chunk_t seed;
ntru_poly_t *r_poly;
uint16_t *r_indices;
/* check for bad parameters */
@ -230,8 +229,8 @@ ntru_crypto_ntru_encrypt(
seed = chunk_create(tmp_buf, ptr - tmp_buf);
r_poly = ntru_poly_create(hash_algid, seed, params->c_bits,
params->N, 2 * params->dF_r,
params->is_product_form);
params->N, params->q, params->dF_r,
params->dF_r, params->is_product_form);
if (!r_poly)
{
result = NTRU_MGF1_FAIL;
@ -249,21 +248,7 @@ ntru_crypto_ntru_encrypt(
params->q_bits, ringel_buf);
/* form R = h * r */
r_indices = r_poly->get_indices(r_poly);
if (params->is_product_form)
{
ntru_ring_mult_product_indices(ringel_buf, (uint16_t)dr1,
(uint16_t)dr2, (uint16_t)dr3,
r_indices, params->N, params->q,
scratch_buf, ringel_buf);
}
else
{
ntru_ring_mult_indices(ringel_buf, (uint16_t)dr, (uint16_t)dr,
r_indices, params->N, params->q,
scratch_buf, ringel_buf);
}
r_poly->ring_mult(r_poly, ringel_buf, ringel_buf);
r_poly->destroy(r_poly);
/* form R mod 4 */
@ -459,7 +444,6 @@ ntru_crypto_ntru_decrypt(
uint8_t *mask_trits;
chunk_t seed;
ntru_poly_t *r_poly;
uint16_t *r_indices;
/* check for bad parameters */
if (!privkey_blob || !ct || !pt_len)
@ -582,29 +566,41 @@ ntru_crypto_ntru_decrypt(
(uint16_t)dF_r2, (uint16_t)dF_r3,
i_buf, params->N, params->q,
scratch_buf, ringel_buf1);
for (i = 0; i < cmprime_len; i++) {
ringel_buf1[i] = (ringel_buf2[i] + 3 * ringel_buf1[i]) & mod_q_mask;
if (ringel_buf1[i] >= (params->q >> 1))
ringel_buf1[i] = ringel_buf1[i] - q_mod_p;
Mtrin_buf[i] = (uint8_t)(ringel_buf1[i] % 3);
if (Mtrin_buf[i] == 1)
++m1;
else if (Mtrin_buf[i] == 2)
--m1;
}
}
for (i = 0; i < cmprime_len; i++)
{
ringel_buf1[i] = (ringel_buf2[i] + 3 * ringel_buf1[i]) & mod_q_mask;
if (ringel_buf1[i] >= (params->q >> 1))
{
ringel_buf1[i] = ringel_buf1[i] - q_mod_p;
}
Mtrin_buf[i] = (uint8_t)(ringel_buf1[i] % 3);
if (Mtrin_buf[i] == 1)
{
++m1;
}
else if (Mtrin_buf[i] == 2)
{
--m1;
}
}
}
else
{
ntru_ring_mult_indices(ringel_buf2, (uint16_t)dF_r, (uint16_t)dF_r,
i_buf, params->N, params->q,
scratch_buf, ringel_buf1);
for (i = 0; i < cmprime_len; i++) {
ringel_buf1[i] = (ringel_buf2[i] + 3 * ringel_buf1[i]) & mod_q_mask;
if (ringel_buf1[i] >= (params->q >> 1))
ringel_buf1[i] = ringel_buf1[i] - q_mod_p;
Mtrin_buf[i] = (uint8_t)(ringel_buf1[i] % 3);
}
}
for (i = 0; i < cmprime_len; i++)
{
ringel_buf1[i] = (ringel_buf2[i] + 3 * ringel_buf1[i]) & mod_q_mask;
if (ringel_buf1[i] >= (params->q >> 1))
{
ringel_buf1[i] = ringel_buf1[i] - q_mod_p;
}
Mtrin_buf[i] = (uint8_t)(ringel_buf1[i] % 3);
}
}
/* check that the candidate message representative meets minimum weight
* requirements
@ -712,8 +708,8 @@ ntru_crypto_ntru_decrypt(
seed = chunk_create(tmp_buf, ptr - tmp_buf);
r_poly = ntru_poly_create(hash_algid, seed, params->c_bits,
params->N, 2 * params->dF_r,
params->is_product_form);
params->N, params->q, params->dF_r,
params->dF_r, params->is_product_form);
if (!r_poly)
{
result = NTRU_MGF1_FAIL;
@ -733,20 +729,7 @@ ntru_crypto_ntru_decrypt(
}
/* form cR' = h * cr */
r_indices = r_poly->get_indices(r_poly);
if (params->is_product_form)
{
ntru_ring_mult_product_indices(ringel_buf1, (uint16_t)dF_r1,
(uint16_t)dF_r2, (uint16_t)dF_r3,
r_indices, params->N, params->q,
scratch_buf, ringel_buf1);
}
else
{
ntru_ring_mult_indices(ringel_buf1, (uint16_t)dF_r, (uint16_t)dF_r,
r_indices, params->N, params->q,
scratch_buf, ringel_buf1);
}
r_poly->ring_mult(r_poly, ringel_buf1, ringel_buf1);
r_poly->destroy(r_poly);
/* compare cR' to cR */
@ -857,7 +840,7 @@ ntru_crypto_ntru_encrypt_keygen(
uint32_t result = NTRU_OK;
ntru_poly_t *F_poly = NULL;
ntru_poly_t *g_poly = NULL;
uint16_t *F_indices, *g_indices;
uint16_t *F_indices;
/* get a pointer to the parameter-set parameters */
@ -959,8 +942,8 @@ ntru_crypto_ntru_encrypt_keygen(
seed = chunk_create(tmp_buf, seed_len);
F_poly = ntru_poly_create(hash_algid, seed, params->c_bits,
params->N, 2 * params->dF_r,
params->is_product_form);
params->N, params->q, params->dF_r,
params->dF_r, params->is_product_form);
if (!F_poly)
{
result = NTRU_MGF1_FAIL;
@ -1055,7 +1038,8 @@ ntru_crypto_ntru_encrypt_keygen(
seed = chunk_create(tmp_buf, seed_len);
g_poly = ntru_poly_create(hash_algid, seed, params->c_bits,
params->N, 2*params->dg + 1, FALSE);
params->N, params->q, params->dg + 1,
params->dg, FALSE);
if (!g_poly)
{
result = NTRU_MGF1_FAIL;
@ -1067,10 +1051,7 @@ ntru_crypto_ntru_encrypt_keygen(
uint16_t i;
/* compute h = p * (f^-1 * g) mod q */
g_indices = g_poly->get_indices(g_poly);
ntru_ring_mult_indices(ringel_buf2, params->dg + 1, params->dg,
g_indices, params->N, params->q, scratch_buf,
ringel_buf2);
g_poly->ring_mult(g_poly, ringel_buf2, ringel_buf2);
g_poly->destroy(g_poly);
for (i = 0; i < params->N; i++)

View File

@ -22,6 +22,15 @@
#include <utils/test.h>
typedef struct private_ntru_poly_t private_ntru_poly_t;
typedef struct indices_len_t indices_len_t;
/**
* Stores number of +1 and -1 coefficients
*/
struct indices_len_t {
int p;
int m;
};
/**
* Private data of an ntru_poly_t object.
@ -33,22 +42,44 @@ struct private_ntru_poly_t {
*/
ntru_poly_t public;
/**
* Ring dimension equal to the number of polynomial coefficients
*/
uint16_t N;
/**
* Large modulus
*/
uint16_t q;
/**
* Array containing the indices of the non-zero coefficients
*/
uint16_t *indices;
/**
* Number of non-zero coefficients
* Number of sparse polynomials
*/
uint32_t indices_len;
int num_polynomials;
/**
* Number of nonzero coefficients for up to 3 sparse polynomials
*/
indices_len_t indices_len[3];
};
METHOD(ntru_poly_t, get_size, size_t,
private_ntru_poly_t *this)
{
return this->indices_len;
int n;
size_t size = 0;
for (n = 0; n < this->num_polynomials; n++)
{
size += this->indices_len[n].p + this->indices_len[n].m;
}
return size;
}
METHOD(ntru_poly_t, get_indices, uint16_t*,
@ -56,11 +87,113 @@ METHOD(ntru_poly_t, get_indices, uint16_t*,
{
return this->indices;
}
/**
* Multiplication of polynomial a with a sparse polynomial b given by
* the indices of its +1 and -1 coefficients results in polynomial c.
* This is a convolution operation
*/
static void ring_mult_indices(uint16_t *a, indices_len_t len, uint16_t *indices,
uint16_t N, uint16_t mod_q_mask, uint16_t *c)
{
uint16_t *t;
int i, j, k;
/* allocate and initialize temporary array t */
t = malloc(N * sizeof(uint16_t));
for (k = 0; k < N; k++)
{
t[k] = 0;
}
/* t[(i+k)%N] = sum i=0 through N-1 of a[i], for b[k] = -1 */
for (j = len.p; j < len.p + len.m; j++)
{
k = indices[j];
for (i = 0; k < N; ++i, ++k)
{
t[k] += a[i];
}
for (k = 0; i < N; ++i, ++k)
{
t[k] += a[i];
}
}
/* t[(i+k)%N] = -(sum i=0 through N-1 of a[i] for b[k] = -1) */
for (k = 0; k < N; k++)
{
t[k] = -t[k];
}
/* t[(i+k)%N] += sum i=0 through N-1 of a[i] for b[k] = +1 */
for (j = 0; j < len.p; j++)
{
k = indices[j];
for (i = 0; k < N; ++i, ++k)
{
t[k] += a[i];
}
for (k = 0; i < N; ++i, ++k)
{
t[k] += a[i];
}
}
/* c = (a * b) mod q */
for (k = 0; k < N; k++)
{
c[k] = t[k] & mod_q_mask;
}
/* cleanup */
free(t);
}
METHOD(ntru_poly_t, ring_mult, void,
private_ntru_poly_t *this, uint16_t *a, uint16_t *c)
{
uint16_t *bi = this->indices, mod_q_mask = this->q - 1;
if (this->num_polynomials == 1)
{
ring_mult_indices(a, this->indices_len[0], bi, this->N, mod_q_mask, c);
}
else
{
uint16_t *t1, *t2;
int i;
/* allocate temporary arrays */
t1 = malloc(this->N * sizeof(uint16_t));
t2 = malloc(this->N * sizeof(uint16_t));
/* t1 = a * b1 */
ring_mult_indices(a, this->indices_len[0], bi, this->N, mod_q_mask, t1);
/* t1 = (a * b1) * b2 */
bi += this->indices_len[0].p + this->indices_len[0].m;
ring_mult_indices(t1, this->indices_len[1], bi, this->N, mod_q_mask, t1);
/* t2 = a * b3 */
bi += this->indices_len[1].p + this->indices_len[1].m;
ring_mult_indices(a, this->indices_len[2], bi, this->N, mod_q_mask, t2);
/* c = (a * b1 * b2) + (a * b3) */
for (i = 0; i < this->N; i++)
{
c[i] = (t1[i] + t2[i]) & mod_q_mask;
}
/* cleanup */
free(t1);
free(t2);
}
}
METHOD(ntru_poly_t, destroy, void,
private_ntru_poly_t *this)
{
memwipe(this->indices, this->indices_len);
memwipe(this->indices, get_size(this));
free(this->indices);
free(this);
}
@ -69,14 +202,15 @@ METHOD(ntru_poly_t, destroy, void,
* Described in header.
*/
ntru_poly_t *ntru_poly_create(hash_algorithm_t alg, chunk_t seed,
uint8_t c_bits, uint16_t poly_len,
uint32_t indices_count, bool is_product_form)
uint8_t c_bits, uint16_t N, uint16_t q,
uint32_t indices_len_p, uint32_t indices_len_m,
bool is_product_form)
{
private_ntru_poly_t *this;
size_t hash_len, octet_count = 0, i, num_polys, num_indices[3], indices_len;
size_t hash_len, octet_count = 0, i;
uint8_t octets[HASH_SIZE_SHA512], *used, num_left = 0, num_needed;
uint16_t index, limit, left = 0;
int poly_i = 0, index_i = 0;
int n, num_indices, index_i = 0;
ntru_mgf1_t *mgf1;
DBG2(DBG_LIB, "MGF1 is seeded with %u bytes", seed.len);
@ -87,40 +221,47 @@ ntru_poly_t *ntru_poly_create(hash_algorithm_t alg, chunk_t seed,
}
i = hash_len = mgf1->get_hash_size(mgf1);
if (is_product_form)
{
num_polys = 3;
num_indices[0] = 0xff & indices_count;
num_indices[1] = 0xff & (indices_count >> 8);
num_indices[2] = 0xff & (indices_count >> 16);
indices_len = num_indices[0] + num_indices[1] + num_indices[2];
}
else
{
num_polys = 1;
num_indices[0] = indices_count;
indices_len = indices_count;
}
used = malloc(poly_len);
limit = poly_len * ((1 << c_bits) / poly_len);
INIT(this,
.public = {
.get_size = _get_size,
.get_indices = _get_indices,
.ring_mult = _ring_mult,
.destroy = _destroy,
},
.indices_len = indices_len,
.indices = malloc(indices_len * sizeof(uint16_t)),
.N = N,
.q = q,
);
/* generate indices for all polynomials */
while (poly_i < num_polys)
if (is_product_form)
{
memset(used, 0, poly_len);
this->num_polynomials = 3;
for (n = 0; n < 3; n++)
{
this->indices_len[n].p = 0xff & indices_len_p;
this->indices_len[n].m = 0xff & indices_len_m;
indices_len_p >>= 8;
indices_len_m >>= 8;
}
}
else
{
this->num_polynomials = 1;
this->indices_len[0].p = indices_len_p;
this->indices_len[0].m = indices_len_m;
}
this->indices = malloc(sizeof(uint16_t) * get_size(this)),
used = malloc(N);
limit = N * ((1 << c_bits) / N);
/* generate indices for all polynomials */
for (n = 0; n < this->num_polynomials; n++)
{
memset(used, 0, N);
num_indices = this->indices_len[n].p + this->indices_len[n].m;
/* generate indices for a single polynomial */
while (num_indices[poly_i])
while (num_indices)
{
/* generate a random candidate index with a size of c_bits */
do
@ -167,19 +308,18 @@ ntru_poly_t *ntru_poly_create(hash_algorithm_t alg, chunk_t seed,
while (index >= limit);
/* form index and check if unique */
index %= poly_len;
index %= N;
if (!used[index])
{
used[index] = 1;
this->indices[index_i++] = index;
num_indices[poly_i]--;
num_indices--;
}
}
poly_i++;
}
DBG2(DBG_LIB, "MGF1 generates %u octets to derive %u indices",
octet_count, this->indices_len);
octet_count, get_size(this));
mgf1->destroy(mgf1);
free(used);

View File

@ -42,6 +42,11 @@ struct ntru_poly_t {
*/
uint16_t* (*get_indices)(ntru_poly_t *this);
/**
* @return array containing the indices of the non-zero coefficients
*/
void (*ring_mult)(ntru_poly_t *this, uint16_t *a, uint16_t *c);
/**
* Destroy ntru_poly_t object
*/
@ -53,14 +58,17 @@ struct ntru_poly_t {
*
* @param alg hash algorithm to be used by MGF1
* @param seed seed used by MGF1 to generate trits from
* @param poly_len size of the trits polynomial
* @param N ring dimension, number of polynomial coefficients
* @param q large modulus
* @param c_bits number of bits for candidate index
* @param indices_count number of non-zero indices
* @param indices_len_p number of indices for +1 coefficients
* @param indices_len_m number of indices for -1 coefficients
* @param is_product_form generate multiple polynomials
*/
ntru_poly_t *ntru_poly_create(hash_algorithm_t alg, chunk_t seed,
uint8_t c_bits, uint16_t poly_len,
uint32_t indices_count, bool is_product_form);
uint8_t c_bits, uint16_t N, uint16_t q,
uint32_t indices_len_p, uint32_t indices_len_m,
bool is_product_form);
#endif /** NTRU_POLY_H_ @}*/

View File

@ -33,8 +33,8 @@ IMPORT_FUNCTION_FOR_TESTS(ntru, ntru_trits_create, ntru_trits_t*,
IMPORT_FUNCTION_FOR_TESTS(ntru, ntru_poly_create, ntru_poly_t*,
hash_algorithm_t alg, chunk_t seed, uint8_t c_bits,
uint16_t poly_len, uint32_t indices_count,
bool is_product_form)
uint16_t N, uint16_t q, uint32_t indices_len_p,
uint32_t indices_len_m, bool is_product_form)
/**
* NTRU parameter sets to test
@ -302,10 +302,11 @@ END_TEST
typedef struct {
uint8_t c_bits;
uint16_t poly_len;
uint16_t N;
uint16_t q;
bool is_product_form;
uint32_t indices_count;
uint32_t indices_len;
uint32_t indices_size;
uint16_t *indices;
} poly_test_t;
@ -427,10 +428,10 @@ mgf1_test_t mgf1_tests[] = {
0, 1, 1, 2, 0, 2, 2, 0, 0, 0, 1, 1, 0, 1, 0, 1, 1, 0, 1, 1,
0, 1, 2, 0, 1, 1, 0, 1, 2, 0, 0, 1, 2, 2, 0, 0, 2, 1, 2),
{
{ 9, 439, TRUE, 2*(9 + (8 << 8) + (5 << 16)),
{ 9, 439, 2048, TRUE, 9 + (8 << 8) + (5 << 16),
countof(indices_ees439ep1), indices_ees439ep1
},
{ 11, 613, FALSE, 2*55,
{ 11, 613, 2048, FALSE, 55,
countof(indices_ees613ep1), indices_ees613ep1
}
}
@ -514,10 +515,10 @@ mgf1_test_t mgf1_tests[] = {
1, 0, 1, 0, 2, 2, 1, 0, 2, 2, 2, 2, 2, 1, 0, 2, 2, 2, 1, 2,
0, 2, 0, 0, 0, 0, 0, 1, 2, 0, 1, 0, 1),
{
{ 13, 743, TRUE, 2*(11 + (11 << 8) + (15 << 16)),
{ 13, 743, 2048, TRUE, 11 + (11 << 8) + (15 << 16),
countof(indices_ees743ep1), indices_ees743ep1
},
{ 12, 1171, FALSE, 2*106,
{ 12, 1171, 2048, FALSE, 106,
countof(indices_ees1171ep1), indices_ees1171ep1
}
}
@ -632,19 +633,21 @@ START_TEST(test_ntru_poly)
seed.len = mgf1_tests[_i].seed_len;
p = &mgf1_tests[_i].poly_test[0];
poly = ntru_poly_create(HASH_UNKNOWN, seed, p->c_bits, p->poly_len,
p->indices_count, p->is_product_form);
poly = ntru_poly_create(HASH_UNKNOWN, seed, p->c_bits, p->N, p->q,
p->indices_len, p->indices_len,
p->is_product_form);
ck_assert(poly == NULL);
for (n = 0; n < 2; n++)
{
p = &mgf1_tests[_i].poly_test[n];
poly = ntru_poly_create(mgf1_tests[_i].alg, seed, p->c_bits, p->poly_len,
p->indices_count, p->is_product_form);
ck_assert(poly != NULL && poly->get_size(poly) == p->indices_len);
poly = ntru_poly_create(mgf1_tests[_i].alg, seed, p->c_bits, p->N, p->q,
p->indices_len, p->indices_len,
p->is_product_form);
ck_assert(poly != NULL && poly->get_size(poly) == p->indices_size);
indices = poly->get_indices(poly);
for (j = 0; j < p->indices_len; j++)
for (j = 0; j < p->indices_size; j++)
{
ck_assert(indices[j] == p->indices[j]);
}