Created
December 7, 2023 03:48
-
-
Save omc8db/2f17e7e96b779fae1c645598f165b6c1 to your computer and use it in GitHub Desktop.
pyStateGraph
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import inspect | |
from collections import defaultdict | |
class StateGraph: | |
"""Dictionary that lets you add keys calculated from other keys | |
When an entry is updated all dependent values update recursively | |
Values may be any type valid for use in a standard dictionary. | |
Keys must be valid names for function arguments. | |
add_calculation allows you to add a new channel that is calculated on the fly as its dependencies update | |
Example: | |
s = StateGraph() | |
s["foo"] = 1 | |
# Create an entry "baz" that is automatically calculated from foo and bar | |
s.add_calculation("baz", lambda foo, bar: 3 * foo + 2 * bar) | |
print(s["baz"]) # Prints "None" because not all inputs for baz are defined | |
# Calculated entries can depend on other calculated entries | |
s.add_calculation("qux", lambda baz: baz*2) | |
s["bar"] = 2 | |
# These values were recalculatd when their dependency bar changed | |
print(s["baz"]) # Prints 7 | |
print(s["qux"]) # Prints 14 | |
""" | |
def __init__(self): | |
self._measurements={} | |
# input -> list[output] | |
self._deps = defaultdict(list) | |
# key name -> (evaluator, inputs) | |
self._functions: dict[str, tuple(callable, list[str])] = {} | |
def __getitem__(self, key): | |
return self._measurements[key] | |
def __setitem__(self, key, value): | |
self._measurements[key] = value | |
for dep in self._deps[key]: | |
self._calculate(dep) | |
def __str__(self): | |
return self._measurements.__str__() | |
def add_calculation(self, key: str, function: callable): | |
deps = inspect.getfullargspec(function).args | |
for dep in deps: | |
self._deps[dep].append(key) | |
# Referenced inputs that don't exist yet start undefined | |
self._measurements[dep] = self._measurements.get(dep, None) | |
self._functions[key] = function, deps | |
self._calculate(key) | |
def _calculate(self, key: str): | |
f, input_names = self._functions[key] | |
inputs = [self._measurements[c] for c in input_names] | |
if any((x is None for x in inputs)): | |
self[key] = None | |
else: | |
self[key] = f(*inputs) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
License is public domain, copy-paste wherever you need it.