Karatsuba Multiplication

Imports and Setup

Imports

# from python
from __future__ import annotations
from functools import partial
from collections import namedtuple

import math
import random
import sys


# from pypi
from joblib import Parallel, delayed
from expects import (
    be_a,
    contain_exactly,
    equal,
    expect,
    raise_error
)

import altair
import pandas

from graeae import Timer
from graeae.visualization.altair_helpers import output_path, save_chart

Set Up

TIMER = Timer()
SLUG = "karatsuba-multiplication"
OUTPUT_PATH = output_path(SLUG)
save_it = partial(save_chart, output_path=OUTPUT_PATH)

MultiplicationOutput = namedtuple("MultiplicationOutput", ["product", "count"])
PlotMultiplicationOutput = namedtuple(
    "PlotOutput",
    ["product", "count", "digits", "factor_1", "factor_2"])

The Algorithms

Grade School Multiplication

Grade-School multiplication is how most of us are taught to multiply numbers with more than one digit each. Each digit in one number is multiplied by each digit in the other to create a partial product. Once we've gone through all the digits in the first number we sum up all the partial products we calculated to get our final answer.

\begin{algorithm}
\caption{Grade-School}
\begin{algorithmic}
\REQUIRE The input arrays are of the same length ($n$)
\INPUT Two arrays representing digits in integers ($a, b$)
\OUTPUT The product of the inputs

\PROCEDURE{GradeSchool}{$number_1, number_2$}

\FOR {$j \in \{0 \ldots n - 1\}$}
  \STATE $carry \gets 0$
  \FOR {$i \in \{0 \ldots n - 1\}$}
    \STATE $product \gets a[i] \times b[j] + carry$
    \STATE $partial[j][i + j] \gets product \bmod 10$
    \STATE $carry \gets product/10$
  \ENDFOR
  \STATE $partial[j][n + j] \gets carry$
\ENDFOR

\STATE $carry \gets 0$

\FOR {$i \in \{0 \ldots 2n - 1\}$}
  \STATE $sum \gets carry$
  \FOR {$j \in \{0 \ldots n - 1\}$}
    \STATE $sum \gets sum + partial[j][i]$
  \ENDFOR
  \STATE $result[i] \gets sum \bmod 10$
  \STATE $carry \gets sum/10$
\ENDFOR

\STATE $result[2n] \gets carry$
\RETURN result
\ENDPROCEDURE
\end{algorithmic}
\end{algorithm}

Karatsuba Multiplication

Karatsuba Multiplication improves on the Grade-School algorithm using a trick from Frederick Gauss, which is a little too much of a diversion. With all these improvements in multiplication it seems like you need a background in number theory that not all Computer Scientist's have (or at least I don't). Maybe take it on faith for now.

\begin{algorithm}
\caption{Karatsuba}
\begin{algorithmic}
\REQUIRE The input arrays are of the same length
\INPUT Two arrays representing digits in integers
\OUTPUT The product of the inputs

\PROCEDURE{Karatsuba}{$number_1, number_2$}

\STATE $\textit{digits} \gets $ \textsc{Length}($number_1$)

\IF {$\textit{digits} = 1$}
  \RETURN $number_1 \times number_2$
\ENDIF

\STATE $middle \gets \left\lfloor \frac{\textit{digits}}{2} \right\rfloor$

\STATE \\
\STATE $MostSignificant_1, LeastSignificant_1 \gets $ \textsc{Split}($number_1, middle$)
\STATE $MostSignificant_2, LeastSignificant_2 \gets $ \textsc{Split}($number_2, middle$)
\STATE \\
\STATE $MostPlusLeast_1 \gets MostSignificant_1 + LeastSignificant_1$
\STATE $MostPlusLeast_2 \gets MostSignificant_2 + LeastSignificant_2$
\STATE \\

\STATE \textit{left} $\gets $ \textsc{Karatsuba}($MostSignificant_1, MostSignificant_2$)
\STATE \textit{summed} $\gets $ \textsc{Karatsuba}($MostPlusLeast_1, MostPlusLeast_2$)
\STATE \textit{right} $\gets $ \textsc{Karatsuba}($LeastSignificant_1, LeastSignificant_2$)
\STATE \\
\STATE \textit{center} $\gets$ (\textit{summed} - \textit{left} - \textit{right})
\STATE \\
\RETURN \textit{left} $\times 10^{\textit{digits}} + \textit{center} \times 10^{\textit{middle}} + \textit{right}$
\ENDPROCEDURE
\end{algorithmic}
\end{algorithm}

This is a mashup of the Wikipedia version and the Algorithms Illuminated version. It's a little tricky in that we're dealing with integers, in theory, but we have to know the number of digits and how to split it up so in code we're going to have to work with a collection instead, but hopefully this will be clearer in code.

An IntList

class IntegerDigits:
    """A hybrid integer and digits list

    Args:
     integer: the number to store
     padding: number of 0's to add to the left of the digits
    """
    def __init__(self, integer: int) -> None:
        self.integer = integer
        self._digits = None
        return

    @property
    def digits(self) -> list:
        """The digits for the given integer

        Raises:
         ValueError if the given integer isn't really an integer

        Returns:
         zero-padded list of digits
        """
        if self._digits is None:
            digits = [int(digit) for digit in str(self.integer)]
            length = len(digits)
            power_of_two = 2**math.ceil(math.log2(length))
            padding = power_of_two - length
            self._digits = [0] * padding + digits
        return self._digits

    def add_padding(self, padding: int) -> None:
        """Add more zeros to the left of the digits

        Args:
         padding: number of zeros to add to the left of the digits
        """
        self._digits = [0] * padding + self.digits
        return

    def set_length(self, target: int) -> None:
        """Set the total length of the digit list

        Args:
         target: total number of digits to have

        Raises:
         RuntimeError: target is less than the current number of digits
        """
        if target < len(self):
            raise RuntimeError(f"target {target} is less than current {len(self.digits)} digits")

        padding = target - len(self)
        self.add_padding(padding)
        return

    def set_equal_length(self, other: IntegerDigits) -> None:
        """Set both self and other to have the same number of digits"""
        target = max(len(self), len(other))
        self.set_length(target)
        other.set_length(target)
        return

    def reset(self) -> None:
        """Clean out any generated attributes"""
        self._digits = None
        return

    # collection methods

    def __len__(self) -> int:
        """The number of digits"""
        return len(self.digits)

    def __getitem__(self, key) -> IntegerDigits:
        """Slice the digits"""
        sliced = self.digits[key]
        if type(sliced) is int:
            sliced = [sliced]
        gotten = IntegerDigits(sum((value * 10**(len(sliced) - 1 - index)
                                    for index, value in enumerate(sliced))))
        # preserve any padding
        gotten._digits = sliced
        return gotten
    # integer operations

    def __add__(self, value) -> IntegerDigits:
        """Add an integer or IntegerDigits to this integer"""
        return IntegerDigits(self.integer + value if type(value) is int
                             else self.integer + value.integer)

    def __sub__(self, value) -> IntegerDigits:
        """Subtract an integer or IntegerDigits from this integer"""
        return IntegerDigits(self.integer - value if type(value) is int
                             else self.integer - value.integer)

    def __mul__(self, value) -> IntegerDigits:
        """multiply integer by integer or IntegerDigits"""
        return IntegerDigits(self.integer * value if type(value) is int
                             else self.integer * value.integer)

    # comparisons
    def __eq__(self, other) -> bool:
        """Compare to integer or IntegerDigits"""
        return other == self.integer

    def __lt__(self, other) -> bool:
        return self.integer < other

    def __gt__(self, other) -> bool:
        return self.integer > other

    def __ge__(self, other) -> bool:
        return self.integer >= other

    def __repr__(self) -> str:
        return f"<IntegerDigits: {self.integer}>"

Test it

test = IntegerDigits(567)
# build the digits padded to power of 2
expect(len(test.digits)).to(equal(4))

# implement the length dunder method
expect(len(test)).to(equal(4))

# add slicing
expect(test[0]).to(equal(0))
expect(test[-1]).to(equal(7))
expect(test[:2].digits).to(contain_exactly(0, 5))

# multiplication
product = test * 2
expect(product.integer).to(equal(567 * 2))
test_2 = IntegerDigits(2)
expect(len(test_2)).to(equal(1))
product = test * test_2
expect(product.integer).to(equal(2 * 567))

# addition
sum_ = test + 10
expect(sum_.integer).to(equal(577))

sum_ = test + test_2
expect(sum_.integer).to(equal(569))

# subtraction
difference = test - 20
expect(difference.integer).to(equal(547))

difference = test_2 - test
expect(difference.integer).to(equal(-565))

An Implementation

Karatsuba Multiplication

