Skip to content

Instantly share code, notes, and snippets.

@riveSunder
Created December 10, 2019 20:37
Show Gist options
  • Save riveSunder/05c7728c052f2ba28b494ad352f7e939 to your computer and use it in GitHub Desktop.
Save riveSunder/05c7728c052f2ba28b494ad352f7e939 to your computer and use it in GitHub Desktop.
import matplotlib.pyplot as plt
"""
a modification of craffel's `draw_neural_net` https://gist.github.com/craffel/2d727968c3aaebd10359
changes:
updated for python3 (xrange --> range)
function now takes layers themselves as an input instead of the layer dimensions
function draws connections according to weights, negative weights are colored red (positive are black)
and the width of each connection is determined by the weight magnitude.
"""
def draw_neural_net(ax, left, right, bottom, top, layers):
'''
Draw a neural network cartoon using matplotilb.
:usage:
>>> fig = plt.figure(figsize=(12, 12))
>>> draw_neural_net(fig.gca(), .1, .9, .1, .9, [4, 7, 2])
:parameters:
- ax : matplotlib.axes.AxesSubplot
The axes on which to plot the cartoon (get e.g. by plt.gca())
- left : float
The center of the leftmost node(s) will be placed here
- right : float
The center of the rightmost node(s) will be placed here
- bottom : float
The center of the bottommost node(s) will be placed here
- top : float
The center of the topmost node(s) will be placed here
- layer_sizes : list of int
List of layer sizes, including input and output dimensionality
'''
ax.axis('off')
n_layers = len(layers)
layer_sizes = [layer.shape[0] for layer in layers]
layer_sizes.append(layers[-1].shape[1])
v_spacing = (top - bottom)/float(max(layer_sizes))
h_spacing = (right - left)/float(len(layer_sizes) - 1)
# Nodes
for n, layer_size in enumerate(layer_sizes):
layer_top = v_spacing*(layer_size - 1)/2. + (top + bottom)/2.
for m in range(layer_size):
circle = plt.Circle((n*h_spacing + left, layer_top - m*v_spacing), v_spacing/4.,
color='w', ec='k', zorder=40)
ax.add_artist(circle)
# Edges
for n, (layer_size_a, layer_size_b) in enumerate(zip(layer_sizes[:-1], layer_sizes[1:])):
layer_top_a = v_spacing*(layer_size_a - 1)/2. + (top + bottom)/2.
layer_top_b = v_spacing*(layer_size_b - 1)/2. + (top + bottom)/2.
for m in range(layer_size_a):
for o in range(layer_size_b):
weight_value = layers[n][m,o]
if weight_value:
wv_sign = np.sign(weight_value)
my_color = 'r' if weight_value < 0.0 else 'b'
my_color = 'k' if weight_value == 1.0 else my_color
line = plt.Line2D([n*h_spacing + left, (n + 1)*h_spacing + left],
[layer_top_a - m*v_spacing, layer_top_b - o*v_spacing], \
lw = 3*np.abs(weight_value), color=my_color, alpha = 0.25)
ax.add_artist(line)
@riveSunder
Copy link
Author

Example Output
cartpole_champion_lineup

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment