diff --git a/pmb/meta/__init__.py b/pmb/meta/__init__.py index 88262b39..daa5dfb6 100644 --- a/pmb/meta/__init__.py +++ b/pmb/meta/__init__.py @@ -4,6 +4,8 @@ import copy from typing import Callable, Dict, Optional +import inspect + class Wrapper: def __init__(self, cache: "Cache", func: Callable): @@ -12,6 +14,8 @@ class Wrapper: self.disabled = False self.__module__ = func.__module__ self.__name__ = func.__name__ + self.hits = 0 + self.misses = 0 # When someone attempts to call a cached function, they'll @@ -27,21 +31,23 @@ class Wrapper: key = self.cache.build_key(self.func, *args, **kwargs) # Don't cache if key is None: + self.misses += 1 return self.func(*args, **kwargs) if key not in self.cache.cache: - try: - self.cache.cache[key] = self.func(*args, **kwargs) - except Exception as e: - raise e - elif self.cache.cache_deepcopy: - self.cache.cache[key] = copy.deepcopy(self.cache.cache[key]) + self.misses += 1 + self.cache.cache[key] = self.func(*args, **kwargs) + else: + self.hits += 1 + if self.cache.cache_deepcopy: + self.cache.cache[key] = copy.deepcopy(self.cache.cache[key]) - #print(f"Cache: {func.__name__}({key})") return self.cache.cache[key] def cache_clear(self): self.cache.clear() + self.misses = 0 + self.hits = 0 def cache_disable(self): self.disabled = True @@ -77,46 +83,39 @@ class Cache: if not self.params and not self.kwargs: return key - argnames = list(func.__code__.co_varnames)[:func.__code__.co_argcount] + signature = inspect.signature(func) - # Build a dictionary of the arguments passed to the function and their values - # including the default values - # This is a silly mess because I wanted to avoid using inspect, but the reflection - # stuff is not fun to work with... - _kwargs = {} - kwargs_start = len(argnames)-len(list(func.__defaults__ or [])) - 1 - for i in range(len(argnames)-1, 0, -1): - arg = argnames[i] - if arg not in self.kwargs: - continue - if arg in kwargs: - _kwargs[argnames[i]] = kwargs[arg] - elif i >= kwargs_start: - #print(f"{func.__name__} -- {i}: {argnames[i]}") - _kwargs[argnames[i]] = list(func.__defaults__ or [])[kwargs_start + i - 1] - passed_args = dict(zip(argnames, args + tuple(_kwargs))) - - #print(f"Cache.build_key({func}, {args}, {kwargs}) -- {passed_args}") - if self.kwargs: - for k, v in self.kwargs.items(): - if k not in argnames: - raise ValueError(f"Cache key attribute {k} is not a valid parameter to {func.__name__}()") - # Get the value passed into the function, or the default value - # FIXME: could get a false hit if this is None - passed_val = passed_args.get(k, _kwargs.get(k)) - if passed_val != v: - return None + passed_args: Dict[str, str] = {} + for i, (k, val) in enumerate(signature.parameters.items()): + if k in self.params or k in self.kwargs: + if i < len(args): + passed_args[k] = args[i] + elif k in kwargs: + passed_args[k] = kwargs[k] + elif val.default != inspect.Parameter.empty: + passed_args[k] = val.default else: - key += f"{k}=({v})~" + raise ValueError(f"Invalid cache key argument {k}" + f" in function {func.__module__}.{func.__name__}") + + for k, v in self.kwargs.items(): + if k not in signature.parameters.keys(): + raise ValueError(f"Cache key attribute {k} is not a valid parameter to {func.__name__}()") + passed_val = passed_args[k] + if passed_val != v: + # Don't cache + return None + else: + key += f"{k}=({v})~" if self.params: - for i, param in enumerate(args + tuple(kwargs.keys())): - if argnames[i] in self.params[0]: - if param.__str__ != object.__str__: - key += f"{param}~" + for k, v in passed_args.items(): + if k in self.params: + if v.__str__ != object.__str__: + key += f"{v}~" else: - raise ValueError(f"Cache key argument {argnames[i]} to function" - f" {func.__name__} must be a stringable type") + raise ValueError(f"Cache key argument {k} to function" + f" {func.__name__} must be a stringable type") return key diff --git a/pmb/meta/test_cache.py b/pmb/meta/test_cache.py new file mode 100644 index 00000000..121d7f68 --- /dev/null +++ b/pmb/meta/test_cache.py @@ -0,0 +1,90 @@ +from typing import List +import pytest + +from . import Cache, Wrapper + + +def test_cache_hits_basic(): + def multiply_2(x: int) -> int: + return x * 2 + + multiply_2_cached = Cache("x")(multiply_2) + + assert isinstance(multiply_2_cached, Wrapper) + + assert multiply_2(2) == 4 + + assert multiply_2_cached(2) == 4 + assert multiply_2_cached.misses == 1 + + assert multiply_2_cached(2) == 4 + assert multiply_2_cached.hits == 1 + + assert multiply_2_cached(3) == 6 + assert multiply_2_cached.misses == 2 + + assert multiply_2_cached(4) == 8 + assert multiply_2_cached.misses == 3 + + assert multiply_2_cached(3) == 6 + assert multiply_2_cached.hits == 2 + +def test_cache_hits_kwargs(): + def multiply_2(x: int, y: int = 2, z: List[int] = []) -> int: + return x * y + sum(z) + + multiply_2_cached = Cache("x", "y", "z")(multiply_2) + + assert isinstance(multiply_2_cached, Wrapper) + + assert multiply_2(2) == 4 + assert multiply_2_cached(2) == 4 + assert multiply_2_cached.misses == 1 + assert multiply_2(2, 3) == multiply_2_cached(2, 3) + assert multiply_2_cached.misses == 2 + assert multiply_2(2, 3) == multiply_2_cached(2, 3) + assert multiply_2_cached.hits == 1 + + assert multiply_2(3, 4, [1, 1]) == 14 + assert multiply_2_cached(3, 4, [1, 1]) == 14 + assert multiply_2_cached(3, 3, [1, 1]) == 11 + assert multiply_2_cached.misses == 4 + assert multiply_2_cached(3, 4, [1, 1]) == 14 + assert multiply_2_cached.hits == 2 + + # Should only cache when y=3 + multiply_2_cached_y3 = Cache("x", "z", y=3)(multiply_2) + + assert multiply_2_cached_y3(1, 1, [1, 1]) == 3 + assert multiply_2_cached_y3.misses == 1 + + assert multiply_2_cached_y3(1, 1, [1, 1]) == 3 + assert multiply_2_cached_y3.misses == 2 + + assert multiply_2_cached_y3(1, 3, [4, 1]) == 8 + assert multiply_2_cached_y3.misses == 3 + assert multiply_2_cached_y3(1, 3, [4, 1]) == 8 + assert multiply_2_cached_y3.hits == 1 + +def test_build_key(): + def multiply_2(x: int, y: int = 2, z: List[int] = []) -> int: + return x * y + sum(z) + + multiply_2_cached = Cache("x", "y", "z")(multiply_2) + + key = multiply_2_cached.cache.build_key(multiply_2, 1, 2, [3, 4]) + print(f"KEY: {key}") + + assert key == "~1~2~[3, 4]~" + + multiply_2_cached_y4 = Cache("x", "z", y=4)(multiply_2) + + # Key should be None since y != 4 + key = multiply_2_cached_y4.cache.build_key(multiply_2, 1, 2, [3, 4]) + print(f"Expecting None KEY: {key}") + assert key is None + + # Now we expect a real key since y is 4 + key = multiply_2_cached_y4.cache.build_key(multiply_2, 1, 4, [3, 4]) + print(f"KEY: {key}") + assert key == "~y=(4)~1~[3, 4]~"