Skip to content

Instantly share code, notes, and snippets.

@Burntt
Created March 1, 2022 07:26
Show Gist options
  • Save Burntt/540655ce4565fd5d7e7059b58a3d0656 to your computer and use it in GitHub Desktop.
Save Burntt/540655ce4565fd5d7e7059b58a3d0656 to your computer and use it in GitHub Desktop.
def perturbation_rank(model,x,y,names):
errors = []
X_saved = x
y = y.flatten()
with torch.no_grad():
model.eval()
for i in range(x.shape[1]):
# Convert to numpy, shuffle, convert back to tensor, predict
x = x.detach().numpy()
np.random.shuffle(x[:,i])
x = torch.from_numpy(x).float().to(device)
pred = model(x)
# log_loss requires (classification target, probabilities)
pred = pred.cpu().detach().numpy()
error = metrics.log_loss(y, pred)
errors.append(error)
# Reset x to saved tensor matrix
x = X_saved
max_error = np.max(errors)
importance = [e/max_error for e in errors]
data = {'name':names,'error':errors,'importance':importance}
result = pd.DataFrame(data,columns = ['name','error','importance'])
result.sort_values(by=['importance'],ascending=[0],inplace=True)
result.reset_index(inplace=True,drop=True)
return result
names = list(fractional_diff_data.columns)
names.remove('label_barrier')
rank = perturbation_rank(model_NN1,
torch.from_numpy(X_test).float(),
torch.from_numpy(y_test.astype(int)).long(),
names
)
display(rank)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment