diff --git a/aux/ristretto.sage b/aux/ristretto.sage index 388f5d3..ac0c450 100644 --- a/aux/ristretto.sage +++ b/aux/ristretto.sage @@ -10,6 +10,7 @@ def dec_le(x): return sum(b<<(8*i) for i,b in enumerate(x)) def randombytes(n): return bytearray([randint(0,255) for _ in range(n)]) def optimized_version_of(spec): + """Decorator: This function is an optimized version of some specification""" def decorator(f): def wrapper(self,*args,**kwargs): try: spec_ans = getattr(self,spec,spec)(*args,**kwargs),None @@ -25,7 +26,7 @@ def optimized_version_of(spec): if spec_ans[0] != opt_ans[0]: raise SpecException("Mismatch in %s: %s != %s" % (f.__name__,str(spec_ans[0]),str(opt_ans[0]))) - if opt_ans[1] is not None: raise opt_ans[1] + if opt_ans[1] is not None: raise else: return opt_ans[0] wrapper.__name__ = f.__name__ return wrapper @@ -52,8 +53,8 @@ def isqrt_i(x): if is_square(x): return True,1/sqrt(x) else: return False,1/sqrt(x*gen) -class EdwardsPoint(object): - """Abstract class for point an an Edwards curve; needs F,a,d to work""" +class QuotientEdwardsPoint(object): + """Abstract class for point an a quotiented Edwards curve; needs F,a,d,cofactor to work""" def __init__(self,x=0,y=1): x = self.x = self.F(x) y = self.y = self.F(y) @@ -79,11 +80,16 @@ class EdwardsPoint(object): def __neg__(self): return self.__class__(-self.x,self.y) def __sub__(self,other): return self + (-other) def __rmul__(self,other): return self*other - def __eq__(self,other): return tuple(self) == tuple(other) + def __eq__(self,other): + """NB: this is the only method that is different from the usual one""" + x,y = self + X,Y = other + return x*Y == X*y or (self.cofactor==8 and -self.a*x*X == y*Y) def __ne__(self,other): return not (self==other) def __mul__(self,exp): exp = int(exp) + if exp < 0: exp,self = -exp,-self total = self.__class__() work = self while exp != 0: @@ -103,14 +109,9 @@ class EdwardsPoint(object): return self.__class__(self.y*self.i, self.x*self.i) else: return self.__class__(-self.x, -self.y) - -class RistrettoPoint(EdwardsPoint): - """Like current decaf but tweaked for simplicity""" - def __eq__(self,other): - x,y = self - X,Y = other - return x*Y == X*y or x*X == y*Y + + # Utility functions @classmethod def bytesToGf(cls,bytes,mustBeProper=True,mustBePositive=False): """Convert little-endian bytes to field element, sanity check length""" @@ -123,6 +124,14 @@ class RistrettoPoint(EdwardsPoint): raise InvalidEncodingException("%d is negative!" % s) return cls.F(s) + @classmethod + def gfToBytes(cls,x,mustBePositive=False): + """Convert little-endian bytes to field element, sanity check length""" + if lobit(x) and mustBePositive: x = -x + return enc_le(x,cls.encLen) + +class RistrettoPoint(QuotientEdwardsPoint): + """The new Ristretto group""" def encodeSpec(self): """Unoptimized specification for encoding""" x,y = self @@ -133,8 +142,7 @@ class RistrettoPoint(EdwardsPoint): if lobit(x): x,y = -x,-y s = xsqrt(self.a*(y-1)/(y+1),exn=Exception("Unimplemented: point is odd: " + str(self))) - - return enc_le(s,self.encLen) + return self.gfToBytes(s) @classmethod def decodeSpec(cls,s): @@ -162,9 +170,8 @@ class RistrettoPoint(EdwardsPoint): i1 = isr*u1 i2 = isr*u2 z_inv = i1*i2*t - - rotate = self.cofactor==8 and lobit(t*z_inv) - if rotate: + + if self.cofactor==8 and lobit(t*z_inv): x,y = y*self.i,x*self.i den_inv = self.magic * i1 else: @@ -172,9 +179,7 @@ class RistrettoPoint(EdwardsPoint): if lobit(x*z_inv): y = -y s = (z-y) * den_inv - if lobit(s): s=-s - - return enc_le(s,self.encLen) + return self.gfToBytes(s,mustBePositive=True) @classmethod @optimized_version_of("decodeSpec") @@ -224,8 +229,7 @@ class RistrettoPoint(EdwardsPoint): else: sgn,s,t = -1,xsqrt(n2), r*(r-1)*(a+d)^2 / den - 1 - ret = cls.fromJacobiQuartic(s,t,sgn) - return ret + return cls.fromJacobiQuartic(s,t,sgn) @classmethod @optimized_version_of("elligatorSpec") @@ -243,7 +247,51 @@ class RistrettoPoint(EdwardsPoint): s = isri*num t = isri*s*(r-1)*(d+a)^2 + sgn return cls.fromJacobiQuartic(s,t,sgn) - + + +class Decaf1Point(QuotientEdwardsPoint): + """Like current decaf but tweaked for simplicity""" + def encodeSpec(self): + """Unoptimized specification for encoding""" + a,d = self.a,self.d + x,y = self + if x==0: return(self.gfToBytes(0)) + + isr2 = isqrt(a*(y^2-1)) / self.magic + altx = 1/isr2*self.isoMagic + if lobit(altx): s = (1+x*y*isr2)/(a*x) + else: s = (1-x*y*isr2)/(a*x) + + # TODO: cofactor 8 + return self.gfToBytes(s,mustBePositive=True) + + @classmethod + def decodeSpec(cls,s): + """Unoptimized specification for decoding""" + a,d = cls.a,cls.d + s = cls.bytesToGf(s,mustBePositive=True) + + if s==0: return cls() + isr = isqrt(s^4 + 2*(a-2*d)*s^2 + 1) + altx = 2*s*isr*cls.isoMagic + if lobit(altx): isr = -isr + x = 2*s / (1+a*s^2) + y = (1-a*s^2) * isr + + # TODO: cofactor 8 + return cls(x,y) + + @optimized_version_of("encodeSpec") + def encode(self): + """Encode, optimized version""" + return self.encodeSpec() # TODO + + @classmethod + @optimized_version_of("decodeSpec") + def decode(cls,s): + """Decode, optimized version""" + return cls.decodeSpec(s) # TODO + class Ed25519Point(RistrettoPoint): F = GF(2^255-19) d = F(-121665/121666) @@ -261,51 +309,50 @@ class Ed25519Point(RistrettoPoint): if lobit(x): x = -x return cls(x,y) -class TwistedEd448GoldilocksPoint(RistrettoPoint): +class IsoEd448Point(RistrettoPoint): F = GF(2^448-2^224-1) - d = F(-39082) - a = F(-1) + d = F(39082/39081) + a = F(1) qnr = -1 magic = isqrt(a*d-1) cofactor = 4 encLen = 56 - + @classmethod def base(cls): - y = cls.F(6) # TODO: no it isn't - x = sqrt((y^2-1)/(cls.d*y^2+1)) - if lobit(x): x = -x - return cls(x,y) - -class Ed448GoldilocksPoint(RistrettoPoint): - # TODO: decaf vs ristretto + # = ..., -3/2 + return cls.decodeSpec(bytearray(binascii.unhexlify( + "00000000000000000000000000000000000000000000000000000000"+ + "fdffffffffffffffffffffffffffffffffffffffffffffffffffffff"))) + +class TwistedEd448GoldilocksPoint(Decaf1Point): F = GF(2^448-2^224-1) - d = F(-39081) - a = F(1) + d = F(-39082) + a = F(-1) qnr = -1 magic = isqrt(a*d-1) cofactor = 4 encLen = 56 - + isoMagic = IsoEd448Point.magic + @classmethod def base(cls): - return cls( - 0xaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa955555555555555555555555555555555555555555555555555555555, - 0xae05e9634ad7048db359d6205086c2b0036ed7a035884dd7b7e36d728ad8c4b80d6565833a2a3098bbbcb2bed1cda06bdaeafbcdea9386ed - ) + return cls.decodeSpec(bytearray(binascii.unhexlify( + "00000000000000000000000000000000000000000000000000000000"+ + "fdffffffffffffffffffffffffffffffffffffffffffffffffffffff"))) -class IsoEd448Point(RistrettoPoint): +class Ed448GoldilocksPoint(Decaf1Point): F = GF(2^448-2^224-1) - d = F(1/39081+1) + d = F(-39081) a = F(1) qnr = -1 magic = isqrt(a*d-1) cofactor = 4 encLen = 56 + isoMagic = IsoEd448Point.magic @classmethod def base(cls): - # = ..., -3/2 return cls.decodeSpec(bytearray(binascii.unhexlify( "00000000000000000000000000000000000000000000000000000000"+ "fdffffffffffffffffffffffffffffffffffffffffffffffffffffff"))) @@ -338,15 +385,27 @@ def test(cls,n): if Q1 + Q0 != Q2: raise TestFailedException("Scalarmul doesn't work") Q = Q1 test(Ed25519Point,100) +test(IsoEd448Point,100) test(TwistedEd448GoldilocksPoint,100) test(Ed448GoldilocksPoint,100) -test(IsoEd448Point,100) + +def gangtest(classes,n): + for i in xrange(n): + rets = [bytes((cls.base()*i).encode()) for cls in classes] + if len(set(rets)) != 1: + print "Divergence at %d" % i + for c,ret in zip(classes,rets): + print c,binascii.hexlify(ret) + print +gangtest([IsoEd448Point,TwistedEd448GoldilocksPoint,Ed448GoldilocksPoint],100) + + def testElligator(cls,n): for i in xrange(n): cls.elligator(randombytes(cls.encLen)) testElligator(Ed25519Point,100) -testElligator(Ed448GoldilocksPoint,100) -testElligator(TwistedEd448GoldilocksPoint,100) testElligator(IsoEd448Point,100) +# testElligator(Ed448GoldilocksPoint,100) +# testElligator(TwistedEd448GoldilocksPoint,100)