Fixed carry propagation bug in P-256 'm62' implementation (found by Auke Zeilstra...
[BearSSL] / src / ec / ec_p256_m31.c
index 0631a13..d57ef7b 100644 (file)
@@ -394,7 +394,7 @@ mul_f256(uint32_t *d, const uint32_t *a, const uint32_t *b)
        uint32_t t[18];
        uint64_t s[18];
        uint64_t cc, x;
-       uint32_t z;
+       uint32_t z, c;
        int i;
 
        mul9(t, a, b);
@@ -423,17 +423,17 @@ mul_f256(uint32_t *d, const uint32_t *a, const uint32_t *b)
        }
 
        for (i = 17; i >= 9; i --) {
-               uint64_t x;
-
-               x = s[i];
-               s[i - 1] += ARSHW(x, 2);
-               s[i - 2] += (x << 28) & 0x3FFFFFFF;
-               s[i - 2] -= ARSHW(x, 4);
-               s[i - 3] -= (x << 26) & 0x3FFFFFFF;
-               s[i - 5] -= ARSHW(x, 10);
-               s[i - 6] -= (x << 20) & 0x3FFFFFFF;
-               s[i - 8] += ARSHW(x, 16);
-               s[i - 9] += (x << 14) & 0x3FFFFFFF;
+               uint64_t y;
+
+               y = s[i];
+               s[i - 1] += ARSHW(y, 2);
+               s[i - 2] += (y << 28) & 0x3FFFFFFF;
+               s[i - 2] -= ARSHW(y, 4);
+               s[i - 3] -= (y << 26) & 0x3FFFFFFF;
+               s[i - 5] -= ARSHW(y, 10);
+               s[i - 6] -= (y << 20) & 0x3FFFFFFF;
+               s[i - 8] += ARSHW(y, 16);
+               s[i - 9] += (y << 14) & 0x3FFFFFFF;
        }
 
        /*
@@ -465,7 +465,15 @@ mul_f256(uint32_t *d, const uint32_t *a, const uint32_t *b)
        d[8] &= 0xFFFF;
 
        /*
-        * Subtract cc*p.
+        * One extra round of reduction, for cc*2^256, which means
+        * adding cc*(2^224-2^192-2^96+1) to a 256-bit (nonnegative)
+        * value. If cc is negative, then it may happen (rarely, but
+        * not neglectibly so) that the result would be negative. In
+        * order to avoid that, if cc is negative, then we add the
+        * modulus once. Note that if cc is negative, then propagating
+        * that carry must yield a value lower than the modulus, so
+        * adding the modulus once will keep the final result under
+        * twice the modulus.
         */
        z = (uint32_t)cc;
        d[3] -= z << 6;
@@ -473,6 +481,12 @@ mul_f256(uint32_t *d, const uint32_t *a, const uint32_t *b)
        d[7] -= ARSH(z, 18);
        d[7] += (z << 14) & 0x3FFFFFFF;
        d[8] += ARSH(z, 16);
+       c = z >> 31;
+       d[0] -= c;
+       d[3] += c << 6;
+       d[6] += c << 12;
+       d[7] -= c << 14;
+       d[8] += c << 16;
        for (i = 0; i < 9; i ++) {
                uint32_t w;
 
@@ -492,7 +506,7 @@ square_f256(uint32_t *d, const uint32_t *a)
        uint32_t t[18];
        uint64_t s[18];
        uint64_t cc, x;
-       uint32_t z;
+       uint32_t z, c;
        int i;
 
        square9(t, a);
@@ -521,17 +535,17 @@ square_f256(uint32_t *d, const uint32_t *a)
        }
 
        for (i = 17; i >= 9; i --) {
-               uint64_t x;
-
-               x = s[i];
-               s[i - 1] += ARSHW(x, 2);
-               s[i - 2] += (x << 28) & 0x3FFFFFFF;
-               s[i - 2] -= ARSHW(x, 4);
-               s[i - 3] -= (x << 26) & 0x3FFFFFFF;
-               s[i - 5] -= ARSHW(x, 10);
-               s[i - 6] -= (x << 20) & 0x3FFFFFFF;
-               s[i - 8] += ARSHW(x, 16);
-               s[i - 9] += (x << 14) & 0x3FFFFFFF;
+               uint64_t y;
+
+               y = s[i];
+               s[i - 1] += ARSHW(y, 2);
+               s[i - 2] += (y << 28) & 0x3FFFFFFF;
+               s[i - 2] -= ARSHW(y, 4);
+               s[i - 3] -= (y << 26) & 0x3FFFFFFF;
+               s[i - 5] -= ARSHW(y, 10);
+               s[i - 6] -= (y << 20) & 0x3FFFFFFF;
+               s[i - 8] += ARSHW(y, 16);
+               s[i - 9] += (y << 14) & 0x3FFFFFFF;
        }
 
        /*
@@ -563,7 +577,15 @@ square_f256(uint32_t *d, const uint32_t *a)
        d[8] &= 0xFFFF;
 
        /*
-        * Subtract cc*p.
+        * One extra round of reduction, for cc*2^256, which means
+        * adding cc*(2^224-2^192-2^96+1) to a 256-bit (nonnegative)
+        * value. If cc is negative, then it may happen (rarely, but
+        * not neglectibly so) that the result would be negative. In
+        * order to avoid that, if cc is negative, then we add the
+        * modulus once. Note that if cc is negative, then propagating
+        * that carry must yield a value lower than the modulus, so
+        * adding the modulus once will keep the final result under
+        * twice the modulus.
         */
        z = (uint32_t)cc;
        d[3] -= z << 6;
@@ -571,6 +593,12 @@ square_f256(uint32_t *d, const uint32_t *a)
        d[7] -= ARSH(z, 18);
        d[7] += (z << 14) & 0x3FFFFFFF;
        d[8] += ARSH(z, 16);
+       c = z >> 31;
+       d[0] -= c;
+       d[3] += c << 6;
+       d[6] += c << 12;
+       d[7] -= c << 14;
+       d[8] += c << 16;
        for (i = 0; i < 9; i ++) {
                uint32_t w;
 
@@ -1061,7 +1089,7 @@ p256_decode(p256_jacobian *P, const void *src, size_t len)
        memcpy(P->y, ty, sizeof ty);
        memset(P->z, 0, sizeof P->z);
        P->z[0] = 1;
-       return NEQ(bad, 0) ^ 1;
+       return EQ(bad, 0);
 }
 
 /*