Last active
November 19, 2019 14:23
-
-
Save klieret/0f3b1b8a0b9b143702ae1aa662264029 to your computer and use it in GitHub Desktop.
Oftentimes, we call functions over a wide range of parameter values, e.g. to generate plots for different scenarios. For multiple parameters, this leads to ugly multi-loops and cumbersome boilerplate. With this snippet it becomes one line.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Callable, Optional, Dict, Any | |
import itertools | |
from tqdm.auto import tqdm | |
def product_call( | |
fct: Callable, | |
static: Optional[Dict[str, Any]] = None, | |
multi: Optional[Dict[str, Any]] = None, | |
progress=False | |
) -> Tuple[List[Dict[str, Any]], List[Any]]: | |
""" Execute the same function multiple times, where we change parameters | |
in each call. | |
Kilian Lieret 2019 | |
https://gist.github.com/klieret/0f3b1b8a0b9b143702ae1aa662264029 | |
Example: | |
multiplex(fct, multi=dict(a=[1, 2, 3]), static=dict(b=4)) | |
Will call | |
fct(a=1, b=4) | |
fct(a=2, b=4) | |
fct(a=3, b=4) | |
If a value for one of the static arguments is a string, we will try to | |
format it with the multi arguments, e.g. static=dict(b="{a}.pdf") | |
will result in | |
fct(a=1, b="1.pdf") | |
fct(a=2, b="2.pdf") | |
fct(a=3, b="3.pdf") | |
Args: | |
fct: The function to call | |
static: Dictionary of keyword arguments that are passed to the function | |
(but that do not chanage for each call) | |
multi: Dictionary of keyword arguments to lists of values, that are | |
passed to the function (see above) | |
progress: Use tqdm to draw a nice progress bar (default False) | |
Returns: | |
list of kwarg dictionaries and list of corresponding results. | |
""" | |
if static is None: | |
static = {} | |
if multi is None: | |
multi = {} | |
keys = list(multi.keys()) | |
values = [multi[key] for key in keys] # Ordered! | |
products = list(itertools.product(*values)) | |
if progress: | |
products = tqdm(progress) | |
results = [] | |
kwargs_list = [] | |
class FormatDefaultDict(dict): | |
def __missing__(self, key): | |
return '{' + key + '}' | |
for setting in products: | |
multi_kwargs = {keys[i]: setting[i] for i in range(len(keys))} | |
static_kwargs = {} | |
for key, value in static.items(): | |
if isinstance(value, str): | |
value = value.format_map(FormatDefaultDict(multi_kwargs)) | |
static_kwargs[key] = value | |
kwargs = {**multi_kwargs, **static_kwargs} | |
res = fct(**kwargs) | |
results.append(res) | |
kwargs_list.append(kwargs) | |
return kwargs_list, results | |
def test_product_call(): | |
def _test(a, b, c): | |
return a + b | |
kwargs, values = product_call( | |
_test, | |
static=dict(c="{a}{b}"), | |
multi=dict(a=[1, 2, 3], b=[4, 5, 6]) | |
) | |
assert kwargs == [ | |
{'a': 1, 'b': 4, 'c': '14'}, | |
{'a': 1, 'b': 5, 'c': '15'}, | |
{'a': 1, 'b': 6, 'c': '16'}, | |
{'a': 2, 'b': 4, 'c': '24'}, | |
{'a': 2, 'b': 5, 'c': '25'}, | |
{'a': 2, 'b': 6, 'c': '26'}, | |
{'a': 3, 'b': 4, 'c': '34'}, | |
{'a': 3, 'b': 5, 'c': '35'}, | |
{'a': 3, 'b': 6, 'c': '36'} | |
] | |
assert values == [5, 6, 7, 6, 7, 8, 7, 8, 9] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment