"""Visualization utilities for the meshed package."""
from typing import Iterable, Any
from i2.signatures import Sig
[docs]
def dot_lines_of_objs(objs: Iterable, start_lines=(), end_lines=(), **kwargs):
r"""
Get lines generator for the graphviz.DiGraph(body=list(...))
>>> from meshed.base import FuncNode
>>> def add(a, b=1):
... return a + b
>>> def mult(x, y=3):
... return x * y
>>> def exp(mult, a):
... return mult ** a
>>> func_nodes = [
... FuncNode(add, out='x'),
... FuncNode(mult, name='the_product'),
... FuncNode(exp)
... ]
>>> lines = list(dot_lines_of_objs(func_nodes))
>>> assert lines == [
... 'x [label="x" shape="none"]',
... '_add [label="_add" shape="box"]',
... '_add -> x',
... 'a [label="a" shape="none"]',
... 'b [label="b=" shape="none"]',
... 'a -> _add',
... 'b -> _add',
... 'mult [label="mult" shape="none"]',
... 'the_product [label="the_product" shape="box"]',
... 'the_product -> mult',
... 'x [label="x" shape="none"]',
... 'y [label="y=" shape="none"]',
... 'x -> the_product',
... 'y -> the_product',
... 'exp [label="exp" shape="none"]',
... '_exp [label="_exp" shape="box"]',
... '_exp -> exp',
... 'mult [label="mult" shape="none"]',
... 'a [label="a" shape="none"]',
... 'mult -> _exp',
... 'a -> _exp'
... ] # doctest: +SKIP
>>> from meshed.util import dot_to_ascii
>>>
>>> print(dot_to_ascii('\n'.join(lines))) # doctest: +SKIP
<BLANKLINE>
a ─┐
│
│ │
│ │
▼ │
┌─────────────┐ │
b= ──▶ │ _add │ │
└─────────────┘ │
│ │
│ │
▼ │
│
x │
│
│ │
│ │
▼ │
┌─────────────┐ │
y= ──▶ │ the_product │ │
└─────────────┘ │
│ │
│ │
▼ │
│
mult │
│
│ │
│ │
▼ │
┌─────────────┐ │
│ _exp │ ◀┘
└─────────────┘
│
│
▼
<BLANKLINE>
exp
<BLANKLINE>
"""
# Should we validate here, or outside this module?
# from meshed.base import validate_that_func_node_names_are_sane
# validate_that_func_node_names_are_sane(func_nodes)
yield from start_lines
for obj in objs:
yield from obj.dot_lines(**kwargs)
yield from end_lines
dot_lines_of_func_nodes = dot_lines_of_objs # backwards compatiblity alias
# TODO: Should we integrate this to dot_lines_of_func_parameters directly (decorator?)
[docs]
def add_new_line_if_none(s: str):
"""Since graphviz 0.18, need to have a newline in body lines.
This util is there to address that, adding newlines to body lines
when missing."""
if s and s[-1] != '\n':
return s + '\n'
return s
# ------------------------------------------------------------------------------
# Unused -- consider deleting
def _parameters_and_names_from_sig(
sig: Sig, out=None, func_name=None,
):
func_name = func_name or sig.name
out = out or sig.name
if func_name == out:
func_name = '_' + func_name
assert isinstance(func_name, str) and isinstance(out, str)
return sig.parameters, out, func_name
# ------------------------------------------------------------------------------
# Old stuff
def visualize_graph(graph):
import graphviz
from IPython.display import display
dot = graphviz.Digraph()
# Add nodes to the graph
for node in graph:
dot.node(node)
# Add edges to the graph
for node, neighbors in graph.items():
for neighbor in neighbors:
dot.edge(node, neighbor)
# Render and display the graph in the notebook
display(dot)
def visualize_graph_interactive(graph):
import graphviz
import networkx as nx
import ipywidgets as widgets
from IPython.display import display
g = nx.DiGraph(graph)
# Create an empty Graphviz graph
dot = graphviz.Digraph()
# Add nodes to the Graphviz graph
for node in g.nodes:
dot.node(str(node))
# Add edges to the Graphviz graph
for edge in g.edges:
dot.edge(str(edge[0]), str(edge[1]))
# Render the initial graph visualization
graph_widget = widgets.HTML(value=dot.pipe(format='svg').decode('utf-8'))
display(graph_widget)
def add_edge(sender):
source = source_node.value
target = target_node.value
if (source, target) not in g.edges:
g.add_edge(source, target)
dot.edge(str(source), str(target))
graph_widget.value = dot.pipe(format='svg').decode('utf-8')
source_node.value = ''
target_node.value = ''
def add_node(sender):
node = new_node.value
if node not in g.nodes:
g.add_node(node)
dot.node(str(node))
graph_widget.value = dot.pipe(format='svg').decode('utf-8')
new_node.value = ''
def delete_edge(sender):
source = str(delete_source.value)
target = str(delete_target.value)
if (source, target) in g.edges:
g.remove_edge(source, target)
dot.body.remove(f'\t{source} -> {target}\n')
graph_widget.value = dot.pipe(format='svg').decode('utf-8')
delete_source.value = ''
delete_target.value = ''
def delete_node(sender):
node = delete_node_value.value
if node in g.nodes:
g.remove_node(node)
dot.body[:] = [line for line in dot.body if str(node) not in line]
graph_widget.value = dot.pipe(format='svg').decode('utf-8')
delete_node_value.value = ''
source_node = widgets.Text(placeholder='Source Node')
target_node = widgets.Text(placeholder='Target Node')
add_edge_button = widgets.Button(description='Add Edge')
add_edge_button.on_click(add_edge)
new_node = widgets.Text(placeholder='New Node')
add_node_button = widgets.Button(description='Add Node')
add_node_button.on_click(add_node)
delete_source = widgets.Text(placeholder='Source Node')
delete_target = widgets.Text(placeholder='Target Node')
delete_edge_button = widgets.Button(description='Delete Edge')
delete_edge_button.on_click(delete_edge)
delete_node_value = widgets.Text(placeholder='Node')
delete_node_button = widgets.Button(description='Delete Node')
delete_node_button.on_click(delete_node)
controls = widgets.HBox([source_node, target_node, add_edge_button])
controls2 = widgets.HBox([new_node, add_node_button])
controls3 = widgets.HBox([delete_source, delete_target, delete_edge_button])
controls4 = widgets.HBox([delete_node_value, delete_node_button])
display(controls)
display(controls2)
display(controls3)
display(controls4)