Maximum Pairwise Product
Imports
Since the test only uses python standard library I'll try and stick to that, but since the stress-test isn't part of the assignment I'll cheat and use numpy to generate the random input.
# python standad library
from datetime import (
datetime,
timedelta,
)
import random
# from pypi
from numba import jit
import numpy
Problem Statement
Find the maximum product of two distinct numbers in a sequence of non-negative integers.
- Input: A sequence of non-negative integers.
- Output: The maximum value that can be obtained by multiplying two different elements from the sequence.
Given a sequence of non-negative numbers \(a_1,\ldots,a_n\), compute
\[ \max_{1 \le i \neq j \le n} a_i a_j \]
\(i\) and \(j\) should be different, although \(a_i\) and \(a_j\) might be equal.
Input | |
---|---|
First Line | n - the number of input values |
Second Line | \(a_1 \ldots a_n\) - space-separated list of values |
Output | the maximum pairwise product from the input. |
Constraints | \(2 \le n \le 2 \times 10^5; 0 \le a_1,\ldots,a_n\le 2 \times 10^5\) |
Example Values
First Input | Second Input | Output |
---|---|---|
5 | 5 6 2 7 4 | 42 |
3 | 1 2 3 | 6 |
10 | 7 5 14 2 8 8 10 1 2 3 | 140 |
Limit | Value |
---|---|
Time | 5 seconds |
Memory | 512 Mb |
Some Constants
This is just a translation of some of the problem statement values to python so we can use them.
MAX_TIME = timedelta(seconds=5)
MAX_INPUTS = 2 * 10**5
MAX_VALUE = MAX_INPUTS
MAX_CASE = dict(output=39999800000,
inputs=numpy.arange(1, MAX_VALUE + 1))
assert len(MAX_CASE["inputs"]) == MAX_INPUTS
assert max(MAX_CASE["inputs"]) == MAX_VALUE, "Actual: {}".format(max(MAX_CASE["inputs"]))
INPUTS = {
42: [5, 6, 2, 7, 4,],
6: [1, 2, 3],
140: [7, 5, 14, 2, 8, 8, 10, 1, 2, 3],
2: [2, 1],
2: [1, 2],
10**5 * 9 * 10**4: [10**5, 9 * 10**4],
}
Helpers
These are some functions to help validate the algorithms.
def check_inputs(implementation, inputs=INPUTS, use_max=True):
"""Checks the inputs with the implementation
Args:
implementation: callable to check
inputs (dict): expected, input pairs
use_max (bool): if True use the max-range value (too slow for brute force)
Raises:
AssertionError: one of the outputs wasn't expected
"""
for expected, input_values in inputs.items():
start = datetime.now()
actual = implementation(input_values)
assert actual == expected, "Inputs: {} Expected: {} Actual: {}".format(
input_values,
expected,
actual)
print("Elapsed Time: {}".format(datetime.now() - start))
if use_max:
print("Max Case")
start = datetime.now()
actual = implementation(MAX_CASE["inputs"])
expected = MAX_CASE["output"]
assert actual == expected, "Inputs: {} Expected: {} Actual: {}".format(
MAX_CASE["inputs"],
expected,
actual)
print("Elapsed Time: {}".format(datetime.now() - start))
print("All passed")
return
Brute Force Implementation
This is given as part of the problem. It traverses all the values and finds the largest product.
def max_pairwise_product_brute(numbers):
"""Calculates the largest pairwise-product for a list
Args:
numbers (list): integers to check
Returns:
int: largest product created from numbers
"""
n = len(numbers)
max_product = 0
for first in range(n):
for second in range(first + 1, n):
max_product = max(max_product,
numbers[first] * numbers[second])
return max_product
check_inputs(max_pairwise_product_brute, use_max=False)
Because we are traversing all the numbers twice, the brute-force version has a run time of \(O(n^2)\). Since the \(n\) can be from \(2\) to \(2 \times 10^5\) that means our maximum run time will be \(4 \times 10^10\), which is too large.
Running this through the grader gives this output.
Failed case #4/17: time limit exceeded (Time used: 9.98/5.00, memory used: 20905984/536870912.)
Search
Instead of using nested for-loops, we can search the numbers twice to find the two biggest numbers, this changes the run time to \(2n\) or \(O(n)\).
def max_pairwise_product_take_two(numbers):
"""Finds the maximum pairwise product in te numbers
Args:
numbers (list): non-negative integers
Returns:
int: largest possible product from the numbers
"""
first_index = 0
first_value = 0
n = len(numbers)
assert n >= 2
for index in range(1, n):
if numbers[index] > first_value:
first_value = numbers[index]
first_index = index
second_value = 0
start = 1 if first_index == 0 else 0
for index in range(start, n):
if index != first_index and numbers[index] > second_value:
second_value = numbers[index]
return first_value * second_value
check_inputs(max_pairwise_product_take_two)
This one passes the grader, doing surprisingly well, even though I was thinking it would need more optimizing.
Good job! (Max time used: 0.15/5.00, max memory used: 26734592/536870912.)
Another Improvement
Rather than go through the second loop, I thought that since the previous maximum value is always the next highest value so far, we can just save it directly.
def max_pairwise_product_take_three(numbers):
"""Finds the maximum pairwise product in te numbers
Args:
numbers (list): non-negative integers
Returns:
int: largest possible product from the numbers
"""
max_value = 0
previous_value = 0
n = len(numbers)
assert n >= 2
for number in numbers:
if number > max_value:
previous_value, max_value = max_value, number
elif number > previous_value:
previous_value = number
return max_value * previous_value
check_inputs(max_pairwise_product_take_three)
Stress Test
Even thought we're already passing, part of the assignment was to create a stress test to really exercise the algorithm once you have it passing.
def stress_test(implementation, tag, maximum_size=MAX_INPUTS ,
maximum_value=MAX_VALUE, iterations=10):
"""Repeatedly creates random inputs to test the implementation
This compares the output of the implementation against our brute-force version
Args:
implementation: callable to test
tag (str): something to identify the implementation
maximum_size (int): the maximum number of numbers for an input
maximum_value (int): the maximum value for any input
iterations (int): the number of times to test (if None runs infinitely)
"""
true_count = 0
iteration = 0
increment = 1 if iterations is not None else 0
iterations = 1 if iterations is None else iterations
max_time = timedelta(0)
while iteration < iterations:
start = datetime.now()
true_count += 1
iteration += increment
print("***** ({}) Trial: {} *****".format(tag, true_count))
n = random.randrange(2, maximum_size + 1)
print("Input Size: {}".format(n))
inputs = numpy.random.randint(maximum_value + 1, size=n)
print("Running Brute Force")
brute_start = datetime.now()
output_brute = max_pairwise_product_brute_jit(inputs)
print("Brute Force Time: {}".format(datetime.now() - brute_start))
print("Running {} implementation".format(tag))
implementation_start = datetime.now()
output_implementation = implementation(inputs)
implementation_end = datetime.now()
implementation_elapsed = implementation_end - implementation_start
if implementation_elapsed > MAX_TIME:
print("Error Time Exceeded: {}".format(implementation_elapsed))
break
print("Implementation Time: {}".format(implementation_elapsed))
if implementation_elapsed > max_time:
max_time = implementation_elapsed
if output_brute != output_implementation:
print("error: Expected {}, Actual {}", output_brute , output_implementation)
print("Inputs: {}".format(inputs))
break
print("Elapsed time: {}".format(datetime.now() - start))
print("Max {} time: {}".format(tag, max_time))
return
Boosted Brute Force
To try and get this working I'm going to use numba to (hopefully) speed it enough to make the stress test runnable.
@jit
def max_pairwise_product_brute_jit(numbers):
"""Calculates the largest pairwise-product for a list
Args:
numbers (list): integers to check
Returns:
int: largest product created from numbers
"""
n = len(numbers)
max_product = 0
for first in range(n):
for second in range(first + 1, n):
max_product = max(max_product,
numbers[first] * numbers[second])
return max_product
print("One Pass Method")
stress_test(max_pairwise_product_take_three, tag="One-Pass", iterations=10)
Using Sort
Since we need the top two values we can get a more efficient algorithm by sorting the values.
def max_pairwise_product_sort(numbers):
"""Calculates the largest pairwise-product for a list
Args:
numbers (list): integers to check
Returns:
int: largest product created from numbers
"""
assert len(numbers) > 1
numbers = sorted(numbers, reverse=True)
return numbers[0] * numbers[1]
print("\n\nSort method")
stress_test(max_pairwise_product_sort, tag="Sort", iterations=100)