def karatsuba(integer_1: IntegerDigits,
              integer_2: IntegerDigits) -> MultiplicationOutput:
    """Multiply integer_1 and integer_2

    Args:
     integer_1, integer_2: arrays with equal number of digits

    Returns:
     product of the integers, count
    """
    digits = len(integer_1)
    if digits == 1:
        return MultiplicationOutput(integer_1 * integer_2, 1)
    middle = digits//2

    most_significant_1, least_significant_1 = integer_1[:middle], integer_1[middle:]
    most_significant_2, least_significant_2 = integer_2[:middle], integer_2[middle:]

    most_plus_least_1 = most_significant_1 + least_significant_1
    most_plus_least_2 = most_significant_2 + least_significant_2

    # a hack to keep them the same number of digits after the addition
    most_plus_least_1.set_equal_length(most_plus_least_2)

    left, count_left = karatsuba(most_significant_1, most_significant_2)
    summed, count_summed = karatsuba(most_plus_least_1, most_plus_least_2)
    right, count_right  = karatsuba(least_significant_1, least_significant_2)

    center = summed - left - right

    output = left * 10**digits + center * 10**middle + right

    if output < 0:
        raise RuntimeError(f"left: {left} center: {center} right: {right}")

    return MultiplicationOutput(output, count_left + count_summed + count_right)
def karatsuba_multiplication(integer_1: int,
                             integer_2: int,
                             count_padding: bool=True) -> PlotMultiplicationOutput:
    """Sets up and runs the Karatsuba Multiplication

    Args:
     integer_1, integer_2: the two values to multiply
     count_padding: whether the digit count should include the padding

    Returns:
     product, count, digits
    """
    assert integer_1 >=0
    assert integer_2 >= 0

    integer_1 = IntegerDigits(integer_1)
    integer_2 = IntegerDigits(integer_2)
    if not count_padding:
        for index, digit in enumerate(integer_1.digits):
            if digit > 0:
                original_1 = len(integer_1.digits[index:])
                break
        for index, digit in enumerate(integer_2.digits):
            if digit > 0:
                original_2 = len(integer_2.digits[index:])
                break
        original_digits = max(original_1, original_2)

    # make them have the same number of digits
    integer_1.set_equal_length(integer_2)

    if count_padding:
        original_digits = len(integer_1)
    output = karatsuba(integer_1, integer_2)
    return PlotMultiplicationOutput(product=output.product,
                                    count=output.count,
                                    digits=original_digits,
                                    factor_1=integer_1.integer,
                                    factor_2=integer_2.integer)

Test

a, b = 2, 3
output = karatsuba_multiplication(a, b)
expect(output.product).to(equal(a * b))
expect(output.digits).to(equal(1))

a = 222
output = karatsuba_multiplication(a, b, True)
expect(output.product).to(equal(666))
expect(output.digits).to(equal(4))

Test

def test_karatsuba(first: int, second: int):
    expected = first * second
    output = karatsuba_multiplication(first, second)
    expect(output.product).to(equal(expected))
    return
limit = int(sys.maxsize**0.5)
for digits in range(limit - 100, limit):
    a = random.randrange(digits - 1000, digits + 1000)
    b = random.randrange(digits - 1000, digits + 1000)
    try:
        test_karatsuba(a, b)
    except AssertionError as error:
        print(f"maxsize: {sys.maxsize}")
        print(f"a: {a}")
        print(f"b: {b}")
        print(f"a x b: {a * b}")
        print(f"maxsize - a * b: {sys.maxsize - a * b}")
        raise

Example values from the Algorithms Illuminated website.

a = 3141592653589793238462643383279502884197169399375105820974944592
b = 2718281828459045235360287471352662497757247093699959574966967627
test_karatsuba(a, b)

Run Time

Using the Master Method

Let's use the Master Method to find the theoretical upper bound for Karatsuba Multiplication.

The basic form of the Master Method is this:

\[ T(n) = a T(\frac{n}{b}) + O(n^d) \]

Variable Description Value
\(a\) Recursive calls within the function 3
\(b\) Amount the input is split up 2
\(d\) Exponent for the work done outside of the recursion 1

We make three recursive calls within the Karatsuba function and split the data in half within each call. The amount of work done outside the recursion is constant so \(O\left(n^d\right) = O\left(n^1\right)\). \(a > b^d\) so we have the case where the sub-problems grow faster than the input is reduced, giving us:

\begin{align} T(n) &= O\left(n^{\log_b a}\right) \\ &= O\left(n^{\log_2 3}\right) \end{align}

With Padding

Let's plot the base-case counts alongside the theoretical bounds we found using the Master Method.

First we'll create the numbers to multiply.

digit_supply = range(1, 101)
things_to_multiply = [(random.randrange(10**(digits - 1), 10**digits),
                        random.randrange(10**(digits - 1), 10**digits))
                        for digits in digit_supply]

