Last active
April 8, 2020 15:46
-
-
Save kingspp/e87d35a8a9c70d699a2b3bb941dfb4bd to your computer and use it in GitHub Desktop.
Showcases custom metric (Confusion Matrix) implementation in Tensorflow
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import typing | |
import matplotlib.pyplot as plt | |
import tensorflow as tf | |
def plot_confusion_matrix(cm: np.array, label_mappings:typing.Dict, num_classes:int) -> plt.Figure: | |
""" | |
| **@author:** Prathyush SP | |
| | |
| Create a confusion matrix using matplotlib | |
:param cm: A confusion matrix: A square ```numpy array``` of the same size as labels | |
:param label_mappings: The labels for classes | |
:param num_classes: Number of classes | |
:return: A ``matplotlib.figure.Figure`` object with a numerical and graphical representation of the cm array | |
""" | |
fig = plt.Figure(figsize=(num_classes, num_classes), dpi=333, facecolor='w', edgecolor='k') | |
ax = fig.add_subplot(1, 1, 1) | |
ax.imshow(cm, cmap='YlGn') | |
classes = [re.sub(r'([a-z](?=[A-Z])|[A-Z](?=[A-Z][a-z]))', r'\1 ', x) for x in | |
[str(_) for _ in label_mappings.values()]] | |
classes = ['\n'.join(textwrap.wrap(l, 20)) for l in classes] | |
tick_marks = np.arange(len(classes)) | |
# Setup Predicted | |
ax.set_xlabel('Prediction') | |
ax.set_xticks(tick_marks) | |
ax.set_xticklabels(classes, rotation=-90, ha='center') | |
ax.xaxis.set_label_position('bottom') | |
ax.xaxis.tick_bottom() | |
# Setup Label | |
ax.set_ylabel('Label') | |
ax.set_yticks(tick_marks) | |
ax.set_yticklabels(classes, va='center') | |
ax.yaxis.set_label_position('left') | |
ax.yaxis.tick_left() | |
for i, j in itertools.product(range(num_classes), range(num_classes)): | |
ax.text(j, i, int(cm[i, j]) if cm[i, j] != 0 else '.', horizontalalignment="center", | |
verticalalignment='center', color="black") | |
fig.set_tight_layout(tight=True) | |
return fig | |
def figure_to_summary(fig:plt.Figure, name:str) -> tf.Summary: | |
""" | |
| **@author:** Prathyush SP | |
| | |
| Converts a matplotlib figure into a TensorFlow Summary object | |
:param fig: A matplotlib.figure.Figure object. | |
:param name: Name of the plot | |
:return: A TensorFlow Summary protobuf object containing the plot image | |
as a image summary. | |
""" | |
# attach a new canvas if not exists | |
if fig.canvas is None: | |
matplotlib.backends.backend_agg.FigureCanvasAgg(fig) | |
fig.canvas.draw() | |
w, h = fig.canvas.get_width_height() | |
# get PNG data from the figure | |
png_buffer = io.BytesIO() | |
fig.canvas.print_png(png_buffer) | |
png_encoded = png_buffer.getvalue() | |
png_buffer.close() | |
summary_image = tf.Summary.Image(height=h, width=w, colorspace=4, # RGB-A | |
encoded_image_string=png_encoded) | |
summary = tf.Summary(value=[tf.Summary.Value(tag=name, image=summary_image)]) | |
return summary | |
metric_values=np.random.random([2,2]) | |
figure = plot_confusion_matrix(metric_values, label_mapping:{0:'Event Occured', 1:"Event Skipped"}, num_classes=2) | |
summary = figure_to_summary(figure, name="confusion_matrix") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment