Created
December 3, 2016 03:35
-
-
Save llj098/0c27580364e27d53b1f386a248e36499 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
import tensorflow as tf | |
from tensorflow.contrib.framework import add_arg_scope | |
from tensorflow.contrib.layers.python.layers import utils | |
slim = tf.contrib.slim | |
def squeeze(inputs, num_outputs, fire_id): | |
return slim.conv2d(inputs, num_outputs, 1, stride=1, scope="fire/squeeze/"+str(fire_id)) | |
def expand(inputs, num_outputs, fire_id): | |
with tf.variable_scope('expand'): | |
e1x1 = slim.conv2d(inputs, num_outputs, 1, stride=1, scope='fire/ex1x1/'+str(fire_id)) | |
e3x3 = slim.conv2d(inputs, num_outputs, 3, scope='fire/ex3x3/'+str(fire_id)) | |
return tf.concat(3, [e1x1, e3x3]) | |
def fire_module(x, s=16, e=64, fire_id=0): | |
return expand(squeeze(x, s, fire_id), e, fire_id) | |
def inference(images, _a, phase_train=True, weight_decay=0.8, reuse=None): | |
with tf.variable_scope('SqueezeNet', 'SqueezeNet', [images], reuse=reuse): | |
x_image = tf.reshape(images, [-1, 160, 160, 3]) | |
net = slim.conv2d(x_image, 96, 7, scope="conv_1", stride=2) #[?,14,14,64] | |
net = slim.max_pool2d(net, 3, stride=2, scope="maxpool1") #[?,6,6,64] | |
net = fire_module(net, fire_id=2) | |
net = fire_module(net, fire_id=3) | |
net = fire_module(net, 32, 128, fire_id=4) | |
net = slim.max_pool2d(net, 3, scope="maxpool3") | |
net = fire_module(net, 32, 128, fire_id=5) | |
net = fire_module(net, 48, 192, fire_id=6) | |
net = fire_module(net, 48, 192, fire_id=7) | |
net = fire_module(net, 64, 256, fire_id=8) | |
net = slim.max_pool2d(net, 2, stride=2, scope="maxpool8") | |
net = fire_module(net, 64, 256, fire_id=9) | |
net = slim.dropout(net, 0.5) | |
net = slim.conv2d(net, 1792, 1, scope="conv10", padding="VALID") | |
net = slim.avg_pool2d(net, 9) | |
net = slim.flatten(net) | |
return net,{} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment