Karatsuba Multiplication
Table of Contents
Imports and Setup
Imports
# from python
from __future__ import annotations
from functools import partial
from collections import namedtuple
import math
import random
import sys
# from pypi
from joblib import Parallel, delayed
from expects import (
be_a,
contain_exactly,
equal,
expect,
raise_error
)
import altair
import pandas
from graeae import Timer
from graeae.visualization.altair_helpers import output_path, save_chart
Set Up
TIMER = Timer()
SLUG = "karatsuba-multiplication"
OUTPUT_PATH = output_path(SLUG)
save_it = partial(save_chart, output_path=OUTPUT_PATH)
MultiplicationOutput = namedtuple("MultiplicationOutput", ["product", "count"])
PlotMultiplicationOutput = namedtuple(
"PlotOutput",
["product", "count", "digits", "factor_1", "factor_2"])
The Algorithms
Grade School Multiplication
Grade-School multiplication is how most of us are taught to multiply numbers with more than one digit each. Each digit in one number is multiplied by each digit in the other to create a partial product. Once we've gone through all the digits in the first number we sum up all the partial products we calculated to get our final answer.
\begin{algorithm} \caption{Grade-School} \begin{algorithmic} \REQUIRE The input arrays are of the same length ($n$) \INPUT Two arrays representing digits in integers ($a, b$) \OUTPUT The product of the inputs \PROCEDURE{GradeSchool}{$number_1, number_2$} \FOR {$j \in \{0 \ldots n - 1\}$} \STATE $carry \gets 0$ \FOR {$i \in \{0 \ldots n - 1\}$} \STATE $product \gets a[i] \times b[j] + carry$ \STATE $partial[j][i + j] \gets product \bmod 10$ \STATE $carry \gets product/10$ \ENDFOR \STATE $partial[j][n + j] \gets carry$ \ENDFOR \STATE $carry \gets 0$ \FOR {$i \in \{0 \ldots 2n - 1\}$} \STATE $sum \gets carry$ \FOR {$j \in \{0 \ldots n - 1\}$} \STATE $sum \gets sum + partial[j][i]$ \ENDFOR \STATE $result[i] \gets sum \bmod 10$ \STATE $carry \gets sum/10$ \ENDFOR \STATE $result[2n] \gets carry$ \RETURN result \ENDPROCEDURE \end{algorithmic} \end{algorithm}
Karatsuba Multiplication
Karatsuba Multiplication improves on the Grade-School algorithm using a trick from Frederick Gauss, which is a little too much of a diversion. With all these improvements in multiplication it seems like you need a background in number theory that not all Computer Scientist's have (or at least I don't). Maybe take it on faith for now.
\begin{algorithm} \caption{Karatsuba} \begin{algorithmic} \REQUIRE The input arrays are of the same length \INPUT Two arrays representing digits in integers \OUTPUT The product of the inputs \PROCEDURE{Karatsuba}{$number_1, number_2$} \STATE $\textit{digits} \gets $ \textsc{Length}($number_1$) \IF {$\textit{digits} = 1$} \RETURN $number_1 \times number_2$ \ENDIF \STATE $middle \gets \left\lfloor \frac{\textit{digits}}{2} \right\rfloor$ \STATE \\ \STATE $MostSignificant_1, LeastSignificant_1 \gets $ \textsc{Split}($number_1, middle$) \STATE $MostSignificant_2, LeastSignificant_2 \gets $ \textsc{Split}($number_2, middle$) \STATE \\ \STATE $MostPlusLeast_1 \gets MostSignificant_1 + LeastSignificant_1$ \STATE $MostPlusLeast_2 \gets MostSignificant_2 + LeastSignificant_2$ \STATE \\ \STATE \textit{left} $\gets $ \textsc{Karatsuba}($MostSignificant_1, MostSignificant_2$) \STATE \textit{summed} $\gets $ \textsc{Karatsuba}($MostPlusLeast_1, MostPlusLeast_2$) \STATE \textit{right} $\gets $ \textsc{Karatsuba}($LeastSignificant_1, LeastSignificant_2$) \STATE \\ \STATE \textit{center} $\gets$ (\textit{summed} - \textit{left} - \textit{right}) \STATE \\ \RETURN \textit{left} $\times 10^{\textit{digits}} + \textit{center} \times 10^{\textit{middle}} + \textit{right}$ \ENDPROCEDURE \end{algorithmic} \end{algorithm}
This is a mashup of the Wikipedia version and the Algorithms Illuminated version. It's a little tricky in that we're dealing with integers, in theory, but we have to know the number of digits and how to split it up so in code we're going to have to work with a collection instead, but hopefully this will be clearer in code.
An IntList
class IntegerDigits:
"""A hybrid integer and digits list
Args:
integer: the number to store
padding: number of 0's to add to the left of the digits
"""
def __init__(self, integer: int) -> None:
self.integer = integer
self._digits = None
return
@property
def digits(self) -> list:
"""The digits for the given integer
Raises:
ValueError if the given integer isn't really an integer
Returns:
zero-padded list of digits
"""
if self._digits is None:
digits = [int(digit) for digit in str(self.integer)]
length = len(digits)
power_of_two = 2**math.ceil(math.log2(length))
padding = power_of_two - length
self._digits = [0] * padding + digits
return self._digits
def add_padding(self, padding: int) -> None:
"""Add more zeros to the left of the digits
Args:
padding: number of zeros to add to the left of the digits
"""
self._digits = [0] * padding + self.digits
return
def set_length(self, target: int) -> None:
"""Set the total length of the digit list
Args:
target: total number of digits to have
Raises:
RuntimeError: target is less than the current number of digits
"""
if target < len(self):
raise RuntimeError(f"target {target} is less than current {len(self.digits)} digits")
padding = target - len(self)
self.add_padding(padding)
return
def set_equal_length(self, other: IntegerDigits) -> None:
"""Set both self and other to have the same number of digits"""
target = max(len(self), len(other))
self.set_length(target)
other.set_length(target)
return
def reset(self) -> None:
"""Clean out any generated attributes"""
self._digits = None
return
# collection methods
def __len__(self) -> int:
"""The number of digits"""
return len(self.digits)
def __getitem__(self, key) -> IntegerDigits:
"""Slice the digits"""
sliced = self.digits[key]
if type(sliced) is int:
sliced = [sliced]
gotten = IntegerDigits(sum((value * 10**(len(sliced) - 1 - index)
for index, value in enumerate(sliced))))
# preserve any padding
gotten._digits = sliced
return gotten
# integer operations
def __add__(self, value) -> IntegerDigits:
"""Add an integer or IntegerDigits to this integer"""
return IntegerDigits(self.integer + value if type(value) is int
else self.integer + value.integer)
def __sub__(self, value) -> IntegerDigits:
"""Subtract an integer or IntegerDigits from this integer"""
return IntegerDigits(self.integer - value if type(value) is int
else self.integer - value.integer)
def __mul__(self, value) -> IntegerDigits:
"""multiply integer by integer or IntegerDigits"""
return IntegerDigits(self.integer * value if type(value) is int
else self.integer * value.integer)
# comparisons
def __eq__(self, other) -> bool:
"""Compare to integer or IntegerDigits"""
return other == self.integer
def __lt__(self, other) -> bool:
return self.integer < other
def __gt__(self, other) -> bool:
return self.integer > other
def __ge__(self, other) -> bool:
return self.integer >= other
def __repr__(self) -> str:
return f"<IntegerDigits: {self.integer}>"
Test it
test = IntegerDigits(567)
# build the digits padded to power of 2
expect(len(test.digits)).to(equal(4))
# implement the length dunder method
expect(len(test)).to(equal(4))
# add slicing
expect(test[0]).to(equal(0))
expect(test[-1]).to(equal(7))
expect(test[:2].digits).to(contain_exactly(0, 5))
# multiplication
product = test * 2
expect(product.integer).to(equal(567 * 2))
test_2 = IntegerDigits(2)
expect(len(test_2)).to(equal(1))
product = test * test_2
expect(product.integer).to(equal(2 * 567))
# addition
sum_ = test + 10
expect(sum_.integer).to(equal(577))
sum_ = test + test_2
expect(sum_.integer).to(equal(569))
# subtraction
difference = test - 20
expect(difference.integer).to(equal(547))
difference = test_2 - test
expect(difference.integer).to(equal(-565))
An Implementation
Karatsuba Multiplication
def karatsuba(integer_1: IntegerDigits,
integer_2: IntegerDigits) -> MultiplicationOutput:
"""Multiply integer_1 and integer_2
Args:
integer_1, integer_2: arrays with equal number of digits
Returns:
product of the integers, count
"""
digits = len(integer_1)
if digits == 1:
return MultiplicationOutput(integer_1 * integer_2, 1)
middle = digits//2
most_significant_1, least_significant_1 = integer_1[:middle], integer_1[middle:]
most_significant_2, least_significant_2 = integer_2[:middle], integer_2[middle:]
most_plus_least_1 = most_significant_1 + least_significant_1
most_plus_least_2 = most_significant_2 + least_significant_2
# a hack to keep them the same number of digits after the addition
most_plus_least_1.set_equal_length(most_plus_least_2)
left, count_left = karatsuba(most_significant_1, most_significant_2)
summed, count_summed = karatsuba(most_plus_least_1, most_plus_least_2)
right, count_right = karatsuba(least_significant_1, least_significant_2)
center = summed - left - right
output = left * 10**digits + center * 10**middle + right
if output < 0:
raise RuntimeError(f"left: {left} center: {center} right: {right}")
return MultiplicationOutput(output, count_left + count_summed + count_right)
def karatsuba_multiplication(integer_1: int,
integer_2: int,
count_padding: bool=True) -> PlotMultiplicationOutput:
"""Sets up and runs the Karatsuba Multiplication
Args:
integer_1, integer_2: the two values to multiply
count_padding: whether the digit count should include the padding
Returns:
product, count, digits
"""
assert integer_1 >=0
assert integer_2 >= 0
integer_1 = IntegerDigits(integer_1)
integer_2 = IntegerDigits(integer_2)
if not count_padding:
for index, digit in enumerate(integer_1.digits):
if digit > 0:
original_1 = len(integer_1.digits[index:])
break
for index, digit in enumerate(integer_2.digits):
if digit > 0:
original_2 = len(integer_2.digits[index:])
break
original_digits = max(original_1, original_2)
# make them have the same number of digits
integer_1.set_equal_length(integer_2)
if count_padding:
original_digits = len(integer_1)
output = karatsuba(integer_1, integer_2)
return PlotMultiplicationOutput(product=output.product,
count=output.count,
digits=original_digits,
factor_1=integer_1.integer,
factor_2=integer_2.integer)
Test
a, b = 2, 3
output = karatsuba_multiplication(a, b)
expect(output.product).to(equal(a * b))
expect(output.digits).to(equal(1))
a = 222
output = karatsuba_multiplication(a, b, True)
expect(output.product).to(equal(666))
expect(output.digits).to(equal(4))
Test
def test_karatsuba(first: int, second: int):
expected = first * second
output = karatsuba_multiplication(first, second)
expect(output.product).to(equal(expected))
return
limit = int(sys.maxsize**0.5)
for digits in range(limit - 100, limit):
a = random.randrange(digits - 1000, digits + 1000)
b = random.randrange(digits - 1000, digits + 1000)
try:
test_karatsuba(a, b)
except AssertionError as error:
print(f"maxsize: {sys.maxsize}")
print(f"a: {a}")
print(f"b: {b}")
print(f"a x b: {a * b}")
print(f"maxsize - a * b: {sys.maxsize - a * b}")
raise
Example values from the Algorithms Illuminated website.
a = 3141592653589793238462643383279502884197169399375105820974944592
b = 2718281828459045235360287471352662497757247093699959574966967627
test_karatsuba(a, b)
Run Time
Using the Master Method
Let's use the Master Method to find the theoretical upper bound for Karatsuba Multiplication.
The basic form of the Master Method is this:
\[ T(n) = a T(\frac{n}{b}) + O(n^d) \]
Variable | Description | Value |
---|---|---|
\(a\) | Recursive calls within the function | 3 |
\(b\) | Amount the input is split up | 2 |
\(d\) | Exponent for the work done outside of the recursion | 1 |
We make three recursive calls within the Karatsuba function and split the data in half within each call. The amount of work done outside the recursion is constant so \(O\left(n^d\right) = O\left(n^1\right)\). \(a > b^d\) so we have the case where the sub-problems grow faster than the input is reduced, giving us:
\begin{align} T(n) &= O\left(n^{\log_b a}\right) \\ &= O\left(n^{\log_2 3}\right) \end{align}With Padding
Let's plot the base-case counts alongside the theoretical bounds we found using the Master Method.
First we'll create the numbers to multiply.
digit_supply = range(1, 101)
things_to_multiply = [(random.randrange(10**(digits - 1), 10**digits),
random.randrange(10**(digits - 1), 10**digits))
for digits in digit_supply]
Now we'll do the math, running the cases in parallel using Joblib.
with TIMER:
karatsuba_outputs = Parallel(n_jobs=-1)(
delayed(karatsuba_multiplication)(*thing_to_multiply)
for thing_to_multiply in things_to_multiply)
Started: 2022-05-13 23:52:06.399789 Ended: 2022-05-13 23:52:09.347825 Elapsed: 0:00:02.948036
Now a little plotting.
frame = pandas.DataFrame.from_dict(
{"Karatsuba Count": [output.count for output in karatsuba_outputs],
"Digits": [output.digits for output in karatsuba_outputs],
"digits^log2(3)": [output.digits**(math.log2(3)) for output in karatsuba_outputs],
"6 x digits^log2(3)": [6 * output.digits**(math.log2(3)) for output in karatsuba_outputs]
})
melted = frame.melt(id_vars=["Digits"], value_vars=["Karatsuba Count",
"digits^log2(3)",
"6 x digits^log2(3)"],
var_name="Source", value_name="Multiplications")
chart = altair.Chart(melted).mark_line(point=altair.OverlayMarkDef()).encode(
x="Digits", y="Multiplications",
color="Source",
tooltip=["Digits",
altair.Tooltip("Multiplications", format=",")]).properties(
title="Basic Multiplications vs Digits (with Padding)",
width=800,
height=525)
save_it(chart, "karatsuba-multiplications")
Since when I added the padding I made sure that the number of digits was a power of two, the numbers are bunched up around those powers of two (so there's a lot of wasted computation, maybe) but the multiplication counts still fall within a constant multiple of our theoretical runtime.
Without Padding
Since I didn't make the karatsuba work without padding this will just show the points spaced out, but the counts will still be based on there being padding.
unpadded = lambda a, b: karatsuba_multiplication(a, b, count_padding=False)
with TIMER:
unpadded_outputs = Parallel(n_jobs=-1)(
delayed(unpadded)(*thing_to_multiply)
for thing_to_multiply in things_to_multiply)
Started: 2022-05-13 23:52:20.020179 Ended: 2022-05-13 23:52:22.052011 Elapsed: 0:00:02.031832
frame = pandas.DataFrame.from_dict(
{"Karatsuba Count": [output.count for output in unpadded_outputs],
"Digits (pre-padding)": [output.digits for output in unpadded_outputs],
"digits^log2(3)": [output.digits**(math.log2(3)) for output in karatsuba_outputs],
"6 x digits^log2(3)": [6 * output.digits**(math.log2(3)) for output in karatsuba_outputs],
"6 x digits^log2(3) (no padding)": [6 * output.digits**(math.log2(3))
for output in unpadded_outputs],
"n^2 (no padding)": [output.digits**2
for output in unpadded_outputs],
})
melted = frame.melt(id_vars=["Digits (pre-padding)"], value_vars=["Karatsuba Count",
"digits^log2(3)",
"6 x digits^log2(3)",
"6 x digits^log2(3) (no padding)",
"n^2 (no padding)"],
var_name="Source", value_name="Multiplications")
chart = altair.Chart(melted).mark_line().encode(
x="Digits (pre-padding)", y="Multiplications",
color="Source",
tooltip=[altair.Tooltip("Digits (pre-padding)", type="quantitative"),
altair.Tooltip("Multiplications", format=",")]).properties(
title="Basic Multiplications vs Digits (without Padding)",
width=800,
height=525)
save_it(chart, "karatsuba-multiplications-unpadded")
Since I don't have an easy way to turn off using padding the Multiplication counts are still based on using padding, but this view spreads the digit-counts out so it's a little easier to see. The Multiplication counts are broken up into bands because the padding is based on keeping the number of digits a power of two.
Just for reference, here's the last product we multiplied.
output = karatsuba_outputs[-1]
print(f"{output.product.integer:,}")
expect(output.product).to(equal(output.factor_1 * output.factor_2))
56,913,917,723,202,495,576,238,408,244,650,506,926,406,731,625,206,370,840,517,493,281,396,538,892,710,818,017,869,257,379,987,881,688,195,601,612,438,838,803,669,047,089,313,679,236,814,971,999,554,405,895,121,583,263,228,500,933,878,783,310,375,258,385,063,631,332
Sources
Karatsuba Multiplicatio
Grade-School Multiplication
I took the grade-school algorithm from the Lecture 2 Notes on this course-site.
- McGill University School of Computer Science: COMP 250 (sec 1) [Internet]. [cited 2022 May 14]. Available from: http://crypto.cs.mcgill.ca/~crepeau/COMP250/