You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 

392 lines
10 KiB

  1. /* Copyright (c) 2014 Cryptography Research, Inc.
  2. * Released under the MIT License. See LICENSE.txt for license information.
  3. */
  4. #include "f_field.h"
  5. typedef struct {
  6. uint64x3_t lo, hi, hier;
  7. } nonad_t;
  8. static inline __uint128_t widemulu(uint64_t a, uint64_t b) {
  9. return ((__uint128_t)(a)) * b;
  10. }
  11. static inline __int128_t widemuls(int64_t a, int64_t b) {
  12. return ((__int128_t)(a)) * b;
  13. }
  14. /* This is a trick to prevent terrible register allocation by hiding things from clang's optimizer */
  15. static inline uint64_t opacify(uint64_t x) {
  16. __asm__ volatile("" : "+r"(x));
  17. return x;
  18. }
  19. /* These used to be hexads, leading to 10% better performance, but there were overflow issues */
  20. static inline void nonad_mul (
  21. nonad_t *hex,
  22. const uint64_t *a,
  23. const uint64_t *b
  24. ) {
  25. __uint128_t xu, xv, xw;
  26. uint64_t tmp = opacify(a[2]);
  27. xw = widemulu(tmp, b[0]);
  28. tmp <<= 1;
  29. xu = widemulu(tmp, b[1]);
  30. xv = widemulu(tmp, b[2]);
  31. tmp = opacify(a[1]);
  32. xw += widemulu(tmp, b[1]);
  33. xv += widemulu(tmp, b[0]);
  34. tmp <<= 1;
  35. xu += widemulu(tmp, b[2]);
  36. tmp = opacify(a[0]);
  37. xu += widemulu(tmp, b[0]);
  38. xv += widemulu(tmp, b[1]);
  39. xw += widemulu(tmp, b[2]);
  40. uint64x3_t
  41. lo = { (uint64_t)(xu), (uint64_t)(xv), (uint64_t)(xw), 0 },
  42. hi = { (uint64_t)(xu>>64), (uint64_t)(xv>>64), (uint64_t)(xw>>64), 0 };
  43. hex->hier = hi>>52;
  44. hex->hi = (hi<<12)>>6 | lo>>58;
  45. hex->lo = lo & mask58;
  46. }
  47. static inline void hexad_mul_signed (
  48. nonad_t *hex,
  49. const int64_t *a,
  50. const int64_t *b
  51. ) {
  52. __int128_t xu, xv, xw;
  53. int64_t tmp = opacify(a[2]);
  54. xw = widemuls(tmp, b[0]);
  55. tmp <<= 1;
  56. xu = widemuls(tmp, b[1]);
  57. xv = widemuls(tmp, b[2]);
  58. tmp = opacify(a[1]);
  59. xw += widemuls(tmp, b[1]);
  60. xv += widemuls(tmp, b[0]);
  61. tmp <<= 1;
  62. xu += widemuls(tmp, b[2]);
  63. tmp = opacify(a[0]);
  64. xu += widemuls(tmp, b[0]);
  65. xv += widemuls(tmp, b[1]);
  66. xw += widemuls(tmp, b[2]);
  67. uint64x3_t
  68. lo = { (uint64_t)(xu), (uint64_t)(xv), (uint64_t)(xw), 0 },
  69. hi = { (uint64_t)(xu>>64), (uint64_t)(xv>>64), (uint64_t)(xw>>64), 0 };
  70. /*
  71. hex->hier = (uint64x4_t)((int64x4_t)hi>>52);
  72. hex->hi = (hi<<12)>>6 | lo>>58;
  73. hex->lo = lo & mask58;
  74. */
  75. hex->hi = hi<<6 | lo>>58;
  76. hex->lo = lo & mask58;
  77. }
  78. static inline void nonad_sqr (
  79. nonad_t *hex,
  80. const uint64_t *a
  81. ) {
  82. __uint128_t xu, xv, xw;
  83. int64_t tmp = a[2];
  84. tmp <<= 1;
  85. xw = widemulu(tmp, a[0]);
  86. xv = widemulu(tmp, a[2]);
  87. tmp <<= 1;
  88. xu = widemulu(tmp, a[1]);
  89. tmp = a[1];
  90. xw += widemulu(tmp, a[1]);
  91. tmp <<= 1;
  92. xv += widemulu(tmp, a[0]);
  93. tmp = a[0];
  94. xu += widemulu(tmp, a[0]);
  95. uint64x3_t
  96. lo = { (uint64_t)(xu), (uint64_t)(xv), (uint64_t)(xw), 0 },
  97. hi = { (uint64_t)(xu>>64), (uint64_t)(xv>>64), (uint64_t)(xw>>64), 0 };
  98. hex->hier = hi>>52;
  99. hex->hi = (hi<<12)>>6 | lo>>58;
  100. hex->lo = lo & mask58;
  101. }
  102. static inline void hexad_sqr_signed (
  103. nonad_t *hex,
  104. const int64_t *a
  105. ) {
  106. __uint128_t xu, xv, xw;
  107. int64_t tmp = a[2];
  108. tmp <<= 1;
  109. xw = widemuls(tmp, a[0]);
  110. xv = widemuls(tmp, a[2]);
  111. tmp <<= 1;
  112. xu = widemuls(tmp, a[1]);
  113. tmp = a[1];
  114. xw += widemuls(tmp, a[1]);
  115. tmp <<= 1;
  116. xv += widemuls(tmp, a[0]);
  117. tmp = a[0];
  118. xu += widemuls(tmp, a[0]);
  119. uint64x3_t
  120. lo = { (uint64_t)(xu), (uint64_t)(xv), (uint64_t)(xw), 0 },
  121. hi = { (uint64_t)(xu>>64), (uint64_t)(xv>>64), (uint64_t)(xw>>64), 0 };
  122. /*
  123. hex->hier = (uint64x4_t)((int64x4_t)hi>>52);
  124. hex->hi = (hi<<12)>>6 | lo>>58;
  125. hex->lo = lo & mask58;
  126. */
  127. hex->hi = hi<<6 | lo>>58;
  128. hex->lo = lo & mask58;
  129. }
  130. void gf_mul (gf *__restrict__ cs, const gf *as, const gf *bs) {
  131. int i;
  132. #if 0
  133. assert(as->limb[3] == 0 && as->limb[7] == 0 && as->limb[11] == 0);
  134. assert(bs->limb[3] == 0 && bs->limb[7] == 0 && bs->limb[11] == 0);
  135. for (i=0; i<12; i++) {
  136. assert(as->limb[i] < 5ull<<57);
  137. assert(bs->limb[i] < 5ull<<57);
  138. }
  139. #endif
  140. /* Bounds on the hexads and nonads.
  141. *
  142. * Limbs < 2<<58 + ep.
  143. * Nonad mul < 1<<58, 1<<58, tiny
  144. * -> t0 < (3,2,2)<<58 + tiny
  145. * t1,t2 < 2<<58 + tiny
  146. * * w < (4,2,2)
  147. * Hexad mul < +- (5,4,3) * 4<<116 -> 2^58 lo, +- (5,4,3) * 4<<58+ep
  148. * TimesW < (2,1,1)<<58, (6,5,4)*4<<58 + ep
  149. * ot2 = t0 + timesW(t2 + t1 - acdf.hi - bcef.lo);
  150. == (3,2,2) + (4,2,2) + (4,2,2) +- (6,5,4)*4 - (1) << 58
  151. in (-25, +35) << 58
  152. uint64x3_t ot0 = t0 + timesW(t2 + t1 - acdf.hi - bcef.lo);
  153. uint64x3_t ot1 = t0 + t1 - abde.lo + timesW(t2 - bcef.hi);
  154. uint64x3_t ot2 = t0 + t1 + t2 - abde.hi - acdf.lo + vhi2;
  155. */
  156. uint64_t *c = cs->limb;
  157. const uint64_t *a = as->limb, *b = bs->limb;
  158. nonad_t ad, be, cf, abde, bcef, acdf;
  159. nonad_mul(&ad, &a[0], &b[0]);
  160. nonad_mul(&be, &a[4], &b[4]);
  161. nonad_mul(&cf, &a[8], &b[8]);
  162. uint64_t amt = 26;
  163. uint64x3_t vhi = { amt*((1ull<<58)-1), amt*((1ull<<58)-1), amt*((1ull<<58)-1), 0 },
  164. vhi2 = { 0, 0, -amt<<57, 0 };
  165. uint64x3_t t2 = cf.lo + be.hi + ad.hier, t0 = ad.lo + timesW(cf.hi + be.hier) + vhi, t1 = ad.hi + be.lo + timesW(cf.hier);
  166. int64_t ta[4] VECTOR_ALIGNED, tb[4] VECTOR_ALIGNED;
  167. // it seems to be faster not to vectorize these loops
  168. for (i=0; i<3; i++) {
  169. ta[i] = a[i]-a[i+4];
  170. tb[i] = b[i]-b[i+4];
  171. }
  172. hexad_mul_signed(&abde,ta,tb);
  173. for (i=0; i<3; i++) {
  174. ta[i] = a[i+4]-a[i+8];
  175. tb[i] = b[i+4]-b[i+8];
  176. }
  177. hexad_mul_signed(&bcef,ta,tb);
  178. for (i=0; i<3; i++) {
  179. ta[i] = a[i]-a[i+8];
  180. tb[i] = b[i]-b[i+8];
  181. }
  182. hexad_mul_signed(&acdf,ta,tb);
  183. uint64x3_t ot0 = t0 + timesW(t2 + t1 - acdf.hi - bcef.lo);
  184. uint64x3_t ot1 = t0 + t1 - abde.lo + timesW(t2 - bcef.hi);
  185. uint64x3_t ot2 = t0 + t1 + t2 - abde.hi - acdf.lo + vhi2;
  186. uint64x3_t out0 = (ot0 & mask58) + timesW(ot2>>58);
  187. uint64x3_t out1 = (ot1 & mask58) + (ot0>>58);
  188. uint64x3_t out2 = (ot2 & mask58) + (ot1>>58);
  189. *(uint64x4_t *)&c[0] = out0;
  190. *(uint64x4_t *)&c[4] = out1;
  191. *(uint64x4_t *)&c[8] = out2;
  192. }
  193. void gf_sqr (gf *__restrict__ cs, const gf *as) {
  194. int i;
  195. #if 0
  196. assert(as->limb[3] == 0 && as->limb[7] == 0 && as->limb[11] == 0);
  197. for (i=0; i<12; i++) {
  198. assert(as->limb[i] < 5ull<<57);
  199. }
  200. #endif
  201. uint64_t *c = cs->limb;
  202. const uint64_t *a = as->limb;
  203. nonad_t ad, be, cf, abde, bcef, acdf;
  204. nonad_sqr(&ad, &a[0]);
  205. nonad_sqr(&be, &a[4]);
  206. nonad_sqr(&cf, &a[8]);
  207. uint64_t amt = 26;
  208. uint64x3_t vhi = { amt*((1ull<<58)-1), amt*((1ull<<58)-1), amt*((1ull<<58)-1), 0 },
  209. vhi2 = { 0, 0, -amt<<57, 0 };
  210. uint64x3_t t2 = cf.lo + be.hi + ad.hier, t0 = ad.lo + timesW(cf.hi + be.hier) + vhi, t1 = ad.hi + be.lo + timesW(cf.hier);
  211. int64_t ta[4] VECTOR_ALIGNED;
  212. // it seems to be faster not to vectorize these loops
  213. for (i=0; i<3; i++) {
  214. ta[i] = a[i]-a[i+4];
  215. }
  216. hexad_sqr_signed(&abde,ta);
  217. for (i=0; i<3; i++) {
  218. ta[i] = a[i+4]-a[i+8];
  219. }
  220. hexad_sqr_signed(&bcef,ta);
  221. for (i=0; i<3; i++) {
  222. ta[i] = a[i]-a[i+8];
  223. }
  224. hexad_sqr_signed(&acdf,ta);
  225. uint64x3_t ot0 = t0 + timesW(t2 + t1 - acdf.hi - bcef.lo);
  226. uint64x3_t ot1 = t0 + t1 - abde.lo + timesW(t2 - bcef.hi);
  227. uint64x3_t ot2 = t0 + t1 + t2 - abde.hi - acdf.lo + vhi2;
  228. uint64x3_t out0 = (ot0 & mask58) + timesW(ot2>>58);
  229. uint64x3_t out1 = (ot1 & mask58) + (ot0>>58);
  230. uint64x3_t out2 = (ot2 & mask58) + (ot1>>58);
  231. *(uint64x4_t *)&c[0] = out0;
  232. *(uint64x4_t *)&c[4] = out1;
  233. *(uint64x4_t *)&c[8] = out2;
  234. }
  235. void gf_mulw (gf *__restrict__ cs, const gf *as, uint64_t b) {
  236. #if 0
  237. int i;
  238. assert(as->limb[3] == 0 && as->limb[7] == 0 && as->limb[11] == 0);
  239. for (i=0; i<12; i++) {
  240. assert(as->limb[i] < 1ull<<61);
  241. }
  242. assert(b < 1ull<<61);
  243. #endif
  244. const uint64_t *a = as->limb;
  245. uint64_t *c = cs->limb;
  246. __uint128_t accum0 = 0, accum3 = 0, accum6 = 0;
  247. uint64_t mask = (1ull<<58) - 1;
  248. accum0 += widemulu(b, a[0]);
  249. accum3 += widemulu(b, a[1]);
  250. accum6 += widemulu(b, a[2]);
  251. c[0] = accum0 & mask; accum0 >>= 58;
  252. c[1] = accum3 & mask; accum3 >>= 58;
  253. c[2] = accum6 & mask; accum6 >>= 58;
  254. accum0 += widemulu(b, a[4]);
  255. accum3 += widemulu(b, a[5]);
  256. accum6 += widemulu(b, a[6]);
  257. c[4] = accum0 & mask; accum0 >>= 58;
  258. c[5] = accum3 & mask; accum3 >>= 58;
  259. c[6] = accum6 & mask; accum6 >>= 58;
  260. accum0 += widemulu(b, a[8]);
  261. accum3 += widemulu(b, a[9]);
  262. accum6 += widemulu(b, a[10]);
  263. c[8] = accum0 & mask; accum0 >>= 58;
  264. c[9] = accum3 & mask; accum3 >>= 58;
  265. c[10] = accum6 & (mask>>1); accum6 >>= 57;
  266. accum0 += c[1];
  267. c[1] = accum0 & mask;
  268. c[5] += accum0 >> 58;
  269. accum3 += c[2];
  270. c[2] = accum3 & mask;
  271. c[6] += accum3 >> 58;
  272. accum6 += c[0];
  273. c[0] = accum6 & mask;
  274. c[4] += accum6 >> 58;
  275. c[3] = c[7] = c[11] = 0;
  276. }
  277. void gf_strong_reduce (gf *a) {
  278. uint64_t mask = (1ull<<58)-1, mask2 = (1ull<<57)-1;
  279. /* first, clear high */
  280. __int128_t scarry = a->limb[LIMBPERM(8)]>>57;
  281. a->limb[LIMBPERM(8)] &= mask2;
  282. /* now the total is less than 2p */
  283. /* compute total_value - p. No need to reduce mod p. */
  284. int i;
  285. for (i=0; i<9; i++) {
  286. scarry = scarry + a->limb[LIMBPERM(i)] - ((i==8) ? mask2 : mask);
  287. a->limb[LIMBPERM(i)] = scarry & ((i==8) ? mask2 : mask);
  288. scarry >>= (i==8) ? 57 : 58;
  289. }
  290. /* uncommon case: it was >= p, so now scarry = 0 and this = x
  291. * common case: it was < p, so now scarry = -1 and this = x - p + 2^521
  292. * so let's add back in p. will carry back off the top for 2^521.
  293. */
  294. assert(word_is_zero(scarry) | word_is_zero(scarry+1));
  295. uint64_t scarry_mask = scarry & mask;
  296. __uint128_t carry = 0;
  297. /* add it back */
  298. for (i=0; i<9; i++) {
  299. carry = carry + a->limb[LIMBPERM(i)] + ((i==8)?(scarry_mask>>1):scarry_mask);
  300. a->limb[LIMBPERM(i)] = carry & ((i==8) ? mask>>1 : mask);
  301. carry >>= (i==8) ? 57 : 58;
  302. }
  303. assert(word_is_zero(carry + scarry));
  304. a->limb[3] = a->limb[7] = a->limb[11] = 0;
  305. }