Source code for meshed.scrap.cached_dag

from typing import Mapping
from functools import partial
from collections import ChainMap
from meshed.itools import edge_reversed_graph, descendants
from i2 import Sig, Param, sort_params


[docs] class NotAllowed(Exception): """To use to indicate that something is not allowed"""
[docs] class OverWritesNotAllowedError(NotAllowed): """Error to raise when a writes to existing keys are not allowed"""
def get_first_item_and_assert_unicity(seq): seq_length = len(seq) if seq_length: assert seq_length == 1, ( f'There should be one and one only item in the ' f'sequence: {seq}' ) return seq[0] else: return None def func_node_names_and_outs(dag): for func_node in dag.func_nodes: yield func_node.name, func_node.out
[docs] class NoOverwritesDict(dict): """ A dict where you're not allowed to write to a key that already has a value in it. >>> d = NoOverwritesDict(a=1, b=2) >>> d {'a': 1, 'b': 2} Writing is allowed, in new keys >>> d['c'] = 3 >>> d {'a': 1, 'b': 2, 'c': 3} It's also okay to write into an existing key if the value it holds is identical. In fact, the write doesn't even happen. >>> d['b'] = 2 But if we try to write a different value... >>> d['b'] = 22 # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): ... cached_dag.OverWritesNotAllowedError: The b key already exists and you're not allowed to change its value """ def __setitem__(self, key, value): if key not in self: super().__setitem__(key, value) elif value != self[key]: raise OverWritesNotAllowedError( f"The {key} key already exists and you're not allowed to change its " f'value' )
# else, don't even write the value since it's the same NoSuchKey = type('NoSuchKey', (), {}) # TODO: Cache validation and invalidation # TODO: Continue constructing uppward towards lazyprop-using class (instances are # varnodes)
[docs] class CachedDag: """ Wraps a DAG, using it to compute any of it's var nodes from it's dependents, with the capability of caching intermediate var nodes for later reuse. >>> def add(a, b=1): ... return a + b >>> def mult(x, y=2): ... return x * y >>> def subtract(a, b=4): ... return a - b >>> from meshed import code_to_dag >>> >>> @code_to_dag(func_src=locals()) ... def dag(w, ww, www): ... x = mult(w, ww) ... y = add(x, www) ... z = subtract(x, y) >>> print(dag.dot_digraph_ascii()) # doctest: +SKIP .. code-block:: w ┌──────────┐ ww= ──▶ │ mult │ └──────────┘ x ─┐ │ │ │ │ ▼ │ ┌──────────┐ │ www= ──▶ │ add │ │ └──────────┘ │ │ │ │ │ ▼ │ y= │ │ │ │ │ ▼ │ ┌──────────┐ │ │ subtract │ ◀┘ └──────────┘ z >>> from inspect import signature >>> g = CachedDag(dag) >>> signature(g) <Signature (k, /, **input_kwargs)> We can get ``ww`` because it has a default: (TODO: This (and further tests) stopped working since code_to_dag was enhanced with the ability to use the wrapped function's signature to determine the signature of the output dag. Need to fix this.) >>> g('ww') 2 But we can't get ``y`` because we don't have what it depends on: >>> g('y') Traceback (most recent call last): ... TypeError: The input_kwargs of a dag call is missing 1 required argument: 'w' It needs a ``w?``! No, it needs an ``x``! But to get an ``x`` you need a ``w``, and... >>> g('x') Traceback (most recent call last): ... TypeError: The input_kwargs of a dag call is missing 1 required argument: 'w' So let's give it a w! >>> g('x', w=3) # == 3 * 2 == 6 And now this works: >>> g('x') 6 because >>> g.cache {'x': 6} and this will work too: >>> g('y') 7 >>> g.cache {'x': 6, 'y': 7} But this is something we need to handle better! >>> g('x', w=10) 6 This is happending because there's already a x in the cache, and it takes precedence. This would be okay if consider CachedDag as a low level object that is never actually used by a user. But we need to protect the user from such effects! First, we probably should cache inputs too. The we can: - Make computation take precedence over cache, overwriting the existing cache with the new resulting values - Allow the user to declare the entire cache, or just some variables in it, as write-once, to avoid creating bugs with the above proposal. - Cache multiple paths (lru_cache style) for different input combinations """ def __init__(self, dag, cache=True, name=None): self.dag = dag self.reversed_graph = edge_reversed_graph(dag.graph_ids) self.roots = set(self.dag.roots) self.leafs = set(self.dag.leafs) self.var_nodes = set(self.dag.var_nodes) self.func_node_of_id = {fn.out: fn for fn in self.dag.func_nodes} self.name = name self.out_of_func_node_name = dict(func_node_names_and_outs(self.dag)) self._dag_sig = Sig(self.dag) self.defaults = self._dag_sig.defaults if cache is True: self.cache = NoOverwritesDict() elif not isinstance(cache, Mapping): raise NotImplementedError( 'This type of cache is not implemented (must resolve to a Mapping): ' f'{cache=}' ) self._cache = ChainMap(self.defaults, self.cache) @property def __name__(self): return self.name or self.dag.__name__ def __iter__(self): yield from self.reversed_graph def func_node_id(self, k): func_node_name = get_first_item_and_assert_unicity(self.reversed_graph[k]) if func_node_name is not None: return self.out_of_func_node_name[func_node_name] # TODO: Consider having args and kwargs instead of just input_kwargs. # or making it (k, /, *args, **kwargs) def __call__(self, k, /, **input_kwargs): # print(f"Calling ({k=},{input_kwargs=})\t{self.cache=}") input_kwargs = dict(input_kwargs) if intersection := (input_kwargs.keys() & self.cache.keys()): # TODO: Can give the user a more informative/correct message, since the # user has more options than just the root nodes: They some combination of # intermediates would also satisfy requirements. raise ValueError( f"input_kwargs can't contain any keys that are already in cache! " f'These names were in both: {intersection}' ) _cache = ChainMap(input_kwargs, self._cache) if k in _cache: return _cache[k] input_kwargs = dict(input_kwargs) func_node_id = self.func_node_id(k) # print(f"{func_node_id=}") if func_node_id: if (output := self.cache.get(func_node_id)) is not None: return output else: func_node = self.func_node_of_id[func_node_id] input_sources = { src: self(src, **input_kwargs) for src in func_node.bind.values() } # inputs = dict(input_sources, **input_kwargs) # # TODO: do we need to include **self.defaults in the middle? inputs = ChainMap(_cache, input_sources) # print(f"Computing {func_node_id}: ", end=" ") output = func_node.call_on_scope(inputs, write_output_into_scope=False) self.cache[func_node_id] = output # print(f"result -> {output}") return output else: # k is a root node assert k in self.roots, f'Was expecting this to be a root node: {k}' inputs = ChainMap(input_kwargs, self._cache) if (output := inputs.get(k, NoSuchKey)) is not NoSuchKey: return output else: raise TypeError( f'The input_kwargs of a {self.__name__} call is missing 1 required ' f"argument: '{k}'" ) def _call(self, k, /, **kwargs): return self(k, **kwargs)
[docs] def roots_for(self, node): """ The set of roots that lead to ``node``. >>> from meshed.makers import code_to_dag >>> @code_to_dag ... def dag(): ... x = mult(w, ww) ... y = add(x, www) ... z = subtract(x, y) >>> print(dag.synopsis_string()) w,ww -> mult -> x x,www -> add -> y x,y -> subtract -> z >>> g = CachedDag(dag) >>> sorted(g.roots_for('x')) ['w', 'ww'] >>> sorted(g.roots_for('y')) ['w', 'ww', 'www'] """ return set( filter(self.roots.__contains__, descendants(self.reversed_graph, node)) )
def _signature_for_node_method(self, node): def gen(): for name in filter(lambda x: x not in self.cache, self.roots_for(node)): yield Param( name=name, kind=Param.KEYWORD_ONLY, default=self.defaults.get(name, Param.empty), annotation=self._dag_sig.annotations.get(name, Param.empty), ) return Sig(sort_params(gen())) def inject_methods(self, obj=None): # TODO: Should be input_names of reversed_graph, but resulting "shadow" in # the root nodes, along with their defaults (filtered by cache) non_root_var_nodes = list(filter(lambda x: x not in self.roots, self.var_nodes)) if obj is None: from types import SimpleNamespace obj = SimpleNamespace(**{k: None for k in non_root_var_nodes}) for var_node in non_root_var_nodes: sig = self._signature_for_node_method(var_node) f = sig(partial(self._call, var_node)) setattr(obj, var_node, f) obj._cache = self.cache return obj
[docs] def cached_dag_test(): """ Covering issue https://github.com/i2mint/meshed/issues/34 about "CachedDag.cache should be populated with inputs that it was called on" """ from meshed.dag import DAG def f(a, x=1): return a + x def g(a, y=2): return a * y dag = DAG([f, g]) c = CachedDag(dag) c('g', a=1) assert c.cache == {'g': 2, 'a': 1} assert c('f' == 2)
def add(a, b=1): return a + b def mult(x, y=2): return x * y def exp(mult, n=3): return mult ** n def subtract(a, b=4): return a - b # from meshed import code_to_dag # # # @code_to_dag(func_src=locals()) # def dag(w, ww, www): # x = mult(w, ww) # y = add(x, www) # z = subtract(x, y) # # # g = CachedDag(dag) # # assert g('z', {'w': 2, 'ww': 3, 'www': 4}) == -4 == dag(2, 3, 4)