"""
seriously modified version of yahoo/graphkit
"""
# ---------- base --------------------------------------------------------------
[docs]
class Data(object):
"""
This wraps any data that is consumed or produced
by a Operation. This data should also know how to serialize
itself appropriately.
This class an "abstract" class that should be extended by
any class working with data in the HiC framework.
"""
def __init__(self, **kwargs):
pass
def get_data(self):
raise NotImplementedError
def set_data(self, data):
raise NotImplementedError
from dataclasses import dataclass, field
[docs]
@dataclass
class Operation:
"""
This is an abstract class representing a data transformation. To use this,
please inherit from this class and customize the ``.compute`` method to your
specific application.
Names may be given to this layer and its inputs and outputs. This is
important when connecting layers and data in a Network object, as the
names are used to construct the graph.
:param str name: The name the operation (e.g. conv1, conv2, etc..)
:param list needs: Names of input data objects this layer requires.
:param list provides: Names of output data objects this provides.
:param dict params: A dict of key/value pairs representing parameters
associated with your operation. These values will be
accessible using the ``.params`` attribute of your object.
NOTE: It's important that any values stored in this
argument must be pickelable.
"""
name: str = field(default='None')
needs: list = field(default=None)
provides: list = field(default=None)
params: dict = field(default_factory=dict)
def __post_init__(self):
"""
This method is a hook for you to override. It gets called after this
object has been initialized with its ``needs``, ``provides``, ``name``,
and ``params`` attributes. People often override this method to implement
custom loading logic required for objects that do not pickle easily, and
for initialization of c++ dependencies.
"""
pass
def __eq__(self, other):
"""
Operation equality is based on name of layer.
(__eq__ and __hash__ must be overridden together)
"""
return bool(self.name is not None and self.name == getattr(other, 'name', None))
def __hash__(self):
"""
Operation equality is based on name of layer.
(__eq__ and __hash__ must be overridden together)
"""
return hash(self.name)
[docs]
def compute(self, inputs):
"""
This method must be implemented to perform this layer's feed-forward
computation on a given set of inputs.
:param list inputs:
A list of :class:`Data` objects on which to run the layer's
feed-forward computation.
:returns list:
Should return a list of :class:`Data` objects representing
the results of running the feed-forward computation on
``inputs``.
"""
raise NotImplementedError
def _compute(self, named_inputs, outputs=None):
inputs = [named_inputs[d] for d in self.needs]
results = self.compute(inputs)
results = zip(self.provides, results)
if outputs:
outputs = set(outputs)
results = filter(lambda x: x[0] in outputs, results)
return dict(results)
def __getstate__(self):
"""
This allows your operation to be pickled.
Everything needed to instantiate your operation should be defined by the
following attributes: params, needs, provides, and name
No other piece of state should leak outside of these 4 variables
"""
result = {}
# this check should get deprecated soon. its for downward compatibility
# with earlier pickled operation objects
if hasattr(self, 'params'):
result['params'] = self.__dict__['params']
result['needs'] = self.__dict__['needs']
result['provides'] = self.__dict__['provides']
result['name'] = self.__dict__['name']
return result
def __setstate__(self, state):
"""
load from pickle and instantiate the detector
"""
for k in iter(state):
self.__setattr__(k, state[k])
self.__postinit__()
def __repr__(self):
"""
Display more informative names for the Operation class
"""
return "%s(name='%s', needs=%s, provides=%s)" % (
self.__class__.__name__,
self.name,
self.needs,
self.provides,
)
[docs]
class NetworkOperation(Operation):
def __init__(self, **kwargs):
self.net = kwargs.pop('net')
Operation.__init__(self, **kwargs)
# set execution mode to single-threaded sequential by default
self._execution_method = 'sequential'
def _compute(self, named_inputs, outputs=None):
return self.net.compute(outputs, named_inputs, method=self._execution_method)
def __call__(self, *args, **kwargs):
return self._compute(*args, **kwargs)
[docs]
def set_execution_method(self, method):
"""
Determine how the network will be executed.
Args:
method: str
If "parallel", execute graph operations concurrently
using a threadpool.
"""
options = ['parallel', 'sequential']
assert method in options
self._execution_method = method
def plot(self, filename=None, show=False):
self.net.plot(filename=filename, show=show)
def __getstate__(self):
state = Operation.__getstate__(self)
state['net'] = self.__dict__['net']
return state
# ------------ modifiers -------------------------------------------------------
"""
This sub-module contains input/output modifiers that can be applied to
arguments to ``needs`` and ``provides`` to let GraphKit know it should treat
them differently.
Copyright 2016, Yahoo Inc.
Licensed under the terms of the Apache License, Version 2.0. See the LICENSE
file associated with the project for terms.
"""
[docs]
class optional(str):
"""
Input values in ``needs`` may be designated as optional using this modifier.
If this modifier is applied to an input value, that value will be input to
the ``operation`` if it is available. The function underlying the
``operation`` should have a parameter with the same name as the input value
in ``needs``, and the input value will be passed as a keyword argument if
it is available.
Here is an example of an operation that uses an optional argument::
from graphkit import operation, compose
from graphkit.modifiers import optional
# Function that adds either two or three numbers.
def myadd(a, b, c=0):
return a + b + c
# Designate c as an optional argument.
graph = compose('mygraph')(
operator(name='myadd', needs=['a', 'b', optional('c')], provides='sum')(myadd)
)
# The graph works with and without 'c' provided as input.
assert graph({'a': 5, 'b': 2, 'c': 4})['sum'] == 11
assert graph({'a': 5, 'b': 2})['sum'] == 7
"""
pass
# ------------ network ------------------------------------------------------
# Copyright 2016, Yahoo Inc.
# Licensed under the terms of the Apache License, Version 2.0. See the LICENSE file associated with the project for terms.
from contextlib import suppress
with suppress(ModuleNotFoundError, ImportError):
import time
import os
import networkx as nx
from io import StringIO
# uses base.Operation
class DataPlaceholderNode(str):
"""
A node for the Network graph that describes the name of a Data instance
produced or required by a layer.
"""
def __repr__(self):
return 'DataPlaceholderNode("%s")' % self
class DeleteInstruction(str):
"""
An instruction for the compiled list of evaluation steps to free or delete
a Data instance from the Network's cache after it is no longer needed.
"""
def __repr__(self):
return 'DeleteInstruction("%s")' % self
class Network(object):
"""
This is the main network implementation. The class contains all of the
code necessary to weave together operations into a directed-acyclic-graph (DAG)
and pass data through.
"""
def __init__(self, **kwargs):
""" """
# directed graph of layer instances and data-names defining the net.
self.graph = nx.DiGraph()
self._debug = kwargs.get('debug', False)
# this holds the timing information for eache layer
self.times = {}
# a compiled list of steps to evaluate layers *in order* and free mem.
self.steps = []
# This holds a cache of results for the _find_necessary_steps
# function, this helps speed up the compute call as well avoid
# a multithreading issue that is occuring when accessing the
# graph in networkx
self._necessary_steps_cache = {}
def add_op(self, operation):
"""
Adds the given operation and its data requirements to the network graph
based on the name of the operation, the names of the operation's needs, and
the names of the data it provides.
:param Operation operation: Operation object to add.
"""
# assert layer and its data requirements are named.
assert operation.name, 'Operation must be named'
assert operation.needs is not None, "Operation's 'needs' must be named"
assert (
operation.provides is not None
), "Operation's 'provides' must be named"
# assert layer is only added once to graph
assert (
operation not in self.graph.nodes()
), 'Operation may only be added once'
# add nodes and edges to graph describing the data needs for this layer
for n in operation.needs:
self.graph.add_edge(DataPlaceholderNode(n), operation)
# add nodes and edges to graph describing what this layer provides
for p in operation.provides:
self.graph.add_edge(operation, DataPlaceholderNode(p))
# clear compiled steps (must recompile after adding new layers)
self.steps = []
def list_layers(self):
assert self.steps, 'network must be compiled before listing layers.'
return [(s.name, s) for s in self.steps if isinstance(s, Operation)]
def show_layers(self):
"""Shows info (name, needs, and provides) about all layers in this network."""
for name, step in self.list_layers():
print('layer_name: ', name)
print('\t', 'needs: ', step.needs)
print('\t', 'provides: ', step.provides)
print('')
def compile(self):
"""Create a set of steps for evaluating layers
and freeing memory as necessary"""
# clear compiled steps
self.steps = []
# create an execution order such that each layer's needs are provided.
ordered_nodes = list(nx.dag.topological_sort(self.graph))
# add Operations evaluation steps, and instructions to free data.
for i, node in enumerate(ordered_nodes):
if isinstance(node, DataPlaceholderNode):
continue
elif isinstance(node, Operation):
# add layer to list of steps
self.steps.append(node)
# Add instructions to delete predecessors as possible. A
# predecessor may be deleted if it is a data placeholder that
# is no longer needed by future Operations.
for predecessor in self.graph.predecessors(node):
if self._debug:
print('checking if node %s can be deleted' % predecessor)
predecessor_still_needed = False
for future_node in ordered_nodes[i + 1 :]:
if isinstance(future_node, Operation):
if predecessor in future_node.needs:
predecessor_still_needed = True
break
if not predecessor_still_needed:
if self._debug:
print(
' adding delete instruction for %s' % predecessor
)
self.steps.append(DeleteInstruction(predecessor))
else:
raise TypeError('Unrecognized network graph node')
def _find_necessary_steps(self, outputs, inputs):
"""
Determines what graph steps need to pe run to get to the requested
outputs from the provided inputs. Eliminates steps that come before
(in topological order) any inputs that have been provided. Also
eliminates steps that are not on a path from he provided inputs to
the requested outputs.
:param list outputs:
A list of desired output names. This can also be ``None``, in which
case the necessary steps are all graph nodes that are reachable
from one of the provided inputs.
:param dict inputs:
A dictionary mapping names to values for all provided inputs.
:returns:
Returns a list of all the steps that need to be run for the
provided inputs and requested outputs.
"""
# return steps if it has already been computed before for this set of inputs and outputs
outputs = (
tuple(sorted(outputs)) if isinstance(outputs, (list, set)) else outputs
)
inputs_keys = tuple(sorted(inputs.keys()))
cache_key = (inputs_keys, outputs)
if cache_key in self._necessary_steps_cache:
return self._necessary_steps_cache[cache_key]
graph = self.graph
if not outputs:
# If caller requested all outputs, the necessary nodes are all
# nodes that are reachable from one of the inputs. Ignore input
# names that aren't in the graph.
necessary_nodes = set()
for input_name in iter(inputs):
if graph.has_node(input_name):
necessary_nodes |= nx.descendants(graph, input_name)
else:
# If the caller requested a subset of outputs, find any nodes that
# are made unecessary because we were provided with an input that's
# deeper into the network graph. Ignore input names that aren't
# in the graph.
unnecessary_nodes = set()
for input_name in iter(inputs):
if graph.has_node(input_name):
unnecessary_nodes |= nx.ancestors(graph, input_name)
# Find the nodes we need to be able to compute the requested
# outputs. Raise an exception if a requested output doesn't
# exist in the graph.
necessary_nodes = set()
for output_name in outputs:
if not graph.has_node(output_name):
raise ValueError(
'graphkit graph does not have an output '
'node named %s' % output_name
)
necessary_nodes |= nx.ancestors(graph, output_name)
# Get rid of the unnecessary nodes from the set of necessary ones.
necessary_nodes -= unnecessary_nodes
necessary_steps = [step for step in self.steps if step in necessary_nodes]
# save this result in a precomputed cache for future lookup
self._necessary_steps_cache[cache_key] = necessary_steps
# Return an ordered list of the needed steps.
return necessary_steps
def compute(self, outputs, named_inputs, method=None):
"""
Run the graph. Any inputs to the network must be passed in by name.
:param list output: The names of the data node you'd like to have returned
once all necessary computations are complete.
If you set this variable to ``None``, all
data nodes will be kept and returned at runtime.
:param dict named_inputs: A dict of key/value pairs where the keys
represent the data nodes you want to populate,
and the values are the concrete values you
want to set for the data node.
:returns: a dictionary of output data objects, keyed by name.
"""
# assert that network has been compiled
assert self.steps, 'network must be compiled before calling compute.'
assert (
isinstance(outputs, (list, tuple)) or outputs is None
), 'The outputs argument must be a list'
# choose a method of execution
if method == 'parallel':
return self._compute_thread_pool_barrier_method(named_inputs, outputs)
else:
return self._compute_sequential_method(named_inputs, outputs)
def _compute_thread_pool_barrier_method(
self, named_inputs, outputs, thread_pool_size=10
):
"""
This method runs the graph using a parallel pool of thread executors.
You may achieve lower total latency if your graph is sufficiently
sub divided into operations using this method.
"""
from multiprocessing.dummy import Pool
# if we have not already created a thread_pool, create one
if not hasattr(self, '_thread_pool'):
self._thread_pool = Pool(thread_pool_size)
pool = self._thread_pool
cache = {}
cache.update(named_inputs)
necessary_nodes = self._find_necessary_steps(outputs, named_inputs)
# this keeps track of all nodes that have already executed
has_executed = set()
# with each loop iteration, we determine a set of operations that can be
# scheduled, then schedule them onto a thread pool, then collect their
# results onto a memory cache for use upon the next iteration.
while True:
# the upnext list contains a list of operations for scheduling
# in the current round of scheduling
upnext = []
for node in necessary_nodes:
# only delete if all successors for the data node have been executed
if isinstance(node, DeleteInstruction):
if ready_to_delete_data_node(node, has_executed, self.graph):
if node in cache:
cache.pop(node)
# continue if this node is anything but an operation node
if not isinstance(node, Operation):
continue
if (
ready_to_schedule_operation(node, has_executed, self.graph)
and node not in has_executed
):
upnext.append(node)
# stop if no nodes left to schedule, exit out of the loop
if len(upnext) == 0:
break
done_iterator = pool.imap_unordered(
lambda op: (op, op._compute(cache)), upnext
)
for op, result in done_iterator:
cache.update(result)
has_executed.add(op)
if not outputs:
return cache
else:
return {k: cache[k] for k in iter(cache) if k in outputs}
def _compute_sequential_method(self, named_inputs, outputs):
"""
This method runs the graph one operation at a time in a single thread
"""
# start with fresh data cache
cache = {}
# add inputs to data cache
cache.update(named_inputs)
# Find the subset of steps we need to run to get to the requested
# outputs from the provided inputs.
all_steps = self._find_necessary_steps(outputs, named_inputs)
self.times = {}
for step in all_steps:
if isinstance(step, Operation):
if self._debug:
print('-' * 32)
print('executing step: %s' % step.name)
# time execution...
t0 = time.time()
# compute layer outputs
layer_outputs = step._compute(cache)
# add outputs to cache
cache.update(layer_outputs)
# record execution time
t_complete = round(time.time() - t0, 5)
self.times[step.name] = t_complete
if self._debug:
print('step completion time: %s' % t_complete)
# Process DeleteInstructions by deleting the corresponding data
# if possible.
elif isinstance(step, DeleteInstruction):
if outputs and step not in outputs:
# Some DeleteInstruction steps may not exist in the cache
# if they come from optional() needs that are not privoded
# as inputs. Make sure the step exists before deleting.
if step in cache:
if self._debug:
print("removing data '%s' from cache." % step)
cache.pop(step)
else:
raise TypeError('Unrecognized instruction.')
if not outputs:
# Return the whole cache as output, including input and
# intermediate data nodes.
return cache
else:
# Filter outputs to just return what's needed.
# Note: list comprehensions exist in python 2.7+
return {k: cache[k] for k in iter(cache) if k in outputs}
def plot(self, filename=None, show=False):
"""
Plot the graph.
params:
:param str filename:
Write the output to a png, pdf, or graphviz dot file. The extension
controls the output format.
:param boolean show:
If this is set to True, use matplotlib to show the graph diagram
(Default: False)
:returns:
An instance of the pydot graph
"""
from contextlib import suppress
with suppress(ModuleNotFoundError, ImportError):
import pydot
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
assert self.graph is not None
def get_node_name(a):
if isinstance(a, DataPlaceholderNode):
return a
return a.name
g = pydot.Dot(graph_type='digraph')
# draw nodes
for nx_node in self.graph.nodes():
if isinstance(nx_node, DataPlaceholderNode):
node = pydot.Node(name=nx_node, shape='rect')
else:
node = pydot.Node(name=nx_node.name, shape='circle')
g.add_node(node)
# draw edges
for src, dst in self.graph.edges():
src_name = get_node_name(src)
dst_name = get_node_name(dst)
edge = pydot.Edge(src=src_name, dst=dst_name)
g.add_edge(edge)
# save plot
if filename:
basename, ext = os.path.splitext(filename)
with open(filename, 'w') as fh:
if ext.lower() == '.png':
fh.write(g.create_png())
elif ext.lower() == '.dot':
fh.write(g.to_string())
elif ext.lower() in ['.jpg', '.jpeg']:
fh.write(g.create_jpeg())
elif ext.lower() == '.pdf':
fh.write(g.create_pdf())
elif ext.lower() == '.svg':
fh.write(g.create_svg())
else:
raise Exception(
'Unknown file format for saving graph: %s' % ext
)
# display graph via matplotlib
if show:
png = g.create_png()
sio = StringIO(png)
img = mpimg.imread(sio)
plt.imshow(img, aspect='equal')
plt.show()
return g
def ready_to_schedule_operation(op, has_executed, graph):
"""
Determines if a Operation is ready to be scheduled for execution based on
what has already been executed.
Args:
op:
The Operation object to check
has_executed: set
A set containing all operations that have been executed so far
graph:
The networkx graph containing the operations and data nodes
Returns:
A boolean indicating whether the operation may be scheduled for
execution based on what has already been executed.
"""
dependencies = set(
filter(lambda v: isinstance(v, Operation), nx.ancestors(graph, op))
)
return dependencies.issubset(has_executed)
def ready_to_delete_data_node(name, has_executed, graph):
"""
Determines if a DataPlaceholderNode is ready to be deleted from the
cache.
Args:
name:
The name of the data node to check
has_executed: set
A set containing all operations that have been executed so far
graph:
The networkx graph containing the operations and data nodes
Returns:
A boolean indicating whether the data node can be deleted or not.
"""
data_node = get_data_node(name, graph)
return set(graph.successors(data_node)).issubset(has_executed)
def get_data_node(name, graph):
"""
Gets a data node from a graph using its name
"""
for node in graph.nodes():
if node == name and isinstance(node, DataPlaceholderNode):
return node
return None
# ------------ functional ------------------------------------------------------
# Copyright 2016, Yahoo Inc.
# Licensed under the terms of the Apache License, Version 2.0. See the LICENSE file associated with the project for terms.
from itertools import chain
# uses Operation, NetworkOperation from base
# uses Network from network
class FunctionalOperation(Operation):
def __init__(self, **kwargs):
self.fn = kwargs.pop('fn')
Operation.__init__(self, **kwargs)
def _compute(self, named_inputs, outputs=None):
inputs = [
named_inputs[d] for d in self.needs if not isinstance(d, optional)
]
# Find any optional inputs in named_inputs. Get only the ones that
# are present there, no extra `None`s.
optionals = {
n: named_inputs[n]
for n in self.needs
if isinstance(n, optional) and n in named_inputs
}
# Combine params and optionals into one big glob of keyword arguments.
kwargs = {k: v for d in (self.params, optionals) for k, v in d.items()}
result = self.fn(*inputs, **kwargs)
if len(self.provides) == 1:
result = [result]
result = zip(self.provides, result)
if outputs:
outputs = set(outputs)
result = filter(lambda x: x[0] in outputs, result)
return dict(result)
def __call__(self, *args, **kwargs):
return self.fn(*args, **kwargs)
def __getstate__(self):
state = Operation.__getstate__(self)
state['fn'] = self.__dict__['fn']
return state
class operation(Operation):
"""
This object represents an operation in a computation graph. Its
relationship to other operations in the graph is specified via its
``needs`` and ``provides`` arguments.
:param function fn:
The function used by this operation. This does not need to be
specified when the operation object is instantiated and can instead
be set via ``__call__`` later.
:param str name:
The name of the operation in the computation graph.
:param list needs:
Names of input data objects this operation requires. These should
correspond to the ``args`` of ``fn``.
:param list provides:
Names of output data objects this operation provides.
:param dict params:
A dict of key/value pairs representing constant parameters
associated with your operation. These can correspond to either
``args`` or ``kwargs`` of ``fn`.
"""
def __init__(self, fn=None, **kwargs):
self.fn = fn
Operation.__init__(self, **kwargs)
def _normalize_kwargs(self, kwargs):
# Allow single value for needs parameter
if 'needs' in kwargs and type(kwargs['needs']) == str:
assert kwargs['needs'], 'empty string provided for `needs` parameters'
kwargs['needs'] = [kwargs['needs']]
# Allow single value for provides parameter
if 'provides' in kwargs and type(kwargs['provides']) == str:
assert kwargs[
'provides'
], 'empty string provided for `needs` parameters'
kwargs['provides'] = [kwargs['provides']]
assert kwargs['name'], 'operation needs a name'
assert type(kwargs['needs']) == list, 'no `needs` parameter provided'
assert type(kwargs['provides']) == list, 'no `provides` parameter provided'
assert hasattr(
kwargs['fn'], '__call__'
), 'operation was not provided with a callable'
if type(kwargs['params']) is not dict:
kwargs['params'] = {}
return kwargs
def __call__(self, fn=None, **kwargs):
"""
This enables ``operation`` to act as a decorator or as a functional
operation, for example::
@operator(name='myadd1', needs=['a', 'b'], provides=['c'])
def myadd(a, b):
return a + b
or::
def myadd(a, b):
return a + b
operator(name='myadd1', needs=['a', 'b'], provides=['c'])(myadd)
:param function fn:
The function to be used by this ``operation``.
:return:
Returns an operation class that can be called as a function or
composed into a computation graph.
"""
if fn is not None:
self.fn = fn
total_kwargs = {}
total_kwargs.update(vars(self))
total_kwargs.update(kwargs)
total_kwargs = self._normalize_kwargs(total_kwargs)
return FunctionalOperation(**total_kwargs)
def __repr__(self):
"""
Display more informative names for the Operation class
"""
return "%s(name='%s', needs=%s, provides=%s, fn=%s)" % (
self.__class__.__name__,
self.name,
self.needs,
self.provides,
self.fn.__name__,
)
class compose(object):
"""
This is a simple class that's used to compose ``operation`` instances into
a computation graph.
:param str name:
A name for the graph being composed by this object.
:param bool merge:
If ``True``, this compose object will attempt to merge together
``operation`` instances that represent entire computation graphs.
Specifically, if one of the ``operation`` instances passed to this
``compose`` object is itself a graph operation created by an
earlier use of ``compose`` the sub-operations in that graph are
compared against other operations passed to this ``compose``
instance (as well as the sub-operations of other graphs passed to
this ``compose`` instance). If any two operations are the same
(based on name), then that operation is computed only once, instead
of multiple times (one for each time the operation appears).
"""
def __init__(self, name=None, merge=False):
assert name, 'compose needs a name'
self.name = name
self.merge = merge
def __call__(self, *operations):
"""
Composes a collection of operations into a single computation graph,
obeying the ``merge`` property, if set in the constructor.
:param operations:
Each argument should be an operation instance created using
``operation``.
:return:
Returns a special type of operation class, which represents an
entire computation graph as a single operation.
"""
assert len(operations), 'no operations provided to compose'
# If merge is desired, deduplicate operations before building network
if self.merge:
merge_set = set()
for op in operations:
if isinstance(op, NetworkOperation):
net_ops = filter(
lambda x: isinstance(x, Operation), op.net.steps
)
merge_set.update(net_ops)
else:
merge_set.add(op)
operations = list(merge_set)
def order_preserving_uniquifier(seq, seen=None):
seen = seen if seen else set()
seen_add = seen.add
return [x for x in seq if not (x in seen or seen_add(x))]
provides = order_preserving_uniquifier(
chain(*[op.provides for op in operations])
)
needs = order_preserving_uniquifier(
chain(*[op.needs for op in operations]), set(provides)
)
# compile network
net = Network()
for op in operations:
net.add_op(op)
net.compile()
return NetworkOperation(
name=self.name, needs=needs, provides=provides, params={}, net=net
)