1 module crypto.bigint;
2 
3 import std.bigint;
4 import std.array : Appender;
5 import std.algorithm.mutation : reverse;
6 import std.conv : to, text;
7 import std.exception : enforce;
8 
9 import crypto.random;
10 
11 struct BigIntHelper
12 {
13     /**
14     Random generate a BigInt by bitLength.
15     */
16     static BigInt randomGenerate(uint bitLength, int highBit = -1, int lowBit = -1)
17     {
18         enforce((bitLength > 0) && (bitLength % 8 == 0));
19 
20         ubyte[] buffer = new ubyte[bitLength / 8];
21 
22         uint pos = 0;
23         uint current = 0;
24         foreach (ref a; buffer)
25         {
26             if (pos == 0)
27             {
28                 current = rnd.next;
29             }
30 
31             a = cast(ubyte)(current >> 8 * pos);
32             pos = (pos + 1) % uint.sizeof;
33         }
34 
35         if (highBit == 0)
36         {
37             buffer[0] &= (0xFF >> 1);
38         }
39         else if (highBit == 1)
40         {
41             buffer[0] |= (0x01 << 7);
42         }
43 
44         if (lowBit == 0)
45         {
46             buffer[$ - 1] &= (0xFF << 1);
47         }
48         else if (lowBit == 1)
49         {
50             buffer[$ - 1] |= 0x01;
51         }
52 
53         return BigIntHelper.bigIntFromUByteArray(buffer);
54     }
55 
56     /**
57     Random generate a BigInt between min and max.
58     */
59     static BigInt randomGenerate(const BigInt min, const BigInt max)
60     {
61         enforce(max >= min, text("BigIntHelper.randomGenerate(): invalid bounding interval ", min, ", ", max));
62 
63         BigInt r = randomGenerate(cast(uint)((max.uintLength + 1) * uint.sizeof * 8));
64         return r % (max - min + 1) + min;
65     }
66 
67     ///
68     static ubyte[] bigIntToUByteArray(BigInt value)
69     {
70         Appender!(ubyte[]) app;
71 
72         while (value > 0)
73         {
74             app.put((value - ((value >> 8) << 8)).to!ubyte);
75             value >>= 8;
76         }
77 
78         reverse(app.data);
79 
80         return app.data;
81     }
82 
83     ///
84     static BigInt bigIntFromUByteArray(in ubyte[] buffer)
85     {
86         BigInt ret = BigInt("0");
87 
88         for (uint i; i < buffer.length; i++)
89         {
90             ret <<= 8;
91             ret += buffer[i];
92         }
93 
94         return ret;
95     }
96 
97     static if (__VERSION__ >= 2087)
98         alias powmod = std.bigint.powmod;
99     else
100     {
101         ///
102         static BigInt powmod(const BigInt base, const BigInt exponent, const BigInt modulus)
103         {
104             assert(base >= 1 && exponent >= 0 && modulus >= 1);
105 
106             if (exponent == 0)
107             {
108                 return BigInt(1) % modulus;
109             }
110 
111             if (exponent == 1)
112             {
113                 return base % modulus;
114             }
115 
116             BigInt temp = powmod(base, exponent / 2, modulus);
117 
118             return (exponent & 1) ? mul(mul(temp, temp), base) % modulus : mul(temp, temp) % modulus;
119         }
120     }
121 
122     ///
123     static bool millerRabinPrimeTest(const BigInt n, const size_t confidence)
124     {
125         enforce(confidence > 0, "confidence must be a positive integer greater than 0.");
126 
127         if (n < 2)
128         {
129             return false;
130         }
131         if (n == 2)
132         {
133             return true;
134         }
135 
136         BigInt[] bases;
137         if (n < 1_373_653)
138         {
139             bases = [BigInt(2), BigInt(3)];
140         }
141         else if (n <= 9_080_191)
142         {
143             bases = [BigInt(31), BigInt(73)];
144         }
145         else if (n <= 4_759_123_141)
146         {
147             bases = [BigInt(2), BigInt(7), BigInt(61)];
148         }
149         else if (n <= 2_152_302_898_747)
150         {
151             bases = [BigInt(2), BigInt(3), BigInt(5), BigInt(7), BigInt(11)];
152         }
153         else if (n <= 341_550_071_728_320)
154         {
155             if (n == 46_856_248_255_981)
156             {
157                 return false;
158             }
159 
160             bases = [BigInt(2), BigInt(3), BigInt(5), BigInt(7), BigInt(11), BigInt(13), BigInt(17)];
161         }
162         else if (n < 10_000_000_000_000_000)
163         {
164             bases = [BigInt(2), BigInt(3), BigInt(7), BigInt(61), BigInt(24251)];
165         }
166         else
167         {
168             // Generate random numbers between 2 and n - 1.
169             bases = new BigInt[confidence];
170             import std.algorithm.iteration : each;
171             bases.each!((ref b) => (b = randomGenerate(BigInt(2), n - 1)));
172         }
173 
174         import std.algorithm.searching : all;
175         return (bases.all!((base) => (powmod(base, n - 1, n) == 1)));
176     }
177 
178 private:
179 
180     /++
181         Bug BigInt mul() of phobos will be fixed in version 2.087.0
182         Details:
183             https://github.com/dlang/phobos/pull/6972
184     +/
185     static BigInt mul(const BigInt a, const BigInt b)
186     {
187         uint[] au = bigIntToUintArr(a);
188         uint[] bu = bigIntToUintArr(b);
189 
190         uint[] r = new uint[au.length + bu.length];
191 
192         for (size_t i = 0; i < bu.length; i++)
193         {
194             for (size_t j = 0; j < au.length; j++)
195             {
196                 ulong t = cast(ulong)bu[i] * au[j] + r[i + j];
197                 r[i + j] = t & 0xFFFF_FFFF;
198                 uint c = t >> 32;
199                 size_t h = i + j + 1;
200 
201                 while (c != 0)
202                 {
203                     t = cast(ulong)c + r[h];
204                     r[h] = t & 0xFFFF_FFFF;
205                     c = t >> 32;
206                     h++;
207                 }
208             }
209         }
210 
211         return uintArrToBigInt(r);
212     }
213 
214     static uint[] bigIntToUintArr(const BigInt data)
215     {
216         size_t n = data.uintLength();
217         uint[] arr = new uint[n];
218 
219         for (size_t i = 0; i < n; i++)
220         {
221             arr[i] = data.getDigit!uint(i);
222         }
223 
224         return arr;
225     }
226 
227     static BigInt uintArrToBigInt(const uint[] arr)
228     {
229         size_t zeros = 0;
230         foreach_reverse(d; arr)
231         {
232             if (d != 0)
233             {
234                 break;
235             }
236 
237             zeros++;
238         }
239 
240         BigInt data = 0;
241 
242         foreach_reverse (d; arr[0..$ - zeros])
243         {
244             data <<= 32;
245             data += d;
246         }
247 
248         return data;
249     }
250 }