diff --git a/pmb/meta/__init__.py b/pmb/meta/__init__.py index 9c8b7792..a9f82638 100644 --- a/pmb/meta/__init__.py +++ b/pmb/meta/__init__.py @@ -2,13 +2,15 @@ # SPDX-License-Identifier: GPL-3.0-or-later import copy -from typing import Callable, Optional +from typing import Callable, Generic, Optional, TypeVar, overload import inspect +FuncArgs = TypeVar("FuncArgs") +FuncReturn = TypeVar("FuncReturn") -class Wrapper: - def __init__(self, cache: "Cache", func: Callable): +class Wrapper(Generic[FuncArgs, FuncReturn]): + def __init__(self, cache: "Cache", func: Callable[[FuncArgs], FuncReturn]): self.cache = cache self.func = func self.disabled = False @@ -21,7 +23,7 @@ class Wrapper: # actually end up here. We first check if we have a cached # result and if not then we do the actual function call and # cache it if applicable - def __call__(self, *args, **kwargs): + def __call__(self, *args, **kwargs) -> FuncReturn: if self.disabled: return self.func(*args, **kwargs) @@ -123,7 +125,15 @@ class Cache: return key - def __call__(self, func: Callable): + @overload + def __call__(self, func: Callable[..., FuncReturn]) -> Wrapper[None, FuncReturn]: + ... + + @overload + def __call__(self, func: Callable[[FuncArgs], FuncReturn]) -> Wrapper[FuncArgs, FuncReturn]: + ... + + def __call__(self, func: Callable[[FuncArgs], FuncReturn]) -> Wrapper[FuncArgs, FuncReturn]: argnames = func.__code__.co_varnames for a in self.params: if a not in argnames: @@ -131,6 +141,7 @@ class Cache: f"Cache key attribute {a} is not a valid parameter to {func.__name__}()" ) + # FIXME: Once PEP-695 generics are in we shouldn't need this. return Wrapper(self, func) def clear(self):