Created
December 10, 2019 20:37
-
-
Save riveSunder/05c7728c052f2ba28b494ad352f7e939 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import matplotlib.pyplot as plt | |
""" | |
a modification of craffel's `draw_neural_net` https://gist.github.com/craffel/2d727968c3aaebd10359 | |
changes: | |
updated for python3 (xrange --> range) | |
function now takes layers themselves as an input instead of the layer dimensions | |
function draws connections according to weights, negative weights are colored red (positive are black) | |
and the width of each connection is determined by the weight magnitude. | |
""" | |
def draw_neural_net(ax, left, right, bottom, top, layers): | |
''' | |
Draw a neural network cartoon using matplotilb. | |
:usage: | |
>>> fig = plt.figure(figsize=(12, 12)) | |
>>> draw_neural_net(fig.gca(), .1, .9, .1, .9, [4, 7, 2]) | |
:parameters: | |
- ax : matplotlib.axes.AxesSubplot | |
The axes on which to plot the cartoon (get e.g. by plt.gca()) | |
- left : float | |
The center of the leftmost node(s) will be placed here | |
- right : float | |
The center of the rightmost node(s) will be placed here | |
- bottom : float | |
The center of the bottommost node(s) will be placed here | |
- top : float | |
The center of the topmost node(s) will be placed here | |
- layer_sizes : list of int | |
List of layer sizes, including input and output dimensionality | |
''' | |
ax.axis('off') | |
n_layers = len(layers) | |
layer_sizes = [layer.shape[0] for layer in layers] | |
layer_sizes.append(layers[-1].shape[1]) | |
v_spacing = (top - bottom)/float(max(layer_sizes)) | |
h_spacing = (right - left)/float(len(layer_sizes) - 1) | |
# Nodes | |
for n, layer_size in enumerate(layer_sizes): | |
layer_top = v_spacing*(layer_size - 1)/2. + (top + bottom)/2. | |
for m in range(layer_size): | |
circle = plt.Circle((n*h_spacing + left, layer_top - m*v_spacing), v_spacing/4., | |
color='w', ec='k', zorder=40) | |
ax.add_artist(circle) | |
# Edges | |
for n, (layer_size_a, layer_size_b) in enumerate(zip(layer_sizes[:-1], layer_sizes[1:])): | |
layer_top_a = v_spacing*(layer_size_a - 1)/2. + (top + bottom)/2. | |
layer_top_b = v_spacing*(layer_size_b - 1)/2. + (top + bottom)/2. | |
for m in range(layer_size_a): | |
for o in range(layer_size_b): | |
weight_value = layers[n][m,o] | |
if weight_value: | |
wv_sign = np.sign(weight_value) | |
my_color = 'r' if weight_value < 0.0 else 'b' | |
my_color = 'k' if weight_value == 1.0 else my_color | |
line = plt.Line2D([n*h_spacing + left, (n + 1)*h_spacing + left], | |
[layer_top_a - m*v_spacing, layer_top_b - o*v_spacing], \ | |
lw = 3*np.abs(weight_value), color=my_color, alpha = 0.25) | |
ax.add_artist(line) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Example Output
