forked from Mirror/pmbootstrap
meta: cache: fix caching and add tests (MR 2252)
Just use inspect... Fix some fairly big issues and add some tests Signed-off-by: Caleb Connolly <caleb@postmarketos.org>
This commit is contained in:
parent
2cf44da301
commit
29eb4e950e
2 changed files with 131 additions and 42 deletions
|
@ -4,6 +4,8 @@
|
|||
import copy
|
||||
from typing import Callable, Dict, Optional
|
||||
|
||||
import inspect
|
||||
|
||||
|
||||
class Wrapper:
|
||||
def __init__(self, cache: "Cache", func: Callable):
|
||||
|
@ -12,6 +14,8 @@ class Wrapper:
|
|||
self.disabled = False
|
||||
self.__module__ = func.__module__
|
||||
self.__name__ = func.__name__
|
||||
self.hits = 0
|
||||
self.misses = 0
|
||||
|
||||
|
||||
# When someone attempts to call a cached function, they'll
|
||||
|
@ -27,21 +31,23 @@ class Wrapper:
|
|||
key = self.cache.build_key(self.func, *args, **kwargs)
|
||||
# Don't cache
|
||||
if key is None:
|
||||
self.misses += 1
|
||||
return self.func(*args, **kwargs)
|
||||
|
||||
if key not in self.cache.cache:
|
||||
try:
|
||||
self.cache.cache[key] = self.func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
raise e
|
||||
elif self.cache.cache_deepcopy:
|
||||
self.cache.cache[key] = copy.deepcopy(self.cache.cache[key])
|
||||
self.misses += 1
|
||||
self.cache.cache[key] = self.func(*args, **kwargs)
|
||||
else:
|
||||
self.hits += 1
|
||||
if self.cache.cache_deepcopy:
|
||||
self.cache.cache[key] = copy.deepcopy(self.cache.cache[key])
|
||||
|
||||
#print(f"Cache: {func.__name__}({key})")
|
||||
return self.cache.cache[key]
|
||||
|
||||
def cache_clear(self):
|
||||
self.cache.clear()
|
||||
self.misses = 0
|
||||
self.hits = 0
|
||||
|
||||
def cache_disable(self):
|
||||
self.disabled = True
|
||||
|
@ -77,46 +83,39 @@ class Cache:
|
|||
if not self.params and not self.kwargs:
|
||||
return key
|
||||
|
||||
argnames = list(func.__code__.co_varnames)[:func.__code__.co_argcount]
|
||||
signature = inspect.signature(func)
|
||||
|
||||
# Build a dictionary of the arguments passed to the function and their values
|
||||
# including the default values
|
||||
# This is a silly mess because I wanted to avoid using inspect, but the reflection
|
||||
# stuff is not fun to work with...
|
||||
_kwargs = {}
|
||||
kwargs_start = len(argnames)-len(list(func.__defaults__ or [])) - 1
|
||||
for i in range(len(argnames)-1, 0, -1):
|
||||
arg = argnames[i]
|
||||
if arg not in self.kwargs:
|
||||
continue
|
||||
if arg in kwargs:
|
||||
_kwargs[argnames[i]] = kwargs[arg]
|
||||
elif i >= kwargs_start:
|
||||
#print(f"{func.__name__} -- {i}: {argnames[i]}")
|
||||
_kwargs[argnames[i]] = list(func.__defaults__ or [])[kwargs_start + i - 1]
|
||||
passed_args = dict(zip(argnames, args + tuple(_kwargs)))
|
||||
|
||||
#print(f"Cache.build_key({func}, {args}, {kwargs}) -- {passed_args}")
|
||||
if self.kwargs:
|
||||
for k, v in self.kwargs.items():
|
||||
if k not in argnames:
|
||||
raise ValueError(f"Cache key attribute {k} is not a valid parameter to {func.__name__}()")
|
||||
# Get the value passed into the function, or the default value
|
||||
# FIXME: could get a false hit if this is None
|
||||
passed_val = passed_args.get(k, _kwargs.get(k))
|
||||
if passed_val != v:
|
||||
return None
|
||||
passed_args: Dict[str, str] = {}
|
||||
for i, (k, val) in enumerate(signature.parameters.items()):
|
||||
if k in self.params or k in self.kwargs:
|
||||
if i < len(args):
|
||||
passed_args[k] = args[i]
|
||||
elif k in kwargs:
|
||||
passed_args[k] = kwargs[k]
|
||||
elif val.default != inspect.Parameter.empty:
|
||||
passed_args[k] = val.default
|
||||
else:
|
||||
key += f"{k}=({v})~"
|
||||
raise ValueError(f"Invalid cache key argument {k}"
|
||||
f" in function {func.__module__}.{func.__name__}")
|
||||
|
||||
for k, v in self.kwargs.items():
|
||||
if k not in signature.parameters.keys():
|
||||
raise ValueError(f"Cache key attribute {k} is not a valid parameter to {func.__name__}()")
|
||||
passed_val = passed_args[k]
|
||||
if passed_val != v:
|
||||
# Don't cache
|
||||
return None
|
||||
else:
|
||||
key += f"{k}=({v})~"
|
||||
|
||||
if self.params:
|
||||
for i, param in enumerate(args + tuple(kwargs.keys())):
|
||||
if argnames[i] in self.params[0]:
|
||||
if param.__str__ != object.__str__:
|
||||
key += f"{param}~"
|
||||
for k, v in passed_args.items():
|
||||
if k in self.params:
|
||||
if v.__str__ != object.__str__:
|
||||
key += f"{v}~"
|
||||
else:
|
||||
raise ValueError(f"Cache key argument {argnames[i]} to function"
|
||||
f" {func.__name__} must be a stringable type")
|
||||
raise ValueError(f"Cache key argument {k} to function"
|
||||
f" {func.__name__} must be a stringable type")
|
||||
|
||||
return key
|
||||
|
||||
|
|
90
pmb/meta/test_cache.py
Normal file
90
pmb/meta/test_cache.py
Normal file
|
@ -0,0 +1,90 @@
|
|||
from typing import List
|
||||
import pytest
|
||||
|
||||
from . import Cache, Wrapper
|
||||
|
||||
|
||||
def test_cache_hits_basic():
|
||||
def multiply_2(x: int) -> int:
|
||||
return x * 2
|
||||
|
||||
multiply_2_cached = Cache("x")(multiply_2)
|
||||
|
||||
assert isinstance(multiply_2_cached, Wrapper)
|
||||
|
||||
assert multiply_2(2) == 4
|
||||
|
||||
assert multiply_2_cached(2) == 4
|
||||
assert multiply_2_cached.misses == 1
|
||||
|
||||
assert multiply_2_cached(2) == 4
|
||||
assert multiply_2_cached.hits == 1
|
||||
|
||||
assert multiply_2_cached(3) == 6
|
||||
assert multiply_2_cached.misses == 2
|
||||
|
||||
assert multiply_2_cached(4) == 8
|
||||
assert multiply_2_cached.misses == 3
|
||||
|
||||
assert multiply_2_cached(3) == 6
|
||||
assert multiply_2_cached.hits == 2
|
||||
|
||||
def test_cache_hits_kwargs():
|
||||
def multiply_2(x: int, y: int = 2, z: List[int] = []) -> int:
|
||||
return x * y + sum(z)
|
||||
|
||||
multiply_2_cached = Cache("x", "y", "z")(multiply_2)
|
||||
|
||||
assert isinstance(multiply_2_cached, Wrapper)
|
||||
|
||||
assert multiply_2(2) == 4
|
||||
assert multiply_2_cached(2) == 4
|
||||
assert multiply_2_cached.misses == 1
|
||||
assert multiply_2(2, 3) == multiply_2_cached(2, 3)
|
||||
assert multiply_2_cached.misses == 2
|
||||
assert multiply_2(2, 3) == multiply_2_cached(2, 3)
|
||||
assert multiply_2_cached.hits == 1
|
||||
|
||||
assert multiply_2(3, 4, [1, 1]) == 14
|
||||
assert multiply_2_cached(3, 4, [1, 1]) == 14
|
||||
assert multiply_2_cached(3, 3, [1, 1]) == 11
|
||||
assert multiply_2_cached.misses == 4
|
||||
assert multiply_2_cached(3, 4, [1, 1]) == 14
|
||||
assert multiply_2_cached.hits == 2
|
||||
|
||||
# Should only cache when y=3
|
||||
multiply_2_cached_y3 = Cache("x", "z", y=3)(multiply_2)
|
||||
|
||||
assert multiply_2_cached_y3(1, 1, [1, 1]) == 3
|
||||
assert multiply_2_cached_y3.misses == 1
|
||||
|
||||
assert multiply_2_cached_y3(1, 1, [1, 1]) == 3
|
||||
assert multiply_2_cached_y3.misses == 2
|
||||
|
||||
assert multiply_2_cached_y3(1, 3, [4, 1]) == 8
|
||||
assert multiply_2_cached_y3.misses == 3
|
||||
assert multiply_2_cached_y3(1, 3, [4, 1]) == 8
|
||||
assert multiply_2_cached_y3.hits == 1
|
||||
|
||||
def test_build_key():
|
||||
def multiply_2(x: int, y: int = 2, z: List[int] = []) -> int:
|
||||
return x * y + sum(z)
|
||||
|
||||
multiply_2_cached = Cache("x", "y", "z")(multiply_2)
|
||||
|
||||
key = multiply_2_cached.cache.build_key(multiply_2, 1, 2, [3, 4])
|
||||
print(f"KEY: {key}")
|
||||
|
||||
assert key == "~1~2~[3, 4]~"
|
||||
|
||||
multiply_2_cached_y4 = Cache("x", "z", y=4)(multiply_2)
|
||||
|
||||
# Key should be None since y != 4
|
||||
key = multiply_2_cached_y4.cache.build_key(multiply_2, 1, 2, [3, 4])
|
||||
print(f"Expecting None KEY: {key}")
|
||||
assert key is None
|
||||
|
||||
# Now we expect a real key since y is 4
|
||||
key = multiply_2_cached_y4.cache.build_key(multiply_2, 1, 4, [3, 4])
|
||||
print(f"KEY: {key}")
|
||||
assert key == "~y=(4)~1~[3, 4]~"
|
Loading…
Add table
Add a link
Reference in a new issue