From 03ba02f90d606b1e52584e523af7e4adf8acd21b Mon Sep 17 00:00:00 2001 From: Michael Hamburg Date: Wed, 5 Jul 2017 20:37:31 -0700 Subject: [PATCH] more ristretto --- aux/ristretto.sage | 85 ++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 83 insertions(+), 2 deletions(-) diff --git a/aux/ristretto.sage b/aux/ristretto.sage index 89eec7c..501704c 100644 --- a/aux/ristretto.sage +++ b/aux/ristretto.sage @@ -7,6 +7,12 @@ def hibit(x): return lobit(2*x) def enc_le(x,n): return bytearray([int(x)>>(8*i) & 0xFF for i in xrange(n)]) def dec_le(x): return sum(b<<(8*i) for i,b in enumerate(x)) +def isqrt(x,exn=InvalidEncodingException("Not on curve")): + """Return 1/sqrt(x)""" + if x==0: return 0 + if not is_square(x): raise exn + return 1/sqrt(x) + class EdwardsPoint(object): """Abstract class for point an an Edwards curve; needs F,a,d to work""" def __init__(self,x=0,y=1): @@ -46,12 +52,17 @@ class EdwardsPoint(object): work += work exp >>= 1 return total + + def xyzt(self): + x,y = self + z = self.F.random_element() + return x*z,y*z,z,x*y*z class Ed25519Point(EdwardsPoint): F = GF(2^255-19) d = F(-121665/121666) a = F(-1) - i = sqrt(Ed25519Point.F(-1)) + i = sqrt(F(-1)) @classmethod def base(cls): @@ -100,8 +111,78 @@ class RistrettoPoint(Ed25519Point): if lobit(x*y) or x==0: raise InvalidEncodingException("x*y has high bit") - + return cls(x,y) + +class OptimizedRistrettoPoint(RistrettoPoint): + magic = isqrt(RistrettoPoint.d+1) + + """Like Ristretto but uses isqrt instead""" + @classmethod + def isqrt_and_inv(cls,isqrt,inv,*args,**kwargs): + s = isqrt(isqrt*inv^2) + return s*inv, s^2*isqrt*inv + + def encode(self): + right_answer = super(OptimizedRistrettoPoint,self).encode() + x,y,z,t = self.xyzt() + x *= self.i + + u1 = (z+y)*(z-y) + u2 = x*y # = t*z + isr = isqrt(u1 * u2^2) + i1 = isr*u1 + i2 = isr*u2 + z_inv = i1*i2*t + + rotate = lobit(t*self.i*z_inv) + if rotate: + x,y = y,x + den_inv = self.magic * i1 + else: + den_inv = i2 + + if rotate ^^ lobit(x*z_inv): y = -y + s = (z-y) * den_inv + if s==0: s = F(1) + if lobit(s): s=-s + + ret = enc_le(s,self.encLen) + assert ret == right_answer + return ret + + @classmethod + def decode(cls,s): + right_answer = super(cls,OptimizedRistrettoPoint).decode(s) + + # Sanity check s + if len(s) != cls.encLen: + raise InvalidEncodingException("wrong length %d" % len(s)) + s = dec_le(s) + if s < 0 or s >= cls.F.modulus() or lobit(s): + raise InvalidEncodingException("%d out of range!" % s) + s = cls.F(s) + + yden = 1+s^2 + ynum = 1-s^2 + yden_sqr = yden^2 + xden_sqr = -cls.d*ynum^2 - yden_sqr + + isr = isqrt(xden_sqr * yden_sqr) + + xden_inv = isr * yden + yden_inv = xden_inv * isr * xden_sqr + + x = 2*s*xden_inv + if lobit(x): x = -x + y = ynum * yden_inv + + if lobit(x*y) or x==0: + raise InvalidEncodingException("x*y has high bit") + + ret = cls(x,y) + assert ret == right_answer + return ret class DecafPoint(Ed25519Point): """Works like current decaf"""