Created
October 11, 2017 11:53
-
-
Save stefanthaler/336caf7cf0ead51aa2f1185d16b77c99 to your computer and use it in GitHub Desktop.
Simple method for ensuring that the gradient flows through all your variables in TensorFlow.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf # 1.3.0 | |
import numpy as np | |
# model + input | |
w = tf.Variable(1, name="w", dtype=tf.float32 ) # parameter to optimize for | |
x = tf.placeholder(shape=(), dtype=tf.float32, name="x") # input | |
# graph operations | |
diff_op = tf.multiply(w, x) # an operation that is differentiable | |
non_diff_op = tf.cast(tf.equal(diff_op, 0), dtype=tf.float32) # operation that is non differentiable | |
# check whether some gradient can flow to your trainable parameters | |
# rudimentary check for gradient flow | |
def ensure_gradient_flow(operations): | |
tf_params = tf.trainable_variables() | |
no_flows = 0 | |
for op in operations: | |
gradients = tf.gradients(op, tf_params) | |
at_least_one_gradient_flows = False | |
for c in zip(gradients, tf_params): | |
if not type(c[0]).__name__=="NoneType": | |
at_least_one_gradient_flows = True | |
break | |
if not at_least_one_gradient_flows: | |
print("No gradient flow for operation '%s'"%op.name) | |
no_flows += 1 | |
if no_flows==0: | |
print("Operations [%s] have at least 1 gradient for at least 1 parameter"%([o.name for o in operations])) | |
print(tf.global_variables()) | |
ensure_gradient_flow([diff_op]) | |
ensure_gradient_flow([non_diff_op]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment