diff --git a/aux/ristretto.sage b/aux/ristretto.sage index a85a2cc..d2c32d6 100644 --- a/aux/ristretto.sage +++ b/aux/ristretto.sage @@ -167,22 +167,36 @@ class RistrettoPoint(QuotientEdwardsPoint): a,d = self.a,self.d x,y,z,t = self.xyzt() - u1 = a*(y+z)*(y-z) - u2 = x*y # = t*z - isr = isqrt(u1*u2^2) - i1 = isr*u1 - i2 = isr*u2 - z_inv = i1*i2*t + if self.cofactor==8: + u1 = a*(y+z)*(y-z) + u2 = x*y # = t*z + isr = isqrt(u1*u2^2) + i1 = isr*u1 + i2 = isr*u2 + z_inv = i1*i2*t - if self.cofactor==8 and negative(t*z_inv): - if a==-1: x,y = y*self.i,x*self.i - else: x,y = -y,x # TODO: test - den_inv = self.magic * i1 + if self.cofactor==8 and negative(t*z_inv): + if a==-1: x,y = y*self.i,x*self.i + else: x,y = -y,x # TODO: test + den_inv = self.magic * i1 + else: + den_inv = i2 + + if negative(x*z_inv): y = -y + s = (z-y) * den_inv else: + u1 = a*(y+z)*(y-z) + u2 = x*y # = t*z + isr = isqrt(u1*u2^2) + i1 = isr*u1 + i2 = isr*u2 + z_inv = i1*i2*t den_inv = i2 - if negative(x*z_inv): y = -y - s = (z-y) * den_inv + if negative(x*z_inv): y = -y + s = (z-y) * den_inv + + return self.gfToBytes(s,mustBePositive=True) @classmethod