Skip to content

Instantly share code, notes, and snippets.

@marekgalovic
Last active November 24, 2018 06:55
Show Gist options
  • Save marekgalovic/922e13a488bdef7db7afa69f4930d00e to your computer and use it in GitHub Desktop.
Save marekgalovic/922e13a488bdef7db7afa69f4930d00e to your computer and use it in GitHub Desktop.
import tensorflow as tf
class BiMPM(object):
'''
Bilateral Multi-Perspective Matching Model
https://arxiv.org/pdf/1702.03814.pdf
'''
def __init__(self, state_size, l_perspectives, embeddings_shape):
self._state_size = state_size
self._l_perspectives = l_perspectives
self._embeddings_shape = embeddings_shape
self._build_graph()
def _build_graph(self):
self.graph = tf.Graph()
with self.graph.as_default():
self._init_embeddings()
self._init_placeholders()
self._embeddings_lookup_layer()
self._context_layer()
self._matching_layer()
self._aggregation_layer()
self._prediction_layer()
self._init_optimizer()
# Metrics
self.metrics = tf.summary.merge_all()
self.saver = tf.train.Saver(max_to_keep=None)
def _init_embeddings(self):
self._embeddings = tf.Variable(tf.zeros(self._embeddings_shape), name='word_embeddings', trainable=False)
self.embeddings_placeholder = tf.placeholder(tf.float32, self._embeddings_shape)
self.embeddings_init_op = self._embeddings.assign(self.embeddings_placeholder)
def _init_placeholders(self):
self.seq1 = tf.placeholder(tf.int32, [None, None])
self.seq2 = tf.placeholder(tf.int32, [None, None])
self.seq1_len = tf.placeholder(tf.int32, [None])
self.seq2_len = tf.placeholder(tf.int32, [None])
self.targets = tf.placeholder(tf.int64, [None])
self.is_training = tf.placeholder(tf.bool)
self.dropout = tf.placeholder(tf.float32)
def _embeddings_lookup_layer(self):
'''
Performs embedding lookup based on passed embedding indices.
'''
with tf.name_scope('embeddings_lookup_layer'):
self._seq1_embedded = tf.nn.embedding_lookup(self._embeddings, self.seq1)
self._seq2_embedded = tf.nn.embedding_lookup(self._embeddings, self.seq2)
def _context_layer(self):
'''
BiRNN with shared weights for seq1 & seq2. Every hidden state is
a embedding of the associated word + it's context.
'''
with tf.name_scope('context_layer'):
encoder_fw_cell = tf.contrib.rnn.DropoutWrapper(
tf.contrib.rnn.LSTMCell(self._state_size),
output_keep_prob = 1 - self.dropout,
state_keep_prob = 1 - self.dropout
)
encoder_bw_cell = tf.contrib.rnn.DropoutWrapper(
tf.contrib.rnn.LSTMCell(self._state_size),
output_keep_prob = 1 - self.dropout,
state_keep_prob = 1 - self.dropout
)
with tf.variable_scope('encoder_birnn'):
((self._seq1_encoder_outputs_fw, self._seq1_encoder_outputs_bw),
(self._seq1_encoder_state_fw, self._seq1_encoder_state_bw)) = tf.nn.bidirectional_dynamic_rnn(
encoder_fw_cell,
encoder_bw_cell,
self._seq1_embedded,
sequence_length=self.seq1_len,
dtype=tf.float32
)
with tf.variable_scope('encoder_birnn', reuse=True):
((self._seq2_encoder_outputs_fw, self._seq2_encoder_outputs_bw),
(self._seq2_encoder_state_fw, self._seq2_encoder_state_bw)) = tf.nn.bidirectional_dynamic_rnn(
encoder_fw_cell,
encoder_bw_cell,
self._seq2_embedded,
sequence_length=self.seq2_len,
dtype=tf.float32
)
def _matching_layer(self):
'''
Apply different matching strategies to compare hidden states
of both context encoders in both directions seq1->seq2, seq2->seq1
'''
def w(name=None):
return tf.Variable(
tf.random_uniform([self._l_perspectives, self._state_size], minval=0, maxval=0.2),
name=name
)
with tf.name_scope('matching_layer'):
W = {'W1': w('W1'), 'W2': w('W2'), 'W3': w('W3'), 'W4': w('W4'),
'W5': w('W5'), 'W6': w('W6'), 'W7': w('W7'), 'W8': w('W8')
}
# seq1 -> seq2
self._seq1_seq2_matched = tf.concat([
self._full_match(self._seq1_encoder_outputs_fw, self._seq2_encoder_state_fw.h, W['W1']),
self._full_match(self._seq1_encoder_outputs_bw, self._seq2_encoder_state_bw.h, W['W2']),
self._maxpooling_match(self._seq1_encoder_outputs_fw, self._seq2_encoder_outputs_fw, W['W3']),
self._maxpooling_match(self._seq1_encoder_outputs_bw, self._seq2_encoder_outputs_bw, W['W4']),
self._attentive_match(self._seq1_encoder_outputs_fw, self._seq2_encoder_outputs_fw, W['W5']),
self._attentive_match(self._seq1_encoder_outputs_bw, self._seq2_encoder_outputs_bw, W['W6']),
self._max_attentive_match(self._seq1_encoder_outputs_fw, self._seq2_encoder_outputs_fw, W['W7']),
self._max_attentive_match(self._seq1_encoder_outputs_bw, self._seq2_encoder_outputs_bw, W['W8'])
], 2, name='seq1_seq2_matched')
# seq2 -> seq1
self._seq2_seq1_matched = tf.concat([
self._full_match(self._seq2_encoder_outputs_fw, self._seq1_encoder_state_fw.h, W['W1']),
self._full_match(self._seq2_encoder_outputs_bw, self._seq1_encoder_state_bw.h, W['W2']),
self._maxpooling_match(self._seq2_encoder_outputs_fw, self._seq1_encoder_outputs_fw, W['W3']),
self._maxpooling_match(self._seq2_encoder_outputs_bw, self._seq1_encoder_outputs_bw, W['W4']),
self._attentive_match(self._seq2_encoder_outputs_fw, self._seq1_encoder_outputs_fw, W['W5']),
self._attentive_match(self._seq2_encoder_outputs_bw, self._seq1_encoder_outputs_bw, W['W6']),
self._max_attentive_match(self._seq2_encoder_outputs_fw, self._seq1_encoder_outputs_fw, W['W7']),
self._max_attentive_match(self._seq2_encoder_outputs_bw, self._seq1_encoder_outputs_bw, W['W8'])
], 2, name='seq2_seq1_matched')
# Metrics
tf.summary.histogram('seq1_seq2_matched', self._seq1_seq2_matched)
tf.summary.histogram('seq2_seq1_matched', self._seq2_seq1_matched)
def _aggregation_layer(self):
'''
BiRNN with shared weights for both sentences. Used to
aggregate concatencated matching streategies output into
a fixed-size vector (final state).
'''
with tf.name_scope('aggregation_layer'):
agg_fw_cell = tf.contrib.rnn.DropoutWrapper(
tf.contrib.rnn.LSTMCell(self._state_size),
output_keep_prob = 1 - self.dropout,
state_keep_prob = 1 - self.dropout
)
agg_bw_cell = tf.contrib.rnn.DropoutWrapper(
tf.contrib.rnn.LSTMCell(self._state_size),
output_keep_prob = 1 - self.dropout,
state_keep_prob = 1 - self.dropout
)
with tf.variable_scope('agg_birnn'):
_, (seq1_agg_state_fw, seq1_agg_state_bw) = tf.nn.bidirectional_dynamic_rnn(
agg_fw_cell,
agg_bw_cell,
self._seq1_seq2_matched,
sequence_length=self.seq1_len,
dtype=tf.float32
)
with tf.variable_scope('agg_birnn', reuse=True):
_, (seq2_agg_state_fw, seq2_agg_state_bw) = tf.nn.bidirectional_dynamic_rnn(
agg_fw_cell,
agg_bw_cell,
self._seq2_seq1_matched,
sequence_length=self.seq2_len,
dtype=tf.float32
)
self._aggregated_state = tf.concat([
seq1_agg_state_fw.h,
seq1_agg_state_bw.h,
seq2_agg_state_fw.h,
seq2_agg_state_bw.h
], 1, name='aggregated_state')
# Metrics
tf.summary.histogram('aggregated_state', self._aggregated_state)
def _prediction_layer(self):
'''
Dense NN classifier.
'''
with tf.name_scope('prediction_layer'):
L1 = tf.layers.dropout(
tf.layers.dense(self._aggregated_state, self._state_size, activation=tf.tanh),
rate = self.dropout,
training = self.is_training
)
self.y = tf.layers.dense(L1, 2, activation=tf.sigmoid, name='y')
# Metrics
tf.summary.histogram('y', self.y)
def _init_optimizer(self):
targets_onehot = tf.one_hot(self.targets, 2)
self.loss = tf.losses.log_loss(targets_onehot, self.y)
self.accuracy = tf.reduce_mean(tf.cast(tf.equal(self.targets, tf.argmax(self.y, 1)), tf.float32))
self.train_op = tf.train.AdamOptimizer().minimize(self.loss)
# Metrics
tf.summary.scalar('loss', self.loss)
tf.summary.scalar('accuracy', self.accuracy)
# Matching functions & helpers
def _cosine(self, v1, v2):
v1_norm = tf.nn.l2_normalize(v1, -1)
v2_norm = tf.nn.l2_normalize(v2, -1)
return tf.reduce_sum(tf.multiply(tf.expand_dims(v1_norm, -2), tf.expand_dims(v2_norm, -3)), -1)
def _w_l2_normed(self, v, W):
return tf.nn.l2_normalize(tf.multiply(tf.expand_dims(v, -2), W), -1)
def _full_match(self, v1, v2, W):
with tf.name_scope('full_match'):
return tf.reduce_sum(
tf.multiply(self._w_l2_normed(v1, W), tf.expand_dims(self._w_l2_normed(v2, W), -3)),
-1
)
def _maxpooling_match(self, v1, v2, W):
with tf.name_scope('maxpooling_match'):
matched = tf.reduce_sum(tf.multiply(
tf.expand_dims(self._w_l2_normed(v1, W), -3),
tf.expand_dims(self._w_l2_normed(v2, W), -4)
), -1)
return tf.reduce_max(matched, -2)
def _attentive_match(self, v1, v2, W):
with tf.name_scope('attentive_match'):
a = self._cosine(v1, v2)
h_mean = tf.matmul(a, v2) / (tf.reduce_sum(a, -1, keep_dims=True) + 1e-12)
matched = tf.matmul(
tf.expand_dims(self._w_l2_normed(v1, W), -2),
tf.expand_dims(self._w_l2_normed(h_mean, W), -1)
)
return tf.squeeze(matched, [3, 4])
def _max_attentive_match(self, v1, v2, W):
with tf.name_scope('max_attentive_match'):
a = self._cosine(v1, v2)
sim_indices = tf.argmax(a, axis=-1)
_h = tf.reduce_sum(
tf.multiply(
tf.expand_dims(tf.one_hot(sim_indices, tf.shape(v2)[1]), -1),
tf.expand_dims(v2, -3)
), -2
)
matched = tf.matmul(
tf.expand_dims(self._w_l2_normed(v1, W), -2),
tf.expand_dims(self._w_l2_normed(_h, W), -1)
)
return tf.squeeze(matched, [3, 4])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment