Modular Exponentiation for high numbers in C++

前端 未结 6 1739
失恋的感觉
失恋的感觉 2020-12-01 12:48

So I\'ve been working recently on an implementation of the Miller-Rabin primality test. I am limiting it to a scope of all 32-bit numbers, because this is a just-for-fun pr

6条回答
  •  我在风中等你
    2020-12-01 13:26

    I wrote something for this recently for RSA in C++, bit messy though.

    #include "BigInteger.h"
    #include 
    #include 
    #include 
    
    BigInteger::BigInteger() {
        digits.push_back(0);
        negative = false;
    }
    
    BigInteger::~BigInteger() {
    }
    
    void BigInteger::addWithoutSign(BigInteger& c, const BigInteger& a, const BigInteger& b) {
        int sum_n_carry = 0;
        int n = (int)a.digits.size();
        if (n < (int)b.digits.size()) {
            n = b.digits.size();
        }
        c.digits.resize(n);
        for (int i = 0; i < n; ++i) {
            unsigned short a_digit = 0;
            unsigned short b_digit = 0;
            if (i < (int)a.digits.size()) {
                a_digit = a.digits[i];
            }
            if (i < (int)b.digits.size()) {
                b_digit = b.digits[i];
            }
            sum_n_carry += a_digit + b_digit;
            c.digits[i] = (sum_n_carry & 0xFFFF);
            sum_n_carry >>= 16;
        }
        if (sum_n_carry != 0) {
            putCarryInfront(c, sum_n_carry);
        }
        while (c.digits.size() > 1 && c.digits.back() == 0) {
            c.digits.pop_back();
        }
        //std::cout << a.toString() << " + " << b.toString() << " == " << c.toString() << std::endl;
    }
    
    void BigInteger::subWithoutSign(BigInteger& c, const BigInteger& a, const BigInteger& b) {
        int sub_n_borrow = 0;
        int n = a.digits.size();
        if (n < (int)b.digits.size())
            n = (int)b.digits.size();
        c.digits.resize(n);
        for (int i = 0; i < n; ++i) {
            unsigned short a_digit = 0;
            unsigned short b_digit = 0;
            if (i < (int)a.digits.size())
                a_digit = a.digits[i];
            if (i < (int)b.digits.size())
                b_digit = b.digits[i];
            sub_n_borrow += a_digit - b_digit;
            if (sub_n_borrow >= 0) {
                c.digits[i] = sub_n_borrow;
                sub_n_borrow = 0;
            } else {
                c.digits[i] = 0x10000 + sub_n_borrow;
                sub_n_borrow = -1;
            }
        }
        while (c.digits.size() > 1 && c.digits.back() == 0) {
            c.digits.pop_back();
        }
        //std::cout << a.toString() << " - " << b.toString() << " == " << c.toString() << std::endl;
    }
    
    int BigInteger::cmpWithoutSign(const BigInteger& a, const BigInteger& b) {
        int n = (int)a.digits.size();
        if (n < (int)b.digits.size())
            n = (int)b.digits.size();
        //std::cout << "cmp(" << a.toString() << ", " << b.toString() << ") == ";
        for (int i = n-1; i >= 0; --i) {
            unsigned short a_digit = 0;
            unsigned short b_digit = 0;
            if (i < (int)a.digits.size())
                a_digit = a.digits[i];
            if (i < (int)b.digits.size())
                b_digit = b.digits[i];
            if (a_digit < b_digit) {
                //std::cout << "-1" << std::endl;
                return -1;
            } else if (a_digit > b_digit) {
                //std::cout << "+1" << std::endl;
                return +1;
            }
        }
        //std::cout << "0" << std::endl;
        return 0;
    }
    
    void BigInteger::multByDigitWithoutSign(BigInteger& c, const BigInteger& a, unsigned short b) {
        unsigned int mult_n_carry = 0;
        c.digits.clear();
        c.digits.resize(a.digits.size());
        for (int i = 0; i < (int)a.digits.size(); ++i) {
            unsigned short a_digit = 0;
            unsigned short b_digit = b;
            if (i < (int)a.digits.size())
                a_digit = a.digits[i];
            mult_n_carry += a_digit * b_digit;
            c.digits[i] = (mult_n_carry & 0xFFFF);
            mult_n_carry >>= 16;
        }
        if (mult_n_carry != 0) {
            putCarryInfront(c, mult_n_carry);
        }
        //std::cout << a.toString() << " x " << b << " == " << c.toString() << std::endl;
    }
    
    void BigInteger::shiftLeftByBase(BigInteger& b, const BigInteger& a, int times) {
        b.digits.resize(a.digits.size() + times);
        for (int i = 0; i < times; ++i) {
            b.digits[i] = 0;
        }
        for (int i = 0; i < (int)a.digits.size(); ++i) {
            b.digits[i + times] = a.digits[i];
        }
    }
    
    void BigInteger::shiftRight(BigInteger& a) {
        //std::cout << "shr " << a.toString() << " == ";
        for (int i = 0; i < (int)a.digits.size(); ++i) {
            a.digits[i] >>= 1;
            if (i+1 < (int)a.digits.size()) {
                if ((a.digits[i+1] & 0x1) != 0) {
                    a.digits[i] |= 0x8000;
                }
            }
        }
        //std::cout << a.toString() << std::endl;
    }
    
    void BigInteger::shiftLeft(BigInteger& a) {
        bool lastBit = false;
        for (int i = 0; i < (int)a.digits.size(); ++i) {
            bool bit = (a.digits[i] & 0x8000) != 0;
            a.digits[i] <<= 1;
            if (lastBit)
                a.digits[i] |= 1;
            lastBit = bit;
        }
        if (lastBit) {
            a.digits.push_back(1);
        }
    }
    
    void BigInteger::putCarryInfront(BigInteger& a, unsigned short carry) {
        BigInteger b;
        b.negative = a.negative;
        b.digits.resize(a.digits.size() + 1);
        b.digits[a.digits.size()] = carry;
        for (int i = 0; i < (int)a.digits.size(); ++i) {
            b.digits[i] = a.digits[i];
        }
        a.digits.swap(b.digits);
    }
    
    void BigInteger::divideWithoutSign(BigInteger& c, BigInteger& d, const BigInteger& a, const BigInteger& b) {
        c.digits.clear();
        c.digits.push_back(0);
        BigInteger two("2");
        BigInteger e = b;
        BigInteger f("1");
        BigInteger g = a;
        BigInteger one("1");
        while (cmpWithoutSign(g, e) >= 0) {
            shiftLeft(e);
            shiftLeft(f);
        }
        shiftRight(e);
        shiftRight(f);
        while (cmpWithoutSign(g, b) >= 0) {
            g -= e;
            c += f;
            while (cmpWithoutSign(g, e) < 0) {
                shiftRight(e);
                shiftRight(f);
            }
        }
        e = c;
        e *= b;
        f = a;
        f -= e;
        d = f;
    }
    
    BigInteger::BigInteger(const BigInteger& other) {
        digits = other.digits;
        negative = other.negative;
    }
    
    BigInteger::BigInteger(const char* other) {
        digits.push_back(0);
        negative = false;
        BigInteger ten;
        ten.digits[0] = 10;
        const char* c = other;
        bool make_negative = false;
        if (*c == '-') {
            make_negative = true;
            ++c;
        }
        while (*c != 0) {
            BigInteger digit;
            digit.digits[0] = *c - '0';
            *this *= ten;
            *this += digit;
            ++c;
        }
        negative = make_negative;
    }
    
    bool BigInteger::isOdd() const {
        return (digits[0] & 0x1) != 0;
    }
    
    BigInteger& BigInteger::operator=(const BigInteger& other) {
        if (this == &other) // handle self assignment
            return *this;
        digits = other.digits;
        negative = other.negative;
        return *this;
    }
    
    BigInteger& BigInteger::operator+=(const BigInteger& other) {
        BigInteger result;
        if (negative) {
            if (other.negative) {
                result.negative = true;
                addWithoutSign(result, *this, other);
            } else {
                int a = cmpWithoutSign(*this, other);
                if (a < 0) {
                    result.negative = false;
                    subWithoutSign(result, other, *this);
                } else if (a > 0) {
                    result.negative = true;
                    subWithoutSign(result, *this, other);
                } else {
                    result.negative = false;
                    result.digits.clear();
                    result.digits.push_back(0);
                }
            }
        } else {
            if (other.negative) {
                int a = cmpWithoutSign(*this, other);
                if (a < 0) {
                    result.negative = true;
                    subWithoutSign(result, other, *this);
                } else if (a > 0) {
                    result.negative = false;
                    subWithoutSign(result, *this, other);
                } else {
                    result.negative = false;
                    result.digits.clear();
                    result.digits.push_back(0);
                }
            } else {
                result.negative = false;
                addWithoutSign(result, *this, other);
            }
        }
        negative = result.negative;
        digits.swap(result.digits);
        return *this;
    }
    
    BigInteger& BigInteger::operator-=(const BigInteger& other) {
        BigInteger neg_other = other;
        neg_other.negative = !neg_other.negative;
        return *this += neg_other;
    }
    
    BigInteger& BigInteger::operator*=(const BigInteger& other) {
        BigInteger result;
        for (int i = 0; i < (int)digits.size(); ++i) {
            BigInteger mult;
            multByDigitWithoutSign(mult, other, digits[i]);
            BigInteger shift;
            shiftLeftByBase(shift, mult, i);
            BigInteger add;
            addWithoutSign(add, result, shift);
            result = add;
        }
        if (negative != other.negative) {
            result.negative = true;
        } else {
            result.negative = false;
        }
        //std::cout << toString() << " x " << other.toString() << " == " << result.toString() << std::endl;
        negative = result.negative;
        digits.swap(result.digits);
        return *this;
    }
    
    BigInteger& BigInteger::operator/=(const BigInteger& other) {
        BigInteger result, tmp;
        divideWithoutSign(result, tmp, *this, other);
        result.negative = (negative != other.negative);
        negative = result.negative;
        digits.swap(result.digits);
        return *this;
    }
    
    BigInteger& BigInteger::operator%=(const BigInteger& other) {
        BigInteger c, d;
        divideWithoutSign(c, d, *this, other);
        *this = d;
        return *this;
    }
    
    bool BigInteger::operator>(const BigInteger& other) const {
        if (negative) {
            if (other.negative) {
                return cmpWithoutSign(*this, other) < 0;
            } else {
                return false;
            }
        } else {
            if (other.negative) {
                return true;
            } else {
                return cmpWithoutSign(*this, other) > 0;
            }
        }
    }
    
    BigInteger& BigInteger::powAssignUnderMod(const BigInteger& exponent, const BigInteger& modulus) {
        BigInteger zero("0");
        BigInteger one("1");
        BigInteger e = exponent;
        BigInteger base = *this;
        *this = one;
        while (cmpWithoutSign(e, zero) != 0) {
            //std::cout << e.toString() << " : " << toString() << " : " << base.toString() << std::endl;
            if (e.isOdd()) {
                *this *= base;
                *this %= modulus;
            }
            shiftRight(e);
            base *= BigInteger(base);
            base %= modulus;
        }
        return *this;
    }
    
    std::string BigInteger::toString() const {
        std::ostringstream os;
        if (negative)
            os << "-";
        BigInteger tmp = *this;
        BigInteger zero("0");
        BigInteger ten("10");
        tmp.negative = false;
        std::stack s;
        while (cmpWithoutSign(tmp, zero) != 0) {
            BigInteger tmp2, tmp3;
            divideWithoutSign(tmp2, tmp3, tmp, ten);
            s.push((char)(tmp3.digits[0] + '0'));
            tmp = tmp2;
        }
        while (!s.empty()) {
            os << s.top();
            s.pop();
        }
        /*
        for (int i = digits.size()-1; i >= 0; --i) {
            os << digits[i];
            if (i != 0) {
                os << ",";
            }
        }
        */
        return os.str();
    

    And an example usage.

    BigInteger a("87682374682734687"), b("435983748957348957349857345"), c("2348927349872344")
    
    // Will Calculate pow(87682374682734687, 435983748957348957349857345) % 2348927349872344
    a.powAssignUnderMod(b, c);
    

    Its fast too, and has unlimited number of digits.

提交回复
热议问题