Last active
November 24, 2018 06:55
-
-
Save marekgalovic/922e13a488bdef7db7afa69f4930d00e to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
class BiMPM(object): | |
''' | |
Bilateral Multi-Perspective Matching Model | |
https://arxiv.org/pdf/1702.03814.pdf | |
''' | |
def __init__(self, state_size, l_perspectives, embeddings_shape): | |
self._state_size = state_size | |
self._l_perspectives = l_perspectives | |
self._embeddings_shape = embeddings_shape | |
self._build_graph() | |
def _build_graph(self): | |
self.graph = tf.Graph() | |
with self.graph.as_default(): | |
self._init_embeddings() | |
self._init_placeholders() | |
self._embeddings_lookup_layer() | |
self._context_layer() | |
self._matching_layer() | |
self._aggregation_layer() | |
self._prediction_layer() | |
self._init_optimizer() | |
# Metrics | |
self.metrics = tf.summary.merge_all() | |
self.saver = tf.train.Saver(max_to_keep=None) | |
def _init_embeddings(self): | |
self._embeddings = tf.Variable(tf.zeros(self._embeddings_shape), name='word_embeddings', trainable=False) | |
self.embeddings_placeholder = tf.placeholder(tf.float32, self._embeddings_shape) | |
self.embeddings_init_op = self._embeddings.assign(self.embeddings_placeholder) | |
def _init_placeholders(self): | |
self.seq1 = tf.placeholder(tf.int32, [None, None]) | |
self.seq2 = tf.placeholder(tf.int32, [None, None]) | |
self.seq1_len = tf.placeholder(tf.int32, [None]) | |
self.seq2_len = tf.placeholder(tf.int32, [None]) | |
self.targets = tf.placeholder(tf.int64, [None]) | |
self.is_training = tf.placeholder(tf.bool) | |
self.dropout = tf.placeholder(tf.float32) | |
def _embeddings_lookup_layer(self): | |
''' | |
Performs embedding lookup based on passed embedding indices. | |
''' | |
with tf.name_scope('embeddings_lookup_layer'): | |
self._seq1_embedded = tf.nn.embedding_lookup(self._embeddings, self.seq1) | |
self._seq2_embedded = tf.nn.embedding_lookup(self._embeddings, self.seq2) | |
def _context_layer(self): | |
''' | |
BiRNN with shared weights for seq1 & seq2. Every hidden state is | |
a embedding of the associated word + it's context. | |
''' | |
with tf.name_scope('context_layer'): | |
encoder_fw_cell = tf.contrib.rnn.DropoutWrapper( | |
tf.contrib.rnn.LSTMCell(self._state_size), | |
output_keep_prob = 1 - self.dropout, | |
state_keep_prob = 1 - self.dropout | |
) | |
encoder_bw_cell = tf.contrib.rnn.DropoutWrapper( | |
tf.contrib.rnn.LSTMCell(self._state_size), | |
output_keep_prob = 1 - self.dropout, | |
state_keep_prob = 1 - self.dropout | |
) | |
with tf.variable_scope('encoder_birnn'): | |
((self._seq1_encoder_outputs_fw, self._seq1_encoder_outputs_bw), | |
(self._seq1_encoder_state_fw, self._seq1_encoder_state_bw)) = tf.nn.bidirectional_dynamic_rnn( | |
encoder_fw_cell, | |
encoder_bw_cell, | |
self._seq1_embedded, | |
sequence_length=self.seq1_len, | |
dtype=tf.float32 | |
) | |
with tf.variable_scope('encoder_birnn', reuse=True): | |
((self._seq2_encoder_outputs_fw, self._seq2_encoder_outputs_bw), | |
(self._seq2_encoder_state_fw, self._seq2_encoder_state_bw)) = tf.nn.bidirectional_dynamic_rnn( | |
encoder_fw_cell, | |
encoder_bw_cell, | |
self._seq2_embedded, | |
sequence_length=self.seq2_len, | |
dtype=tf.float32 | |
) | |
def _matching_layer(self): | |
''' | |
Apply different matching strategies to compare hidden states | |
of both context encoders in both directions seq1->seq2, seq2->seq1 | |
''' | |
def w(name=None): | |
return tf.Variable( | |
tf.random_uniform([self._l_perspectives, self._state_size], minval=0, maxval=0.2), | |
name=name | |
) | |
with tf.name_scope('matching_layer'): | |
W = {'W1': w('W1'), 'W2': w('W2'), 'W3': w('W3'), 'W4': w('W4'), | |
'W5': w('W5'), 'W6': w('W6'), 'W7': w('W7'), 'W8': w('W8') | |
} | |
# seq1 -> seq2 | |
self._seq1_seq2_matched = tf.concat([ | |
self._full_match(self._seq1_encoder_outputs_fw, self._seq2_encoder_state_fw.h, W['W1']), | |
self._full_match(self._seq1_encoder_outputs_bw, self._seq2_encoder_state_bw.h, W['W2']), | |
self._maxpooling_match(self._seq1_encoder_outputs_fw, self._seq2_encoder_outputs_fw, W['W3']), | |
self._maxpooling_match(self._seq1_encoder_outputs_bw, self._seq2_encoder_outputs_bw, W['W4']), | |
self._attentive_match(self._seq1_encoder_outputs_fw, self._seq2_encoder_outputs_fw, W['W5']), | |
self._attentive_match(self._seq1_encoder_outputs_bw, self._seq2_encoder_outputs_bw, W['W6']), | |
self._max_attentive_match(self._seq1_encoder_outputs_fw, self._seq2_encoder_outputs_fw, W['W7']), | |
self._max_attentive_match(self._seq1_encoder_outputs_bw, self._seq2_encoder_outputs_bw, W['W8']) | |
], 2, name='seq1_seq2_matched') | |
# seq2 -> seq1 | |
self._seq2_seq1_matched = tf.concat([ | |
self._full_match(self._seq2_encoder_outputs_fw, self._seq1_encoder_state_fw.h, W['W1']), | |
self._full_match(self._seq2_encoder_outputs_bw, self._seq1_encoder_state_bw.h, W['W2']), | |
self._maxpooling_match(self._seq2_encoder_outputs_fw, self._seq1_encoder_outputs_fw, W['W3']), | |
self._maxpooling_match(self._seq2_encoder_outputs_bw, self._seq1_encoder_outputs_bw, W['W4']), | |
self._attentive_match(self._seq2_encoder_outputs_fw, self._seq1_encoder_outputs_fw, W['W5']), | |
self._attentive_match(self._seq2_encoder_outputs_bw, self._seq1_encoder_outputs_bw, W['W6']), | |
self._max_attentive_match(self._seq2_encoder_outputs_fw, self._seq1_encoder_outputs_fw, W['W7']), | |
self._max_attentive_match(self._seq2_encoder_outputs_bw, self._seq1_encoder_outputs_bw, W['W8']) | |
], 2, name='seq2_seq1_matched') | |
# Metrics | |
tf.summary.histogram('seq1_seq2_matched', self._seq1_seq2_matched) | |
tf.summary.histogram('seq2_seq1_matched', self._seq2_seq1_matched) | |
def _aggregation_layer(self): | |
''' | |
BiRNN with shared weights for both sentences. Used to | |
aggregate concatencated matching streategies output into | |
a fixed-size vector (final state). | |
''' | |
with tf.name_scope('aggregation_layer'): | |
agg_fw_cell = tf.contrib.rnn.DropoutWrapper( | |
tf.contrib.rnn.LSTMCell(self._state_size), | |
output_keep_prob = 1 - self.dropout, | |
state_keep_prob = 1 - self.dropout | |
) | |
agg_bw_cell = tf.contrib.rnn.DropoutWrapper( | |
tf.contrib.rnn.LSTMCell(self._state_size), | |
output_keep_prob = 1 - self.dropout, | |
state_keep_prob = 1 - self.dropout | |
) | |
with tf.variable_scope('agg_birnn'): | |
_, (seq1_agg_state_fw, seq1_agg_state_bw) = tf.nn.bidirectional_dynamic_rnn( | |
agg_fw_cell, | |
agg_bw_cell, | |
self._seq1_seq2_matched, | |
sequence_length=self.seq1_len, | |
dtype=tf.float32 | |
) | |
with tf.variable_scope('agg_birnn', reuse=True): | |
_, (seq2_agg_state_fw, seq2_agg_state_bw) = tf.nn.bidirectional_dynamic_rnn( | |
agg_fw_cell, | |
agg_bw_cell, | |
self._seq2_seq1_matched, | |
sequence_length=self.seq2_len, | |
dtype=tf.float32 | |
) | |
self._aggregated_state = tf.concat([ | |
seq1_agg_state_fw.h, | |
seq1_agg_state_bw.h, | |
seq2_agg_state_fw.h, | |
seq2_agg_state_bw.h | |
], 1, name='aggregated_state') | |
# Metrics | |
tf.summary.histogram('aggregated_state', self._aggregated_state) | |
def _prediction_layer(self): | |
''' | |
Dense NN classifier. | |
''' | |
with tf.name_scope('prediction_layer'): | |
L1 = tf.layers.dropout( | |
tf.layers.dense(self._aggregated_state, self._state_size, activation=tf.tanh), | |
rate = self.dropout, | |
training = self.is_training | |
) | |
self.y = tf.layers.dense(L1, 2, activation=tf.sigmoid, name='y') | |
# Metrics | |
tf.summary.histogram('y', self.y) | |
def _init_optimizer(self): | |
targets_onehot = tf.one_hot(self.targets, 2) | |
self.loss = tf.losses.log_loss(targets_onehot, self.y) | |
self.accuracy = tf.reduce_mean(tf.cast(tf.equal(self.targets, tf.argmax(self.y, 1)), tf.float32)) | |
self.train_op = tf.train.AdamOptimizer().minimize(self.loss) | |
# Metrics | |
tf.summary.scalar('loss', self.loss) | |
tf.summary.scalar('accuracy', self.accuracy) | |
# Matching functions & helpers | |
def _cosine(self, v1, v2): | |
v1_norm = tf.nn.l2_normalize(v1, -1) | |
v2_norm = tf.nn.l2_normalize(v2, -1) | |
return tf.reduce_sum(tf.multiply(tf.expand_dims(v1_norm, -2), tf.expand_dims(v2_norm, -3)), -1) | |
def _w_l2_normed(self, v, W): | |
return tf.nn.l2_normalize(tf.multiply(tf.expand_dims(v, -2), W), -1) | |
def _full_match(self, v1, v2, W): | |
with tf.name_scope('full_match'): | |
return tf.reduce_sum( | |
tf.multiply(self._w_l2_normed(v1, W), tf.expand_dims(self._w_l2_normed(v2, W), -3)), | |
-1 | |
) | |
def _maxpooling_match(self, v1, v2, W): | |
with tf.name_scope('maxpooling_match'): | |
matched = tf.reduce_sum(tf.multiply( | |
tf.expand_dims(self._w_l2_normed(v1, W), -3), | |
tf.expand_dims(self._w_l2_normed(v2, W), -4) | |
), -1) | |
return tf.reduce_max(matched, -2) | |
def _attentive_match(self, v1, v2, W): | |
with tf.name_scope('attentive_match'): | |
a = self._cosine(v1, v2) | |
h_mean = tf.matmul(a, v2) / (tf.reduce_sum(a, -1, keep_dims=True) + 1e-12) | |
matched = tf.matmul( | |
tf.expand_dims(self._w_l2_normed(v1, W), -2), | |
tf.expand_dims(self._w_l2_normed(h_mean, W), -1) | |
) | |
return tf.squeeze(matched, [3, 4]) | |
def _max_attentive_match(self, v1, v2, W): | |
with tf.name_scope('max_attentive_match'): | |
a = self._cosine(v1, v2) | |
sim_indices = tf.argmax(a, axis=-1) | |
_h = tf.reduce_sum( | |
tf.multiply( | |
tf.expand_dims(tf.one_hot(sim_indices, tf.shape(v2)[1]), -1), | |
tf.expand_dims(v2, -3) | |
), -2 | |
) | |
matched = tf.matmul( | |
tf.expand_dims(self._w_l2_normed(v1, W), -2), | |
tf.expand_dims(self._w_l2_normed(_h, W), -1) | |
) | |
return tf.squeeze(matched, [3, 4]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment