Source code for meshed.scrap.collapse_and_expand

"""Ideas on collapsing and expanding nodes
See "Collapse and expand nodes" discussion:
https://github.com/i2mint/meshed/discussions/54

"""

import re
from typing import Union, Iterable, Optional, Callable
from meshed.dag import DAG
from meshed.makers import code_to_dag


[docs] def remove_decorator_code( src: str, decorator_names: Optional[Union[str, Iterable[str]]] = None ) -> str: """ Remove the code corresponding to decorators from a source code string. If decorator_names is None, will remove all decorators. If decorator_names is an iterable of strings, will remove the decorators with those names. Examples: >>> src = ''' ... @decorator ... def func(): ... pass ... ''' >>> print(remove_decorator_code(src)) def func(): pass >>> src = ''' ... @decorator1 ... @decorator2 ... def func(): ... pass ... ''' >>> print(remove_decorator_code(src, "decorator1")) @decorator2 def func(): pass """ import ast if isinstance(decorator_names, str): decorator_names = [decorator_names] class DecoratorRemover(ast.NodeTransformer): def visit_FunctionDef(self, node): if node.decorator_list: if decorator_names is None: node.decorator_list = [] # Remove all decorators else: node.decorator_list = [ d for d in node.decorator_list if not (isinstance(d, ast.Name) and d.id in decorator_names) ] return node def visit_ClassDef(self, node): if node.decorator_list: if decorator_names is None: node.decorator_list = [] # Remove all decorators else: node.decorator_list = [ d for d in node.decorator_list if not (isinstance(d, ast.Name) and d.id in decorator_names) ] return node tree = ast.parse(src) new_tree = DecoratorRemover().visit(tree) ast.fix_missing_locations(new_tree) return ast.unparse(new_tree)
def get_src_string(src: Union[str, DAG]) -> str: if isinstance(src, str): return src elif hasattr(src, "_code_to_dag_src"): return get_src_string(src._code_to_dag_src) elif callable(src): import inspect return inspect.getsource(src) else: raise ValueError( f"src should be a string or have a _code_to_dag_src (meaning the src was " f"made with code_to_dag), not a {type(src)}" ) # TODO: Generalize to src that is any DAG
[docs] def collapse_function_calls( src: Union[str, DAG], call_func_name="call", *, rm_decorator="code_to_dag", include: Optional[Union[Iterable[str], Callable[[str], bool]]] = None, ): """ Contract function calls in a source code string. That is, in source code, or a dag made from code_to_dag, replace calls of the form `call(func, arg)` with `func(arg)`. Note: Doesn't work with arbitrary DAG src, only those made from code_to_dag. """ src_string = get_src_string(src) def should_include(func_name): if include is None: return True if isinstance(include, Iterable): return func_name in include if callable(include): return include(func_name) return False pattern = call_func_name + r"\(([^,]+),\s*([^)]+)\)" def replace(match): func_name, args = match.groups() if should_include(func_name): return f"{func_name}({args})" return match.group(0) new_src = re.sub(pattern, replace, src_string) if rm_decorator: new_src = remove_decorator_code(new_src, decorator_names=rm_decorator) return new_src if isinstance(src, str) else code_to_dag(new_src)
[docs] def expand_function_calls( src: Union[str, "DAG"], call_func_name="call", *, include: Optional[Union[Iterable[str], Callable[[str], bool]]] = None, ) -> str: """ Inverse of collapse_function_calls. It replaces calls of the form `func(arg)` with `call(func, arg)`, except when the function call is part of a function definition header. If include is None, it expands all function calls. If include is a list of function names, only those functions are expanded. If include is a callable, it's used as a filter function. """ src_string = get_src_string(src) def should_include(func_name): if include is None: return True if isinstance(include, Iterable): return func_name in include if callable(include): return include(func_name) return False pattern = r"(\b[a-zA-Z_]\w*)\(([^)]*)\)" def replace(match): # Get the start index of the match index = match.start() # Find the beginning of the current line line_start = src_string.rfind("\n", 0, index) + 1 # Extract the text from the start of the line up to the match current_line = src_string[line_start:index] # If the current line starts with a function definition, skip expansion. if re.match(r"^\s*def\s", current_line): return match.group(0) func_name, args = match.groups() if should_include(func_name): return f"{call_func_name}({func_name}, {args})" return match.group(0) new_src = re.sub(pattern, replace, src_string) return new_src if isinstance(src, str) else code_to_dag(new_src)
# ------------------------------------------------------------------------------ # Older code from dataclasses import dataclass from i2 import Sig from meshed.dag import DAG
[docs] @dataclass class CollapsedDAG: """To collapse a DAG into a single function This is useful for when you want to use a DAG as a function, but you don't want to see all the arguments. """ dag: DAG def __post_init__(self): Sig(self.dag)(self) # so that __call__ gets dag's signature self.__name__ = self.dag.name def __call__(self, *args, **kwargs): return self.dag(*args, **kwargs) def expand(self): return self.dag
# TODO: Finish this def expand_nodes( dag, nodes=None, *, is_node=lambda fnode, node: fnode.name == node or fnode.out == node, expansion_record_store=None, # TODO: Implement this to keep track of what was expanded ): if nodes is None: nodes = ... # find all func_nodes that have isinstance(fn.func, CollapsedDAG) def change_node_or_not(node): if is_node(node): return CollapsedDAG(node.func) else: return node return DAG(list(map(change_node_or_not, dag.func_nodes)))