Last active
November 11, 2020 18:58
-
-
Save lwiklendt/9c7099288f85b59edc903a5aed2d2d64 to your computer and use it in GitHub Desktop.
Extract parameter samples from PyStan's vb method, so that it resembles extract() from the sampling method
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from collections import OrderedDict | |
def pystan_vb_extract(results): | |
param_specs = results['sampler_param_names'] | |
samples = results['sampler_params'] | |
n = len(samples[0]) | |
# first pass, calculate the shape | |
param_shapes = OrderedDict() | |
for param_spec in param_specs: | |
splt = param_spec.split('[') | |
name = splt[0] | |
if len(splt) > 1: | |
idxs = [int(i) for i in splt[1][:-1].split(',')] # no +1 for shape calculation because pystan already returns 1-based indexes for vb! | |
else: | |
idxs = () | |
param_shapes[name] = np.maximum(idxs, param_shapes.get(name, idxs)) | |
# create arrays | |
params = OrderedDict([(name, np.nan * np.empty((n, ) + tuple(shape))) for name, shape in param_shapes.items()]) | |
# second pass, set arrays | |
for param_spec, param_samples in zip(param_specs, samples): | |
splt = param_spec.split('[') | |
name = splt[0] | |
if len(splt) > 1: | |
idxs = [int(i) - 1 for i in splt[1][:-1].split(',')] # -1 because pystan returns 1-based indexes for vb! | |
else: | |
idxs = () | |
params[name][(..., ) + tuple(idxs)] = param_samples | |
return params |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment