19
19
* This algorithm uses Newton's approximation
20
20
* x[i+1] = x[i] - f(x[i])/f'(x[i])
21
21
* which will find the root in log(N) time where
22
- * each step involves a fair bit. This is not meant to
23
- * find huge roots [square and cube, etc].
22
+ * each step involves a fair bit.
24
23
*/
25
24
int mp_n_root_ex (const mp_int * a , mp_digit b , mp_int * c , int fast )
26
25
{
27
26
mp_int t1 , t2 , t3 , a_ ;
28
- int res ;
27
+ int res , cmp ;
28
+ int ilog2 ;
29
29
30
30
/* input must be positive if b is even */
31
31
if (((b & 1u ) == 0u ) && (a -> sign == MP_NEG )) {
@@ -48,9 +48,49 @@ int mp_n_root_ex(const mp_int *a, mp_digit b, mp_int *c, int fast)
48
48
a_ = * a ;
49
49
a_ .sign = MP_ZPOS ;
50
50
51
- /* t2 = 2 */
52
- mp_set (& t2 , 2uL );
53
-
51
+ /* Compute seed: 2^(log_2(n)/b + 2)*/
52
+ ilog2 = mp_count_bits (a );
53
+
54
+ /*
55
+ GCC and clang do not understand the sizeof tests and complain,
56
+ icc (the Intel compiler) seems to understand, at least it doesn't complain.
57
+ 2 of 3 say these macros are necessary, so there they are.
58
+ */
59
+ #if ( !(defined MP_8BIT ) && !(defined MP_16BIT ) )
60
+ /*
61
+ The type of mp_digit might be larger than an int.
62
+ If "b" is larger than INT_MAX it is also larger than
63
+ log_2(n) because the bit-length of the "n" is measured
64
+ with an int and hence the root is always < 2 (two).
65
+ */
66
+ if (sizeof (mp_digit ) >= sizeof (int )) {
67
+ if (b > (mp_digit )(INT_MAX /2 )) {
68
+ mp_set (c , 1uL );
69
+ c -> sign = a -> sign ;
70
+ res = MP_OKAY ;
71
+ goto LBL_T3 ;
72
+ }
73
+ }
74
+ #endif
75
+ /* "b" is smaller than INT_MAX, we can cast safely */
76
+ if (ilog2 < (int )b ) {
77
+ mp_set (c , 1uL );
78
+ c -> sign = a -> sign ;
79
+ res = MP_OKAY ;
80
+ goto LBL_T3 ;
81
+ }
82
+ ilog2 = ilog2 / ((int )b );
83
+ if (ilog2 == 0 ) {
84
+ mp_set (c , 1uL );
85
+ c -> sign = a -> sign ;
86
+ res = MP_OKAY ;
87
+ goto LBL_T3 ;
88
+ }
89
+ /* Start value must be larger than root */
90
+ ilog2 += 2 ;
91
+ if ((res = mp_2expt (& t2 ,ilog2 )) != MP_OKAY ) {
92
+ goto LBL_T3 ;
93
+ }
54
94
do {
55
95
/* t1 = t2 */
56
96
if ((res = mp_copy (& t2 , & t1 )) != MP_OKAY ) {
@@ -63,7 +103,6 @@ int mp_n_root_ex(const mp_int *a, mp_digit b, mp_int *c, int fast)
63
103
if ((res = mp_expt_d_ex (& t1 , b - 1u , & t3 , fast )) != MP_OKAY ) {
64
104
goto LBL_T3 ;
65
105
}
66
-
67
106
/* numerator */
68
107
/* t2 = t1**b */
69
108
if ((res = mp_mul (& t3 , & t1 , & t2 )) != MP_OKAY ) {
@@ -89,14 +128,39 @@ int mp_n_root_ex(const mp_int *a, mp_digit b, mp_int *c, int fast)
89
128
if ((res = mp_sub (& t1 , & t3 , & t2 )) != MP_OKAY ) {
90
129
goto LBL_T3 ;
91
130
}
131
+ /*
132
+ Number of rounds is at most log_2(root). If it is more it
133
+ got stuck, so break out of the loop and do the rest manually.
134
+ */
135
+ if (ilog2 -- == 0 ) {
136
+ break ;
137
+ }
92
138
} while (mp_cmp (& t1 , & t2 ) != MP_EQ );
93
139
94
140
/* result can be off by a few so check */
141
+ /* Loop beneath can overshoot by one if found root is smaller than actual root */
142
+ for (;;) {
143
+ if ((res = mp_expt_d_ex (& t1 , b , & t2 , fast )) != MP_OKAY ) {
144
+ goto LBL_T3 ;
145
+ }
146
+ cmp = mp_cmp (& t2 , & a_ );
147
+ if (cmp == MP_EQ ) {
148
+ res = MP_OKAY ;
149
+ goto LBL_T3 ;
150
+ }
151
+ if (cmp == MP_LT ) {
152
+ if ((res = mp_add_d (& t1 , 1uL , & t1 )) != MP_OKAY ) {
153
+ goto LBL_T3 ;
154
+ }
155
+ } else {
156
+ break ;
157
+ }
158
+ }
159
+ /* correct overshoot from above or from recurrence */
95
160
for (;;) {
96
161
if ((res = mp_expt_d_ex (& t1 , b , & t2 , fast )) != MP_OKAY ) {
97
162
goto LBL_T3 ;
98
163
}
99
-
100
164
if (mp_cmp (& t2 , & a_ ) == MP_GT ) {
101
165
if ((res = mp_sub_d (& t1 , 1uL , & t1 )) != MP_OKAY ) {
102
166
goto LBL_T3 ;
@@ -123,7 +187,6 @@ int mp_n_root_ex(const mp_int *a, mp_digit b, mp_int *c, int fast)
123
187
return res ;
124
188
}
125
189
#endif
126
-
127
190
/* ref: $Format:%D$ */
128
191
/* git commit: $Format:%H$ */
129
192
/* commit time: $Format:%ai$ */
0 commit comments