Python: async LRU caching

28 December 2019

Caching synchronous Python functions is easy – just use functools.lru_cache:

from functools import lru_cache

def slow_function(*args):
    ...  # do something expensive and return result

The lru_cache decorator caches the results of slow_function so you can avoid unnecessary function calls. This is great for expensive or heavily recursive functions. When the cache reaches max size, the Least Recently Used entry will be discarded (this is what is meant by “LRU cache”). This size limit can be disabled by passing maxsize=None, at which point lru_cache becomes a memoizer.

This is really useful. It’s a great example of how Python comes “batteries included”. So I was surprised when I recently found myself wanting to set up an LRU cache for an async function and realized there is no obvious way to do this.

The stdlib docs don’t mention this case. I didn’t see aything about it on Stack Overflow. I did find a third-party library: it’s called async_lru and it offers a decorator for async functions that behaves similarly to lru_cache.

That library’s author has basically just gone and reimplemented the stdlib decorator from scratch and thrown in a bunch of async defs. The implementation is a few hundred lines long and seems – I don’t know – fine, I guess, but I like to minimize third-party dependencies in my projects, especially for things which seem this simple. It seems like we should be able to do this without reimplementing logic that already exists in stdlib.

And so, after a bit of hacking I found a way to adapt stdlib’s lru_cache to the async case. I’ve included both specific and general examples below.

The gist is this: the logic we want to offload involves maintaining a cache. When we query the cache, we have two cases: cache hit or cache miss. In the case of a cache hit, we can return the cached result; in the case of a cache miss, we have to do async work. We thus need our async function to be able to distinguish between hits and misses. Unfortunately, we have no way of introspecting into the decorator’s internal state.

In fact, things get even trickier, because introducing async workers adds a third case: not only could the result of our computation be absent or present (cache hit or cache miss), it could also be in progress. We want to be able to detect this case as well and avoid restarting any work we have already begun.

Here’s the key idea:

We can sidestep the introspection issue by having the lru_cache wrap a dummy function that just returns a bit of persistent state.

We will initialize the state to a sentinel value.

Our outer async function checks for the sentinel value and replaces it with a Future representing an in-progress computation.

If the async function finds a Future, it registers a callback on that Future, politely waits for the callback to fire, then returns the result.

If the async function encounters anything other than the sentinel or a Future, we just return whatever we find.

What follows is a lightly edited version of the class that first prompted this little diversion. It uses a pool of worker subprocesses to compute expensive Argon2 hashes and keeps a cache of recent inputs in order to avoid processing them more than once.

Note that this snippet does depend on the third-party PyNaCl library.

from concurrent.futures import ProcessPoolExecutor
from functools import lru_cache
import asyncio

from nacl.pwhash import argon2id

class Hasher:
    HASH_SIZE = 28  # 20 for addr + 8 for proof-of-work

    def __init__(self):
        self.pool = ProcessPoolExecutor()  # defaults to 1 worker per core

    def _dummy(self, *args):
        return [None]

    async def do_hash(self, msg: bytes, salt: bytes):
        loop = asyncio.get_event_loop()
        args = (self.HASH_SIZE, msg, salt)

        l = self._dummy(*args)

        if l[0] is None:  # computation has not started
            l[0] = loop.run_in_executor(self.pool, argon2id.kdf, *args)

        if asyncio.isfuture(l[0]):  # computation is in progress
            new_future = loop.create_future()
            l[0].add_done_callback(lambda worker: new_future.set_result(worker.result()))
            l[0] = await new_future  # replace worker with result

        return l[0]  # computation is complete

We use single-element lists as a cheap way of creating little bits of persistent state, and use None as our sentinel value.

We can get away with using None and Futures directly here because we know argon2id.kdf only ever returns bytes, meaning there’s no chance of confusing None or Futures with actual return values. If we were dealing with a function that might return either of these values, we could move to storing 2-tuples where one entry is used as an explicit record of current state, like so: (INITIAL, None), (UNDERWAY, <Future>), (COMPLETE, <result>). We could also wrap this state in a custom class.

In light of the above discussion, here’s a generalized example. This one makes no assumptions about the async function it is caching. I bet you could adapt this into an async_lru_cache decorator pretty easily.

from functools import lru_cache
from enum import Enum
import asyncio

async def slow_function(*args, **kwargs):
    ...  # do work, return result


class Record:
    state = INITIAL
    value = None
    def update(self, state, value):
        self.state, self.value = state, value

@lru_cache  # default size: 128
def _dummy(*args, **kwargs):
    return Record()

async def cached_slow_function(*args, **kwargs):
    record = _dummy(*args, **kwargs)

    if record.state is INITIAL:
        record.update(UNDERWAY, slow_function(*args, **kwargs))

    if record.state is UNDERWAY:
        new_future = asyncio.get_event_loop().create_future()
        record.value.add_done_callback(lambda worker: new_future.set_result(worker.result()))
        record.update(COMPLETE, await new_future)

    assert record.state is COMPLETE
    return record.value

A quick aside for Python devs coming to asyncio from a twisted background: one big difference between asyncio’s Futures and twisted’s Deferreds is that while Deferreds’ callbacks are invoked in a chain with each callback being passed the previous one’s return value, each callback on a Future is passed a reference to the completed Future itself, i.e. the value passed to any given callback on a Future is totally independent of the behavior of any other callback. This may or may not be what you were expecting.

Another, more general note: We could just as easily have used 2-tuples here instead of class Record – and in fact, we’d probably get better performance that way – but I think the code reads better this way. Similarly, the assert near the end is totally unnecessary, but it helps clarify the program logic. As a rule, I try to optimize for readability first, then profile the code and optimize for performance only in hotspots.

In this case, the overhead of instantiating and resolving references against Record objects is likely dwarfed by the overhead of running slow_function, so we would likely only see marginal benefit from performance optimizations, meaning that optimizing for readability – as Python is designed to encourage – is almost certainly the right choice.