Skip to content

Instantly share code, notes, and snippets.

@CVxTz
Created February 15, 2020 18:09
Show Gist options
  • Save CVxTz/f979e5740db5f67453127f21498b17f4 to your computer and use it in GitHub Desktop.
Save CVxTz/f979e5740db5f67453127f21498b17f4 to your computer and use it in GitHub Desktop.
import numpy as np
def kl_loss(y_true, y_pred):
return np.sum(y_true * np.log(y_true / y_pred))
y_true = np.array([0.1, 0.5, 0.2, 0.15, 0.05])
y_pr_1 = np.array([0.15, 0.4, 0.2, 0.2, 0.05])
y_pr_2 = np.array([0.5, 0.1, 0.1, 0.2, 0.1])
print(np.sum(y_true), np.sum(y_pr_2), np.sum(y_pr_1))
print(kl_loss(y_true, y_pr_1)) # => 0.02 because y_pr_1 is similar to y_true
print(kl_loss(y_true, y_pr_2)) # => 0.70
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment