From dd193a3ec596fe27810f2cd5fb860de6452ec982 Mon Sep 17 00:00:00 2001 From: Michael Hamburg Date: Fri, 23 Jun 2017 14:28:54 -0700 Subject: [PATCH] ristretto work --- aux/ristretto.sage | 193 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 193 insertions(+) create mode 100644 aux/ristretto.sage diff --git a/aux/ristretto.sage b/aux/ristretto.sage new file mode 100644 index 0000000..907a7bc --- /dev/null +++ b/aux/ristretto.sage @@ -0,0 +1,193 @@ + +class InvalidEncodingException(Exception): pass +class NotOnCurveException(Exception): pass + +def lobit(x): return int(x) & 1 +def hibit(x): return lobit(2*x) +def enc_le(x,n): return bytearray([int(x)>>(8*i) & 0xFF for i in xrange(n)]) +def dec_le(x): return sum(b<<(8*i) for i,b in enumerate(x)) + +class EdwardsPoint(object): + """Abstract class for point an an Edwards curve; needs F,a,d to work""" + def __init__(self,x=0,y=1): + x = self.x = self.F(x) + y = self.y = self.F(y) + if y^2 + self.a*x^2 != 1 + self.d*x^2*y^2: + raise NotOnCurveException() + + def __repr__(self): + return "%s(%d,%d)" % (self.__class__.__name__, self.x, self.y) + + def __iter__(self): + yield self.x + yield self.y + + def __add__(self,other): + x,y = self + X,Y = other + a,d = self.a,self.d + return self.__class__( + (x*Y+y*X)/(1+d*x*y*X*Y), + (y*Y-a*x*X)/(1-d*x*y*X*Y) + ) + + def __neg__(self): return self.__class__(-self.x,self.y) + def __sub__(self,other): return self + (-other) + def __rmul__(self,other): return self*other + def __eq__(self,other): return tuple(self) == tuple(other) + def __ne__(self,other): return not (self==other) + + def __mul__(self,exp): + exp = int(exp) + total = self.__class__() + work = self + while exp != 0: + if exp & 1: total += work + work += work + exp >>= 1 + return total + +class Ed25519Point(EdwardsPoint): + F = GF(2^255-19) + d = F(-121665/121666) + a = F(-1) + i = sqrt(Ed25519Point.F(-1)) + + @classmethod + def base(cls): + y = cls.F(4/5) + x = sqrt((y^2-1)/(cls.d*y^2+1)) + if lobit(x): x = -x + return cls(x,y) + + def torque(self): + return self.__class__(self.y*self.i, self.x*self.i) + +class RistrettoOption1Point(Ed25519Point): + """Like current decaf but tweaked for simplicity""" + dMont = Ed25519Point.F(-121665) + encLen = 32 + + def __eq__(self,other): + x,y = self + X,Y = other + return x*Y == X*y or x*X == y*Y + + def encode(self): + x,y = self + a,d = self.a,self.d + + if x*y == 0: + # This happens anyway with straightforward impl + return enc_le(0,self.encLen) + + if not is_square((1-y)/(1+y)): + raise Exception("Unimplemented: odd point in RistrettoPoint.encode") + + # Choose representative in 4-torsion group + if lobit(x*y): (x,y) = (self.i*y,self.i*x) + if lobit(x): x,y = -x,-y + + s = sqrt((1-y)/(1+y)) + if lobit(s): s = -s + return enc_le(s,self.encLen) + + @classmethod + def decode(cls,s): + if len(s) != cls.encLen: + raise InvalidEncodingException("wrong length %d" % len(s)) + s = dec_le(s) + if s == 0: return cls(0,1) + if s < 0 or s >= cls.F.modulus() or lobit(s): + raise InvalidEncodingException("%d out of range!" % s) + s = cls.F(s) + + magic = 4*cls.dMont-4 + if not is_square(magic*s^2 / ((s^2-1)^2 - s^2 * magic)): + raise InvalidEncodingException("Not on curve") + + x = sqrt(magic*s^2 / ((s^2-1)^2 - magic * s^2)) + if lobit(x): x=-x + y = (1-s^2)/(1+s^2) + + if lobit(x*y): + raise InvalidEncodingException("x*y has high bit") + + return cls(x,y) + +class RistrettoOption2Point(Ed25519Point): + """Works like current decaf""" + dMont = Ed25519Point.F(-121665) + magic = sqrt(dMont-1) + encLen = 32 + + def __eq__(self,other): + x,y = self + X,Y = other + return x*Y == X*y or x*X == y*Y + + def encode(self): + x,y = self + a,d = self.a,self.d + + if x*y == 0: + # This will happen anyway with straightforward square root trick + return enc_le(0,self.encLen) + + if not is_square((1-y)/(1+y)): + raise Exception("Unimplemented: odd point in RistrettoPoint.encode") + + # Choose representative in 4-torsion group + if hibit(self.magic/(x*y)): (x,y) = (self.i*y,self.i*x) + if hibit(2*self.magic/x): x,y = -x,-y + + s = sqrt((1-y)/(1+y)) + if hibit(s): s = -s + return enc_le(s,self.encLen) + + @classmethod + def decode(cls,s): + if len(s) != cls.encLen: + raise InvalidEncodingException("wrong length %d" % len(s)) + s = dec_le(s) + if s == 0: return cls(0,1) + if s < 0 or s >= (cls.F.modulus()+1)/2: + raise InvalidEncodingException("%d out of range!" % s) + s = cls.F(s) + + if not is_square(s^4 + (2-4*cls.dMont)*s^2 + 1): + raise InvalidEncodingException("Not on curve") + + t = sqrt(s^4 + (2-4*cls.dMont)*s^2 + 1)/s + if hibit(t): t = -t + + y = (1-s^2)/(1+s^2) + x = 2*cls.magic/t + + if y == 0 or lobit(t/y): + raise InvalidEncodingException("t/y has high bit") + + return cls(x,y) + +class TestFailedException(Exception): pass +def test(cls,n): + # TODO: test corner cases like 0,1,i + P = cls.base() + Q = cls() + for i in xrange(n): + QQ = cls.decode(Q.encode()) + if QQ != Q: raise TestFailedException("Round trip %s != %s" % (str(QQ),str(Q))) + if Q.encode() != Q.torque().encode(): + raise TestFailedException("Can't torque %s" % str(Q)) + + Q0 = Q + P + if Q0 == Q: raise TestFailedException("Addition doesn't work") + if Q0-P != Q: raise TestFailedException("Subtraction doesn't work") + + r = randint(1,1000) + Q1 = Q0*r + Q2 = Q0*(r+1) + if Q1 + Q0 != Q2: raise TestFailedException("Scalarmul doesn't work") + Q = Q1 + + \ No newline at end of file