Caching synchronous Python functions is easy – just use functools.lru_cache
:
from functools import lru_cache
@lru_cache(maxsize=100)
def slow_function(*args):
... # do something expensive and return result
The lru_cache
decorator caches the results of slow_function
so you can avoid
unnecessary function calls. This is great for expensive or heavily recursive
functions. When the cache reaches max size, the Least Recently Used entry will
be discarded (this is what is meant by “LRU cache”). This size limit can be
disabled by passing maxsize=None
, at which point lru_cache
becomes a
memoizer.
This is really useful. It’s a great example of how Python comes “batteries
included”. So I was surprised when I recently found myself wanting to set up an
LRU cache for an async
function and realized there is no obvious way to do
this.
The stdlib docs don’t mention this case. I didn’t see aything about it on Stack
Overflow. I did find a third-party library: it’s called async_lru
and it
offers a decorator for async
functions that behaves similarly to lru_cache
.
That library’s author has basically just gone and reimplemented the stdlib
decorator from scratch and thrown in a bunch of async def
s. The implementation
is a few hundred lines long and seems – I don’t know – fine, I guess, but I
like to minimize third-party dependencies in my projects, especially for things
which seem this simple. It seems like we should be able to do this without
reimplementing logic that already exists in stdlib.
And so, after a bit of hacking I found a way to adapt stdlib’s lru_cache
to
the async
case. I’ve included both specific and general examples below.
The gist is this: the logic we want to offload involves maintaining a cache.
When we query the cache, we have two cases: cache hit or cache miss. In the case
of a cache hit, we can return the cached result; in the case of a cache miss, we
have to do async
work. We thus need our async
function to be able to
distinguish between hits and misses. Unfortunately, we have no way of
introspecting into the decorator’s internal state.
In fact, things get even trickier, because introducing async
workers adds a
third case: not only could the result of our computation be absent or present
(cache hit or cache miss), it could also be in progress. We want to be able to
detect this case as well and avoid restarting any work we have already begun.
Here’s the key idea:
We can sidestep the introspection issue by having the lru_cache
wrap a dummy
function that just returns a bit of persistent state.
We will initialize the state to a sentinel value.
Our outer async
function checks for the sentinel value and replaces it with a
Future
representing an in-progress computation.
If the async
function finds a Future
, it registers a callback on that
Future, politely waits for the callback to fire, then returns the result.
If the async
function encounters anything other than the sentinel or a
Future
, we just return whatever we find.
What follows is a lightly edited version of the class that first prompted this
little diversion. It uses a pool of worker subprocesses to compute expensive
Argon2
hashes and keeps a cache of recent inputs in order to avoid processing
them more than once.
Note that this snippet does depend on the third-party PyNaCl
library.
from concurrent.futures import ProcessPoolExecutor
from functools import lru_cache
import asyncio
from nacl.pwhash import argon2id
class Hasher:
HASH_SIZE = 28 # 20 for addr + 8 for proof-of-work
def __init__(self):
self.pool = ProcessPoolExecutor() # defaults to 1 worker per core
@lru_cache(maxsize=500)
def _dummy(self, *args):
return [None]
async def do_hash(self, msg: bytes, salt: bytes):
loop = asyncio.get_event_loop()
args = (self.HASH_SIZE, msg, salt)
l = self._dummy(*args)
if l[0] is None: # computation has not started
l[0] = loop.run_in_executor(self.pool, argon2id.kdf, *args)
if asyncio.isfuture(l[0]): # computation is in progress
new_future = loop.create_future()
l[0].add_done_callback(lambda worker: new_future.set_result(worker.result()))
l[0] = await new_future # replace worker with result
return l[0] # computation is complete
We use single-element lists as a cheap way of creating little bits of persistent
state, and use None
as our sentinel value.
We can get away with using None
and Future
s directly here because we know
argon2id.kdf
only ever returns bytes
, meaning there’s no chance of confusing
None
or Future
s with actual return values. If we were dealing with a
function that might return either of these values, we could move to storing
2-tuples where one entry is used as an explicit record of current state, like
so: (INITIAL, None)
, (UNDERWAY, <Future>)
, (COMPLETE, <result>)
. We could
also wrap this state in a custom class.
In light of the above discussion, here’s a generalized example. This one makes
no assumptions about the async
function it is caching. I bet you could adapt
this into an async_lru_cache
decorator pretty easily.
from functools import lru_cache
from enum import Enum
import asyncio
async def slow_function(*args, **kwargs):
... # do work, return result
INITIAL, UNDERWAY, COMPLETE = Enum('States', 'INITIAL UNDERWAY COMPLETE')
class Record:
state = INITIAL
value = None
def update(self, state, value):
self.state, self.value = state, value
@lru_cache # default size: 128
def _dummy(*args, **kwargs):
return Record()
async def cached_slow_function(*args, **kwargs):
record = _dummy(*args, **kwargs)
if record.state is INITIAL:
record.update(UNDERWAY, slow_function(*args, **kwargs))
if record.state is UNDERWAY:
new_future = asyncio.get_event_loop().create_future()
record.value.add_done_callback(lambda worker: new_future.set_result(worker.result()))
record.update(COMPLETE, await new_future)
assert record.state is COMPLETE
return record.value
A quick aside for Python devs coming to asyncio
from a twisted
background:
one big difference between asyncio
’s Future
s and twisted
’s Deferred
s is
that while Deferred
s’ callbacks are invoked in a chain with each callback
being passed the previous one’s return value, each callback on a Future
is
passed a reference to the completed Future
itself, i.e. the value passed to
any given callback on a Future
is totally independent of the behavior of any
other callback. This may or may not be what you were expecting.
Another, more general note: We could just as easily have used 2-tuples here
instead of class Record
– and in fact, we’d probably get better performance
that way – but I think the code reads better this way. Similarly, the assert
near the end is totally unnecessary, but it helps clarify the program logic. As
a rule, I try to optimize for readability first, then profile the code and
optimize for performance only in hotspots.
In this case, the overhead of instantiating and resolving references against
Record
objects is likely dwarfed by the overhead of running slow_function
,
so we would likely only see marginal benefit from performance optimizations,
meaning that optimizing for readability – as Python is designed to encourage –
is almost certainly the right choice.