Speeding up non-vectorizable code with Cython

Most people know that when working with numeric arrays in Python, it is almost always faster to use vectorized operations instead of loops. But what happens when there's no obvious way to vectorize a slow function? There are many approaches for speeding up code, and plenty of other great posts on the subject. This is just a simple example with Cython.

The set up

Let's take an arbitrary example of some function with logic that makes it annoying to vectorize:

Excel uses column names like A, B,C ... AA, AB, AC up to infinite. Write a function to convert the column name to its column index [...]

This function is a little contrived, but the same process applies to almost any function that is either difficult or impossible to vectorize (or is not actually faster when vectorized).

Let's create an array of fake data:

In [1]:
import string
import random

def get_random_col_name(low=1, high=6):
    n = random.randint(low, high)
    return ''.join(random.choice(string.ascii_uppercase) for _ in xrange(n))

cols = [get_random_col_name() for _ in xrange(1000000)]

print('first ten cols (of {}):\n{}'.format(len(cols), '\n'.join(cols[:10])))
first ten cols (of 1000000):
HI
CURGQ
EOBCS
Y
GLRT
GTYJ
Y
JJG
K
JAZJ

Here's a plain Python implementation of the index calculation:

In [2]:
def column_to_index(col):
    ord_a = 65  # equivalent to ord('A')
    col_index = 0
    for place, letter in enumerate(reversed(col)):
        col_index += (26 ** place) * (ord(letter) - ord_a + 1)
    return col_index - 1

Quick tests for sanity check:

In [3]:
assert column_to_index('A') == 0
assert column_to_index('Z') == 25
assert column_to_index('AA') == 26
assert column_to_index('AB') == 27
assert column_to_index('ZAZ') == (26 ** 2) * 26 + (26 ** 1) * 1 + (26 ** 0) * 26 - 1

Let's see how fast this implementation is when run on all 1,000,000 random column names:

In [4]:
%time idxs = [column_to_index(col) for col in cols]
CPU times: user 3.47 s, sys: 52.1 ms, total: 3.52 s
Wall time: 3.51 s

Speeding things up with Cython

Cython is a superset of Python that allows for optimization by declaring types and compiling Python code. The documentation is very good, and there is a great intro here.

A few general principles:

  • Don't optimize until you know what the slow part is and that it will actually be problematic. Self explanatory, but it's a thing. You may want to do some line profiling first to see if the slow part is what you actually thought it was, and whether there are easy wins in pure Python.

  • Don't optimize if doing the optimization will take longer than the time it will save. Particularly true for one-off code, less true when writing something that will be run often. This is a judgment call.

  • Start with working, vanilla Python and make small changes from there. Cython is a superset of the Python syntax, so you can actually start with the original and it will compile. Think twice before doing a from-scratch rewrite of the function. You probably want tests or, at the very least, sanity checks to make sure the functionality is the same as your reference implementation. There is a big risk of introducing subtle bugs in a function that is supposed to be identical.

  • The cardinal rule of optimizing: stop when it's fast enough. Speeding things up can be entertaining but "optimization golf" doesn't really help anybody in the long term, especially as the code becomes harder and harder to read.

And a few key tricks for getting the fastest (correct) result possible:

  • Declare appropriate types wherever practical. This is usually the first set of changes, and often the source of major efficiency wins.

  • Use simple control structures and avoid dynamic type iteration. Think about how you would write a given loop in C, because if you write it that way it will probably be faster than using idiomatic Python. For example, we are trained to avoid this kind of list iteration in Python:

      N = len(arr)
      for i in range(N):
          value = arr[i]
          # do stuff with value
    

    in favor of the more Pythonic:

      for value in arr:
          # do stuff with value
    

    but the for item in iterable idiom brings in lots of overhead in order to abstract over all the possible types that could be in the iterable. In the example above, you may have already declared a type for value since you know what kind of data your are iterating over.

  • If you're working with arrays, consider disabling much of the safety checking. Normally, Python looks out for you by raising exceptions if you go out of array bounds or try to do other invalid things with memory. Of course, these checks add overhead. If you have already taken things like array bounds into account, you can instruct the compiler to omit those types of checks with decorators on your function such as @cython.boundscheck(False) or @cython.wraparound(False). A list of compiler directives can be found here.

