Skip to content

Instantly share code, notes, and snippets.

@kingspp
Last active April 8, 2020 15:46
Show Gist options
  • Save kingspp/e87d35a8a9c70d699a2b3bb941dfb4bd to your computer and use it in GitHub Desktop.
Save kingspp/e87d35a8a9c70d699a2b3bb941dfb4bd to your computer and use it in GitHub Desktop.
Showcases custom metric (Confusion Matrix) implementation in Tensorflow
import typing
import matplotlib.pyplot as plt
import tensorflow as tf
def plot_confusion_matrix(cm: np.array, label_mappings:typing.Dict, num_classes:int) -> plt.Figure:
"""
| **@author:** Prathyush SP
|
| Create a confusion matrix using matplotlib
:param cm: A confusion matrix: A square ```numpy array``` of the same size as labels
:param label_mappings: The labels for classes
:param num_classes: Number of classes
:return: A ``matplotlib.figure.Figure`` object with a numerical and graphical representation of the cm array
"""
fig = plt.Figure(figsize=(num_classes, num_classes), dpi=333, facecolor='w', edgecolor='k')
ax = fig.add_subplot(1, 1, 1)
ax.imshow(cm, cmap='YlGn')
classes = [re.sub(r'([a-z](?=[A-Z])|[A-Z](?=[A-Z][a-z]))', r'\1 ', x) for x in
[str(_) for _ in label_mappings.values()]]
classes = ['\n'.join(textwrap.wrap(l, 20)) for l in classes]
tick_marks = np.arange(len(classes))
# Setup Predicted
ax.set_xlabel('Prediction')
ax.set_xticks(tick_marks)
ax.set_xticklabels(classes, rotation=-90, ha='center')
ax.xaxis.set_label_position('bottom')
ax.xaxis.tick_bottom()
# Setup Label
ax.set_ylabel('Label')
ax.set_yticks(tick_marks)
ax.set_yticklabels(classes, va='center')
ax.yaxis.set_label_position('left')
ax.yaxis.tick_left()
for i, j in itertools.product(range(num_classes), range(num_classes)):
ax.text(j, i, int(cm[i, j]) if cm[i, j] != 0 else '.', horizontalalignment="center",
verticalalignment='center', color="black")
fig.set_tight_layout(tight=True)
return fig
def figure_to_summary(fig:plt.Figure, name:str) -> tf.Summary:
"""
| **@author:** Prathyush SP
|
| Converts a matplotlib figure into a TensorFlow Summary object
:param fig: A matplotlib.figure.Figure object.
:param name: Name of the plot
:return: A TensorFlow Summary protobuf object containing the plot image
as a image summary.
"""
# attach a new canvas if not exists
if fig.canvas is None:
matplotlib.backends.backend_agg.FigureCanvasAgg(fig)
fig.canvas.draw()
w, h = fig.canvas.get_width_height()
# get PNG data from the figure
png_buffer = io.BytesIO()
fig.canvas.print_png(png_buffer)
png_encoded = png_buffer.getvalue()
png_buffer.close()
summary_image = tf.Summary.Image(height=h, width=w, colorspace=4, # RGB-A
encoded_image_string=png_encoded)
summary = tf.Summary(value=[tf.Summary.Value(tag=name, image=summary_image)])
return summary
metric_values=np.random.random([2,2])
figure = plot_confusion_matrix(metric_values, label_mapping:{0:'Event Occured', 1:"Event Skipped"}, num_classes=2)
summary = figure_to_summary(figure, name="confusion_matrix")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment