diff --git a/aux/ristretto/ristretto.sage b/aux/ristretto/ristretto.sage index 0f217a6..68682da 100644 --- a/aux/ristretto/ristretto.sage +++ b/aux/ristretto/ristretto.sage @@ -529,11 +529,26 @@ def test(cls,n): Q = cls() for i in xrange(n): #print binascii.hexlify(Q.encode()) - QQ = cls.decode(Q.encode()) + QE = Q.encode() + QQ = cls.decode(QE) if QQ != Q: raise TestFailedException("Round trip %s != %s" % (str(QQ),str(Q))) + + # Testing s -> 1/s: encodes -point on cofactor + s = cls.bytesToGf(QE) + if s != 0: + ss = cls.gfToBytes(1/s,mustBePositive=True) + try: + QN = cls.decode(ss) + if cls.cofactor == 8: + raise TestFailedException("1/s shouldnt work for cofactor 8") + if QN != -Q: + raise TestFailedException("s -> 1/s should negate point for cofactor 4") + except InvalidEncodingException as e: + # Should be raised iff cofactor==8 + if cls.cofactor == 4: + raise TestFailedException("s -> 1/s should work for cofactor 4") QT = Q - QE = Q.encode() for h in xrange(cls.cofactor): QT = QT.torque() if QT.encode() != QE: