453 lines
8.5 KiB
C
453 lines
8.5 KiB
C
/*
|
|
* Copyright (C) 2014 Andreas Steffen
|
|
* HSR Hochschule fuer Technik Rapperswil
|
|
*
|
|
* Copyright (C) 2009-2013 Security Innovation
|
|
*
|
|
* This program is free software; you can redistribute it and/or modify it
|
|
* under the terms of the GNU General Public License as published by the
|
|
* Free Software Foundation; either version 2 of the License, or (at your
|
|
* option) any later version. See <http://www.fsf.org/copyleft/gpl.txt>.
|
|
*
|
|
* This program is distributed in the hope that it will be useful, but
|
|
* WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
|
|
* or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
|
|
* for more details.
|
|
*/
|
|
|
|
#include <stdlib.h>
|
|
#include <string.h>
|
|
|
|
#include "ntru_convert.h"
|
|
|
|
/**
|
|
* 3-bit to 2-trit conversion tables: 2 represents -1
|
|
*/
|
|
static uint8_t const bits_2_trit1[] = {0, 0, 0, 1, 1, 1, 2, 2};
|
|
static uint8_t const bits_2_trit2[] = {0, 1, 2, 0, 1, 2, 0, 1};
|
|
|
|
/**
|
|
* See header.
|
|
*/
|
|
void ntru_bits_2_trits(uint8_t const *octets, uint16_t num_trits, uint8_t *trits)
|
|
{
|
|
uint32_t bits24, bits3, shift;
|
|
|
|
while (num_trits >= 16)
|
|
{
|
|
/* get next three octets */
|
|
bits24 = ((uint32_t)(*octets++)) << 16;
|
|
bits24 |= ((uint32_t)(*octets++)) << 8;
|
|
bits24 |= (uint32_t)(*octets++);
|
|
|
|
/* for each 3 bits in the three octets, output 2 trits */
|
|
bits3 = (bits24 >> 21) & 0x7;
|
|
*trits++ = bits_2_trit1[bits3];
|
|
*trits++ = bits_2_trit2[bits3];
|
|
|
|
bits3 = (bits24 >> 18) & 0x7;
|
|
*trits++ = bits_2_trit1[bits3];
|
|
*trits++ = bits_2_trit2[bits3];
|
|
|
|
bits3 = (bits24 >> 15) & 0x7;
|
|
*trits++ = bits_2_trit1[bits3];
|
|
*trits++ = bits_2_trit2[bits3];
|
|
|
|
bits3 = (bits24 >> 12) & 0x7;
|
|
*trits++ = bits_2_trit1[bits3];
|
|
*trits++ = bits_2_trit2[bits3];
|
|
|
|
bits3 = (bits24 >> 9) & 0x7;
|
|
*trits++ = bits_2_trit1[bits3];
|
|
*trits++ = bits_2_trit2[bits3];
|
|
|
|
bits3 = (bits24 >> 6) & 0x7;
|
|
*trits++ = bits_2_trit1[bits3];
|
|
*trits++ = bits_2_trit2[bits3];
|
|
|
|
bits3 = (bits24 >> 3) & 0x7;
|
|
*trits++ = bits_2_trit1[bits3];
|
|
*trits++ = bits_2_trit2[bits3];
|
|
|
|
bits3 = bits24 & 0x7;
|
|
*trits++ = bits_2_trit1[bits3];
|
|
*trits++ = bits_2_trit2[bits3];
|
|
|
|
num_trits -= 16;
|
|
}
|
|
if (num_trits == 0)
|
|
{
|
|
return;
|
|
}
|
|
|
|
/* get three octets */
|
|
bits24 = ((uint32_t)(*octets++)) << 16;
|
|
bits24 |= ((uint32_t)(*octets++)) << 8;
|
|
bits24 |= (uint32_t)(*octets++);
|
|
|
|
shift = 21;
|
|
while (num_trits)
|
|
{
|
|
/**
|
|
* for each 3 bits in the three octets, output up to 2 trits
|
|
* until all trits needed are produced
|
|
*/
|
|
bits3 = (bits24 >> shift) & 0x7;
|
|
shift -= 3;
|
|
*trits++ = bits_2_trit1[bits3];
|
|
if (--num_trits)
|
|
{
|
|
*trits++ = bits_2_trit2[bits3];
|
|
--num_trits;
|
|
}
|
|
}
|
|
}
|
|
|
|
/**
|
|
* See header.
|
|
*/
|
|
bool ntru_trits_2_bits(uint8_t const *trits, uint32_t num_trits, uint8_t *octets)
|
|
{
|
|
bool all_trits_valid = TRUE;
|
|
uint32_t bits24, bits3, shift;
|
|
|
|
while (num_trits >= 16)
|
|
{
|
|
/* convert each 2 trits to 3 bits and pack */
|
|
bits3 = *trits++ * 3;
|
|
bits3 += *trits++;
|
|
if (bits3 > 7)
|
|
{
|
|
bits3 = 7;
|
|
all_trits_valid = FALSE;
|
|
}
|
|
bits24 = (bits3 << 21);
|
|
|
|
bits3 = *trits++ * 3;
|
|
bits3 += *trits++;
|
|
if (bits3 > 7)
|
|
{
|
|
bits3 = 7;
|
|
all_trits_valid = FALSE;
|
|
}
|
|
bits24 |= (bits3 << 18);
|
|
|
|
bits3 = *trits++ * 3;
|
|
bits3 += *trits++;
|
|
if (bits3 > 7)
|
|
{
|
|
bits3 = 7;
|
|
all_trits_valid = FALSE;
|
|
}
|
|
bits24 |= (bits3 << 15);
|
|
|
|
bits3 = *trits++ * 3;
|
|
bits3 += *trits++;
|
|
if (bits3 > 7)
|
|
{
|
|
bits3 = 7;
|
|
all_trits_valid = FALSE;
|
|
}
|
|
bits24 |= (bits3 << 12);
|
|
|
|
bits3 = *trits++ * 3;
|
|
bits3 += *trits++;
|
|
if (bits3 > 7)
|
|
{
|
|
bits3 = 7;
|
|
all_trits_valid = FALSE;
|
|
}
|
|
bits24 |= (bits3 << 9);
|
|
|
|
bits3 = *trits++ * 3;
|
|
bits3 += *trits++;
|
|
if (bits3 > 7)
|
|
{
|
|
bits3 = 7;
|
|
all_trits_valid = FALSE;
|
|
}
|
|
bits24 |= (bits3 << 6);
|
|
|
|
bits3 = *trits++ * 3;
|
|
bits3 += *trits++;
|
|
if (bits3 > 7)
|
|
{
|
|
bits3 = 7;
|
|
all_trits_valid = FALSE;
|
|
}
|
|
bits24 |= (bits3 << 3);
|
|
|
|
bits3 = *trits++ * 3;
|
|
bits3 += *trits++;
|
|
if (bits3 > 7)
|
|
{
|
|
bits3 = 7;
|
|
all_trits_valid = FALSE;
|
|
}
|
|
bits24 |= bits3;
|
|
|
|
num_trits -= 16;
|
|
|
|
/* output three octets */
|
|
*octets++ = (uint8_t)((bits24 >> 16) & 0xff);
|
|
*octets++ = (uint8_t)((bits24 >> 8) & 0xff);
|
|
*octets++ = (uint8_t)(bits24 & 0xff);
|
|
}
|
|
|
|
bits24 = 0;
|
|
shift = 21;
|
|
while (num_trits)
|
|
{
|
|
/* convert each 2 trits to 3 bits and pack */
|
|
bits3 = *trits++ * 3;
|
|
if (--num_trits)
|
|
{
|
|
bits3 += *trits++;
|
|
--num_trits;
|
|
}
|
|
if (bits3 > 7)
|
|
{
|
|
bits3 = 7;
|
|
all_trits_valid = FALSE;
|
|
}
|
|
bits24 |= (bits3 << shift);
|
|
shift -= 3;
|
|
}
|
|
|
|
/* output three octets */
|
|
*octets++ = (uint8_t)((bits24 >> 16) & 0xff);
|
|
*octets++ = (uint8_t)((bits24 >> 8) & 0xff);
|
|
*octets++ = (uint8_t)(bits24 & 0xff);
|
|
|
|
return all_trits_valid;
|
|
}
|
|
|
|
/**
|
|
* See header
|
|
*/
|
|
void ntru_coeffs_mod4_2_octets(uint16_t num_coeffs, uint16_t const *coeffs, uint8_t *octets)
|
|
{
|
|
uint8_t bits2;
|
|
int shift, i;
|
|
|
|
*octets = 0;
|
|
shift = 6;
|
|
for (i = 0; i < num_coeffs; i++)
|
|
{
|
|
bits2 = (uint8_t)(coeffs[i] & 0x3);
|
|
*octets |= bits2 << shift;
|
|
shift -= 2;
|
|
if (shift < 0)
|
|
{
|
|
++octets;
|
|
*octets = 0;
|
|
shift = 6;
|
|
}
|
|
}
|
|
}
|
|
|
|
/**
|
|
* See header.
|
|
*/
|
|
void ntru_trits_2_octet(uint8_t const *trits, uint8_t *octet)
|
|
{
|
|
int i;
|
|
|
|
*octet = 0;
|
|
for (i = 4; i >= 0; i--)
|
|
{
|
|
*octet = (*octet * 3) + trits[i];
|
|
}
|
|
}
|
|
|
|
/**
|
|
* See header.
|
|
*/
|
|
void ntru_octet_2_trits(uint8_t octet, uint8_t *trits)
|
|
{
|
|
int i;
|
|
|
|
for (i = 0; i < 5; i++)
|
|
{
|
|
trits[i] = octet % 3;
|
|
octet = (octet - trits[i]) / 3;
|
|
}
|
|
}
|
|
|
|
/**
|
|
* See header.
|
|
*/
|
|
void ntru_indices_2_trits(uint16_t in_len, uint16_t const *in, bool plus1,
|
|
uint8_t *out)
|
|
{
|
|
uint8_t trit = plus1 ? 1 : 2;
|
|
int i;
|
|
|
|
for (i = 0; i < in_len; i++)
|
|
{
|
|
out[in[i]] = trit;
|
|
}
|
|
}
|
|
|
|
/**
|
|
* See header.
|
|
*/
|
|
void ntru_packed_trits_2_indices(uint8_t const *in, uint16_t num_trits,
|
|
uint16_t *indices_plus1,
|
|
uint16_t *indices_minus1)
|
|
{
|
|
uint8_t trits[5];
|
|
uint16_t i = 0;
|
|
int j;
|
|
|
|
while (num_trits >= 5)
|
|
{
|
|
ntru_octet_2_trits(*in++, trits);
|
|
num_trits -= 5;
|
|
for (j = 0; j < 5; j++, i++)
|
|
{
|
|
if (trits[j] == 1)
|
|
{
|
|
*indices_plus1 = i;
|
|
++indices_plus1;
|
|
}
|
|
else if (trits[j] == 2)
|
|
{
|
|
*indices_minus1 = i;
|
|
++indices_minus1;
|
|
}
|
|
}
|
|
}
|
|
if (num_trits)
|
|
{
|
|
ntru_octet_2_trits(*in, trits);
|
|
for (j = 0; num_trits && (j < 5); j++, i++)
|
|
{
|
|
if (trits[j] == 1)
|
|
{
|
|
*indices_plus1 = i;
|
|
++indices_plus1;
|
|
}
|
|
else if (trits[j] == 2)
|
|
{
|
|
*indices_minus1 = i;
|
|
++indices_minus1;
|
|
}
|
|
--num_trits;
|
|
}
|
|
}
|
|
}
|
|
|
|
/**
|
|
* See header.
|
|
*/
|
|
void ntru_indices_2_packed_trits(uint16_t const *indices, uint16_t num_plus1,
|
|
uint16_t num_minus1, uint16_t num_trits,
|
|
uint8_t *buf, uint8_t *out)
|
|
{
|
|
/* convert indices to an array of trits */
|
|
memset(buf, 0, num_trits);
|
|
ntru_indices_2_trits(num_plus1, indices, TRUE, buf);
|
|
ntru_indices_2_trits(num_minus1, indices + num_plus1, FALSE, buf);
|
|
|
|
/* pack the array of trits */
|
|
while (num_trits >= 5)
|
|
{
|
|
ntru_trits_2_octet(buf, out);
|
|
num_trits -= 5;
|
|
buf += 5;
|
|
++out;
|
|
}
|
|
if (num_trits)
|
|
{
|
|
uint8_t trits[5];
|
|
|
|
memcpy(trits, buf, num_trits);
|
|
memset(trits + num_trits, 0, sizeof(trits) - num_trits);
|
|
ntru_trits_2_octet(trits, out);
|
|
}
|
|
}
|
|
|
|
/**
|
|
* See header
|
|
*/
|
|
void ntru_elements_2_octets(uint16_t in_len, uint16_t const *in, uint8_t n_bits,
|
|
uint8_t *out)
|
|
{
|
|
uint16_t temp;
|
|
int shift, i;
|
|
|
|
/* pack */
|
|
temp = 0;
|
|
shift = n_bits - 8;
|
|
i = 0;
|
|
while (i < in_len)
|
|
{
|
|
/* add bits to temp to fill an octet and output the octet */
|
|
temp |= in[i] >> shift;
|
|
*out++ = (uint8_t)(temp & 0xff);
|
|
shift = 8 - shift;
|
|
if (shift < 1)
|
|
{
|
|
/* next full octet is in current input word */
|
|
shift += n_bits;
|
|
temp = 0;
|
|
}
|
|
else
|
|
{
|
|
/* put remaining bits of input word in temp as partial octet,
|
|
* and increment index to next input word
|
|
*/
|
|
temp = in[i] << (uint16_t)shift;
|
|
++i;
|
|
}
|
|
shift = n_bits - shift;
|
|
}
|
|
|
|
/* output any bits remaining in last input word */
|
|
if (shift != n_bits - 8)
|
|
{
|
|
*out++ = (uint8_t)(temp & 0xff);
|
|
}
|
|
}
|
|
|
|
|
|
/**
|
|
* See header.
|
|
*/
|
|
void ntru_octets_2_elements(uint16_t in_len, uint8_t const *in, uint8_t n_bits,
|
|
uint16_t *out)
|
|
{
|
|
uint16_t temp;
|
|
uint16_t mask = (1 << n_bits) - 1;
|
|
int shift, i;
|
|
|
|
/* unpack */
|
|
temp = 0;
|
|
shift = n_bits;
|
|
i = 0;
|
|
while (i < in_len)
|
|
{
|
|
shift = 8 - shift;
|
|
if (shift < 0)
|
|
{
|
|
/* the current octet will not fill the current element */
|
|
shift += n_bits;
|
|
}
|
|
else
|
|
{
|
|
/* add bits from the current octet to fill the current element and
|
|
* output the element
|
|
*/
|
|
temp |= ((uint16_t)in[i]) >> shift;
|
|
*out++ = temp & mask;
|
|
temp = 0;
|
|
}
|
|
|
|
/* add the remaining bits of the current octet to start an element */
|
|
shift = n_bits - shift;
|
|
temp |= ((uint16_t)in[i]) << shift;
|
|
++i;
|
|
}
|
|
}
|