Last active
February 27, 2025 03:00
-
-
Save naufalso/5f065b9f64554c8c92df1399c6da438f to your computer and use it in GitHub Desktop.
Download evaluation data from weave
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import weave | |
import pandas as pd | |
from tqdm import tqdm | |
def get_calls(project_id, op_name, parent_id = None): | |
client = weave.init(project_id) | |
query_data = { | |
"project_id": project_id, | |
"filter": {"op_names": [op_name]}, | |
"sort_by": [{"field": "started_at", "direction": "desc"}], | |
"expand_columns": ["inputs.model"], | |
"limit": None | |
} | |
if parent_id: | |
query_data["filter"]["parent_ids"] = [parent_id] | |
calls = client.server.calls_query_stream( | |
query_data | |
) | |
return calls | |
def extract_eval_item(call): | |
data = dict(call.inputs["example"]) | |
if "_ref" in data: | |
data.pop("_ref") | |
if call.output is None: | |
data.update({"model_answer": None, "original_answer": None, "correct": None}) | |
return data | |
elif "model_output" in call.output.keys(): | |
model_answer = call.output['model_output'].get('answer', None) | |
original_answer = call.output['model_output'].get('original_answer', None) | |
else: | |
try: | |
model_answer = call.output['output'].get('answer', None) if call.output['output'] is not None else None | |
original_answer = call.output['output'].get('original_answer', None) if call.output['output'] is not None else None | |
except Exception as e: | |
print(call.output) | |
raise e | |
correct = call.output['scores']['match_answer']['match'] if model_answer is not None else None | |
data.update({"model_answer": model_answer, "original_answer": original_answer, "correct": correct}) | |
return data | |
def fetch_evaluation_data(project_id, op_name, parent_id = None): | |
client = weave.init(project_id) | |
query_data = { | |
"project_id": project_id, | |
"filter": {"op_names": [op_name]}, | |
"sort_by": [{"field": "started_at", "direction": "desc"}], | |
"expand_columns": ["inputs.example"], | |
"limit": None | |
} | |
if parent_id: | |
query_data["filter"]["parent_ids"] = [parent_id] | |
calls = client.server.calls_query_stream( | |
query_data | |
) | |
eval_data = [] | |
successes = 0 | |
with tqdm(calls, desc="Extracting data", unit="call") as t: | |
for call in t: | |
try: | |
eval_data.append(extract_eval_item(call)) | |
successes += 1 | |
except: | |
pass | |
t.set_postfix(success=successes) | |
return pd.DataFrame(eval_data) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment