You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

1399 lines
43 KiB

/*==LICENSE==*
CyanWorlds.com Engine - MMOG client, server and tools
Copyright (C) 2011 Cyan Worlds, Inc.
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
Additional permissions under GNU GPL version 3 section 7
If you modify this Program, or any covered work, by linking or
combining it with any of RAD Game Tools Bink SDK, Autodesk 3ds Max SDK,
NVIDIA PhysX SDK, Microsoft DirectX SDK, OpenSSL library, Independent
JPEG Group JPEG library, Microsoft Windows Media SDK, or Apple QuickTime SDK
(or a modified version of those libraries),
containing parts covered by the terms of the Bink SDK EULA, 3ds Max EULA,
PhysX SDK EULA, DirectX SDK EULA, OpenSSL and SSLeay licenses, IJG
JPEG Library README, Windows Media SDK EULA, or QuickTime SDK EULA, the
licensors of this Program grant you additional
permission to convey the resulting work. Corresponding Source for a
non-source form of such a combination shall include the source code for
the parts of OpenSSL and IJG JPEG Library used as well as that of the covered
work.
You can contact Cyan Worlds, Inc. by email legal@cyan.com
or by snail mail at:
Cyan Worlds, Inc.
14617 N Newport Hwy
Mead, WA 99021
*==LICENSE==*/
/*****************************************************************************
*
* $/Plasma20/Sources/Plasma/NucleusLib/pnUtils/Private/pnUtBigNum.cpp
*
***/
#include "../Pch.h"
#pragma hdrstop
/****************************************************************************
*
* Constants and macros
*
***/
const unsigned VAL_BITS = 8 * sizeof(BigNum::Val);
const BigNum::DVal VAL_RANGE = ((BigNum::DVal)1) << VAL_BITS;
#define LOW(dval) ((Val)(dval))
#define HIGH(dval) ((Val)((dval) / VAL_RANGE))
#define PACK(low, high) ((DVal)((high) * VAL_RANGE + (low)))
#define ALLOC_TEMP(struct, count) \
(struct).UseTempAlloc((Val *)_alloca((count) * sizeof(Val)), count)
/****************************************************************************
*
* BigNum private methods
*
***/
//===========================================================================
void BigNum::SetVal (unsigned index, Val value) {
ARRAY(Val)::operator[](index) = value;
}
//===========================================================================
void BigNum::SetVal (unsigned index, DVal value, Val * carry) {
ARRAY(Val)::operator[](index) = LOW(value);
*carry = HIGH(value);
}
//===========================================================================
void BigNum::Trim (unsigned count) {
ASSERT(count <= Count());
while (count && !ARRAY(Val)::operator[](count - 1))
--count;
SetCountFewer(count);
}
//===========================================================================
BigNum * BigNum::UseTempAlloc (Val * ptr, unsigned count) {
m_isTemp = true;
AttachTemp(ptr, count);
return this;
}
/****************************************************************************
*
* BigNum public methods
*
***/
//===========================================================================
BigNum::BigNum () :
m_isTemp(false)
{
}
//===========================================================================
BigNum::BigNum (const BigNum & a) :
m_isTemp(false)
{
Set(a);
}
//===========================================================================
BigNum::BigNum (unsigned a) :
m_isTemp(false)
{
Set(a);
}
//===========================================================================
BigNum::BigNum (unsigned bytes, const void * data) :
m_isTemp(false)
{
FromData(bytes, data);
}
//===========================================================================
BigNum::BigNum (const wchar str[], Val radix) :
m_isTemp(false)
{
FromStr(str, radix);
}
//===========================================================================
BigNum::~BigNum () {
if (m_isTemp)
Detach();
}
//===========================================================================
void BigNum::Add (const BigNum & a, Val b) {
// this = a + b
const unsigned count = a.Count();
GrowToCount(count + 1, true);
unsigned index = 0;
Val carry = b;
for (; index < count; ++index)
SetVal(index, (DVal)((DVal)a[index] + (DVal)carry), &carry);
if (carry)
SetVal(index++, carry);
Trim(index);
}
//===========================================================================
void BigNum::Add (const BigNum & a, const BigNum & b) {
// this = a + b
const unsigned aCount = a.Count();
const unsigned bCount = b.Count();
const unsigned count = aCount + bCount;
GrowToCount(count + 1, true);
unsigned index = 0;
Val carry = 0;
for (; index < count; ++index) {
Val aVal = (index < aCount) ? a[index] : (Val)0;
Val bVal = (index < bCount) ? b[index] : (Val)0;
SetVal(index, (DVal)((DVal)aVal + (DVal)bVal + (DVal)carry), &carry);
}
if (carry)
SetVal(index++, carry);
Trim(index);
}
//===========================================================================
int BigNum::Compare (Val a) const {
// -1 if (this < a)
// 0 if (this == a)
// 1 if (this > a)
// Handle the case where this number has more digits than the comparand
const unsigned count = Count();
ASSERT(!count || (*this)[count - 1]);
if (count > 1)
return 1;
// Handle the case where this number has fewer digits than the comparand
if (!count)
return a ? -1 : 0;
// Handle the case where both numbers share the same number of digits
Val thisVal = (*this)[0];
return (thisVal > a) ? 1 : (thisVal < a) ? -1 : 0;
}
//===========================================================================
int BigNum::Compare (const BigNum & a) const {
// -1 if (this < a)
// 0 if (this == a)
// 1 if (this > a)
// Handle the case where this number has more digits than the comparand
const unsigned thisCount = Count();
const unsigned compCount = a.Count();
ASSERT(!thisCount || (*this)[thisCount - 1]);
ASSERT(!compCount || a[compCount - 1]);
if (thisCount > compCount)
return 1;
// Handle the case where this number has fewer digits than the comparand
if (thisCount < compCount)
return -1;
// Handle the case where both numbers share the same number of digits
for (unsigned index = thisCount; index--; ) {
Val thisVal = (*this)[index];
Val compVal = a[index];
if (thisVal == compVal)
continue;
return (thisVal > compVal) ? 1 : -1;
}
return 0;
}
//===========================================================================
void BigNum::Div (const BigNum & a, Val b, Val * remainder) {
// this = a / b, remainder = a % b
const unsigned count = a.Count();
SetCount(count);
*remainder = 0;
for (unsigned index = count; index--; ) {
DVal value = PACK(a[index], *remainder);
SetVal(index, (Val)(value / b));
*remainder = (Val)(value % b);
}
Trim(count);
}
//===========================================================================
void BigNum::Div (const BigNum & a, const BigNum & b, BigNum * remainder) {
// this = a / b, remainder = a % b
// either this or remainder may be nil
ASSERT(this != remainder);
// Check for division by zero
ASSERT(b.Count() && b[b.Count() - 1]);
// Normalize the operands so that the highest bit is set in the most
// significant word of the denominator
const unsigned shift = 8 * sizeof(Val) - MathHighBitPos(b[b.Count() - 1]) - 1;
BigNum aaBuffer;
BigNum bbBuffer;
BigNum * aa = shift ? ALLOC_TEMP(aaBuffer, a.Count() + 1) : (BigNum *)&a;
BigNum * bb = shift ? ALLOC_TEMP(bbBuffer, b.Count() + 1) : (BigNum *)&b;
if (shift) {
aa->Shl(a, shift);
bb->Shl(b, shift);
}
// Perform the division
DivNormalized(*aa, *bb, remainder);
// Denormalize the remainder
if (remainder)
remainder->Shr(*remainder, shift);
}
//===========================================================================
void BigNum::DivNormalized (const BigNum & a, const BigNum & b, BigNum * remainder) {
// this = a / b, remainder = a % b
// either this or remainder may be nil
// high bit of b must be set
ASSERT(this != remainder);
// Check for division by zero
ASSERT(b.Count() && b[b.Count() - 1]);
// Verify that the operands are normalized
ASSERT(MathHighBitPos(b[b.Count() - 1]) == 8 * sizeof(Val) - 1);
// Handle the case where the denominator is greater than the numerator
if ((b.Count() > a.Count()) || (b.Compare(a) > 0)) {
if (remainder && (remainder != &a))
remainder->Set(a);
if (this)
ZeroCount();
return;
}
// Create a distinct buffer for the denominator if necessary
BigNum denomTemp;
BigNum * denom = ((&b != this) && (&b != remainder)) ? (BigNum *)&b : ALLOC_TEMP(denomTemp, b.Count());
denom->Set(b);
// Store the numerator into the remainder buffer
BigNum numerTemp;
BigNum * numer = remainder ? remainder : ALLOC_TEMP(numerTemp, a.Count());
numer->Set(a);
// Prepare the destination buffer
const unsigned numerCount = numer->Count();
const unsigned denomCount = denom->Count();
if (this)
this->SetCount(numerCount + 1 - denomCount);
// Calculate the quotient one word at a time
DVal t = (DVal)((DVal)(*denom)[denomCount - 1] + (DVal)1);
for (unsigned quotientIndex = numerCount + 1 - denomCount; quotientIndex--; ) {
// Calculate the approximate value of the next quotient word,
// erring on the side of underestimation
Val low = (*numer)[quotientIndex + denomCount - 1];
Val high = (quotientIndex + denomCount < numerCount) ? (*numer)[quotientIndex + denomCount] : (Val)0;
ASSERT(high < t);
Val quotient = (Val)(PACK(low, high) / t);
// Calculate the product of the denominator and this quotient word
// (using zero for all lower quotient words) and subtract the product
// from the current numerator
if (quotient) {
Val borrow = 0;
Val carry = 0;
for (unsigned denomIndex = 0; denomIndex != denomCount; ++denomIndex) {
DVal product = (DVal)(Mul((*denom)[denomIndex], quotient) + carry);
carry = HIGH(product);
numer->SetVal(quotientIndex + denomIndex, (DVal)((DVal)(*numer)[quotientIndex + denomIndex] - (DVal)LOW(product) - (DVal)borrow), &borrow);
borrow = (Val)((Val)0 - (Val)borrow);
}
if (quotientIndex + denomCount != numerCount) {
numer->SetVal(quotientIndex + denomCount, (DVal)((DVal)(*numer)[quotientIndex + denomIndex] - (DVal)carry - (DVal)borrow), &borrow);
carry = 0;
}
ASSERT(!carry);
ASSERT(!borrow);
}
// Check whether we underestimated the quotient word, and adjust
// it if necessary
for (;;) {
// Test whether the current numerator is still greater than or
// equal to the denominator
if ((quotientIndex + denomCount == numerCount) || !(*numer)[quotientIndex + denomCount]) {
bool numerLess = false;
for (unsigned denomIndex = denomCount; !numerLess && denomIndex--; ) {
Val numerVal = (*numer)[quotientIndex + denomIndex];
Val denomVal = (*denom)[denomIndex];
numerLess = (numerVal < denomVal);
if (numerVal != denomVal)
break;
}
if (numerLess)
break;
}
// Increment the quotient by one, and correct the current
// numerator for this adjustment by subtracting the denominator
++quotient;
Val borrow = 0;
for (unsigned denomIndex = 0; denomIndex != denomCount; ++denomIndex) {
numer->SetVal(quotientIndex + denomIndex, (DVal)((DVal)(*numer)[quotientIndex + denomIndex] - (DVal)(*denom)[denomIndex] - (DVal)borrow), &borrow);
borrow = (Val)((Val)0 - (Val)borrow);
}
if (borrow)
numer->SetVal(quotientIndex + denomCount, (DVal)((DVal)(*numer)[quotientIndex + denomCount] - (DVal)borrow), &borrow);
ASSERT(!borrow);
}
ASSERT((quotientIndex + denomCount == numerCount) || !(*numer)[quotientIndex + denomCount]);
// Store the final quotient word
if (this)
this->SetVal(quotientIndex, quotient);
}
// The final remainder is the remaining portion of the numerator
if (remainder) {
ASSERT(remainder == numer);
remainder->Trim(denomCount);
}
// Trim the result
if (this)
this->Trim(numerCount + 1 - denomCount);
}
//===========================================================================
void BigNum::FromData (unsigned bytes, const void * data) {
ASSERT(data || !bytes);
// Calculate the number of words required to hold the data
unsigned count = (bytes + sizeof(Val) - 1) / sizeof(Val);
SetCount(count);
// Fill in whole words
unsigned index = 0;
unsigned offset = 0;
for (; offset + sizeof(Val) <= bytes; ++index, offset += sizeof(Val))
SetVal(index, *(const Val *)((const byte *)data + offset));
// Fill in the final partial word
if (offset < bytes) {
Val value = 0;
MemCopy(&value, (const byte *)data + offset, bytes - offset);
SetVal(index, value);
}
}
//===========================================================================
void BigNum::FromStr (const wchar str[], Val radix) {
ASSERT(str);
// Decode the prefix
if (str[0] == L'0') {
if ((str[1] == L'x') || (str[1] == L'X')) {
str += 2;
if (!radix)
radix = 16;
}
else if ((str[1] >= L'0') && (str[1] <= L'9')) {
str += 1;
if (!radix)
radix = 8;
}
else if (!radix) {
ZeroCount();
return;
}
}
else if (!radix)
radix = 10;
// Decode the number
ZeroCount();
for (; *str; ++str) {
// Decode the next character
Val value;
if ((*str >= L'0') && (*str <= '9'))
value = (Val)(*str - L'0');
else if ((*str >= L'a') && (*str <= L'z'))
value = (Val)(*str + 10 - L'a');
else if ((*str >= L'A') && (*str <= L'Z'))
value = (Val)(*str + 10 - L'A');
else
break;
if (value >= radix)
break;
// Apply it to the result
Mul(*this, radix);
Add(*this, value);
}
}
//===========================================================================
void BigNum::Gcd (const BigNum & a, const BigNum & b) {
// Allocate working copies of a and b
BigNum aa;
BigNum bb;
unsigned maxCount = max(a.Count(), b.Count());
ALLOC_TEMP(aa, maxCount + 1);
ALLOC_TEMP(bb, maxCount + 1);
aa.Set(a);
bb.Set(b);
// Find the greatest common denominator using Euclid's algorithm
Set(bb);
while (aa.Count()) {
Set(aa);
aa.Mod(bb, aa);
bb.Set(*this);
}
}
//===========================================================================
const void * BigNum::GetData (unsigned * bytes) const {
if (bytes)
*bytes = Bytes();
return Ptr();
}
//===========================================================================
unsigned BigNum::HighBitPos () const {
// returns the position of the highest set bit, or -1 if no bits are set
for (unsigned index = Count(); index--; ) {
Val val = (*this)[index];
if (!val)
continue;
return index * 8 * sizeof(Val) + MathHighBitPos(val);
}
return (unsigned)-1;
}
//===========================================================================
bool BigNum::InverseMod (const BigNum & a, const BigNum & b) {
// finds value for this such that (a ^ -1) == (this mod b)
// returns false if a has no inverse modulo b
// Verify that a and b are nonzero
ASSERT(a.Count());
ASSERT(b.Count());
// Verify that a is less than b
ASSERT(a.Compare(b) < 0);
// Verify that either a or b is odd. If both are even then they cannot
// possibly be relatively prime, so there cannot be a solution.
if (!(a.IsOdd() || b.IsOdd()))
return false;
// Allocate buffers for intermediate values
BigNum uArray[3];
BigNum tArray[3];
BigNum * u = uArray;
BigNum * t = tArray;
// Find the inverse using the extended Euclidean algorithm
u[0].SetOne();
u[1].SetZero();
u[2].Set(b);
t[0].Set(a);
t[1].Sub(b, 1);
t[2].Set(a);
do {
do {
if (!u[2].IsOdd()) {
if (u[0].IsOdd() || u[1].IsOdd()) {
u[0].Add(u[0], a);
u[1].Add(u[1], b);
}
u[0].Shr(u[0], 1);
u[1].Shr(u[1], 1);
u[2].Shr(u[2], 1);
}
if (!t[2].IsOdd() || (u[2].Compare(t[2]) < 0))
SWAP(u, t);
} while (!u[2].IsOdd());
while ((u[0].Compare(t[0]) < 0) || (u[1].Compare(t[1]) < 0)) {
u[0].Add(u[0], a);
u[1].Add(u[1], b);
}
u[0].Sub(u[0], t[0]);
u[1].Sub(u[1], t[1]);
u[2].Sub(u[2], t[2]);
} while (t[2].Count());
while ((u[0].Compare(a) >= 0) && (u[1].Compare(b) >= 0)) {
u[0].Sub(u[0], a);
u[1].Sub(u[1], b);
}
// If the greatest common denominator is not one then there is no
// solution
if (u[2].Compare(1))
return false;
// Return the solution
Sub(b, u[1]);
return true;
}
//===========================================================================
bool BigNum::IsMultiple (Val a) const {
// returns true if (this % a) == 0
DVal remainder = 0;
for (unsigned index = Count(); index--; )
remainder = (DVal)(PACK((*this)[index], remainder) % a);
return !remainder;
}
//===========================================================================
bool BigNum::IsOdd () const {
// returns true if this is an odd number
return Count() ? (*this)[0] & 1 : false;
}
//===========================================================================
bool BigNum::IsPrime () const {
// returns true if there is a strong likelihood that this is prime
// Verify that the number is odd, or is exactly equal to two
if (!Count() || (!((*this)[0] & 1) && ((Count() > 1) || ((*this)[0] > 2))))
return false;
// Verify that the number is not evenly divisible by a small prime
static const Val smallPrimes[] = {3, 5, 7, 11};
unsigned loop;
for (loop = 0; loop != arrsize(smallPrimes); ++loop)
if (IsMultiple(smallPrimes[loop]))
return false;
if (Compare(smallPrimes[arrsize(smallPrimes)-1]) <= 0)
return true;
// Rabin-Miller Test
// Calculate b, where b is the number of times 2 divides (this - 1)
BigNum this_1;
ALLOC_TEMP(this_1, Count());
this_1.Sub(*this, 1);
const unsigned b = this_1.LowBitPos();
// Calculate m, such that this == 1 + 2 ^ b * m
BigNum m;
ALLOC_TEMP(m, Count());
m.Shr(this_1, b);
// For a number of witnesses, test whether the witness demonstrates this
// number to be composite via Fermat's Little Theorem, or has a
// nontrivial square root mod n
static const Val witnesses[] = {3, 5, 7};
BigNum z;
ALLOC_TEMP(z, 2 * (Count() + 1));
for (loop = 0; loop != arrsize(witnesses); ++loop) {
// Initialize z to (witness ^ m % this)
z.PowMod(witnesses[loop], m, *this);
// This passes the test if (z == 1)
if (!z.Compare(1))
continue;
for (unsigned j = 0; z.Compare(this_1); ) {
// This fails the test if we reach b iterations.
++j;
if (j == b)
return false;
// Square z. This fails the test if z mod this equals 1.
z.MulMod(z, z, *this);
if (!z.Compare(1))
return false;
}
}
return true;
}
//===========================================================================
unsigned BigNum::LowBitPos () const {
// returns the position of the lowest set bit, or -1 if no bits are set
for (unsigned index = 0, count = Count(); index < count; ++index) {
Val val = (*this)[index];
if (!val)
continue;
for (unsigned bit = 0; ; ++bit)
if (val & (1 << bit))
return index * 8 * sizeof(Val) + bit;
}
return (unsigned)-1;
}
//===========================================================================
void BigNum::Mod (const BigNum & a, const BigNum & b) {
// this = a % b
((BigNum *)nil)->Div(a, b, this);
}
//===========================================================================
void BigNum::ModNormalized (const BigNum & a, const BigNum & b) {
// this = a % b
// high bit of b must be set
((BigNum *)nil)->DivNormalized(a, b, this);
}
//===========================================================================
BigNum::DVal BigNum::Mul (BigNum::Val a, BigNum::Val b) {
// returns a * b
return (DVal)a * (DVal)b;
}
//===========================================================================
void BigNum::Mul (const BigNum & a, Val b) {
// this = a * b
const unsigned count = a.Count();
GrowToCount(count + 1, true);
unsigned index = 0;
Val carry = 0;
for (; index < count; ++index)
SetVal(index, (DVal)(Mul(a[index], b) + carry), &carry);
if (carry)
SetVal(index++, carry);
Trim(index);
}
//===========================================================================
void BigNum::Mul (const BigNum & a, const BigNum & b) {
// this = a * b
const unsigned aCount = a.Count();
const unsigned bCount = b.Count();
const unsigned count = aCount + bCount;
SetCount(count);
if (!count)
return;
// We perform the multiplication from left to right, so that we don't
// overwrite any operand words before they're used in the case that
// the destination is not distinct from either of the operands
SetVal(count - 1, 0);
for (unsigned index = count - 1; index--; ) {
// Iterate every combination of aIndex + bIndex that adds up to
// index, and sum the products of those words
DVal value = 0;
const unsigned aStart = (index < bCount) ? 0 : (index + 1 - bCount);
const unsigned aTerm = min(index + 1, aCount);
for (unsigned aIndex = aStart; aIndex != aTerm; ++aIndex) {
// Accumulate the product of this pair of words
value = (DVal)(Mul(a[aIndex], b[index - aIndex]) + value);
// If the product exceeds the word size, apply carry
Val carry = HIGH(value);
for (unsigned carryIndex = index + 1; carry; ++carryIndex)
SetVal(carryIndex, (DVal)((DVal)(*this)[carryIndex] + (DVal)carry), &carry);
value = LOW(value);
}
// Store the sum of products as the final value for index
SetVal(index, LOW(value));
}
Trim(count);
}
//===========================================================================
void BigNum::MulMod (const BigNum & a, const BigNum & b, const BigNum & c) {
// this = a * b % c
if (this != &c) {
Mul(a, b);
Mod(*this, c);
}
else {
BigNum temp;
ALLOC_TEMP(temp, a.Count() + b.Count());
temp.Mul(a, b);
Mod(temp, c);
}
}
//===========================================================================
void BigNum::PowMod (Val a, const BigNum & b, const BigNum & c) {
// this = a ^ b % c
// Verify that b is distinct from this
BigNum bbBuffer;
const BigNum & bb = (&b != this) ? b : bbBuffer;
if (&bb != &b) {
ALLOC_TEMP(bbBuffer, b.Count());
bbBuffer.Set(b);
}
// Generate a table which may allow us to process two bits at once
Val aMult[4] = {
1,
a,
(Val)(a * a),
(Val)(a * a * a)
};
bool overflow = (aMult[2] < a) || (aMult[3] < a) || (c.Compare(aMult[3]) <= 0);
// Normalize the denominator so that the high bit is set. The result
// will be built shifted an equivalent amount.
const unsigned shift = 8 * sizeof(Val) - MathHighBitPos(c[c.Count() - 1]) - 1;
BigNum cc;
ALLOC_TEMP(cc, c.Count() + 1);
cc.Shl(c, shift);
// Perform the exponentiation from left to right two bits at a time
if (!overflow) {
SetBits(shift, 1);
bool anySet = false;
for (unsigned index = bb.Count(); index--; )
for (unsigned bit = 8 * sizeof(Val); bit; ) {
bit -= 2;
if (anySet) {
Square(*this);
Shr(*this, shift);
ModNormalized(*this, cc);
Square(*this);
Shr(*this, shift);
ModNormalized(*this, cc);
}
unsigned entry = (bb[index] >> bit) & 3;
if (entry) {
Mul(*this, aMult[entry]);
ModNormalized(*this, cc);
anySet = true;
}
}
}
// Perform the exponentiation from left to right a single bit at a time
else {
SetBits(shift, 1);
bool anySet = false;
for (unsigned index = bb.Count(); index--; )
for (unsigned bit = 8 * sizeof(Val); bit--; ) {
if (anySet) {
Square(*this);
ModNormalized(*this, cc);
}
if (bb[index] & (1 << bit)) {
Mul(*this, a);
ModNormalized(*this, cc);
anySet = true;
}
}
}
// Denormalize the result
Shr(*this, shift);
}
//===========================================================================
void BigNum::PowMod (const BigNum & a, const BigNum & b, const BigNum & c) {
// this = a ^ b % c
// Verify that a and b are distinct from this
BigNum distinctTemp;
const BigNum & aa = (&a != this) ? a : distinctTemp;
const BigNum & bb = (&b != this) ? b : distinctTemp;
if ((&aa != &a) || (&bb != &b)) {
ALLOC_TEMP(distinctTemp, Count());
distinctTemp.Set(*this);
}
// Generate a table which will allow us to process two bits at once
BigNum a2;
BigNum a3;
ALLOC_TEMP(a2, 2 * aa.Count() + 1);
ALLOC_TEMP(a3, 3 * aa.Count() + 1);
a2.Mul(aa, aa);
a2.Mod(a2, c);
a3.Mul(aa, a2);
a3.Mod(a3, c);
const BigNum * aMult[] = {
nil,
&aa,
&a2,
&a3
};
// Normalize the denominator so that the high bit is set. The result
// will be built shifted an equivalent amount.
const unsigned shift = 8 * sizeof(Val) - MathHighBitPos(c[c.Count() - 1]) - 1;
BigNum cc;
ALLOC_TEMP(cc, c.Count() + 1);
cc.Shl(c, shift);
// Perform the exponentiation from left to right two bits at a time
SetBits(shift, 1);
bool anySet = false;
for (unsigned index = bb.Count(); index--; )
for (unsigned bit = 8 * sizeof(Val); bit; ) {
bit -= 2;
if (anySet) {
Square(*this);
Shr(*this, shift);
ModNormalized(*this, cc);
Square(*this);
Shr(*this, shift);
ModNormalized(*this, cc);
}
unsigned entry = (bb[index] >> bit) & 3;
if (entry) {
Mul(*this, *aMult[entry]);
ModNormalized(*this, cc);
anySet = true;
}
}
// Denormalize the result
Shr(*this, shift);
}
//===========================================================================
void BigNum::Rand (const BigNum & a, BigNum * seed) {
// this = random number less than a
ASSERT(seed != &a);
ASSERT(seed != this);
// Verify that a is distinct from this
BigNum distinctTemp;
const BigNum & aa = (&a != this) ? a : distinctTemp;
if (&aa != &a) {
ALLOC_TEMP(distinctTemp, a.Count());
distinctTemp.Set(a);
}
// Count the number of bits in a
unsigned bits = aa.HighBitPos() + 1;
for (;;) {
// Generate a random number with the same number of bits as a
Rand(bits, seed);
// Check whether the number is less than a
if (Compare(aa) < 0)
break;
}
}
//===========================================================================
void BigNum::Rand (unsigned bits, BigNum * seed) {
// this = random number with bits or fewer bits
ASSERT(seed != this);
// Prepare the output buffer
const unsigned count = (bits + 8 * sizeof(Val) - 1) / (8 * sizeof(Val));
SetCount(count);
if (!count)
return;
// Prepare the seed
unsigned seedCount = seed->Count();
if (!seedCount)
seed->SetCount(++seedCount);
unsigned seedIndex = 0;
// Produce a random number with the correct number of words
for (unsigned index = 0; index < count; ++index) {
// Read the next word of the seed
dword randValue = (*seed)[seedIndex] ^ ((index == seedIndex) ? 0x075bd924 : 0);
// Produce one word of randomness, 16 bits at a time
Val value = 0;
for (unsigned bit = 0; bit < 8 * sizeof(Val); bit += 16) {
const dword A = 0xbc8f;
const dword Q = 0xadc8;
const dword R = 0x0d47;
dword div = randValue / Q;
randValue = A * (randValue - Q * div) - R * div;
randValue &= 0x7fffffff;
value |= (randValue & 0xffff) << bit;
}
// Store the random word
SetVal(index, value);
// Update the seed and move to the seed next word
seed->SetVal(seedIndex, (Val)randValue);
if (++seedIndex >= seedCount)
seedIndex = 0;
}
// Mask the final word to contain the correct number of bits
Val mask = (Val)((Val)-1 >> (count * 8 * sizeof(Val) - bits));
SetVal(count - 1, (Val)((*this)[count - 1] & mask));
// Trim the result
Trim(count);
// Rotate the seed so the next unused seed word will be the first seed
// word used in the next random operation
if (seedIndex) {
BigNum saved;
ALLOC_TEMP(saved, seedCount);
saved.Set(*seed);
for (unsigned index = 0; index < seedCount; ++index)
(*seed)[index] = saved[(index + seedIndex) % seedCount];
}
}
//===========================================================================
void BigNum::RandPrime (unsigned bits, BigNum * seed) {
// Calculate the required number of words to hold the generated number
unsigned count = (bits + 8 * sizeof(Val) - 1) / (8 * sizeof(Val));
// For large bit counts, calculate the prime number as 2 * q * n + 1,
// where q is a random prime with fewer bits, and n is a random number
// chosen as follows:
// n >= (2 ^ (bits - 1) - 1) / (2 * q)
// n < (2 ^ bits - 1) / (2 * q)
if (bits > 128) {
// Choose a prime number q, and multiply it by 2
BigNum q_2;
ALLOC_TEMP(q_2, count / 2 + 2);
q_2.RandPrime(bits / 2, seed);
q_2.Mul(q_2, 2);
// Calculate the lower bound
BigNum lowerBound;
ALLOC_TEMP(lowerBound, count + 1);
lowerBound.SetBits(0, bits - 1);
lowerBound.Div(lowerBound, q_2, nil);
// Calculate the upper bound
BigNum upperBound;
ALLOC_TEMP(upperBound, count + 1);
upperBound.SetBits(0, bits);
upperBound.Div(upperBound, q_2, nil);
// Calculate the number of bits in the upper bound
unsigned upperBoundBits = upperBound.HighBitPos() + 1;
for (;;) {
// Choose a random number between the lower and upper bounds
Rand(upperBoundBits, seed);
if (Compare(upperBound) >= 0)
continue;
if (Compare(lowerBound) < 0)
continue;
// Calculate 2 * q * n + 1
Mul(*this, q_2);
Add(*this, 1);
// Test whether the result is prime
if (IsPrime())
break;
}
}
// For small bit counts, choose a random number with the requested
// number of bits, then keep incrementing it until we find a prime
else {
// Define the upper bound for a number with the requested number
// of bits
BigNum bound;
ALLOC_TEMP(bound, count + 1);
bound.SetBits(bits, 1);
for (;;) {
// Choose a random number with (bits - 1) bits
Rand(bits - 1, seed);
// Subtract it from the upper bound to generate a number with
// the high bit set
Sub(bound, *this);
// Keep incrementing the number until we find a prime
if (!IsOdd())
Add(*this, 1);
while (!IsPrime())
Add(*this, 2);
// If the number reached or exceeded the upper bound, try again
if (Compare(bound) < 0)
break;
}
}
}
//===========================================================================
void BigNum::Set (const BigNum & a) {
// this = a
if (&a == this)
return;
const unsigned count = a.Count();
SetCount(count);
for (unsigned index = count; index--; )
SetVal(index, a[index]);
}
//===========================================================================
void BigNum::Set (unsigned a) {
// this = a
ZeroCount();
if (a)
for (unsigned index = 0; ; ++index) {
SetCount(index + 1);
SetVal(index, LOW(a));
if (a < VAL_RANGE)
break;
a = (unsigned)(a / VAL_RANGE);
}
}
//===========================================================================
void BigNum::SetBits (unsigned setBitsOffset, unsigned setBitsCount) {
// this = binary [1...][0...], where 'setBitsOffset' is the number of
// zero bits and 'setBitsCount' is the number of one bits
if (!setBitsCount) {
ZeroCount();
return;
}
const unsigned setBitsTerm = setBitsOffset + setBitsCount - 1;
const unsigned bitsPerWord = 8 * sizeof(Val);
const unsigned firstSetWord = setBitsOffset / bitsPerWord;
const unsigned lastSetWord = (setBitsOffset + setBitsCount - 1) / bitsPerWord;
Val firstSetMask = (Val)((Val)-1 << (setBitsOffset % bitsPerWord));
Val lastSetMask = (Val)((Val)-1 >> (bitsPerWord - setBitsTerm % bitsPerWord - 1));
if (firstSetWord == lastSetWord)
firstSetMask = lastSetMask = (Val)(firstSetMask & lastSetMask);
SetCount(lastSetWord + 1);
unsigned index = 0;
for (; index < firstSetWord; ++index)
SetVal(index, 0);
SetVal(index++, firstSetMask);
if (firstSetWord == lastSetWord)
return;
for (; index < lastSetWord; ++index)
SetVal(index, (Val)-1);
SetVal(index, lastSetMask);
}
//===========================================================================
void BigNum::SetOne () {
// this = 1
SetCount(1);
SetVal(0, 1);
}
//===========================================================================
void BigNum::SetZero () {
// this = 0
ZeroCount();
}
//===========================================================================
void BigNum::Shl (const BigNum & a, unsigned b) {
// this = a << b
ASSERT(b < 8 * sizeof(Val));
if (!b) {
Set(a);
return;
}
const unsigned bInv = 8 * sizeof(Val) - b;
const unsigned count = a.Count();
SetCount(count + 1);
Val curr = 0;
for (unsigned index = count; index >= 1; --index) {
Val next = a[index - 1];
SetVal(index, (Val)((next >> bInv) | (curr << b)));
curr = next;
}
SetVal(0, (Val)(curr << b));
Trim(count + 1);
}
//===========================================================================
void BigNum::Shr (const BigNum & a, unsigned b) {
// this = a >> b
ASSERT(b < 8 * sizeof(Val));
if (!b) {
Set(a);
return;
}
const unsigned bInv = 8 * sizeof(Val) - b;
const unsigned count = a.Count();
SetCount(count);
if (!count)
return;
Val curr = a[0];
for (unsigned index = 0; index + 1 < count; ++index) {
Val next = a[index + 1];
SetVal(index, (Val)((next << bInv) | (curr >> b)));
curr = next;
}
SetVal(count - 1, (Val)(curr >> b));
Trim(count);
}
//===========================================================================
void BigNum::Square (const BigNum & a) {
// this = a * a
const unsigned aCount = a.Count();
const unsigned count = 2 * aCount;
SetCount(count);
if (!count)
return;
// We perform the multiplication from left to right, so that we don't
// overwrite any operand words before they're used in the case that
// the destination is not distinct from the operand
SetVal(count - 1, 0);
for (unsigned index = count - 1; index--; ) {
// Iterate every combination of source indices that adds up to
// index, and sum the products of those words
DVal value = 0;
unsigned aIndex = (index < aCount) ? 0 : (index + 1 - aCount);
unsigned bIndex;
for (; (int)((bIndex = index - aIndex) - aIndex) >= 0; ++aIndex) {
// Calculate the product of this pair of words
DVal product = Mul(a[aIndex], a[bIndex]);
// Add the product to the sum. If it exceeds the word size,
// apply carry.
value = (DVal)(value + product);
Val carry = HIGH(value);
unsigned carryIndex;
for (carryIndex = index + 1; carry; ++carryIndex)
SetVal(carryIndex, (DVal)((DVal)(*this)[carryIndex] + (DVal)carry), &carry);
value = LOW(value);
// If this pair of words should be multiplied twice, add the
// product again.
if (aIndex == bIndex)
continue;
value = (DVal)(value + product);
carry = HIGH(value);
for (carryIndex = index + 1; carry; ++carryIndex)
SetVal(carryIndex, (DVal)((DVal)(*this)[carryIndex] + (DVal)carry), &carry);
value = LOW(value);
}
// Store the sum of products as the final value for index
SetVal(index, LOW(value));
}
Trim(count);
}
//===========================================================================
void BigNum::Sub (const BigNum & a, Val b) {
// this = a - b
const unsigned count = a.Count();
SetCount(count);
Val borrow = b;
for (unsigned index = 0; index < count; ++index) {
SetVal(index, (DVal)((DVal)a[index] - (DVal)borrow), &borrow);
borrow = (Val)((Val)0 - (Val)borrow);
}
ASSERT(!borrow);
Trim(index);
}
//===========================================================================
void BigNum::Sub (const BigNum & a, const BigNum & b) {
// this = a - b
const unsigned count = a.Count();
const unsigned bCount = b.Count();
GrowToCount(count, true);
Val borrow = 0;
for (unsigned index = 0; index < count; ++index) {
Val bVal = (index < bCount) ? b[index] : (Val)0;
SetVal(index, (DVal)((DVal)a[index] - (DVal)bVal - (DVal)borrow), &borrow);
borrow = (Val)((Val)0 - (Val)borrow);
}
ASSERT(!borrow);
Trim(index);
}
//===========================================================================
void BigNum::ToStr (BigNum * buffer, Val radix) const {
ASSERT(this != buffer);
// Calculate the number of characters in the prefix
unsigned prefixChars;
if (radix == 16)
prefixChars = 2;
else if (radix == 8)
prefixChars = 1;
else
prefixChars = 0;
// Preallocate space for the output string
unsigned charsPerVal = 0;
for (Val testVal = (Val)-1; testVal; testVal = (Val)(testVal / radix))
++charsPerVal;
const unsigned charsTotal = max(1, Count()) * charsPerVal + prefixChars + 1;
buffer->SetCount((charsTotal * sizeof(wchar) + sizeof(Val) - 1) / sizeof(Val));
// Build the prefix
wchar * prefix = (wchar *)buffer->Ptr();
if (prefixChars) {
prefix[0] = L'0';
if (radix == 16)
prefix[1] = L'x';
else
ASSERT(prefixChars == 1);
}
// Build the number starting with the least significant digit
wchar * start = prefix + prefixChars;
wchar * curr = start;
BigNum work;
ALLOC_TEMP(work, Count());
work.Set(*this);
do {
// Extract the next value
Val remainder;
work.Div(work, radix, &remainder);
// Encode it as a character in the output string
if (remainder >= 10)
*curr++ = (wchar)(L'a' + (unsigned)remainder - 10);
else
*curr++ = (wchar)(L'0' + (unsigned)remainder);
} while (work.Count());
*curr = 0;
// Reverse the order of the output string
for (wchar * left = start, * right = curr - 1; left < right; ++left, --right)
SWAP(*left, *right);
}