// modpwr.c
// to caculate: ab mod m = ab % m
// 1. convert b to binary, odd: 2k+... + 1, or even: 2k+... +21
// 2. only when bit in b is 1, it need to be caculated
// 3. if power of 2, we can build a lookup table to reduce the caculations - todo
// 4. Expand the series by modular multiplication rules, combine the result finally:
// ab % m = a(2k+... ) % m = a(2k) * a(...) % m = (a(2k) % m) * (a(...) % m) % m
#include "stdio.h"
bool testbit(int x, int k){
int table[8] = {1, 2, 4, 8, 0x10, 0x20, 0x40, 0x80};
int q = k / 8; // number of byte
int r = k - q * 8; // number of bit
if (q>0) x >>= q << 3;// shift to byte 0, x = x >> q*8
return x & table[r];// test the mask bit
}
int modpwr(int a, int b, int m){ // ab mod m
int bits = sizeof(b) * 8 ; // total bits in b
int ans = 1;
int n, an;
for(int i = 0, power2 = 1; i < bits; i++) { // 20 = 1
if (testbit(b, i)) { // only bit=1 need to be caculated
n = power2; // exponent of a = 2i
an = 1; // initialize to caculate power series of a => an
while (n--) an = an*a % m; // an = apower2 % m to prevent overflow
ans = ans * an % m; // combine all result
}
power2 *= 2; // caculate power series of 2
}
return ans;
}
int main() {
printf("%d \n", modpwr(3,3,6));
}
沒有留言:
張貼留言