"""Test dags"""
import pytest
from meshed.makers import code_to_dag
# Note: This is just for the linter not to complain about the code_to_dag dag
mult, add, subtract, w, ww, www, x, y, z = map(lambda x: x, [None] * 9)
@code_to_dag()
def mult_add_subtract_dag():
x = mult(w, ww)
y = add(x, www)
z = subtract(x, y)
def pass_on_tuple(a, b):
return a, b
def add(x, y):
return x + y
def _expand_and_sum_dag():
x, y = pass_on_tuple(w, ww)
result = add(x, y)
expand_and_sum_dag = code_to_dag(_expand_and_sum_dag, func_src=locals())
def test_code_to_dag_itemgetter():
assert expand_and_sum_dag(2, 3) == 5
def test_dag_operations():
# from meshed.makers import code_to_dag
#
# # Note: This is just for the linter not to complain about the code_to_dag dag
# mult, add, subtract, w, ww, www, x, y, z = map(lambda x: x, [None] * 9)
#
# @code_to_dag()
# def dag():
# x = mult(w, ww)
# y = add(x, www)
# z = subtract(x, y)
dag = mult_add_subtract_dag
from i2 import Sig
assert str(Sig(dag)) == '(w, ww, www)'
assert (
dag(1, 2, 3) == 'subtract(x=mult(w=1, ww=2), y=add(x=mult(w=1, ww=2), www=3))'
)
dag = dag.ch_funcs(
mult=lambda w, ww: w * ww,
add=lambda x, www: x + www,
subtract=lambda x, y: x - y,
)
assert str(Sig(dag)) == '(w, ww, www)'
assert dag(1, 2, 3) == -3
[docs]
def test_funcnode_bind():
"""
Test the renaming of arguments and output of functions using FuncNode and its
effect on DAG
"""
from meshed.dag import DAG
from meshed import FuncNode
def f(a, b):
return a + b
def g(a_plus_b, d):
return a_plus_b * d
# here we specify that the output of f will be injected in g as an argument for the parameter a_plus_b
f_node = FuncNode(func=f, out='a_plus_b')
g_node = FuncNode(func=g)
dag = DAG((f_node, g_node))
assert dag(a=1, b=2, d=3) == 9
# we can do more complex renaming as well, for example here we specify that the value for b is also the value for d,
# resulting in the dag being now 2 variable dag
f_node = FuncNode(func=f, out='a_plus_b')
g_node = FuncNode(func=g, bind={'d': 'b'})
dag = DAG((f_node, g_node))
assert dag(a=1, b=2) == 6
def test_iterize_dag():
def f(a, b=2):
return a + b
def g(f, c=3):
return f * c
from meshed import DAG
d = DAG([f, g])
# d.dot_digraph() # smoke testing the digraph
assert ( # if you needed to apply d to an iterator, you'd normally do this
list(map(d, [1, 2, 3]))
) == ([9, 12, 15])
# But if you need a function that "looks like" d, but is "vectorized" (really
# iterized) version...
from functools import partial
from inspect import signature
def iterize(func):
_iterized_func = partial(map, func)
_iterized_func.__signature__ = signature(func)
return _iterized_func
di = iterize(d)
# di has the same signature as d:
assert signature(di) == signature(d)
assert (list(di([1, 2, 3]))) == ([9, 12, 15]) # But works with a being an iterator
# Note that di will return an iterator that needs to be "consumed" (here with list)
# That is, no matter what the (iterable) type of the input is.
# If you wanted to systematically get your output as a list (or tuple, or set,
# or numpy.array),
# there's several choices...
# You could use i2.Pipe
from i2 import Pipe
di_list = Pipe(di, list)
assert di_list([1, 2, 3]) == [9, 12, 15]
[docs]
def test_binding_to_a_root_node():
"""
See: https://github.com/i2mint/meshed/issues/7
"""
from meshed.dag import DAG
from meshed.util import ValidationError
from meshed import FuncNode
def f(a, b):
return a + b
def g(a_plus_b, d):
return a_plus_b * d
# we bind d to b, and it works!
f_node = FuncNode(func=f, out='a_plus_b')
g_node = FuncNode(func=g, bind={'d': 'b'})
dag = DAG((f_node, g_node))
assert dag(a=1, b=2) == 6
# but if b and d are not aligned on all other parameter props besides name
# (kind, default, annotation), then we get an error
def gg(a_plus_b, d=4):
return a_plus_b * d
gg_node = FuncNode(func=gg, bind={'d': 'b'})
with pytest.raises(ValidationError) as e_info:
_ = DAG((f_node, gg_node))
assert "didn't have the same default" in e_info.value.args[0]
# There's several solutions to this.
# First, we can simply prepare the functions so that the defaults align.
# The following shows how to do this in two different ways
# 1: "Manually"
def ff(a, b=4):
return f(a, b)
ff_node = FuncNode(func=ff, out='a_plus_b')
dag = DAG((ff_node, gg_node))
assert dag(a=1, b=2) == 6
# 2: With i2.Sig
from i2 import Sig
give_default_to_b = lambda func: Sig(func).ch_defaults(b=4)(func)
ff_node = FuncNode(func=give_default_to_b(f), out='a_plus_b')
dag = DAG((ff_node, gg_node))
assert dag(a=1, b=2) == 6
# And if you don't specify b, it has that default you set!
assert dag(a=1) == 20
# Second, we could specify a different "merging policy" (the function that
# determines how to resolve the issue of several params with the same name
# (or binding) that conflict on some prop (kind, default and/or annotation)
# Before we go there though, let's show that default is not the only problem.
# If the annotation, or the kind are different, we also run in to the same problem
# (and solution to it)
def f(a, b):
return a + b
def ggg(a_plus_b, d: int): # note that d has no default, but an annotation
return a_plus_b * d
ggg_node = FuncNode(func=ggg, bind={'d': 'b'})
with pytest.raises(ValidationError) as e_info:
_ = DAG((f_node, ggg_node))
assert "didn't have the same annotation" in e_info.value.args[0]
# Solution (with i2.Sig)
give_annotation_to_b = lambda func: Sig(func).ch_annotations(b=int)(func)
ff_node = FuncNode(func=give_annotation_to_b(f), out='a_plus_b')
dag = DAG((ff_node, ggg_node))
assert dag(a=1, b=2) == 6
# The other solution to the parameter property misalignment is to tell the DAG
# constructor what we want it to do with conflicts. For example, just ignore them.
# (Not a good general policy though!)
from meshed.dag import conservative_parameter_merge
from functools import partial
first_wins_all_merger = partial(
conservative_parameter_merge,
same_kind=False,
same_default=False,
same_annotation=False,
)
def f(a, b: int, /):
return a + b
def g(a_plus_b, d: float = 4):
return a_plus_b * d
lenient_dag_maker = partial(DAG, parameter_merge=first_wins_all_merger)
f_node = FuncNode(func=f, out='a_plus_b')
g_node = FuncNode(func=g, bind={'d': 'b'})
dag = lenient_dag_maker([f_node, g_node])
assert dag(1, 2) == 6
# Note we can't do dag(a=1, b=2) since (like f) it's position-only.
# Indeed the dag inherits its arguments' properties from the functions composing it, in this case f
# Resolving conflicts this way isn't the best general policy (that's why it's not
# the default).
# In production, it's advised to implement a more careful merging policy, possibly
# specifying (in the `parameter_merge` callable itself) explicitly what to do for
# every situation that we encounter.
def test_dag_partialize():
from functools import partial
from i2 import Sig
from meshed import DAG
from inspect import signature
def foo(a, b):
return a - b
f = DAG([foo])
assert str(Sig(f)) == '(a, b)'
# if we give ``b`` a default:
ff = f.partial(b=9)
assert str(Sig(ff)) == '(a, b=9)'
# note that the Sig of the partial of foo is '(a, *, b=9)' though
assert str(Sig(partial(foo, b=9))) == '(a, *, b=9)'
assert ff(10) == ff(a=10) == 1
# if we give ``a`` (the first arg) a default but not ``b`` (the second arg)
fff = f.partial(a=4) # fixing a, which is before b
# note that this fixing a reorders the parameters (so we have a valid signature!)
assert str(Sig(fff)) == '(b, a=4)'
fn = fff.func_nodes[0]
assert fn.call_on_scope(dict(b=3)) == 1
def f(a, b):
return a + b
def g(c, d=4):
return c * d
def h(f, g):
return g - f
larger_dag = DAG([f, g, h])
new_dag = larger_dag.partial(c=3, a=1)
assert new_dag(b=5, d=6) == 12
assert str(signature(new_dag)) == '(b, a=1, c=3, d=4)'