Last active
September 11, 2017 06:20
-
-
Save eldar/0ecc058670be340b92e5a1044dc8a089 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import datetime as dt | |
import tensorflow as tf | |
import tensorflow.contrib.slim as slim | |
from tensorflow.contrib.slim.nets import resnet_v1 | |
import threading | |
from PoseDataset import PoseDataset | |
from TrainParams import TrainParams | |
dataset = MyDataset() | |
train_param = TrainParams() | |
QUEUE_SIZE = 50 | |
num_classes = 14 | |
inputs = tf.placeholder(tf.float32, shape=[1, None, None, 3]) | |
data_labels = tf.placeholder(tf.float32, shape=[1, None, None, num_classes]) | |
q = tf.FIFOQueue(QUEUE_SIZE, [tf.float32, tf.float32]) | |
enqueue_op = q.enqueue([inputs, data_labels]) | |
inputs_batch, targets_batch = q.dequeue() | |
inputs_batch.set_shape([1, None, None, 3]) | |
targets_batch.set_shape([1, None, None, num_classes]) | |
def load_and_enqueue(sess, enqueue_op, coord, dataset): | |
while not coord.should_stop(): | |
batch = dataset.next_batch() | |
sess.run(enqueue_op, feed_dict={inputs: batch['inputs'], | |
data_labels: batch['data_labels']}) | |
with slim.arg_scope(resnet_v1.resnet_arg_scope(False)): | |
mean = tf.constant([123.68, 116.779, 103.939], | |
dtype=tf.float32, shape=[1, 1, 1, 3], name='img_mean') | |
im_centered = inputs_batch - mean | |
net, end_points = resnet_v1.resnet_v1_101(im_centered, | |
global_pool=False, output_stride=16) | |
pred_upconv = slim.conv2d_transpose(net, num_classes, | |
kernel_size = [3, 3], | |
stride = 2, | |
padding='SAME') | |
loss = slim.losses.sigmoid_cross_entropy(pred_upconv, targets_batch) | |
model_path = 'resnet_v1_101.ckpt' | |
sess = tf.Session() | |
sess.run(tf.initialize_all_variables()) | |
sess.run(tf.initialize_local_variables()) | |
# Restore variables from disk. | |
variables_to_restore = slim.get_variables_to_restore(include=["resnet_v1"]) | |
restorer = tf.train.Saver(variables_to_restore) | |
restorer.restore(sess, model_path) | |
coord = tf.train.Coordinator() | |
t = threading.Thread(target=load_and_enqueue, args=(sess,enqueue_op,coord,dataset)) | |
t.start() | |
optimizer = tf.train.GradientDescentOptimizer(learning_rate=.001) | |
train_op = optimizer.minimize(loss) | |
for it in range(10000): | |
sess.run(train_op) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
relevant stackoverflow question: https://stackoverflow.com/questions/39774449/fully-convolutional-resnets-using-tf-slim-run-very-slow/40126349#40126349