Created
April 18, 2025 19:20
-
-
Save pszemraj/972ed730614139a68a062c86e3721308 to your computer and use it in GitHub Desktop.
multiple‑choice dataset aggregator
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
""" | |
create_unified_mcqa.py – “batteries‑included” multiple‑choice aggregator | |
✅ Handles all datasets listed in the conversation | |
✅ Survives missing/renamed columns | |
✅ Converts every `label` to pure int64 to avoid ClassLabel clashes | |
✅ Explicitly casts features to ensure concatenation compatibility | |
✅ Improved error handling and skipping for malformed examples | |
✅ Limits warning/info messages per dataset | |
✅ Fixes column mismatch error during cast | |
✅ Improved math_qa choice parsing | |
""" | |
from collections import defaultdict | |
from typing import Dict, Any, Callable, List, Union, Optional | |
import re | |
import sys | |
import traceback # For detailed error logging | |
import datasets | |
from datasets import Dataset, concatenate_datasets, Features, Value, Sequence | |
from tqdm import tqdm | |
# ───────────────────────────────────────────────────────────────────────────── | |
# 0) CONFIGURATION & TARGET SCHEMA | |
# ───────────────────────────────────────────────────────────────────────────── | |
MAX_WARNINGS_PER_DATASET = 5 # Limit excessive logging for common issues | |
TARGET_FEATURES = Features({ | |
'context': Value('string'), | |
'question': Value('string'), | |
'choices': Sequence(Value('string')), | |
'label': Value('int64'), | |
'source_dataset': Value('string'), # Added source dataset identifier | |
}) | |
# Global warning counter | |
warning_counters = defaultdict(int) | |
def log_warning(dataset_name: str, message: str): | |
"""Logs a warning message, respecting the limit per dataset.""" | |
if warning_counters[dataset_name] < MAX_WARNINGS_PER_DATASET: | |
print(f"Warning [{dataset_name}]: {message}", file=sys.stderr) | |
warning_counters[dataset_name] += 1 | |
elif warning_counters[dataset_name] == MAX_WARNINGS_PER_DATASET: | |
print(f"Warning [{dataset_name}]: (Further warnings suppressed for this dataset)", file=sys.stderr) | |
warning_counters[dataset_name] += 1 # Increment once more to prevent re-printing this message | |
def log_info(dataset_name: str, message: str): | |
"""Logs an info message (e.g., skipping), respecting the limit per dataset.""" | |
# Using the same counter as warnings for simplicity, could be separate | |
if warning_counters[dataset_name] < MAX_WARNINGS_PER_DATASET: | |
print(f"Info [{dataset_name}]: {message}", file=sys.stderr) | |
# Don't increment counter for info, or make it separate if needed | |
# No suppression message for info to avoid clutter | |
# ───────────────────────────────────────────────────────────────────────────── | |
# 1) REGISTRY | |
# ───────────────────────────────────────────────────────────────────────────── | |
REGISTRY: Dict[str, Dict[str, Any]] = { | |
"race": dict(path="race", name="all"), | |
"commonsense_qa": dict(path="tau/commonsense_qa"), | |
"sciq": dict(path="allenai/sciq"), | |
"math_qa": dict(path="allenai/math_qa"), | |
"swag": dict(path="allenai/swag", name="regular"), | |
"hellaswag": dict(path="rowan/hellaswag"), | |
"social_i_qa": dict(path="allenai/social_i_qa"), | |
"cosmos_qa": dict(path="allenai/cosmos_qa"), | |
"piqa": dict(path="ybisk/piqa"), | |
"winogrande": dict(path="allenai/winogrande", name="winogrande_l"), | |
"dream": dict(path="dataset-org/dream"), | |
"quail": dict(path="textmachinelab/quail"), | |
"medmcqa": dict(path="openlifescienceai/medmcqa"), | |
} | |
# ───────────────────────────────────────────────────────────────────────────── | |
# 2) SPECIAL CANONICALISERS | |
# Return None if an example is invalid/malformed and should be skipped. | |
# Now include 'source_dataset' in the returned dict. | |
# ───────────────────────────────────────────────────────────────────────────── | |
def _ensure_string_choices(choices: List[Any]) -> List[str]: | |
return [str(c) if c is not None else "" for c in choices] | |
# --- Canonicalizers updated to accept dataset_name and return source_dataset --- | |
def canon_commonsenseqa(ex: Dict[str, Any], dataset_name: str) -> Optional[Dict[str, Any]]: | |
try: | |
# Basic structure checks | |
choices_data = ex.get("choices") | |
if not choices_data or not isinstance(choices_data, dict): | |
log_warning(dataset_name, f"Missing or invalid 'choices' dict: {ex.get('id', 'N/A')}") | |
return None | |
labels = choices_data.get("label") | |
texts = choices_data.get("text") | |
answer_key = ex.get("answerKey") | |
if not labels or not texts or answer_key is None: | |
log_warning(dataset_name, f"Missing labels, texts, or answerKey in choices: {ex.get('id', 'N/A')}") | |
return None | |
if len(labels) != len(texts): | |
log_warning(dataset_name, f"Mismatch length label/text: {ex.get('id', 'N/A')}") | |
return None | |
if not labels: # Handle empty choices case | |
log_warning(dataset_name, f"Empty choices found: {ex.get('id', 'N/A')}") | |
return None | |
# Process choices and label | |
order = sorted(range(len(labels)), key=lambda i: labels[i]) | |
choices_text = [texts[i] for i in order] | |
label = labels.index(answer_key) # Find 0-based index of the answer key | |
except (KeyError, ValueError, IndexError, TypeError) as e: | |
log_warning(dataset_name, f"Failed to process example {ex.get('id', 'N/A')}. Error: {e}") | |
return None | |
return { | |
"context": "", | |
"question": str(ex.get("question", "")), | |
"choices": _ensure_string_choices(choices_text), | |
"label": int(label), | |
"source_dataset": dataset_name, | |
} | |
def canon_sciq(ex: Dict[str, Any], dataset_name: str) -> Optional[Dict[str, Any]]: | |
try: | |
distractors = [] | |
if "distractors" in ex and isinstance(ex["distractors"], list): | |
distractors = ex["distractors"] | |
elif "distractor1" in ex and "distractor2" in ex and "distractor3" in ex: | |
# Ensure they are not None or empty before adding | |
d1, d2, d3 = ex.get("distractor1"), ex.get("distractor2"), ex.get("distractor3") | |
if d1 is not None and d2 is not None and d3 is not None: | |
distractors = [d1, d2, d3] | |
else: | |
log_warning(dataset_name, f"Missing one or more distractorN keys: {ex.get('id', 'N/A')}") | |
return None | |
else: | |
log_warning(dataset_name, f"Could not find distractors list or keys: {ex.get('id', 'N/A')}") | |
return None | |
correct_answer = ex.get("correct_answer") | |
if correct_answer is None: | |
log_warning(dataset_name, f"Missing 'correct_answer': {ex.get('id', 'N/A')}") | |
return None | |
choices = [correct_answer, *distractors] | |
if len(choices) < 2: # Need at least two choices | |
log_warning(dataset_name, f"Too few choices ({len(choices)}) after processing: {ex.get('id', 'N/A')}") | |
return None | |
except (KeyError, TypeError) as e: | |
log_warning(dataset_name, f"Failed to process example {ex.get('id', 'N/A')}. Error: {e}") | |
return None | |
return { | |
"context": str(ex.get("support", "")), | |
"question": str(ex.get("question", "")), | |
"choices": _ensure_string_choices(choices), | |
"label": 0, # Correct answer is always first by construction here | |
"source_dataset": dataset_name, | |
} | |
def canon_hellaswag(ex: Dict[str, Any], dataset_name: str) -> Optional[Dict[str, Any]]: | |
try: | |
ctx_a = str(ex.get('ctx_a', '')).strip() | |
ctx_b = str(ex.get('ctx_b', '')).strip() | |
context = f"{ctx_a} {ctx_b}".strip() | |
# HellaSwag has 'endings' field which is a list | |
choices = ex.get("endings") | |
if choices is None or not isinstance(choices, list): | |
# Fallback to ending0, ending1... if 'endings' list is missing | |
choices = [] | |
num_endings = 4 # Default assumption | |
for i in range(num_endings): | |
key = f"ending{i}" | |
ending = ex.get(key) | |
if ending is None: | |
log_info(dataset_name, f"Skipping example {ex.get('ind', 'N/A')} due to missing key '{key}'.") | |
return None | |
choices.append(ending) | |
if not choices or len(choices) < 2: | |
log_warning(dataset_name, f"No valid choices found for example {ex.get('ind', 'N/A')}") | |
return None | |
label = int(ex["label"]) # Label is already 0-based index | |
if not (0 <= label < len(choices)): | |
log_warning(dataset_name, f"Invalid label {label} for {len(choices)} choices: {ex.get('ind', 'N/A')}") | |
return None | |
except (KeyError, ValueError, TypeError) as e: | |
log_warning(dataset_name, f"Failed to process example {ex.get('ind', 'N/A')}. Error: {e}") | |
return None | |
return { | |
"context": context, | |
"question": str(ex.get("activity_label", "")), | |
"choices": _ensure_string_choices(choices), | |
"label": label, | |
"source_dataset": dataset_name, | |
} | |
def canon_piqa(ex: Dict[str, Any], dataset_name: str) -> Optional[Dict[str, Any]]: | |
try: | |
sol1 = ex.get("sol1") | |
sol2 = ex.get("sol2") | |
if sol1 is None or sol2 is None: | |
log_warning(dataset_name, f"Missing sol1/sol2: {ex.get('id', 'N/A')}") | |
return None | |
choices = [sol1, sol2] | |
label = int(ex["label"]) # Label is 0 or 1 | |
if label not in [0, 1]: | |
log_warning(dataset_name, f"Invalid label {label}: {ex.get('id', 'N/A')}") | |
return None | |
except (KeyError, ValueError, TypeError) as e: | |
log_warning(dataset_name, f"Failed to process example {ex.get('id', 'N/A')}. Error: {e}") | |
return None | |
return { | |
"context": "", | |
"question": str(ex.get("goal", "")), | |
"choices": _ensure_string_choices(choices), | |
"label": label, | |
"source_dataset": dataset_name, | |
} | |
def canon_dream(ex: Dict[str, Any], dataset_name: str) -> Optional[Dict[str, Any]]: | |
try: | |
context = "\n".join(ex.get("dialogue", [])) | |
question = str(ex.get("question", "")) | |
# Dataset uses 'choice' list, not option_X keys | |
choices = ex.get("choice") | |
if choices is None or not isinstance(choices, list) or len(choices) < 2: | |
log_warning(dataset_name, f"Missing or invalid 'choice' list: {ex.get('id', 'N/A')}") | |
return None | |
# Answer is in 'answer' field, seems to be the text of the correct choice | |
answer_text = ex.get("answer") | |
if answer_text is None: | |
log_warning(dataset_name, f"Missing 'answer' field: {ex.get('id', 'N/A')}") | |
return None | |
# Find the index of the answer text in the choices list | |
try: | |
# Need case-insensitive comparison potentially? Or exact match? Assume exact for now. | |
str_choices = _ensure_string_choices(choices) | |
label = str_choices.index(str(answer_text)) | |
except ValueError: | |
log_warning(dataset_name, f"Answer text '{answer_text}' not found in choices {str_choices}: {ex.get('id', 'N/A')}") | |
return None | |
if not (0 <= label < len(choices)): | |
# This check should be redundant if index() worked, but keep as safeguard | |
log_warning(dataset_name, f"Invalid label {label} for {len(choices)} choices: {ex.get('id', 'N/A')}") | |
return None | |
except (KeyError, ValueError, TypeError, IndexError) as e: | |
log_warning(dataset_name, f"Failed to process example {ex.get('id', 'N/A')}. Error: {e}") | |
return None | |
return { | |
"context": context, | |
"question": question, | |
"choices": _ensure_string_choices(choices), # Ensure strings again just in case | |
"label": int(label), | |
"source_dataset": dataset_name, | |
} | |
def canon_quail(ex: Dict[str, Any], dataset_name: str) -> Optional[Dict[str, Any]]: | |
try: | |
choices = ex.get("answers") | |
if choices is None or not isinstance(choices, list) or len(choices) < 2: # Allow variable choices? Check dataset spec. Assume 4 usually. | |
log_warning(dataset_name, f"Missing or invalid 'answers' list: {ex.get('id', 'N/A')}") | |
return None | |
label = int(ex["correct_answer_id"]) # 0-based index | |
if not (0 <= label < len(choices)): | |
log_warning(dataset_name, f"Invalid label {label} for {len(choices)} choices: {ex.get('id', 'N/A')}") | |
return None | |
except (KeyError, ValueError, TypeError) as e: | |
log_warning(dataset_name, f"Failed to process example {ex.get('id', 'N/A')}. Error: {e}") | |
return None | |
return { | |
"context": str(ex.get("context", "")), | |
"question": str(ex.get("question", "")), | |
"choices": _ensure_string_choices(choices), | |
"label": label, | |
"source_dataset": dataset_name, | |
} | |
def canon_medmcqa(ex: Dict[str, Any], dataset_name: str) -> Optional[Dict[str, Any]]: | |
try: | |
choices = [] | |
keys = ["opa", "opb", "opc", "opd"] | |
for k in keys: | |
choice = ex.get(k) | |
if choice is None: | |
log_info(dataset_name, f"Skipping example {ex.get('id', 'N/A')} due to missing key '{k}'.") | |
return None | |
choices.append(choice) | |
if len(choices) < 2: | |
log_warning(dataset_name, f"Too few choices found: {ex.get('id', 'N/A')}") | |
return None | |
cop_val = ex.get("cop") # Correct option | |
label = -1 | |
if isinstance(cop_val, str) and cop_val.upper() in "ABCD": | |
label = ord(cop_val.upper()) - ord("A") | |
elif isinstance(cop_val, int) and 0 <= cop_val < len(choices): # Handles 0-based index | |
label = cop_val | |
elif isinstance(cop_val, str) and cop_val.isdigit(): # Handles "1", "2", etc. (assume 1-based) | |
val_int = int(cop_val) | |
if 1 <= val_int <= len(choices): | |
label = val_int - 1 | |
else: | |
log_warning(dataset_name, f"Numeric string 'cop' value {cop_val} out of range [1,{len(choices)}]: {ex.get('id', 'N/A')}") | |
return None | |
else: | |
log_warning(dataset_name, f"Unexpected 'cop' value '{cop_val}' (type: {type(cop_val)}): {ex.get('id', 'N/A')}") | |
return None | |
if not (0 <= label < len(choices)): | |
log_warning(dataset_name, f"Final label index {label} out of bounds [0,{len(choices)-1}]: {ex.get('id', 'N/A')}") | |
return None | |
except (KeyError, ValueError, TypeError) as e: | |
log_warning(dataset_name, f"Failed to process example {ex.get('id', 'N/A')}. Error: {e}") | |
return None | |
return { | |
"context": str(ex.get("exp", "")), # Use 'exp' (explanation) as context | |
"question": str(ex.get("question", "")), | |
"choices": _ensure_string_choices(choices), | |
"label": int(label), | |
"source_dataset": dataset_name, | |
} | |
SPECIAL: Dict[str, Callable[[Dict[str, Any], str], Optional[Dict[str, Any]]]] = { | |
"commonsense_qa": canon_commonsenseqa, | |
"sciq": canon_sciq, | |
"hellaswag": canon_hellaswag, | |
"piqa": canon_piqa, | |
"dream": canon_dream, | |
"quail": canon_quail, | |
"medmcqa": canon_medmcqa, | |
} | |
# ───────────────────────────────────────────────────────────────────────────── | |
# 3) GENERIC CANONICALISER | |
# Return None if an example is invalid/malformed and should be skipped. | |
# Now include 'source_dataset' in the returned dict. | |
# ───────────────────────────────────────────────────────────────────────────── | |
COLMAP: Dict[str, Dict[str, Union[str, List[str]]]] = { | |
"race": { | |
"context": "article", | |
"question": "question", | |
"choices": "options", # List of strings | |
"label": "answer", # 'A', 'B', 'C', 'D' | |
}, | |
"math_qa": { | |
"context": "Problem", | |
"question": "Rationale", | |
"choices": "options", # String like "a) text , b) text ..." | |
"label": "correct", # 'a', 'b', 'c', 'd', 'e' | |
}, | |
"swag": { | |
"context": "sent1", | |
"question": "startphrase", | |
"choice_keys": ["ending0", "ending1", "ending2", "ending3"], | |
"label": "label", # 0, 1, 2, 3 | |
}, | |
"social_i_qa": { | |
"context": "context", | |
"question": "question", | |
"choice_keys": ["answerA", "answerB", "answerC"], | |
"label": "label", # 1, 2, 3 -> needs conversion to 0-based | |
}, | |
"cosmos_qa": { | |
"context": "context", | |
"question": "question", | |
"choice_keys": ["answer0", "answer1", "answer2", "answer3"], | |
"label": "label", # 0, 1, 2, 3 | |
}, | |
"winogrande": { | |
"context": "", | |
"question": "sentence", | |
"choice_keys": ["option1", "option2"], | |
"label": "answer", # '1' or '2' -> needs conversion to 0-based | |
}, | |
} | |
def _ci_get(obj: dict, key: str, default: Any = None) -> Any: | |
"""Case‑insensitive safe lookup.""" | |
if not key: return default | |
# Try case-insensitive first | |
for k in obj: | |
if k.lower() == key.lower(): | |
return obj[k] | |
# Fallback to exact key match if case-insensitive fails | |
if key in obj: | |
return obj[key] | |
return default | |
def _parse_mathqa_options(options_str: str, dataset_name: str) -> Optional[List[str]]: | |
"""Specific parser for MathQA 'options' string.""" | |
if not isinstance(options_str, str): | |
log_warning(dataset_name, f"Expected string for MathQA options, got {type(options_str)}") | |
return None | |
# Split primarily by comma, as it seems the main separator between choices | |
# Need to be careful about commas *within* choices (e.g., "e) none of these") | |
# Let's try splitting by the pattern "letter )" or "letter ." | |
pattern = r'\s*([a-eA-E]\s*[\)\.])\s*' | |
# Use re.split but keep the delimiter to know where choices start | |
parts = re.split(f'({pattern})', options_str) | |
choices = [] | |
current_choice = "" | |
# Iterate through parts: ['', 'a ) ', 'choice text , ', 'b ) ', 'choice text ...'] | |
for i, part in enumerate(parts): | |
part = part.strip() | |
if not part: continue | |
is_delimiter = re.fullmatch(pattern, part) | |
if is_delimiter and current_choice: | |
# We hit a new delimiter, so the previous current_choice is complete | |
choices.append(current_choice.strip().rstrip(',')) # Remove trailing comma if any | |
current_choice = "" # Reset for next choice | |
elif not is_delimiter: | |
# This is part of the choice text | |
current_choice += part + " " # Add space in case split removed it | |
# Add the last choice | |
if current_choice: | |
choices.append(current_choice.strip().rstrip(',')) | |
# Basic validation | |
if not choices or len(choices) < 2: | |
# Fallback: try simple comma split if regex fails badly | |
choices = [c.strip() for c in options_str.split(',') if c.strip()] | |
# Clean prefixes again after comma split | |
choices = [re.sub(r"^\s*([a-eA-E][\)\.]|[0-9]+\.?)\s*", "", c).strip() for c in choices] | |
if not choices or len(choices) < 2: | |
log_warning(dataset_name, f"Could not parse MathQA options string into multiple choices: '{options_str}'") | |
return None # Return None if parsing fails | |
return choices | |
def canon_from_map(ex: Dict[str, Any], cmap: Dict[str, Union[str, List[str]]], dataset_name: str) -> Optional[Dict[str, Any]]: | |
"""Canonicalises an example based on a column map. Returns None on failure.""" | |
try: | |
out = { | |
"context": str(_ci_get(ex, cmap.get("context", ""), "")), | |
"question": str(_ci_get(ex, cmap.get("question", ""), "")), | |
} | |
# --- Choices --- | |
choices = [] | |
if "choice_keys" in cmap: # Explicit list of keys for choices | |
keys = cmap["choice_keys"] | |
if not isinstance(keys, list): | |
log_warning(dataset_name, f"'choice_keys' must be a list in COLMAP. Skipping.") | |
return None | |
for key in keys: | |
choice = _ci_get(ex, key) | |
if choice is None: | |
log_info(dataset_name, f"Skipping example due to missing choice key '{key}'.") | |
return None | |
choices.append(choice) | |
elif "choices" in cmap: # Single key containing choices | |
choices_key = cmap["choices"] | |
raw_choices = _ci_get(ex, choices_key) | |
if raw_choices is None: | |
log_info(dataset_name, f"Skipping example due to missing choices key '{choices_key}'.") | |
return None | |
if isinstance(raw_choices, list): # Already a list (like race) | |
choices = raw_choices | |
elif isinstance(raw_choices, str): # String needing split | |
if dataset_name == "math_qa": | |
parsed = _parse_mathqa_options(raw_choices, dataset_name) | |
if parsed is None: return None # Skip if parsing failed | |
choices = parsed | |
else: # Generic string splitting (e.g., using '|') | |
split_delimiters = r'\s*\|\s*|\s*;\s*|\s*\n\s*' | |
parts = re.split(split_delimiters, raw_choices.strip()) | |
# Basic prefix cleaning for generic case | |
choices = [re.sub(r"^\s*([a-zA-Z0-9][\)\.]|[0-9]+\.?)\s*", "", p).strip() for p in parts if p.strip()] | |
else: | |
log_warning(dataset_name, f"Unexpected type for choices field '{choices_key}': {type(raw_choices)}. Skipping.") | |
return None | |
else: | |
log_warning(dataset_name, f"Column map must define 'choice_keys' or 'choices'. Skipping.") | |
return None | |
if not choices or len(choices) < 2: | |
log_warning(dataset_name, f"No valid choices extracted (need >= 2). Skipping.") | |
return None | |
out["choices"] = _ensure_string_choices(choices) # Ensure strings | |
# --- Label --- | |
label_key = cmap.get("label") | |
if not label_key: | |
log_warning(dataset_name, f"Column map missing 'label' key. Skipping.") | |
return None | |
label_val = _ci_get(ex, label_key) | |
if label_val is None: | |
log_info(dataset_name, f"Skipping example due to missing label key '{label_key}'.") | |
return None | |
label_int = -1 # Default invalid label | |
is_one_based_known = dataset_name in ["social_i_qa", "winogrande"] | |
if isinstance(label_val, str): | |
label_str = label_val.strip() | |
if label_str.isdigit(): | |
val_int = int(label_str) | |
if is_one_based_known and val_int >= 1: | |
label_int = val_int - 1 | |
else: # Assume 0-based if numeric string otherwise | |
label_int = val_int | |
elif len(label_str) == 1 and 'A' <= label_str.upper() <= 'Z': | |
# Handle 'a'/'A' = 0, 'b'/'B' = 1 etc. | |
label_int = ord(label_str.upper()) - ord('A') | |
else: | |
log_warning(dataset_name, f"Unhandled string label format: '{label_val}'. Skipping.") | |
return None | |
elif isinstance(label_val, (int, float)): # Already numeric | |
val_int = int(label_val) | |
if is_one_based_known and val_int >= 1: | |
label_int = val_int - 1 | |
else: # Assume 0-based if integer otherwise | |
label_int = val_int | |
else: | |
log_warning(dataset_name, f"Unexpected label type: {type(label_val)} for value '{label_val}'. Skipping.") | |
return None | |
# Validate label index is within bounds | |
if not (0 <= label_int < len(out["choices"])): | |
log_warning(dataset_name, f"Label index {label_int} out of bounds for choices length {len(out['choices'])}. Original label: '{label_val}'. Skipping.") | |
return None | |
out["label"] = label_int | |
out["source_dataset"] = dataset_name # Add source dataset here | |
return out | |
except Exception as e: | |
log_warning(dataset_name, f"Unhandled exception processing example. Error: {e}. Data: {ex}") | |
# traceback.print_exc(file=sys.stderr) # Optionally print full trace for debugging | |
return None # Skip on any unexpected error | |
# ───────────────────────────────────────────────────────────────────────────── | |
# 4) MAIN PROCESSING LOOP | |
# ───────────────────────────────────────────────────────────────────────────── | |
aggregated: Dict[int, List[Dataset]] = defaultdict(list) | |
skipped_datasets: Dict[str, str] = {} | |
processed_counts: Dict[str, int] = defaultdict(int) | |
skipped_examples: Dict[str, int] = defaultdict(int) | |
print("\n▶ loading + canonicalising\n") | |
warning_counters.clear() # Reset counters for the run | |
for name, load_args in tqdm(REGISTRY.items(), desc="Datasets"): | |
print(f"\nProcessing: {name}") | |
warning_counters[name] = 0 # Reset counter for this dataset | |
original_count = 0 | |
processed_ds = None | |
try: | |
ds = datasets.load_dataset(**load_args, split="train", trust_remote_code=True) | |
original_count = len(ds) | |
original_columns = ds.column_names # Get original columns *before* mapping | |
print(f" Loaded {name} ({original_count:,} examples)") | |
except Exception as e: | |
print(f" Skipping {name} - load_error: {e}", file=sys.stderr) | |
skipped_datasets[name] = f"load_error: {e}" | |
continue | |
try: | |
map_function = None | |
if name in SPECIAL: | |
print(f" Applying special canonicaliser for {name}...") | |
map_function = lambda ex: SPECIAL[name](ex, name) | |
elif name in COLMAP: | |
print(f" Applying generic canonicaliser for {name}...") | |
map_function = lambda ex: canon_from_map(ex, COLMAP[name], name) | |
else: | |
print(f" Skipping {name} - No canonicaliser found (add to SPECIAL or COLMAP)", file=sys.stderr) | |
skipped_datasets[name] = "no_canonicaliser" | |
continue | |
# --- Perform mapping --- | |
# remove_columns ensures only the dict returned by map_function remains | |
processed_ds = ds.map( | |
map_function, | |
remove_columns=original_columns, # Remove all original columns | |
desc=f"canon {name}", | |
load_from_cache_file=False, # Force re-processing | |
batched=False # Process example by example for skipping | |
) | |
# --- Filter out None examples (skipped by canonicalizers) --- | |
initial_processed_count = len(processed_ds) | |
processed_ds = processed_ds.filter(lambda ex: ex is not None) | |
final_processed_count = len(processed_ds) | |
num_skipped = initial_processed_count - final_processed_count | |
if num_skipped > 0: | |
print(f" Filtered out {num_skipped:,} invalid/skipped examples for {name}.") | |
skipped_examples[name] = num_skipped | |
# Check for empty datasets after filtering | |
if final_processed_count == 0: | |
print(f" Skipping {name} - resulted in empty dataset after processing and filtering.", file=sys.stderr) | |
skipped_datasets[name] = "empty_after_processing" | |
continue | |
# --- Cast features explicitly --- | |
# Now the dataset should *only* have the columns returned by the canonicalizer | |
# which should match TARGET_FEATURES | |
print(f" Casting features for {name}...") | |
processed_ds = processed_ds.cast(TARGET_FEATURES) | |
print(f" Cast successful for {name}.") | |
# Determine number of choices | |
# Use features if possible, fallback to first example | |
n_choices = -1 | |
if isinstance(processed_ds.features['choices'], Sequence): | |
# If length is fixed and known in features, use that | |
# seq_len = processed_ds.features['choices'].length | |
# if seq_len != -1: n_choices = seq_len | |
# else: n_choices = len(processed_ds[0]["choices"]) # Fallback if variable length | |
n_choices = len(processed_ds[0]["choices"]) # Simpler: just check first example after cast | |
else: | |
# This shouldn't happen after successful cast | |
print(f" Skipping {name} - 'choices' feature is not a Sequence after casting.", file=sys.stderr) | |
skipped_datasets[name] = "choices_not_sequence" | |
continue | |
if n_choices <= 1: | |
print(f" Skipping {name} - only {n_choices} choice(s) found after processing. Requires >= 2.", file=sys.stderr) | |
skipped_datasets[name] = f"too_few_choices ({n_choices})" | |
continue | |
aggregated[n_choices].append(processed_ds) | |
processed_counts[name] = final_processed_count | |
print(f"✓ {name:<14} → {n_choices}-choice, {final_processed_count:,} rows added (out of {original_count:,} original).") | |
except Exception as e: | |
print(f" Skipping {name} - unhandled map/filter/cast error: {e}", file=sys.stderr) | |
traceback.print_exc(file=sys.stderr) # Print full traceback | |
skipped_datasets[name] = f"map_filter_cast_error: {e}" | |
continue | |
# ───────────────────────────────────────────────────────────────────────────── | |
# 5) CONCATENATE + REPORT | |
# ───────────────────────────────────────────────────────────────────────────── | |
print("\n▶ concatenating datasets by choice count\n") | |
mega: Dict[int, Dataset] = {} | |
# Reset counters before final report | |
warning_counters.clear() | |
for k, v_list in aggregated.items(): | |
if v_list: | |
print(f" Concatenating {len(v_list)} dataset(s) for {k}-choice group...") | |
try: | |
# Ensure features are identical before concatenating (should be due to cast) | |
# first_features = v_list[0].features | |
# for ds in v_list[1:]: | |
# if ds.features != first_features: | |
# print(f"ERROR: Feature mismatch in {k}-choice group before concat!", file=sys.stderr) | |
# print(f" Expected: {first_features}", file=sys.stderr) | |
# print(f" Found: {ds.features} in dataset from {ds[0]['source_dataset']}", file=sys.stderr) | |
# raise ValueError("Feature mismatch detected before concatenation") | |
mega[k] = concatenate_datasets(v_list) | |
print(f" Concatenated {k}-choice group ({len(mega[k]):,} rows) successfully.") | |
except Exception as e: | |
print(f" Failed to concatenate {k}-choice group: {e}", file=sys.stderr) | |
skipped_datasets[f"concat_{k}_choice"] = f"concat_error: {e}" | |
traceback.print_exc(file=sys.stderr) | |
else: | |
print(f" Skipping empty group for {k}-choice.") | |
print("\n★ SUMMARY ★") | |
total_rows = 0 | |
print("\n--- Aggregated Datasets ---") | |
for k in sorted(mega.keys()): | |
ds = mega[k] | |
count = len(ds) | |
# Get unique sources contributing to this group | |
sources = set() | |
if count > 0: | |
try: | |
# Sample sources - might be slow for large datasets | |
sample_size = min(1000, count) | |
sources = set(ds.select(range(sample_size))['source_dataset']) | |
# If needed, get all sources (can be slow): | |
# sources = set(ds['source_dataset']) | |
except Exception as e: | |
print(f" (Could not retrieve sources for {k}-choice group: {e})") | |
num_datasets = len(sources) if sources else len(aggregated.get(k,[])) # Approx if sampling fails | |
print(f" {k}-choice : {count:>10,} rows (from ~{num_datasets} dataset(s))") | |
total_rows += count | |
print(f" ──────────────────────────") | |
print(f" Total : {total_rows:>10,} rows") | |
print("\n--- Processed Examples per Source ---") | |
for name in sorted(REGISTRY.keys()): | |
if name in processed_counts: | |
original = REGISTRY.get(name, {}).get('_original_count', 0) # Get original count if stored | |
skipped = skipped_examples.get(name, 0) | |
added = processed_counts[name] | |
print(f" {name:<20}: {added:>10,} rows added ({skipped:,} skipped)") | |
elif name not in skipped_datasets: | |
print(f" {name:<20}: {'0 rows (unexpected)':>10}") | |
if skipped_examples: | |
print("\n--- Skipped Examples Summary ---") | |
total_skipped = sum(skipped_examples.values()) | |
print(f" Total examples skipped across all datasets: {total_skipped:,}") | |
if skipped_datasets: | |
print("\n⚠ SKIPPED DATASETS / FATAL ERRORS ⚠") | |
for n, r in sorted(skipped_datasets.items()): | |
error_msg = r.split('\n')[0] | |
if len(error_msg) > 100: error_msg = error_msg[:97] + "..." | |
print(f" {n:<25}: {error_msg}") | |
# Optional: Save the datasets | |
# SAVE_PATH = "./aggregated_mcqa_v7" | |
# print(f"\n💾 Saving aggregated datasets to {SAVE_PATH}...") | |
# for k, ds in mega.items(): | |
# save_dir = f"{SAVE_PATH}/{k}_choice" | |
# print(f" Saving {k}-choice dataset ({len(ds):,} rows) to {save_dir}") | |
# ds.save_to_disk(save_dir) | |
print("\n📊 Demo samples") | |
for k in sorted(mega.keys()): | |
print(f"\n--- {k}-choice example ---") | |
try: | |
if len(mega[k]) > 0: | |
print(mega[k][0]) | |
else: | |
print(" (Dataset group is empty or failed concatenation)") | |
except Exception as e: | |
print(f" Error displaying sample for {k}-choice: {e}") | |
print("\n✅ Aggregation process finished.") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment