Skip to content

Commit

Permalink
(branch-faster-f25519): get multiply working
Browse files Browse the repository at this point in the history
  • Loading branch information
Ian Harvey committed Jun 3, 2014
1 parent 206c047 commit 08c591b
Show file tree
Hide file tree
Showing 2 changed files with 144 additions and 32 deletions.
58 changes: 46 additions & 12 deletions python-models/mult.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,12 @@ def mul256_rs( wordX, wordY ):
if r < NDIGITS:
sX = r
sY = 0
count = r+1
else:
sX = NDIGITS-1
sY = (r-NDIGITS+1)
for i in range(sX-sY+1):
count = 2*NDIGITS-1-r
for i in range(count):
r64 += wordX[sX] * wordY[sY]
sX -= 1
sY += 1
Expand All @@ -24,23 +26,55 @@ def mul256_rs( wordX, wordY ):
res[NDIGITS*2-1] = r64
return res

def toWords(v):
return [ (v >> i) & MASK for i in range(0, 255, NBITS) ]

def toNum(w):
return sum([ (w[i] << (i*NBITS)) for i in range(len(w)) ])

def mul(x,y):
wX = [ (x >> i) & MASK for i in range(0, 255, NBITS) ]
wY = [ (y >> i) & MASK for i in range(0, 255, NBITS) ]
wR = mul256_rs(wX, wY)

return sum([ (wR[i] << (i*NBITS)) for i in range(len(wR)) ])
return toNum( mul256_rs(toWords(x), toWords(y)) )

def mul_reduce_approx(wS):
print "|wS = ", hex(toNum(wS))
for i in range(2):
wD = [ (wS[j+8] >> 23) | ((wS[j+9] << 6) & MASK)
for j in range(9) ]
wS[8] &= 0x7fffff
for j in range(9,18):
wS[j] = 0

r64 = 0
for j in range(9):
r64 = r64 + wD[j]*19 + wS[j]
wS[j] = r64 & MASK
r64 >>= NBITS
wS[9] = r64

return wS

P25519 = (1 << 255) - 19

def mul_mod(x,y):
wS = mul256_rs(toWords(x), toWords(y))
res = toNum( mul_reduce_approx(wS) )
if res >= P25519:
return res - P25519
return res

tstlist = [ 0, 1, 0x80000000, 0xFFFFFFFF, (1 << 64)-1, (1 << 64),
(1<<128)-1, (1<<256)-1 ]
(1<<128)-1,
P25519-0x80000000,
P25519-1 ]

if __name__ == '__main__':
for x in tstlist:
for y in tstlist:
res = mul(x,y)
print "x=", hex(x)
print "y=", hex(y)
print "x*y=", hex(x*y)
res = mul_mod(x,y)
ans = (x*y) % P25519
print "x =", hex(x)
print "y =", hex(y)
print "x*y=", hex(ans)
print "res=", hex(res)
assert(res == x*y)
assert(res == ans)

118 changes: 98 additions & 20 deletions src/f25519mul_mini.c
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,54 @@
#define MPIMINI_INTERNAL_API
#include "f25519_mini.h"

#if 0
static void mul_reduce_approx(uint32_t *dst, uint32_t *src)
// dst is 8 words long
// src is 16 words long
#define USE_64BIT

#ifdef USE_64BIT

typedef uint64_t U64;

static void u64_sum_row( U64 *sum, const int32_t *s_up, const int32_t *s_dn, int count)
{
U64 acc = *sum;
while ( count-- > 0 )
{
acc += (U64)(*s_up) * (*s_dn);
s_up ++;
s_dn --;
}
*sum = acc;
}

static void u64_shift_bits( U64 *val )
{
*val >>= F25519MINI_BITS;
}

static int32_t u64_mask_bits ( const U64 *src )
{
return (int32_t) (*src & F25519MINI_BITMASK);
}

static void u64_clear( U64 *val )
{
*val = 0;
}

#endif

typedef struct
{
int32_t digits[F25519MINI_DIGITS*2];
}
MulResult;

#if (F25519MINI_BITS != 29) || (F25519MINI_DIGITS != 9)
#error "mul_reduce_approx is incorrect for F25519MINI_BITS"
#endif

static void mul_reduce_approx(int32_t *dst, int32_t *src)
// dst is 9 words long
// src is 18 words long
// Forms dst = (src % (2^255-19) + {0..1} * (2^255-19)
{
int i;
Expand All @@ -24,39 +68,73 @@ static void mul_reduce_approx(uint32_t *dst, uint32_t *src)
for (i=0; i<2; i++)
{
int j;
U64 r64;


// Do dst <- (src >> 255), which will be our N~
for (j=0; j < 8; j++)
dst[j] = (src[j+7] >> 31) | (src[j+8] << 1);

// 255 bits is 8 digits and 23 bits

for (j=0; j < 9; j++)
dst[j] = (src[j+8] >> 23) | ((src[j+9] << 6) & F25519MINI_BITMASK);

// Do src' -= (N~ * 2^255), i.e. clear all bits > 255

src[7] &= 0x7FFFFFFF;
for (j=8; j < 16; j++)
src[8] &= F25519MINI_BITMASK >> 6;
for (j=9; j < 18; j++)
src[j] = 0;

// Adjust: src'' += (N~ * 19), so src'' == src - (N~ * (2^255-19))
mpi_mulrow_mini_ (src, 19, dst);
// We can ignore the carry - mulrow will correctly set src[0..8]
// and we know the true value of src won't be > 20 * (2^255)

u64_clear(&r64);
for (j=0; j < 9; j++)
{
r64 += src[j];
r64 += (U64)dst[j] * 19;
src[j] = u64_mask_bits(&r64);
u64_shift_bits(&r64);
}
src[9] = u64_mask_bits(&r64);

// Now src = (src - N~ * (2^255-19))
// First time round, max value is ~ 19 * (2^255-19)
}

for (i=0; i<8; i++)
for (i=0; i<9; i++)
dst[i] = src[i];
}

#endif

void F25519_mul3_mini(F25519_Mini *res, const F25519_Mini *s1, const F25519_Mini *s2)
{
// TODO
//ULong_Mini mr;
//mpimul_mini(&mr, s1, s2);
//mul_reduce_approx(res->digits, mr.digits);
//F25519_reduce_mini_(res);
MulResult w;
int r;
U64 r64;

u64_clear(&r64);

for ( r=0; r < F25519MINI_DIGITS*2 - 1; r++ )
{
if ( r < F25519MINI_DIGITS )
{
u64_sum_row(&r64,
&s1->digits[0],
&s2->digits[r],
r+1);
}
else
{
u64_sum_row(&r64,
&s1->digits[r-F25519MINI_DIGITS+1],
&s2->digits[F25519MINI_DIGITS-1],
2*F25519MINI_DIGITS-1-r);
}
w.digits[r] = u64_mask_bits(&r64);
u64_shift_bits(&r64);
}

w.digits[r] = u64_mask_bits(&r64);

mul_reduce_approx(res->digits, w.digits);
F25519_reduce_mini_(res);
}

void F25519_sqr_mini(F25519_Mini *res, const F25519_Mini *s)
Expand Down

0 comments on commit 08c591b

Please sign in to comment.