Now we'll do the math, running the cases in parallel using Joblib.

with TIMER:
    karatsuba_outputs = Parallel(n_jobs=-1)(
        delayed(karatsuba_multiplication)(*thing_to_multiply)
        for thing_to_multiply in things_to_multiply)
Started: 2022-05-13 23:52:06.399789
Ended: 2022-05-13 23:52:09.347825
Elapsed: 0:00:02.948036

Now a little plotting.

frame = pandas.DataFrame.from_dict(
    {"Karatsuba Count": [output.count for output in karatsuba_outputs],
     "Digits": [output.digits for output in karatsuba_outputs],
     "digits^log2(3)": [output.digits**(math.log2(3)) for output in karatsuba_outputs],
     "6 x digits^log2(3)": [6 * output.digits**(math.log2(3)) for output in karatsuba_outputs]     
})

melted = frame.melt(id_vars=["Digits"],  value_vars=["Karatsuba Count",
                                                     "digits^log2(3)",
                                                     "6 x digits^log2(3)"],
                    var_name="Source", value_name="Multiplications")

chart = altair.Chart(melted).mark_line(point=altair.OverlayMarkDef()).encode(
    x="Digits", y="Multiplications",
    color="Source",
    tooltip=["Digits",
             altair.Tooltip("Multiplications", format=",")]).properties(
                 title="Basic Multiplications vs Digits (with Padding)",
                 width=800,
                 height=525)

save_it(chart, "karatsuba-multiplications")

Figure Missing

Since when I added the padding I made sure that the number of digits was a power of two, the numbers are bunched up around those powers of two (so there's a lot of wasted computation, maybe) but the multiplication counts still fall within a constant multiple of our theoretical runtime.

Without Padding

Since I didn't make the karatsuba work without padding this will just show the points spaced out, but the counts will still be based on there being padding.

unpadded = lambda a, b: karatsuba_multiplication(a, b, count_padding=False)

with TIMER:
    unpadded_outputs = Parallel(n_jobs=-1)(
        delayed(unpadded)(*thing_to_multiply)
        for thing_to_multiply in things_to_multiply)
Started: 2022-05-13 23:52:20.020179
Ended: 2022-05-13 23:52:22.052011
Elapsed: 0:00:02.031832
frame = pandas.DataFrame.from_dict(
    {"Karatsuba Count": [output.count for output in unpadded_outputs],
     "Digits (pre-padding)": [output.digits for output in unpadded_outputs],
     "digits^log2(3)": [output.digits**(math.log2(3)) for output in karatsuba_outputs],
     "6 x digits^log2(3)": [6 * output.digits**(math.log2(3)) for output in karatsuba_outputs],
     "6 x digits^log2(3) (no padding)": [6 * output.digits**(math.log2(3))
                                         for output in unpadded_outputs],
     "n^2 (no padding)": [output.digits**2
                          for output in unpadded_outputs],
})

melted = frame.melt(id_vars=["Digits (pre-padding)"],  value_vars=["Karatsuba Count",
                                                                   "digits^log2(3)",
                                                                   "6 x digits^log2(3)",
                                                                   "6 x digits^log2(3) (no padding)",
                                                                   "n^2 (no padding)"],
                    var_name="Source", value_name="Multiplications")

chart = altair.Chart(melted).mark_line().encode(
    x="Digits (pre-padding)", y="Multiplications",
    color="Source",
    tooltip=[altair.Tooltip("Digits (pre-padding)", type="quantitative"),
             altair.Tooltip("Multiplications", format=",")]).properties(
                 title="Basic Multiplications vs Digits (without Padding)",
                 width=800,
                 height=525)

save_it(chart, "karatsuba-multiplications-unpadded")

Figure Missing

Since I don't have an easy way to turn off using padding the Multiplication counts are still based on using padding, but this view spreads the digit-counts out so it's a little easier to see. The Multiplication counts are broken up into bands because the padding is based on keeping the number of digits a power of two.

Just for reference, here's the last product we multiplied.

output = karatsuba_outputs[-1]
print(f"{output.product.integer:,}")

expect(output.product).to(equal(output.factor_1 * output.factor_2))
56,913,917,723,202,495,576,238,408,244,650,506,926,406,731,625,206,370,840,517,493,281,396,538,892,710,818,017,869,257,379,987,881,688,195,601,612,438,838,803,669,047,089,313,679,236,814,971,999,554,405,895,121,583,263,228,500,933,878,783,310,375,258,385,063,631,332

Sources

Grade-School Multiplication

I took the grade-school algorithm from the Lecture 2 Notes on this course-site.