You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 

491 lines
12 KiB

  1. /* Copyright (c) 2014 Cryptography Research, Inc.
  2. * Released under the MIT License. See LICENSE.txt for license information.
  3. */
  4. #include "p521.h"
  5. typedef struct {
  6. uint64x3_t lo, hi, hier;
  7. } nonad_t;
  8. static __inline__ uint64_t is_zero(uint64_t a) {
  9. /* let's hope the compiler isn't clever enough to optimize this. */
  10. return (((__uint128_t)a)-1)>>64;
  11. }
  12. static inline __uint128_t widemulu(uint64_t a, uint64_t b) {
  13. return ((__uint128_t)(a)) * b;
  14. }
  15. static inline __int128_t widemuls(int64_t a, int64_t b) {
  16. return ((__int128_t)(a)) * b;
  17. }
  18. /* This is a trick to prevent terrible register allocation by hiding things from clang's optimizer */
  19. static inline uint64_t opacify(uint64_t x) {
  20. __asm__ volatile("" : "+r"(x));
  21. return x;
  22. }
  23. /* These used to be hexads, leading to 10% better performance, but there were overflow issues */
  24. static inline void nonad_mul (
  25. nonad_t *hex,
  26. const uint64_t *a,
  27. const uint64_t *b
  28. ) {
  29. __uint128_t xu, xv, xw;
  30. uint64_t tmp = opacify(a[2]);
  31. xw = widemulu(tmp, b[0]);
  32. tmp <<= 1;
  33. xu = widemulu(tmp, b[1]);
  34. xv = widemulu(tmp, b[2]);
  35. tmp = opacify(a[1]);
  36. xw += widemulu(tmp, b[1]);
  37. xv += widemulu(tmp, b[0]);
  38. tmp <<= 1;
  39. xu += widemulu(tmp, b[2]);
  40. tmp = opacify(a[0]);
  41. xu += widemulu(tmp, b[0]);
  42. xv += widemulu(tmp, b[1]);
  43. xw += widemulu(tmp, b[2]);
  44. uint64x3_t
  45. lo = { (uint64_t)(xu), (uint64_t)(xv), (uint64_t)(xw), 0 },
  46. hi = { (uint64_t)(xu>>64), (uint64_t)(xv>>64), (uint64_t)(xw>>64), 0 };
  47. hex->hier = hi>>52;
  48. hex->hi = (hi<<12)>>6 | lo>>58;
  49. hex->lo = lo & mask58;
  50. }
  51. static inline void hexad_mul_signed (
  52. nonad_t *hex,
  53. const int64_t *a,
  54. const int64_t *b
  55. ) {
  56. __int128_t xu, xv, xw;
  57. int64_t tmp = opacify(a[2]);
  58. xw = widemuls(tmp, b[0]);
  59. tmp <<= 1;
  60. xu = widemuls(tmp, b[1]);
  61. xv = widemuls(tmp, b[2]);
  62. tmp = opacify(a[1]);
  63. xw += widemuls(tmp, b[1]);
  64. xv += widemuls(tmp, b[0]);
  65. tmp <<= 1;
  66. xu += widemuls(tmp, b[2]);
  67. tmp = opacify(a[0]);
  68. xu += widemuls(tmp, b[0]);
  69. xv += widemuls(tmp, b[1]);
  70. xw += widemuls(tmp, b[2]);
  71. uint64x3_t
  72. lo = { (uint64_t)(xu), (uint64_t)(xv), (uint64_t)(xw), 0 },
  73. hi = { (uint64_t)(xu>>64), (uint64_t)(xv>>64), (uint64_t)(xw>>64), 0 };
  74. /*
  75. hex->hier = (uint64x4_t)((int64x4_t)hi>>52);
  76. hex->hi = (hi<<12)>>6 | lo>>58;
  77. hex->lo = lo & mask58;
  78. */
  79. hex->hi = hi<<6 | lo>>58;
  80. hex->lo = lo & mask58;
  81. }
  82. static inline void nonad_sqr (
  83. nonad_t *hex,
  84. const uint64_t *a
  85. ) {
  86. __uint128_t xu, xv, xw;
  87. int64_t tmp = a[2];
  88. tmp <<= 1;
  89. xw = widemulu(tmp, a[0]);
  90. xv = widemulu(tmp, a[2]);
  91. tmp <<= 1;
  92. xu = widemulu(tmp, a[1]);
  93. tmp = a[1];
  94. xw += widemulu(tmp, a[1]);
  95. tmp <<= 1;
  96. xv += widemulu(tmp, a[0]);
  97. tmp = a[0];
  98. xu += widemulu(tmp, a[0]);
  99. uint64x3_t
  100. lo = { (uint64_t)(xu), (uint64_t)(xv), (uint64_t)(xw), 0 },
  101. hi = { (uint64_t)(xu>>64), (uint64_t)(xv>>64), (uint64_t)(xw>>64), 0 };
  102. hex->hier = hi>>52;
  103. hex->hi = (hi<<12)>>6 | lo>>58;
  104. hex->lo = lo & mask58;
  105. }
  106. static inline void hexad_sqr_signed (
  107. nonad_t *hex,
  108. const int64_t *a
  109. ) {
  110. __uint128_t xu, xv, xw;
  111. int64_t tmp = a[2];
  112. tmp <<= 1;
  113. xw = widemuls(tmp, a[0]);
  114. xv = widemuls(tmp, a[2]);
  115. tmp <<= 1;
  116. xu = widemuls(tmp, a[1]);
  117. tmp = a[1];
  118. xw += widemuls(tmp, a[1]);
  119. tmp <<= 1;
  120. xv += widemuls(tmp, a[0]);
  121. tmp = a[0];
  122. xu += widemuls(tmp, a[0]);
  123. uint64x3_t
  124. lo = { (uint64_t)(xu), (uint64_t)(xv), (uint64_t)(xw), 0 },
  125. hi = { (uint64_t)(xu>>64), (uint64_t)(xv>>64), (uint64_t)(xw>>64), 0 };
  126. /*
  127. hex->hier = (uint64x4_t)((int64x4_t)hi>>52);
  128. hex->hi = (hi<<12)>>6 | lo>>58;
  129. hex->lo = lo & mask58;
  130. */
  131. hex->hi = hi<<6 | lo>>58;
  132. hex->lo = lo & mask58;
  133. }
  134. void
  135. p521_mul (
  136. p521_t *__restrict__ cs,
  137. const p521_t *as,
  138. const p521_t *bs
  139. ) {
  140. int i;
  141. #if 0
  142. assert(as->limb[3] == 0 && as->limb[7] == 0 && as->limb[11] == 0);
  143. assert(bs->limb[3] == 0 && bs->limb[7] == 0 && bs->limb[11] == 0);
  144. for (i=0; i<12; i++) {
  145. assert(as->limb[i] < 5ull<<57);
  146. assert(bs->limb[i] < 5ull<<57);
  147. }
  148. #endif
  149. /* Bounds on the hexads and nonads.
  150. *
  151. * Limbs < 2<<58 + ep.
  152. * Nonad mul < 1<<58, 1<<58, tiny
  153. * -> t0 < (3,2,2)<<58 + tiny
  154. * t1,t2 < 2<<58 + tiny
  155. * * w < (4,2,2)
  156. * Hexad mul < +- (5,4,3) * 4<<116 -> 2^58 lo, +- (5,4,3) * 4<<58+ep
  157. * TimesW < (2,1,1)<<58, (6,5,4)*4<<58 + ep
  158. * ot2 = t0 + timesW(t2 + t1 - acdf.hi - bcef.lo);
  159. == (3,2,2) + (4,2,2) + (4,2,2) +- (6,5,4)*4 - (1) << 58
  160. in (-25, +35) << 58
  161. uint64x3_t ot0 = t0 + timesW(t2 + t1 - acdf.hi - bcef.lo);
  162. uint64x3_t ot1 = t0 + t1 - abde.lo + timesW(t2 - bcef.hi);
  163. uint64x3_t ot2 = t0 + t1 + t2 - abde.hi - acdf.lo + vhi2;
  164. */
  165. uint64_t *c = cs->limb;
  166. const uint64_t *a = as->limb, *b = bs->limb;
  167. nonad_t ad, be, cf, abde, bcef, acdf;
  168. nonad_mul(&ad, &a[0], &b[0]);
  169. nonad_mul(&be, &a[4], &b[4]);
  170. nonad_mul(&cf, &a[8], &b[8]);
  171. uint64_t amt = 26;
  172. uint64x3_t vhi = { amt*((1ull<<58)-1), amt*((1ull<<58)-1), amt*((1ull<<58)-1), 0 },
  173. vhi2 = { 0, 0, -amt<<57, 0 };
  174. uint64x3_t t2 = cf.lo + be.hi + ad.hier, t0 = ad.lo + timesW(cf.hi + be.hier) + vhi, t1 = ad.hi + be.lo + timesW(cf.hier);
  175. int64_t ta[4] VECTOR_ALIGNED, tb[4] VECTOR_ALIGNED;
  176. // it seems to be faster not to vectorize these loops
  177. for (i=0; i<3; i++) {
  178. ta[i] = a[i]-a[i+4];
  179. tb[i] = b[i]-b[i+4];
  180. }
  181. hexad_mul_signed(&abde,ta,tb);
  182. for (i=0; i<3; i++) {
  183. ta[i] = a[i+4]-a[i+8];
  184. tb[i] = b[i+4]-b[i+8];
  185. }
  186. hexad_mul_signed(&bcef,ta,tb);
  187. for (i=0; i<3; i++) {
  188. ta[i] = a[i]-a[i+8];
  189. tb[i] = b[i]-b[i+8];
  190. }
  191. hexad_mul_signed(&acdf,ta,tb);
  192. uint64x3_t ot0 = t0 + timesW(t2 + t1 - acdf.hi - bcef.lo);
  193. uint64x3_t ot1 = t0 + t1 - abde.lo + timesW(t2 - bcef.hi);
  194. uint64x3_t ot2 = t0 + t1 + t2 - abde.hi - acdf.lo + vhi2;
  195. uint64x3_t out0 = (ot0 & mask58) + timesW(ot2>>58);
  196. uint64x3_t out1 = (ot1 & mask58) + (ot0>>58);
  197. uint64x3_t out2 = (ot2 & mask58) + (ot1>>58);
  198. *(uint64x4_t *)&c[0] = out0;
  199. *(uint64x4_t *)&c[4] = out1;
  200. *(uint64x4_t *)&c[8] = out2;
  201. }
  202. void
  203. p521_sqr (
  204. p521_t *__restrict__ cs,
  205. const p521_t *as
  206. ) {
  207. int i;
  208. #if 0
  209. assert(as->limb[3] == 0 && as->limb[7] == 0 && as->limb[11] == 0);
  210. for (i=0; i<12; i++) {
  211. assert(as->limb[i] < 5ull<<57);
  212. }
  213. #endif
  214. uint64_t *c = cs->limb;
  215. const uint64_t *a = as->limb;
  216. nonad_t ad, be, cf, abde, bcef, acdf;
  217. nonad_sqr(&ad, &a[0]);
  218. nonad_sqr(&be, &a[4]);
  219. nonad_sqr(&cf, &a[8]);
  220. uint64_t amt = 26;
  221. uint64x3_t vhi = { amt*((1ull<<58)-1), amt*((1ull<<58)-1), amt*((1ull<<58)-1), 0 },
  222. vhi2 = { 0, 0, -amt<<57, 0 };
  223. uint64x3_t t2 = cf.lo + be.hi + ad.hier, t0 = ad.lo + timesW(cf.hi + be.hier) + vhi, t1 = ad.hi + be.lo + timesW(cf.hier);
  224. int64_t ta[4] VECTOR_ALIGNED;
  225. // it seems to be faster not to vectorize these loops
  226. for (i=0; i<3; i++) {
  227. ta[i] = a[i]-a[i+4];
  228. }
  229. hexad_sqr_signed(&abde,ta);
  230. for (i=0; i<3; i++) {
  231. ta[i] = a[i+4]-a[i+8];
  232. }
  233. hexad_sqr_signed(&bcef,ta);
  234. for (i=0; i<3; i++) {
  235. ta[i] = a[i]-a[i+8];
  236. }
  237. hexad_sqr_signed(&acdf,ta);
  238. uint64x3_t ot0 = t0 + timesW(t2 + t1 - acdf.hi - bcef.lo);
  239. uint64x3_t ot1 = t0 + t1 - abde.lo + timesW(t2 - bcef.hi);
  240. uint64x3_t ot2 = t0 + t1 + t2 - abde.hi - acdf.lo + vhi2;
  241. uint64x3_t out0 = (ot0 & mask58) + timesW(ot2>>58);
  242. uint64x3_t out1 = (ot1 & mask58) + (ot0>>58);
  243. uint64x3_t out2 = (ot2 & mask58) + (ot1>>58);
  244. *(uint64x4_t *)&c[0] = out0;
  245. *(uint64x4_t *)&c[4] = out1;
  246. *(uint64x4_t *)&c[8] = out2;
  247. }
  248. void
  249. p521_mulw (
  250. p521_t *__restrict__ cs,
  251. const p521_t *as,
  252. uint64_t b
  253. ) {
  254. #if 0
  255. int i;
  256. assert(as->limb[3] == 0 && as->limb[7] == 0 && as->limb[11] == 0);
  257. for (i=0; i<12; i++) {
  258. assert(as->limb[i] < 1ull<<61);
  259. }
  260. assert(b < 1ull<<61);
  261. #endif
  262. const uint64_t *a = as->limb;
  263. uint64_t *c = cs->limb;
  264. __uint128_t accum0 = 0, accum3 = 0, accum6 = 0;
  265. uint64_t mask = (1ull<<58) - 1;
  266. accum0 += widemulu(b, a[0]);
  267. accum3 += widemulu(b, a[1]);
  268. accum6 += widemulu(b, a[2]);
  269. c[0] = accum0 & mask; accum0 >>= 58;
  270. c[1] = accum3 & mask; accum3 >>= 58;
  271. c[2] = accum6 & mask; accum6 >>= 58;
  272. accum0 += widemulu(b, a[4]);
  273. accum3 += widemulu(b, a[5]);
  274. accum6 += widemulu(b, a[6]);
  275. c[4] = accum0 & mask; accum0 >>= 58;
  276. c[5] = accum3 & mask; accum3 >>= 58;
  277. c[6] = accum6 & mask; accum6 >>= 58;
  278. accum0 += widemulu(b, a[8]);
  279. accum3 += widemulu(b, a[9]);
  280. accum6 += widemulu(b, a[10]);
  281. c[8] = accum0 & mask; accum0 >>= 58;
  282. c[9] = accum3 & mask; accum3 >>= 58;
  283. c[10] = accum6 & (mask>>1); accum6 >>= 57;
  284. accum0 += c[1];
  285. c[1] = accum0 & mask;
  286. c[5] += accum0 >> 58;
  287. accum3 += c[2];
  288. c[2] = accum3 & mask;
  289. c[6] += accum3 >> 58;
  290. accum6 += c[0];
  291. c[0] = accum6 & mask;
  292. c[4] += accum6 >> 58;
  293. c[3] = c[7] = c[11] = 0;
  294. }
  295. void
  296. p521_strong_reduce (
  297. p521_t *a
  298. ) {
  299. uint64_t mask = (1ull<<58)-1, mask2 = (1ull<<57)-1;
  300. /* first, clear high */
  301. __int128_t scarry = a->limb[LIMBPERM(8)]>>57;
  302. a->limb[LIMBPERM(8)] &= mask2;
  303. /* now the total is less than 2p */
  304. /* compute total_value - p. No need to reduce mod p. */
  305. int i;
  306. for (i=0; i<9; i++) {
  307. scarry = scarry + a->limb[LIMBPERM(i)] - ((i==8) ? mask2 : mask);
  308. a->limb[LIMBPERM(i)] = scarry & ((i==8) ? mask2 : mask);
  309. scarry >>= (i==8) ? 57 : 58;
  310. }
  311. /* uncommon case: it was >= p, so now scarry = 0 and this = x
  312. * common case: it was < p, so now scarry = -1 and this = x - p + 2^521
  313. * so let's add back in p. will carry back off the top for 2^521.
  314. */
  315. assert(is_zero(scarry) | is_zero(scarry+1));
  316. uint64_t scarry_mask = scarry & mask;
  317. __uint128_t carry = 0;
  318. /* add it back */
  319. for (i=0; i<9; i++) {
  320. carry = carry + a->limb[LIMBPERM(i)] + ((i==8)?(scarry_mask>>1):scarry_mask);
  321. a->limb[LIMBPERM(i)] = carry & ((i==8) ? mask>>1 : mask);
  322. carry >>= (i==8) ? 57 : 58;
  323. }
  324. assert(is_zero(carry + scarry));
  325. a->limb[3] = a->limb[7] = a->limb[11] = 0;
  326. }
  327. mask_t
  328. p521_is_zero (
  329. const struct p521_t *a
  330. ) {
  331. struct p521_t b;
  332. p521_copy(&b,a);
  333. p521_strong_reduce(&b);
  334. uint64_t any = 0;
  335. unsigned int i;
  336. for (i=0; i<sizeof(b)/sizeof(b.limb[0]); i++) {
  337. any |= b.limb[i];
  338. }
  339. return is_zero(any);
  340. }
  341. void
  342. p521_serialize (
  343. uint8_t *serial,
  344. const struct p521_t *x
  345. ) {
  346. unsigned int i,k=0;
  347. p521_t red;
  348. p521_copy(&red, x);
  349. p521_strong_reduce(&red);
  350. uint64_t r=0;
  351. int bits = 0;
  352. for (i=0; i<9; i++) {
  353. r |= red.limb[LIMBPERM(i)] << bits;
  354. for (bits += 58; bits >= 8; bits -= 8) {
  355. serial[k++] = r;
  356. r >>= 8;
  357. }
  358. assert(bits <= 6);
  359. }
  360. assert(bits);
  361. serial[k++] = r;
  362. }
  363. mask_t
  364. p521_deserialize (
  365. p521_t *x,
  366. const uint8_t serial[LIMBPERM(66)]
  367. ) {
  368. int i,k=0,bits=0;
  369. __uint128_t out = 0;
  370. uint64_t mask = (1ull<<58)-1;
  371. for (i=0; i<9; i++) {
  372. out >>= 58;
  373. for (; bits<58; bits+=8) {
  374. out |= ((__uint128_t)serial[k++])<<bits;
  375. }
  376. x->limb[LIMBPERM(i)] = out & mask;
  377. bits -= 58;
  378. }
  379. /* Check for reduction. First, high has to be < 2^57 */
  380. mask_t good = is_zero(out>>57);
  381. uint64_t and = -1ull;
  382. for (i=0; i<8; i++) {
  383. and &= x->limb[LIMBPERM(i)];
  384. }
  385. and &= (2*out+1);
  386. good &= is_zero((and+1)>>58);
  387. x->limb[3] = x->limb[7] = x->limb[11] = 0;
  388. return good;
  389. }