1 module crypto.bigint;
2 
3 import std.bigint;
4 import std.algorithm.mutation : reverse, swap;
5 import std.algorithm.searching : find, all;
6 import std.conv : to, text;
7 import std.exception : enforce;
8 import std.range : repeat, array;
9 import std.math : abs;
10 
11 import crypto.random;
12 
13 struct BigIntHelper
14 {
15     /// Random generate a BigInt by bitLength.
16     static BigInt randomGenerate(uint bitLength, int highBit = -1, int lowBit = -1)
17     {
18         enforce((bitLength > 0) && (bitLength % 8 == 0));
19 
20         ubyte[] buffer = new ubyte[bitLength / 8];
21 
22         uint pos = 0;
23         uint current = 0;
24         foreach (ref a; buffer)
25         {
26             if (pos == 0)
27             {
28                 current = rnd.next;
29             }
30 
31             a = cast(ubyte)(current >> 8 * pos);
32             pos = (pos + 1) % uint.sizeof;
33         }
34 
35         if (highBit == 0)
36         {
37             buffer[0] &= (0xFF >> 1);
38         }
39         else if (highBit == 1)
40         {
41             buffer[0] |= (0x01 << 7);
42         }
43 
44         if (lowBit == 0)
45         {
46             buffer[$ - 1] &= (0xFF << 1);
47         }
48         else if (lowBit == 1)
49         {
50             buffer[$ - 1] |= 0x01;
51         }
52 
53         return BigIntHelper.fromBytes(buffer);
54     }
55 
56     /// Random generate a BigInt between min and max.
57     static BigInt randomGenerate(const BigInt min, const BigInt max)
58     {
59         enforce(max >= min, text("BigIntHelper.randomGenerate(): invalid bounding interval ", min, ", ", max));
60 
61         BigInt r = randomGenerate(cast(uint)((max.uintLength + 1) * uint.sizeof * 8));
62         return r % (max - min + 1) + min;
63     }
64 
65     ///
66     static ubyte[] toUBytes(const BigInt value) pure nothrow
67     {
68         size_t len = value.uintLength();
69         ubyte[] ubytes = new ubyte[len * uint.sizeof];
70 
71         for (size_t i = 0; i < len; i++)
72         {
73             uint digit = value.getDigit!uint(i);
74             ubyte* p = cast(ubyte*)&digit;
75 
76             for (size_t j = 0; j < uint.sizeof; j++)
77             {
78                 ubytes[(len - i - 1) * uint.sizeof + (uint.sizeof - j - 1)] = *(p + j);
79             }
80         }
81 
82         return ubytes.find!((a, b) => a != b)(0);
83     }
84 
85     /++
86         Because std.bigint's member `data` is a private property,
87         and there is no API `setDigit` that opens the opposite of getDigit,
88         it can only be shifted by digits one by one.
89         !! Here is a performance bottleneck.
90     +/
91     static BigInt fromBytes(in ubyte[] buffer) pure nothrow
92     {
93         size_t supplement = (uint.sizeof - buffer.length % uint.sizeof) % uint.sizeof;
94         ubyte[] bytes = (supplement > 0) ? (cast(ubyte)0).repeat(supplement).array ~ buffer : cast(ubyte[])buffer;
95         BigInt data = 0;
96 
97         for (size_t i = 0; i < bytes.length / uint.sizeof; i++)
98         {
99             uint digit;
100             ubyte* p = cast(ubyte*)&digit;
101 
102             for (size_t j = 0; j < uint.sizeof; j++)
103             {
104                 *(p + j) = bytes[i * uint.sizeof + uint.sizeof - j - 1];
105             }
106 
107             data <<= 32;
108             data += digit;
109         }
110 
111         return data;
112     }
113 
114     static if (__VERSION__ >= 2087)
115         alias powmod = std.bigint.powmod;
116     else
117     {
118         ///
119         static BigInt powmod(const BigInt base, const BigInt exponent, const BigInt modulus) pure nothrow
120         {
121             assert(base >= 1 && exponent >= 0 && modulus >= 1);
122 
123             if (exponent == 0)
124             {
125                 return BigInt(1) % modulus;
126             }
127 
128             if (exponent == 1)
129             {
130                 return base % modulus;
131             }
132 
133             BigInt temp = powmod(base, exponent / 2, modulus);
134 
135             return (exponent & 1) ? mul(mul(temp, temp), base) % modulus : mul(temp, temp) % modulus;
136         }
137     }
138 
139     /**
140     Test whether BigInt n is prime.
141         Step 1: millerRabinPrimeTest
142         Step 2: lucasLehmerTest
143     */
144     static bool isProbablePrime(const BigInt n, const size_t confidence)
145     {
146         bool passed = millerRabinPrimeTest(n, confidence);
147 
148         /**
149         When n < 10_000_000_000_000_000,
150         there is no need to lucasLehmerTest, And trust the result of millerRabinPrimeTest.
151         */
152         if (!passed || (n < 10_000_000_000_000_000))
153         {
154             return passed;
155         }
156 
157         return lucasLehmerTest(n);
158     }
159 
160 private:
161 
162     /++
163     Bug BigInt mul() of phobos will be fixed in version 2.087.0
164         Details: https://github.com/dlang/phobos/pull/6972
165     +/
166     static if (__VERSION__ < 2087)
167     {
168         static BigInt mul(const BigInt a, const BigInt b) pure nothrow
169         {
170             uint[] au = toUintArray(a);
171             uint[] bu = toUintArray(b);
172 
173             uint[] r = new uint[au.length + bu.length];
174 
175             for (size_t i = 0; i < bu.length; i++)
176             {
177                 for (size_t j = 0; j < au.length; j++)
178                 {
179                     ulong t = cast(ulong)bu[i] * au[j] + r[i + j];
180                     r[i + j] = t & 0xFFFF_FFFF;
181                     uint c = t >> 32;
182                     size_t h = i + j + 1;
183 
184                     while (c != 0)
185                     {
186                         t = cast(ulong)c + r[h];
187                         r[h] = t & 0xFFFF_FFFF;
188                         c = t >> 32;
189                         h++;
190                     }
191                 }
192             }
193 
194             return fromUintArray(r);
195         }
196 
197         static uint[] toUintArray(const BigInt data) pure nothrow
198         {
199             size_t n = data.uintLength();
200             uint[] arr = new uint[n];
201 
202             for (size_t i = 0; i < n; i++)
203             {
204                 arr[i] = data.getDigit!uint(i);
205             }
206 
207             return arr;
208         }
209 
210         static BigInt fromUintArray(const uint[] arr) pure nothrow
211         {
212             size_t zeros = 0;
213             foreach_reverse (d; arr)
214             {
215                 if (d != 0)
216                 {
217                     break;
218                 }
219 
220                 zeros++;
221             }
222 
223             BigInt data = 0;
224 
225             foreach_reverse (d; arr[0..$ - zeros])
226             {
227                 data <<= 32;
228                 data += d;
229             }
230 
231             return data;
232         }
233     }
234 
235     ///
236     static bool millerRabinPrimeTest(const BigInt n, const size_t confidence)
237     {
238         enforce(confidence > 0, "confidence must be a positive integer greater than 0.");
239 
240         if (n < 2)
241         {
242             return false;
243         }
244         if (n == 2)
245         {
246             return true;
247         }
248 
249         BigInt[] bases;
250         if (n < 1_373_653)
251         {
252             bases = [BigInt(2), BigInt(3)];
253         }
254         else if (n <= 9_080_191)
255         {
256             bases = [BigInt(31), BigInt(73)];
257         }
258         else if (n <= 4_759_123_141)
259         {
260             bases = [BigInt(2), BigInt(7), BigInt(61)];
261         }
262         else if (n <= 2_152_302_898_747)
263         {
264             bases = [BigInt(2), BigInt(3), BigInt(5), BigInt(7), BigInt(11)];
265         }
266         else if (n <= 341_550_071_728_320)
267         {
268             if (n == 46_856_248_255_981)
269             {
270                 return false;
271             }
272 
273             bases = [BigInt(2), BigInt(3), BigInt(5), BigInt(7), BigInt(11), BigInt(13), BigInt(17)];
274         }
275         else if (n < 10_000_000_000_000_000)
276         {
277             bases = [BigInt(2), BigInt(3), BigInt(7), BigInt(61), BigInt(24251)];
278         }
279         else
280         {
281             if (!smallPrimesTable.all!((prime) => (powmod(prime, n - 1, n) == 1)))
282             {
283                 return false;
284             }
285 
286             /**
287             Although in theory base should be between 2 and n - 1, because confidence is optimized before call,
288             the larger n is, the smaller confidence is, so the requirement for base can not be too small,
289             so the minimum value does not use 2, but uses n / 2 instead.
290             */
291             bases = new BigInt[confidence];
292             import std.algorithm.iteration : each;
293             bases.each!((ref b) => (b = randomGenerate(n / 2, n - 1)));
294             //bases.each!((ref b) => (b = randomGenerate(BigInt(2), n - 1)));
295         }
296 
297         return (bases.all!((base) => (powmod(base, n - 1, n) == 1)));
298     }
299 
300     /**
301     Returns true if n is a Lucas-Lehmer probable prime.
302         The following assumptions are made:
303         BigInt n is a positive, odd number. So it can only be call after millerRabinPrimeTest is passed.
304     */
305     static bool lucasLehmerTest(const BigInt n)
306     {
307         immutable BigInt nPlusOne = n + 1;
308 
309         int d = 5;
310         while (jacobiSymbol(d, n) != -1)
311         {
312             // 5, -7, 9, -11, ...
313             d = (d < 0) ? abs(d) + 2 : -(d + 2);
314         }
315 
316         return lucasLehmerSequence(d, nPlusOne, n) % n == 0;
317     }
318 
319     static int jacobiSymbol(int p, const BigInt n)
320     {
321         if (p == 0)
322             return 0;
323 
324         int j = 1;
325         int u = cast(int) (n.getDigit!uint(0));
326 
327         // Make p positive
328         if (p < 0)
329         {
330             p = -p;
331             immutable n8 = u & 7;
332             if ((n8 == 3) || (n8 == 7))
333                 j = -j; // 3 (011) or 7 (111) mod 8
334         }
335 
336         // Get rid of factors of 2 in p
337         while ((p & 3) == 0)
338             p >>= 2;
339         if ((p & 1) == 0)
340         {
341             p >>= 1;
342             if (((u ^ (u >> 1)) & 2) != 0)
343                 j = -j; // 3 (011) or 5 (101) mod 8
344         }
345         if (p == 1)
346             return j;
347 
348         // Then, apply quadratic reciprocity
349         if ((p & u & 2) != 0)   // p = u = 3 (mod 4)?
350             j = -j;
351         // And reduce u mod p
352         u = n % p;
353 
354         // Now compute Jacobi(u,p), u < p
355         while (u != 0)
356         {
357             while ((u & 3) == 0)
358                 u >>= 2;
359             if ((u & 1) == 0)
360             {
361                 u >>= 1;
362                 if (((p ^ (p >> 1)) & 2) != 0)
363                     j = -j;     // 3 (011) or 5 (101) mod 8
364             }
365             if (u == 1)
366                 return j;
367 
368             // Now both u and p are odd, so use quadratic reciprocity
369             assert(u < p);
370             swap(u, p);
371             if ((u & p & 2) != 0) // u = p = 3 (mod 4)?
372                 j = -j;
373 
374             // Now u >= p, so it can be reduced
375             u %= p;
376         }
377 
378         return 0;
379     }
380 
381     static BigInt lucasLehmerSequence(const int z, const BigInt k, const BigInt n)
382     {
383         bool testBit(const BigInt n, const int m)
384         {
385             int digit = cast(int) (n.getDigit!uint(m >>> 5));
386             return (digit & (1 << (m & 31))) != 0;
387         }
388 
389         BigInt d = z;
390         BigInt u = 1, u2;
391         BigInt v = 1, v2;
392 
393         for (int i = cast(int)(k.uintLength * uint.sizeof * 8 - 2); i >= 0; i--)
394         {
395             u2 = (u * v) % n;
396             v2 = (v * v + d * u * u) % n;
397             if (testBit(v2, 0))
398                 v2 -= n;
399             v2 >>= 1;
400 
401             u = u2; v = v2;
402             if (testBit(k, i))
403             {
404                 u2 = (u + v) % n;
405                 if (testBit(u2, 0))
406                     u2 -= n;
407 
408                 u2 >>= 1;
409                 v2 = (v + d * u) % n;
410                 if (testBit(v2, 0))
411                     v2 -= n;
412                 v2 >>= 1;
413 
414                 u = u2; v = v2;
415             }
416         }
417 
418         return u;
419     }
420 
421     immutable static BigInt[] smallPrimesTable = [
422         2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61,
423         67, 71, 73, 79, 83, 89, 97, 101, 103, 107, 109, 113, 127, 131, 137,
424         139, 149, 151, 157, 163, 167, 173, 179, 181, 191, 193, 197, 199, 211,
425         223, 227, 229, 233, 239, 241 ];
426 }