Jupyter Snippet P4M 11

Jupyter Snippet P4M 11

All of these python notebooks are available at [https://gitlab.erc.monash.edu.au/andrease/Python4Maths.git]

Speed with Python


As an interpreted language, python is natively fairly slow. However there are a number of ways to speed python up by using a complier or just-in-time compiler. Examples include pypy and cython. Here we will use a less intrusive approach that comes with the Anaconda distribution: the numba library which allows just-in-time compilation of python code with some limitations.


While pthon is not really designed for fast computation, the numba library allows speed that comes close to C/Fortran for selected functions that perform computationally intensive tasks. This is achieved by importing the numba library and marking up computationally intensive functions with the @jit decorator. This decorator requests the function to be compiled (without any other intervention by the user). For it to effective the functions typically have to be restricted to use only a subset of Python data types (in particular lists, and numpy arrays/matrices, tuples and numbers). While some other data sets are supported and the level of support is increasing, complex data types provided by external python libraries are unlikely to be ever supported.


Here is an example of using numba to speed up matrix multiplication. The function is defined twice with the same code but once with the normal interpreter and the second time using numba’s just-in-time compilation. For comparison we also include matrix multiplication with the numpy library which performs the core computation in compiled code. Typical output shows that the jit version is order of magnitude faster than standard python, though still slightly slower than the numpy compiled code. The numpy implementation uses the BLAS library (when available) which is very heavily optimised compared to the simple minded matrix multiplication below.

import numba
import numpy as np
from timeit import timeit

def pyMult(A,B): # very simple/crude matrix multiplication
    m,n = A.shape
    p,q = B.shape
    if n != p: print("ERROR: invalid dimensions: %d != %d",n,p)
    for i in range(m):
        for j in range(q):
            for k in range(n): c += A[i][k]*B[k][j]
    return C

## Here we mark a function as being "just-in-time compiled" (jit)
## In addition we tell the compiler that this is a function that
## - has two arguments that are 64bit floating point matrices
## - returns a 64 bit floating point matrix
## - no interpreted python should be used (nopython=True)
## The arguments in brackets are optional (@numba.jit is sufficient)
def jitMult(A,B):
    m,n = A.shape
    p,q = B.shape
    if n != p: print("ERROR: invalid dimensions: %d != %d",n,p)
    for i in range(m):
        for j in range(q):
            for k in range(n): c += A[i][k]*B[k][j]
    return C
mat=lambda m,n: np.random.rand(m,n)
print("Normal python %.4f sec" % timeit(lambda: pyMult(mat(100,200),mat(200,50)),number=9))
print("JIT    python %.4f sec" % timeit(lambda:jitMult(mat(100,200),mat(200,50)),number=9))
print("Numpy C code  %.4f sec" % timeit(lambda: np.asmatrix(mat(100,200))*np.asmatrix(mat(200,50)), number=9))
Normal python 5.0939 sec
JIT    python 0.0091 sec
Numpy C code  0.0153 sec

Speed Exercise

Consider the following problem: given a list of integers (possibly with repeats), create the set of all unique triplets of numbers from this list so that the sum of numbers is zero. This can be written very simply in python:

# simplistic definition any tuple of numbers corresponding to 3 unique indices in N in ascending order that sum to 0
s3=lambda N: {(n1,n2,n3) for i1,n1 in enumerate(N) for i2,n2 in enumerate(N) for i3,n3 in enumerate(N)
              if len({i1,i2,i3}) ==3 and n1<=n2<=n3 and n1+n2+n3==0}
from itertools import combinations  # use itertools library to create combinations
s3basic=lambda N:{ tuple(sorted([i,j,k])) for i,j,k in combinations(N,3) if i+j+k==0}

How fast can we make this computation? Here are two functions in Python with the second being significantly faster than the first.

def three_sum(num):  
    "A first implementation: assumes num is a list/array of integers"
    if len(num)<3: return []  
    for i in range(len(num)-2):  
        if i!=0 and num[i]==num[i-1]:continue  
        while left<right:  
            if num[left]+num[right]==-num[i]:  
                while num[left]==num[left-1] and left<right:left=left+1  
                while num[right]==num[right+1] and left<right: right=right-1  
            elif num[left]+num[right]<-num[i]:  
    return result  

from collections import Counter # count repated entries
def py_3sum(num):  
    "Given a list/array of integers, return unique sets of triplets summing to zero"
    if len(num)<3: return []  
    count = Counter(num)
    num = sorted(count.keys())  # unique numbers in order
    positive= {i for i in num if i>0} # last number must be >0
    for i,first in enumerate(num):
        if first >= 0: break
        last = -first//2
        while num[end] > last: end -= 1
        if num[end] == last and first%2==0:
            if count[last]>1: result.append( (first,last,last))
            end -= 1
        for second in num[i+(count[first]==1):end+1]:
            if -(first+second) in positive:
                result.append( (first,second,-(first+second)) )
    if count[0] >= 3: result.append( (0,0,0) )
    return result  

Let’s check correctness with a small list

nums = [-25,-10,-10,-7,-4,-3,2,2,4,8,10]
print("Original     ",sorted(s3(nums))) # for easy comparison
print("First python ",three_sum(nums))
print("Second python",py_3sum(nums))
Original      [(-10, 2, 8), (-7, -3, 10), (-4, 2, 2)]
First python  [[-10, 2, 8], [-7, -3, 10], [-4, 2, 2]]
Second python [(-10, 2, 8), (-7, -3, 10), (-4, 2, 2)]

Now let’s test this for speed. We will ignore the s3 function as it is really slow! Note that it still takes some time to run the test as the first python method is also not very fast (around 5 seconds per call with a list of 5000 numbers)

from timeit import timeit
import random
nums = [random.randint(-10000,10000) for i in range(n)]
total=timeit(lambda :three_sum(nums),number=repeat)  
print("First python  = %7.2f ms / call"%(total*1000/repeat))
total=timeit(lambda :py_3sum(nums),number=repeat)  
print("Second python = %7.2f ms / call"%(total*1000/repeat))
First python  = 4119.3223 ms / call
Second python =  403.1344 ms / call

Challenge: write a function using the numba jit functionality that runs at least twice as fast as the pure python code. Note that you may need to use a slightly different algorithm and perhaps convert the list of integers to a numpy array of int32 numbers as input to the jit function. So the signature of the jit compiled function might be something like:

def jit_3sum(num):  
     pass # write your code here
# test it
total=timeit(lambda :jit_3sum(np.array(nums,np.int32)),number=repeat)  
print("jit python   = %7.2f ms / call"%(total*1000/repeat)) 

It should be possible to get this to run at least twice as fast as the second python version, though getting a whole order of magnitude speed-up is difficult.