Created
November 13, 2020 15:35
-
-
Save guschmue/a8e6a52619e5f1734e89a98102d37146 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright (c) Microsoft Corporation. All rights reserved. | |
# Licensed under the MIT license. | |
"""Test examples.""" | |
import os | |
import subprocess | |
import unittest | |
import tensorflow as tf | |
import numpy as np | |
import tensorflow_hub as hub | |
import tf2onnx | |
import timeit | |
import time | |
import zipfile | |
import keras2onnx | |
import onnxruntime as rt | |
from onnx import helper | |
from tensorflow.python.keras.saving import saving_utils as _saving_utils | |
from tensorflow.keras import layers, models | |
from tensorflow import keras | |
from common import check_opset_min_version, check_opset_max_version, check_tf_min_version | |
def onnx_session(model): | |
providers = ['CPUExecutionProvider'] | |
if rt.get_device() == "GPU": | |
gpus = os.environ.get("CUDA_VISIBLE_DEVICES") | |
if gpus is None or len(gpus) > 1: | |
providers = ['CUDAExecutionProvider'] | |
opt = rt.SessionOptions() | |
opt.graph_optimization_level = rt.GraphOptimizationLevel.ORT_ENABLE_ALL | |
# opt.enable_profiling = True | |
sess = rt.InferenceSession(model, sess_options=opt, providers=providers) | |
inputs = [input_meta.name for input_meta in sess.get_inputs()] | |
outputs = [output_meta.name for output_meta in sess.get_outputs()] | |
return sess, inputs, outputs | |
def to_onnx(model, output=None, use_tf2onnx=True, large_model=False): | |
if use_tf2onnx: | |
const_node_values = None | |
external_tensor_storage = None | |
function = _saving_utils.trace_model_call(model) | |
concrete_func = function.get_concrete_function() | |
input_names = [input_tensor.name for input_tensor in concrete_func.inputs | |
if input_tensor.dtype != tf.dtypes.resource] | |
output_names = [output_tensor.name for output_tensor in concrete_func.outputs | |
if output_tensor.dtype != tf.dtypes.resource] | |
frozen_graph = tf2onnx.tf_loader.from_function( | |
concrete_func, input_names, output_names, large_model=large_model) | |
if large_model: | |
# frozen_graph, input_names, outputs, concrete_func, imported = frozen_graph | |
from tf2onnx.tf_utils import compress_graph_def | |
const_node_values = compress_graph_def(frozen_graph) | |
external_tensor_storage = tf2onnx.graph.ExternalTensorStorage() | |
with tf.Graph().as_default() as tf_graph: | |
tf.import_graph_def(frozen_graph, name='') | |
g = tf2onnx.tfonnx.process_tf_graph( | |
tf_graph, opset=12, input_names=input_names, output_names=output_names, const_node_values=const_node_values) | |
onnx_graph = tf2onnx.optimizer.optimize_graph(g) | |
model_proto = onnx_graph.make_model("converted from tf2onnx", external_tensor_storage=external_tensor_storage) | |
else: | |
model_proto = keras2onnx.convert_keras(model, model.name) | |
if output: | |
if large_model: | |
tf2onnx.utils.save_onnx_zip(output, model_proto, external_tensor_storage) | |
else: | |
tf2onnx.utils.save_protobuf(output, model_proto) | |
inputs = [n.name for n in model_proto.graph.input] | |
outputs = [n.name for n in model_proto.graph.output] | |
return model_proto, inputs, outputs, external_tensor_storage | |
def bert_from_hub(url, output_name): | |
max_seq_length = 512 | |
input_word_ids = tf.keras.layers.Input( | |
shape=(max_seq_length,), dtype=tf.int32, name="input_word_ids") | |
input_mask = tf.keras.layers.Input( | |
shape=(max_seq_length,), dtype=tf.int32, name="input_mask") | |
segment_ids = tf.keras.layers.Input( | |
shape=(max_seq_length,), dtype=tf.int32, name="segment_ids") | |
bert_layer = hub.KerasLayer(url, trainable=False) | |
bert_inputs = [input_word_ids, input_mask, segment_ids] | |
pooled_output, sequence_output = bert_layer(bert_inputs) | |
# vocab_file = bert_layer.resolved_object.vocab_file.asset_path.numpy() | |
# do_lower_case = bert_layer.resolved_object.do_lower_case.numpy() | |
# tokenizer = tokenization.FullTokenizer(vocab_file, do_lower_case) | |
model = tf.keras.models.Model(inputs=bert_inputs, outputs=[ | |
pooled_output, sequence_output]) | |
model.compile(loss='binary_crossentropy', | |
optimizer='adam', metrics=['accuracy']) | |
print("inputs", model.inputs) | |
print("outputs", model.outputs) | |
in_words = np.ones([1, max_seq_length], dtype=np.int32) | |
in_mask = np.ones([1, max_seq_length], dtype=np.int32) | |
in_segment = np.ones([1, max_seq_length], dtype=np.int32) | |
tf_ret = model.predict([in_words, in_mask, in_segment]) | |
N = 20 | |
print("tf perftest ...") | |
start = time.time() | |
for i in range(0, N): | |
_ = model.predict([in_words, in_mask, in_segment]) | |
tf_res = 1000 * (time.time() - start) / N | |
print("convert ...") | |
large_model = output_name.endswith(".zip") | |
model_proto, inputs, outputs, external_tensor_storage = to_onnx( | |
model, output_name, use_tf2onnx=True, large_model=large_model) | |
tf.compat.v1.reset_default_graph() | |
feeds = {"input_word_ids:0": in_words, | |
"input_mask:0": in_mask, | |
"segment_ids:0": in_segment} | |
name = "bert" | |
save_path = output_name+".unpacked" | |
with zipfile.ZipFile(output_name, 'r') as z: | |
z.extractall(save_path) | |
output_name = os.path.join(save_path, "__MODEL_PROTO.onnx") | |
sess, inputs, outputs = onnx_session(output_name) | |
onnx_ret = sess.run(outputs, feeds) | |
print("ort perftest ...") | |
start = time.time() | |
for i in range(0, N): | |
_ = sess.run(outputs, feeds) | |
ort_res = 1000 * (time.time() - start) / N | |
print(f"{url} TF={tf_res} ORT={ort_res}") | |
for i1, i2 in zip(tf_ret, onnx_ret): | |
np.testing.assert_allclose(i1, i2, rtol=1) | |
class TestKerasApps(unittest.TestCase): | |
def setUp(self): | |
tf.compat.v1.reset_default_graph() | |
def test_tfkeras_bert0(self): | |
bert_from_hub( | |
"https://tfhub.dev/tensorflow/bert_en_cased_L-12_H-768_A-12/2", "/tmp/bert_en_cased_L-12_H-768_A-12.zip") | |
@unittest.skip("later") | |
def test_tfkeras_bert1(self): | |
bert_from_hub( | |
"https://tfhub.dev/tensorflow/bert_en_cased_L-24_H-1024_A-16/2", "/tmp/bert_en_cased_L-24_H-1024_A-16.zip") | |
@unittest.skip("later") | |
def test_tfkeras_bert2(self): | |
bert_from_hub( | |
"https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/2", "/tmp/bert_en_uncased_L-12_H-768_A-12.zip") | |
@unittest.skip("later") | |
def test_tfkeras_bert3(self): | |
bert_from_hub( | |
"https://tfhub.dev/tensorflow/bert_en_uncased_L-24_H-1024_A-16/2", "/tmp/bert_en_uncased_L-24_H-1024_A-16.zip") | |
@unittest.skip("later") | |
def test_tfkeras_bert4(self): | |
bert_from_hub( | |
"https://tfhub.dev/tensorflow/bert_en_wwm_cased_L-24_H-1024_A-16/2", "/tmp/bert_en_wwm_cased_L-24_H-1024_A-16.zip") | |
@unittest.skip("later") | |
def test_tfkeras_bert5(self): | |
bert_from_hub( | |
"https://tfhub.dev/tensorflow/bert_en_wwm_uncased_L-24_H-1024_A-16/2", "/tmp/bert_en_wwm_uncased_L-24_H-1024_A-16.zip") | |
@unittest.skip("later") | |
def test_tfkeras_bert6(self): | |
bert_from_hub( | |
"https://tfhub.dev/tensorflow/bert_en_cased_L-12_H-768_A-12/3", "/tmp/bert_en_cased_L-12_H-768_A-12_3.zip") | |
if __name__ == '__main__': | |
unittest.main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment