diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 5ac2530a6..bddadd629 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -17,8 +17,8 @@ import functools import inspect import logging import threading -from collections import namedtuple from typing import Any, cast +from weakref import WeakValueDictionary from six import itervalues @@ -430,7 +430,7 @@ class CacheDescriptor(_CacheDescriptorBase): # Add our own `cache_context` to argument list if the wrapped function # has asked for one if self.add_cache_context: - kwargs["cache_context"] = _CacheContext(cache, cache_key) + kwargs["cache_context"] = _CacheContext.get_instance(cache, cache_key) try: cached_result_d = cache.get(cache_key, callback=invalidate_callback) @@ -625,14 +625,36 @@ class CacheListDescriptor(_CacheDescriptorBase): return wrapped -class _CacheContext(namedtuple("_CacheContext", ("cache", "key"))): - # We rely on _CacheContext implementing __eq__ and __hash__ sensibly, - # which namedtuple does for us (i.e. two _CacheContext are the same if - # their caches and keys match). This is important in particular to - # dedupe when we add callbacks to lru cache nodes, otherwise the number - # of callbacks would grow. +class _CacheContext: + """Holds cache information from the cached function higher in the calling order. + + Can be used to invalidate the higher level cache entry if something changes + on a lower level. + """ + + _cache_context_objects = WeakValueDictionary() + + def __init__(self, cache, cache_key): + self._cache = cache + self._cache_key = cache_key + def invalidate(self): - self.cache.invalidate(self.key) + """Invalidates the cache entry referred to by the context.""" + self._cache.invalidate(self._cache_key) + + @classmethod + def get_instance(cls, cache, cache_key): + """Returns an instance constructed with the given arguments. + + A new instance is only created if none already exists. + """ + + # We make sure there are no identical _CacheContext instances. This is + # important in particular to dedupe when we add callbacks to lru cache + # nodes, otherwise the number of callbacks would grow. + return cls._cache_context_objects.setdefault( + (cache, cache_key), cls(cache, cache_key) + ) def cached(