"""\
Callback-based cache implementation (no storage, only triggers callbacks on set).
"""
__all__ = [
"CallbackCache",
]
from ..utils.basic.log_utils import get_logger
logger = get_logger(__name__)
from .base import BaseCache
from typing import Any, Generator, Optional, Iterable, Dict, Callable, Union
[docs]
class CallbackCache(BaseCache):
"""\
An implementation of BaseCache that does not cache any data, but calls callbacks on set, and feeds on get.
"""
[docs]
def __init__(
self,
callbacks: Optional[Iterable[Callable[[int, Dict[str, Any]], None]]] = None,
feeds: Optional[Iterable[Callable[[Callable, Any], None]]] = None,
exclude: Optional[Iterable[str]] = None,
*args,
**kwargs,
):
"""\
Initialization.
Args:
callbacks (Optional[Iterable[Callable[[int, Dict[str, Any]], None]]]): List of callback functions to call on set.
Each callback function must has API `callback(key: int, value: Dict[str, Any])`, which handles a cache set event.
feeds (Optional[Iterable[Callable[[Union[Callable, str], Any], None]]]): List of feed functions to call on get.
Each feed function must have API `feed(func: Union[Callable, str], **kwargs)`, which handles a cache get event.
The kwargs are the input to the function.
Notice that feeds must be ordered: the first feed function with a non-Ellipsis return value will be used.
exclude (Optional[Iterable[str]]): Keys to exclude from inputs when creating cache entries.
*args: Additional positional arguments.
**kwargs: Additional keyword arguments.
"""
super().__init__(exclude=exclude, *args, **kwargs)
self.callbacks = callbacks or list()
self.feeds = feeds or list()
def _has(self, key: int) -> bool:
raise NotImplementedError("CallbackCache does not support `_has`.")
def _get(self, key: int, default: Any = ...) -> Dict[str, Any]:
raise NotImplementedError("CallbackCache does not support `_get`.")
def _set(self, key: int, value: Dict[str, Any]):
for cb in self.callbacks:
try:
cb(key, value)
except Exception as e:
logger.error(f"Error occurred in callback for key {key}: {e}. Skipped.")
pass # Ignore callback errors
def _remove(self, key: int):
raise NotImplementedError("CallbackCache does not support `_remove`.")
def __len__(self) -> int:
return 0
def _itervalues(self) -> Generator[Dict[str, Any], None, None]:
return iter([])
def _clear(self):
raise NotImplementedError("CallbackCache does not support `_clear`.")
[docs]
def get(self, func: Union[Callable, str], **kwargs) -> Any:
"""\
Retrieves a cached value for the given function and inputs.
Args:
func (Union[Callable, str]): The function or its name to retrieve the cached value for.
Notice that for `CallbackCache`, when all feed functions return ..., the function will be called:
# (deprecated) If the `func` is callable, it will be called with the provided keyword arguments.
# (deprecated) Otherwise, it will NOT be called.
For better stability, it is recommend to use a default feed function that can handle missing values.
**kwargs: Arbitrary keyword arguments representing the inputs to the function.
Returns:
Any: The cached output if found, otherwise `Ellipsis` (to avoid collisions with functions returning None).
"""
for fd in self.feeds:
try:
result = fd(func, **kwargs)
except Exception as e:
logger.error(f"Error occurred in feed for function {func}: {e}. Skipped.")
result = ...
if result is not ...:
return result
# if callable(func):
# result = func(**kwargs)
# return result
return ...