The common workflow would actually be to rewrite the slowest parts of your program in a .pyx file, transpile the Python code to C code (.c), and then compile that to a shared object file (.so). From there, you can import your optimized functions in any .py file and use them as normal.

Conveniently, IPython notebooks have a plugin for writing Cython code directly in the notebook.

In [5]:
%load_ext Cython

Here is the re-write:

In [6]:
%%cython
cimport cython

@cython.boundscheck(False)  # these don't contribute much in this example
@cython.wraparound(False)   # but are often useful for numeric arrays
def column_to_index_cy(str col):
    cdef int i, n, letter, ord_a  # declare types - easy win
    cdef unsigned long long col_index = 0  # max value alert! (see below)
    
    ord_a = 65
    n = len(col)
    i = 0
    
    while i < n:  # very C-like
        letter = ord(col[n - i - 1])  # access str in reverse order
        col_index += (26 ** i) * (letter - ord_a + 1)
        i += 1
        
    return col_index - 1

Let's convince ourselves that these are equivalent at least for the first $26^4$ columns:

In [7]:
from itertools import permutations
import string

for n in xrange(1, 5):
    for perm in permutations(string.ascii_uppercase, n):
        col = ''.join(perm)
        assert column_to_index_cy(col) == column_to_index(col)

Now we can see how much faster the Cython version was:

In [8]:
%time idxs = [column_to_index_cy(col) for col in cols]
CPU times: user 315 ms, sys: 56.3 ms, total: 371 ms
Wall time: 352 ms

This is an enormous speedup—the Cython version takes around 10% as long as the pure Python version.

Lower level means worrying about more details

Here, for example, we have to remember that the C types we chose have strict limits and will silently overflow in many cases. While Python will often handle large numbers gracefully, our unsigned long long will explode at a certain point:

In [9]:
%matplotlib inline
from matplotlib import pyplot as plt
import seaborn as sns
from math import log

ullong_max = 2**64-1
print('max value for unsigned long long: {}'.format(ullong_max))

blowup = log(ullong_max, 26)
print('max letters: {}'.format(blowup))

xs = xrange(20)
ys = [26**x for x in xs]

fig, ax = plt.subplots(figsize=(10, 6))

plt.axhline(ullong_max, c='r', ls='--', label='largest index we can store')
plt.plot(xs, ys, label='growth of column index')
plt.semilogy()

ax.annotate('sadness ensues here', xy=(blowup, ullong_max),
            xycoords='data', xytext=(0.85, 0.4), textcoords='axes fraction',
            arrowprops=dict(facecolor='black', shrink=0.05, width=2),
            horizontalalignment='center', verticalalignment='top')

plt.xlabel('number of letters in column name')
plt.ylabel('column index')
plt.legend(loc='lower right', fontsize=16)
plt.show()
max value for unsigned long long: 18446744073709551615
max letters: 13.6157474274

We can see this issue in action:

In [10]:
[column_to_index('A' * i) - column_to_index_cy('A' * i) for i in xrange(1, 20)]
Out[10]:
[0L,
 0L,
 0L,
 0L,
 0L,
 0L,
 0L,
 0L,
 0L,
 0L,
 0L,
 0L,
 0L,
 0L,
 55340232221128654848L,
 1733993942928697851904L,
 45342096933178077872128L,
 1179171221423735667949568L,
 30658673117946011881308160L]

There are ways to mitigate type problems, some more painful than others. Can you afford this to be a little slower? Leave the variable as a Python object and it will be taken care of dynamically. Maybe you want to go wild and use a library like GMP for arbitary precision. In the real world, maybe your business logic maxes out at 10 characters anyway, making this a total non-issue and waste of time to deal with. It's all about keeping the end goal in mind.

Holding decision points with your fingers for easy backtracking is allowed in a just world

Use when appropriate, not when not

Cython is not the only approach—plenty of good articles out there explain the tradeoffs between Cython and Numba (another strategic compiling package). It's just another tool in the toolbox, but it often turns out to be extremely handy and not too much extra work.

Any comments or suggestions? Let me know.