Created
August 22, 2018 15:04
-
-
Save takashioya/7f526d6e48b274c5ab78783b658167e6 to your computer and use it in GitHub Desktop.
aucの計算
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from itertools import combinations | |
from sklearn.metrics import roc_auc_score | |
def calc_auc_slow(label, pred): | |
#激遅 | |
n = label.shape[0] | |
all_combs = np.array(list(combinations(np.arange(n), 2))) | |
label_pairs = label[all_combs] | |
pred_pairs = pred[all_combs] | |
denom = np.bincount(label).prod() | |
mole = np.logical_and(label_pairs[:, 0] != label_pairs[:, 1], pred_pairs[:, 0] == pred_pairs[:, 1]).sum() * 0.5 + \ | |
np.logical_and(label_pairs[:, 0] > label_pairs[:, 1], pred_pairs[:, 0] > pred_pairs[:, 1]).sum() + \ | |
np.logical_and(label_pairs[:, 0] < label_pairs[:, 1], pred_pairs[:, 0] < pred_pairs[:, 1]).sum() | |
return mole/denom | |
if __name__ == '__main__': | |
label = np.array([1, 0, 1, 0, 0, 1, 0, 0, 1]) | |
pred = np.array([1, 2, 3, 2, 3, 2, 2, 1, 1]) | |
assert(roc_auc_score(label, pred) == calc_auc_slow(label, pred)) | |
print(roc_auc_score(label, pred)) | |
print(calc_auc_slow(label, pred)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment