1 module crypto.bigint;
2 
3 import std.bigint;
4 import std.algorithm.mutation : reverse, swap;
5 import std.algorithm.searching : find, all;
6 import std.conv : to, text;
7 import std.exception : enforce;
8 import std.range : repeat, array;
9 import std.math : abs;
10 
11 import crypto.random;
12 
13 struct BigIntHelper
14 {
15     /// Random generate a BigInt by bitLength.
16     static BigInt randomGenerate(uint bitLength, int highBit = -1, int lowBit = -1)
17     {
18         enforce((bitLength > 0) && (bitLength % 8 == 0));
19 
20         ubyte[] buffer = new ubyte[bitLength / 8];
21 
22         uint pos = 0;
23         uint current = 0;
24         foreach (ref a; buffer)
25         {
26             if (pos == 0)
27             {
28                 current = rnd.next;
29             }
30 
31             a = cast(ubyte)(current >> 8 * pos);
32             pos = (pos + 1) % uint.sizeof;
33         }
34 
35         if (highBit == 0)
36         {
37             buffer[0] &= (0xFF >> 1);
38         }
39         else if (highBit == 1)
40         {
41             buffer[0] |= (0x01 << 7);
42         }
43 
44         if (lowBit == 0)
45         {
46             buffer[$ - 1] &= (0xFF << 1);
47         }
48         else if (lowBit == 1)
49         {
50             buffer[$ - 1] |= 0x01;
51         }
52 
53         return BigIntHelper.fromBytes(buffer);
54     }
55 
56     /// Random generate a BigInt between min and max.
57     static BigInt randomGenerate(const BigInt min, const BigInt max)
58     {
59         enforce(max >= min, text("BigIntHelper.randomGenerate(): invalid bounding interval ", min, ", ", max));
60 
61         BigInt r = randomGenerate(cast(uint)((max.uintLength + 1) * uint.sizeof * 8));
62         return r % (max - min + 1) + min;
63     }
64 
65     ///
66     static ubyte[] toUBytes(const BigInt value) pure nothrow
67     {
68         size_t len = value.uintLength();
69         ubyte[] ubytes = new ubyte[len * uint.sizeof];
70 
71         for (size_t i = 0; i < len; i++)
72         {
73             uint digit = value.getDigit!uint(i);
74             ubyte* p = cast(ubyte*)&digit;
75 
76             for (size_t j = 0; j < uint.sizeof; j++)
77             {
78                 ubytes[(len - i - 1) * uint.sizeof + (uint.sizeof - j - 1)] = *(p + j);
79             }
80         }
81 
82         return ubytes.find!((a, b) => a != b)(0);
83     }
84 
85     /++
86         Create a BigInt from an array representing a magnitude,
87         most significant element first.
88 
89         !! Prior to Phobos version 2.094 here is a performance bottleneck.
90     +/
91     static BigInt fromBytes(in ubyte[] buffer) pure nothrow
92     {
93         static if (__VERSION__ >= 2094)
94             return BigInt(false, buffer);
95         else
96         {
97             size_t supplement = (uint.sizeof - buffer.length % uint.sizeof) % uint.sizeof;
98             ubyte[] bytes = (supplement > 0) ? (cast(ubyte)0).repeat(supplement).array ~ buffer : cast(ubyte[])buffer;
99             BigInt data = 0;
100 
101             for (size_t i = 0; i < bytes.length / uint.sizeof; i++)
102             {
103                 uint digit;
104                 ubyte* p = cast(ubyte*)&digit;
105 
106                 for (size_t j = 0; j < uint.sizeof; j++)
107                 {
108                     *(p + j) = bytes[i * uint.sizeof + uint.sizeof - j - 1];
109                 }
110 
111                 data <<= 32;
112                 data += digit;
113             }
114 
115             return data;
116         }
117     }
118 
119     static if (__VERSION__ >= 2087)
120         alias powmod = std.bigint.powmod;
121     else
122     {
123         ///
124         static BigInt powmod(const BigInt base, const BigInt exponent, const BigInt modulus) pure nothrow
125         {
126             assert(base >= 1 && exponent >= 0 && modulus >= 1);
127 
128             if (exponent == 0)
129             {
130                 return BigInt(1) % modulus;
131             }
132 
133             if (exponent == 1)
134             {
135                 return base % modulus;
136             }
137 
138             BigInt temp = powmod(base, exponent / 2, modulus);
139 
140             return (exponent & 1) ? mul(mul(temp, temp), base) % modulus : mul(temp, temp) % modulus;
141         }
142     }
143 
144     /**
145     Test whether BigInt n is prime.
146         Step 1: millerRabinPrimeTest
147         Step 2: lucasLehmerTest
148     */
149     static bool isProbablePrime(const BigInt n, const size_t confidence)
150     {
151         bool passed = millerRabinPrimeTest(n, confidence);
152 
153         /**
154         When n < 10_000_000_000_000_000,
155         there is no need to lucasLehmerTest, And trust the result of millerRabinPrimeTest.
156         */
157         if (!passed || (n < 10_000_000_000_000_000))
158         {
159             return passed;
160         }
161 
162         return lucasLehmerTest(n);
163     }
164 
165 private:
166 
167     /++
168     Bug BigInt mul() of phobos will be fixed in version 2.087.0
169         Details: https://github.com/dlang/phobos/pull/6972
170     +/
171     static if (__VERSION__ < 2087)
172     {
173         static BigInt mul(const BigInt a, const BigInt b) pure nothrow
174         {
175             uint[] au = toUintArray(a);
176             uint[] bu = toUintArray(b);
177 
178             uint[] r = new uint[au.length + bu.length];
179 
180             for (size_t i = 0; i < bu.length; i++)
181             {
182                 for (size_t j = 0; j < au.length; j++)
183                 {
184                     ulong t = cast(ulong)bu[i] * au[j] + r[i + j];
185                     r[i + j] = t & 0xFFFF_FFFF;
186                     uint c = t >> 32;
187                     size_t h = i + j + 1;
188 
189                     while (c != 0)
190                     {
191                         t = cast(ulong)c + r[h];
192                         r[h] = t & 0xFFFF_FFFF;
193                         c = t >> 32;
194                         h++;
195                     }
196                 }
197             }
198 
199             return fromUintArray(r);
200         }
201 
202         static uint[] toUintArray(const BigInt data) pure nothrow
203         {
204             size_t n = data.uintLength();
205             uint[] arr = new uint[n];
206 
207             for (size_t i = 0; i < n; i++)
208             {
209                 arr[i] = data.getDigit!uint(i);
210             }
211 
212             return arr;
213         }
214 
215         static BigInt fromUintArray(const uint[] arr) pure nothrow
216         {
217             size_t zeros = 0;
218             foreach_reverse (d; arr)
219             {
220                 if (d != 0)
221                 {
222                     break;
223                 }
224 
225                 zeros++;
226             }
227 
228             BigInt data = 0;
229 
230             foreach_reverse (d; arr[0..$ - zeros])
231             {
232                 data <<= 32;
233                 data += d;
234             }
235 
236             return data;
237         }
238     }
239 
240     ///
241     static bool millerRabinPrimeTest(const BigInt n, const size_t confidence)
242     {
243         enforce(confidence > 0, "confidence must be a positive integer greater than 0.");
244 
245         if (n < 2)
246         {
247             return false;
248         }
249         if (n == 2)
250         {
251             return true;
252         }
253 
254         BigInt[] bases;
255         if (n < 1_373_653)
256         {
257             bases = [BigInt(2), BigInt(3)];
258         }
259         else if (n <= 9_080_191)
260         {
261             bases = [BigInt(31), BigInt(73)];
262         }
263         else if (n <= 4_759_123_141)
264         {
265             bases = [BigInt(2), BigInt(7), BigInt(61)];
266         }
267         else if (n <= 2_152_302_898_747)
268         {
269             bases = [BigInt(2), BigInt(3), BigInt(5), BigInt(7), BigInt(11)];
270         }
271         else if (n <= 341_550_071_728_320)
272         {
273             if (n == 46_856_248_255_981)
274             {
275                 return false;
276             }
277 
278             bases = [BigInt(2), BigInt(3), BigInt(5), BigInt(7), BigInt(11), BigInt(13), BigInt(17)];
279         }
280         else if (n < 10_000_000_000_000_000)
281         {
282             bases = [BigInt(2), BigInt(3), BigInt(7), BigInt(61), BigInt(24251)];
283         }
284         else
285         {
286             if (!smallPrimesTable.all!((prime) => (powmod(prime, n - 1, n) == 1)))
287             {
288                 return false;
289             }
290 
291             /**
292             Although in theory base should be between 2 and n - 1, because confidence is optimized before call,
293             the larger n is, the smaller confidence is, so the requirement for base can not be too small,
294             so the minimum value does not use 2, but uses n / 2 instead.
295             */
296             bases = new BigInt[confidence];
297             import std.algorithm.iteration : each;
298             bases.each!((ref b) => (b = randomGenerate(n / 2, n - 1)));
299             //bases.each!((ref b) => (b = randomGenerate(BigInt(2), n - 1)));
300         }
301 
302         return (bases.all!((base) => (powmod(base, n - 1, n) == 1)));
303     }
304 
305     /**
306     Returns true if n is a Lucas-Lehmer probable prime.
307         The following assumptions are made:
308         BigInt n is a positive, odd number. So it can only be call after millerRabinPrimeTest is passed.
309     */
310     static bool lucasLehmerTest(const BigInt n)
311     {
312         immutable BigInt nPlusOne = n + 1;
313 
314         int d = 5;
315         while (jacobiSymbol(d, n) != -1)
316         {
317             // 5, -7, 9, -11, ...
318             d = (d < 0) ? abs(d) + 2 : -(d + 2);
319         }
320 
321         return lucasLehmerSequence(d, nPlusOne, n) % n == 0;
322     }
323 
324     static int jacobiSymbol(int p, const BigInt n)
325     {
326         if (p == 0)
327             return 0;
328 
329         int j = 1;
330         int u = cast(int) (n.getDigit!uint(0));
331 
332         // Make p positive
333         if (p < 0)
334         {
335             p = -p;
336             immutable n8 = u & 7;
337             if ((n8 == 3) || (n8 == 7))
338                 j = -j; // 3 (011) or 7 (111) mod 8
339         }
340 
341         // Get rid of factors of 2 in p
342         while ((p & 3) == 0)
343             p >>= 2;
344         if ((p & 1) == 0)
345         {
346             p >>= 1;
347             if (((u ^ (u >> 1)) & 2) != 0)
348                 j = -j; // 3 (011) or 5 (101) mod 8
349         }
350         if (p == 1)
351             return j;
352 
353         // Then, apply quadratic reciprocity
354         if ((p & u & 2) != 0)   // p = u = 3 (mod 4)?
355             j = -j;
356         // And reduce u mod p
357         u = n % p;
358 
359         // Now compute Jacobi(u,p), u < p
360         while (u != 0)
361         {
362             while ((u & 3) == 0)
363                 u >>= 2;
364             if ((u & 1) == 0)
365             {
366                 u >>= 1;
367                 if (((p ^ (p >> 1)) & 2) != 0)
368                     j = -j;     // 3 (011) or 5 (101) mod 8
369             }
370             if (u == 1)
371                 return j;
372 
373             // Now both u and p are odd, so use quadratic reciprocity
374             assert(u < p);
375             swap(u, p);
376             if ((u & p & 2) != 0) // u = p = 3 (mod 4)?
377                 j = -j;
378 
379             // Now u >= p, so it can be reduced
380             u %= p;
381         }
382 
383         return 0;
384     }
385 
386     static BigInt lucasLehmerSequence(const int z, const BigInt k, const BigInt n)
387     {
388         bool testBit(const BigInt n, const int m)
389         {
390             int digit = cast(int) (n.getDigit!uint(m >>> 5));
391             return (digit & (1 << (m & 31))) != 0;
392         }
393 
394         BigInt d = z;
395         BigInt u = 1, u2;
396         BigInt v = 1, v2;
397 
398         for (int i = cast(int)(k.uintLength * uint.sizeof * 8 - 2); i >= 0; i--)
399         {
400             u2 = (u * v) % n;
401             v2 = (v * v + d * u * u) % n;
402             if (testBit(v2, 0))
403                 v2 -= n;
404             v2 >>= 1;
405 
406             u = u2; v = v2;
407             if (testBit(k, i))
408             {
409                 u2 = (u + v) % n;
410                 if (testBit(u2, 0))
411                     u2 -= n;
412 
413                 u2 >>= 1;
414                 v2 = (v + d * u) % n;
415                 if (testBit(v2, 0))
416                     v2 -= n;
417                 v2 >>= 1;
418 
419                 u = u2; v = v2;
420             }
421         }
422 
423         return u;
424     }
425 
426     immutable static BigInt[] smallPrimesTable = [
427         2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61,
428         67, 71, 73, 79, 83, 89, 97, 101, 103, 107, 109, 113, 127, 131, 137,
429         139, 149, 151, 157, 163, 167, 173, 179, 181, 191, 193, 197, 199, 211,
430         223, 227, 229, 233, 239, 241 ];
431 }