Skip to content

Instantly share code, notes, and snippets.

@pszemraj
Created April 18, 2025 19:20
Show Gist options
  • Save pszemraj/972ed730614139a68a062c86e3721308 to your computer and use it in GitHub Desktop.
Save pszemraj/972ed730614139a68a062c86e3721308 to your computer and use it in GitHub Desktop.
multiple‑choice dataset aggregator
#!/usr/bin/env python
"""
create_unified_mcqa.py – “batteries‑included” multiple‑choice aggregator
✅ Handles all datasets listed in the conversation
✅ Survives missing/renamed columns
✅ Converts every `label` to pure int64 to avoid ClassLabel clashes
✅ Explicitly casts features to ensure concatenation compatibility
✅ Improved error handling and skipping for malformed examples
✅ Limits warning/info messages per dataset
✅ Fixes column mismatch error during cast
✅ Improved math_qa choice parsing
"""
from collections import defaultdict
from typing import Dict, Any, Callable, List, Union, Optional
import re
import sys
import traceback # For detailed error logging
import datasets
from datasets import Dataset, concatenate_datasets, Features, Value, Sequence
from tqdm import tqdm
# ─────────────────────────────────────────────────────────────────────────────
# 0) CONFIGURATION & TARGET SCHEMA
# ─────────────────────────────────────────────────────────────────────────────
MAX_WARNINGS_PER_DATASET = 5 # Limit excessive logging for common issues
TARGET_FEATURES = Features({
'context': Value('string'),
'question': Value('string'),
'choices': Sequence(Value('string')),
'label': Value('int64'),
'source_dataset': Value('string'), # Added source dataset identifier
})
# Global warning counter
warning_counters = defaultdict(int)
def log_warning(dataset_name: str, message: str):
"""Logs a warning message, respecting the limit per dataset."""
if warning_counters[dataset_name] < MAX_WARNINGS_PER_DATASET:
print(f"Warning [{dataset_name}]: {message}", file=sys.stderr)
warning_counters[dataset_name] += 1
elif warning_counters[dataset_name] == MAX_WARNINGS_PER_DATASET:
print(f"Warning [{dataset_name}]: (Further warnings suppressed for this dataset)", file=sys.stderr)
warning_counters[dataset_name] += 1 # Increment once more to prevent re-printing this message
def log_info(dataset_name: str, message: str):
"""Logs an info message (e.g., skipping), respecting the limit per dataset."""
# Using the same counter as warnings for simplicity, could be separate
if warning_counters[dataset_name] < MAX_WARNINGS_PER_DATASET:
print(f"Info [{dataset_name}]: {message}", file=sys.stderr)
# Don't increment counter for info, or make it separate if needed
# No suppression message for info to avoid clutter
# ─────────────────────────────────────────────────────────────────────────────
# 1) REGISTRY
# ─────────────────────────────────────────────────────────────────────────────
REGISTRY: Dict[str, Dict[str, Any]] = {
"race": dict(path="race", name="all"),
"commonsense_qa": dict(path="tau/commonsense_qa"),
"sciq": dict(path="allenai/sciq"),
"math_qa": dict(path="allenai/math_qa"),
"swag": dict(path="allenai/swag", name="regular"),
"hellaswag": dict(path="rowan/hellaswag"),
"social_i_qa": dict(path="allenai/social_i_qa"),
"cosmos_qa": dict(path="allenai/cosmos_qa"),
"piqa": dict(path="ybisk/piqa"),
"winogrande": dict(path="allenai/winogrande", name="winogrande_l"),
"dream": dict(path="dataset-org/dream"),
"quail": dict(path="textmachinelab/quail"),
"medmcqa": dict(path="openlifescienceai/medmcqa"),
}
# ─────────────────────────────────────────────────────────────────────────────
# 2) SPECIAL CANONICALISERS
# Return None if an example is invalid/malformed and should be skipped.
# Now include 'source_dataset' in the returned dict.
# ─────────────────────────────────────────────────────────────────────────────
def _ensure_string_choices(choices: List[Any]) -> List[str]:
return [str(c) if c is not None else "" for c in choices]
# --- Canonicalizers updated to accept dataset_name and return source_dataset ---
def canon_commonsenseqa(ex: Dict[str, Any], dataset_name: str) -> Optional[Dict[str, Any]]:
try:
# Basic structure checks
choices_data = ex.get("choices")
if not choices_data or not isinstance(choices_data, dict):
log_warning(dataset_name, f"Missing or invalid 'choices' dict: {ex.get('id', 'N/A')}")
return None
labels = choices_data.get("label")
texts = choices_data.get("text")
answer_key = ex.get("answerKey")
if not labels or not texts or answer_key is None:
log_warning(dataset_name, f"Missing labels, texts, or answerKey in choices: {ex.get('id', 'N/A')}")
return None
if len(labels) != len(texts):
log_warning(dataset_name, f"Mismatch length label/text: {ex.get('id', 'N/A')}")
return None
if not labels: # Handle empty choices case
log_warning(dataset_name, f"Empty choices found: {ex.get('id', 'N/A')}")
return None
# Process choices and label
order = sorted(range(len(labels)), key=lambda i: labels[i])
choices_text = [texts[i] for i in order]
label = labels.index(answer_key) # Find 0-based index of the answer key
except (KeyError, ValueError, IndexError, TypeError) as e:
log_warning(dataset_name, f"Failed to process example {ex.get('id', 'N/A')}. Error: {e}")
return None
return {
"context": "",
"question": str(ex.get("question", "")),
"choices": _ensure_string_choices(choices_text),
"label": int(label),
"source_dataset": dataset_name,
}
def canon_sciq(ex: Dict[str, Any], dataset_name: str) -> Optional[Dict[str, Any]]:
try:
distractors = []
if "distractors" in ex and isinstance(ex["distractors"], list):
distractors = ex["distractors"]
elif "distractor1" in ex and "distractor2" in ex and "distractor3" in ex:
# Ensure they are not None or empty before adding
d1, d2, d3 = ex.get("distractor1"), ex.get("distractor2"), ex.get("distractor3")
if d1 is not None and d2 is not None and d3 is not None:
distractors = [d1, d2, d3]
else:
log_warning(dataset_name, f"Missing one or more distractorN keys: {ex.get('id', 'N/A')}")
return None
else:
log_warning(dataset_name, f"Could not find distractors list or keys: {ex.get('id', 'N/A')}")
return None
correct_answer = ex.get("correct_answer")
if correct_answer is None:
log_warning(dataset_name, f"Missing 'correct_answer': {ex.get('id', 'N/A')}")
return None
choices = [correct_answer, *distractors]
if len(choices) < 2: # Need at least two choices
log_warning(dataset_name, f"Too few choices ({len(choices)}) after processing: {ex.get('id', 'N/A')}")
return None
except (KeyError, TypeError) as e:
log_warning(dataset_name, f"Failed to process example {ex.get('id', 'N/A')}. Error: {e}")
return None
return {
"context": str(ex.get("support", "")),
"question": str(ex.get("question", "")),
"choices": _ensure_string_choices(choices),
"label": 0, # Correct answer is always first by construction here
"source_dataset": dataset_name,
}
def canon_hellaswag(ex: Dict[str, Any], dataset_name: str) -> Optional[Dict[str, Any]]:
try:
ctx_a = str(ex.get('ctx_a', '')).strip()
ctx_b = str(ex.get('ctx_b', '')).strip()
context = f"{ctx_a} {ctx_b}".strip()
# HellaSwag has 'endings' field which is a list
choices = ex.get("endings")
if choices is None or not isinstance(choices, list):
# Fallback to ending0, ending1... if 'endings' list is missing
choices = []
num_endings = 4 # Default assumption
for i in range(num_endings):
key = f"ending{i}"
ending = ex.get(key)
if ending is None:
log_info(dataset_name, f"Skipping example {ex.get('ind', 'N/A')} due to missing key '{key}'.")
return None
choices.append(ending)
if not choices or len(choices) < 2:
log_warning(dataset_name, f"No valid choices found for example {ex.get('ind', 'N/A')}")
return None
label = int(ex["label"]) # Label is already 0-based index
if not (0 <= label < len(choices)):
log_warning(dataset_name, f"Invalid label {label} for {len(choices)} choices: {ex.get('ind', 'N/A')}")
return None
except (KeyError, ValueError, TypeError) as e:
log_warning(dataset_name, f"Failed to process example {ex.get('ind', 'N/A')}. Error: {e}")
return None
return {
"context": context,
"question": str(ex.get("activity_label", "")),
"choices": _ensure_string_choices(choices),
"label": label,
"source_dataset": dataset_name,
}
def canon_piqa(ex: Dict[str, Any], dataset_name: str) -> Optional[Dict[str, Any]]:
try:
sol1 = ex.get("sol1")
sol2 = ex.get("sol2")
if sol1 is None or sol2 is None:
log_warning(dataset_name, f"Missing sol1/sol2: {ex.get('id', 'N/A')}")
return None
choices = [sol1, sol2]
label = int(ex["label"]) # Label is 0 or 1
if label not in [0, 1]:
log_warning(dataset_name, f"Invalid label {label}: {ex.get('id', 'N/A')}")
return None
except (KeyError, ValueError, TypeError) as e:
log_warning(dataset_name, f"Failed to process example {ex.get('id', 'N/A')}. Error: {e}")
return None
return {
"context": "",
"question": str(ex.get("goal", "")),
"choices": _ensure_string_choices(choices),
"label": label,
"source_dataset": dataset_name,
}
def canon_dream(ex: Dict[str, Any], dataset_name: str) -> Optional[Dict[str, Any]]:
try:
context = "\n".join(ex.get("dialogue", []))
question = str(ex.get("question", ""))
# Dataset uses 'choice' list, not option_X keys
choices = ex.get("choice")
if choices is None or not isinstance(choices, list) or len(choices) < 2:
log_warning(dataset_name, f"Missing or invalid 'choice' list: {ex.get('id', 'N/A')}")
return None
# Answer is in 'answer' field, seems to be the text of the correct choice
answer_text = ex.get("answer")
if answer_text is None:
log_warning(dataset_name, f"Missing 'answer' field: {ex.get('id', 'N/A')}")
return None
# Find the index of the answer text in the choices list
try:
# Need case-insensitive comparison potentially? Or exact match? Assume exact for now.
str_choices = _ensure_string_choices(choices)
label = str_choices.index(str(answer_text))
except ValueError:
log_warning(dataset_name, f"Answer text '{answer_text}' not found in choices {str_choices}: {ex.get('id', 'N/A')}")
return None
if not (0 <= label < len(choices)):
# This check should be redundant if index() worked, but keep as safeguard
log_warning(dataset_name, f"Invalid label {label} for {len(choices)} choices: {ex.get('id', 'N/A')}")
return None
except (KeyError, ValueError, TypeError, IndexError) as e:
log_warning(dataset_name, f"Failed to process example {ex.get('id', 'N/A')}. Error: {e}")
return None
return {
"context": context,
"question": question,
"choices": _ensure_string_choices(choices), # Ensure strings again just in case
"label": int(label),
"source_dataset": dataset_name,
}
def canon_quail(ex: Dict[str, Any], dataset_name: str) -> Optional[Dict[str, Any]]:
try:
choices = ex.get("answers")
if choices is None or not isinstance(choices, list) or len(choices) < 2: # Allow variable choices? Check dataset spec. Assume 4 usually.
log_warning(dataset_name, f"Missing or invalid 'answers' list: {ex.get('id', 'N/A')}")
return None
label = int(ex["correct_answer_id"]) # 0-based index
if not (0 <= label < len(choices)):
log_warning(dataset_name, f"Invalid label {label} for {len(choices)} choices: {ex.get('id', 'N/A')}")
return None
except (KeyError, ValueError, TypeError) as e:
log_warning(dataset_name, f"Failed to process example {ex.get('id', 'N/A')}. Error: {e}")
return None
return {
"context": str(ex.get("context", "")),
"question": str(ex.get("question", "")),
"choices": _ensure_string_choices(choices),
"label": label,
"source_dataset": dataset_name,
}
def canon_medmcqa(ex: Dict[str, Any], dataset_name: str) -> Optional[Dict[str, Any]]:
try:
choices = []
keys = ["opa", "opb", "opc", "opd"]
for k in keys:
choice = ex.get(k)
if choice is None:
log_info(dataset_name, f"Skipping example {ex.get('id', 'N/A')} due to missing key '{k}'.")
return None
choices.append(choice)
if len(choices) < 2:
log_warning(dataset_name, f"Too few choices found: {ex.get('id', 'N/A')}")
return None
cop_val = ex.get("cop") # Correct option
label = -1
if isinstance(cop_val, str) and cop_val.upper() in "ABCD":
label = ord(cop_val.upper()) - ord("A")
elif isinstance(cop_val, int) and 0 <= cop_val < len(choices): # Handles 0-based index
label = cop_val
elif isinstance(cop_val, str) and cop_val.isdigit(): # Handles "1", "2", etc. (assume 1-based)
val_int = int(cop_val)
if 1 <= val_int <= len(choices):
label = val_int - 1
else:
log_warning(dataset_name, f"Numeric string 'cop' value {cop_val} out of range [1,{len(choices)}]: {ex.get('id', 'N/A')}")
return None
else:
log_warning(dataset_name, f"Unexpected 'cop' value '{cop_val}' (type: {type(cop_val)}): {ex.get('id', 'N/A')}")
return None
if not (0 <= label < len(choices)):
log_warning(dataset_name, f"Final label index {label} out of bounds [0,{len(choices)-1}]: {ex.get('id', 'N/A')}")
return None
except (KeyError, ValueError, TypeError) as e:
log_warning(dataset_name, f"Failed to process example {ex.get('id', 'N/A')}. Error: {e}")
return None
return {
"context": str(ex.get("exp", "")), # Use 'exp' (explanation) as context
"question": str(ex.get("question", "")),
"choices": _ensure_string_choices(choices),
"label": int(label),
"source_dataset": dataset_name,
}
SPECIAL: Dict[str, Callable[[Dict[str, Any], str], Optional[Dict[str, Any]]]] = {
"commonsense_qa": canon_commonsenseqa,
"sciq": canon_sciq,
"hellaswag": canon_hellaswag,
"piqa": canon_piqa,
"dream": canon_dream,
"quail": canon_quail,
"medmcqa": canon_medmcqa,
}
# ─────────────────────────────────────────────────────────────────────────────
# 3) GENERIC CANONICALISER
# Return None if an example is invalid/malformed and should be skipped.
# Now include 'source_dataset' in the returned dict.
# ─────────────────────────────────────────────────────────────────────────────
COLMAP: Dict[str, Dict[str, Union[str, List[str]]]] = {
"race": {
"context": "article",
"question": "question",
"choices": "options", # List of strings
"label": "answer", # 'A', 'B', 'C', 'D'
},
"math_qa": {
"context": "Problem",
"question": "Rationale",
"choices": "options", # String like "a) text , b) text ..."
"label": "correct", # 'a', 'b', 'c', 'd', 'e'
},
"swag": {
"context": "sent1",
"question": "startphrase",
"choice_keys": ["ending0", "ending1", "ending2", "ending3"],
"label": "label", # 0, 1, 2, 3
},
"social_i_qa": {
"context": "context",
"question": "question",
"choice_keys": ["answerA", "answerB", "answerC"],
"label": "label", # 1, 2, 3 -> needs conversion to 0-based
},
"cosmos_qa": {
"context": "context",
"question": "question",
"choice_keys": ["answer0", "answer1", "answer2", "answer3"],
"label": "label", # 0, 1, 2, 3
},
"winogrande": {
"context": "",
"question": "sentence",
"choice_keys": ["option1", "option2"],
"label": "answer", # '1' or '2' -> needs conversion to 0-based
},
}
def _ci_get(obj: dict, key: str, default: Any = None) -> Any:
"""Case‑insensitive safe lookup."""
if not key: return default
# Try case-insensitive first
for k in obj:
if k.lower() == key.lower():
return obj[k]
# Fallback to exact key match if case-insensitive fails
if key in obj:
return obj[key]
return default
def _parse_mathqa_options(options_str: str, dataset_name: str) -> Optional[List[str]]:
"""Specific parser for MathQA 'options' string."""
if not isinstance(options_str, str):
log_warning(dataset_name, f"Expected string for MathQA options, got {type(options_str)}")
return None
# Split primarily by comma, as it seems the main separator between choices
# Need to be careful about commas *within* choices (e.g., "e) none of these")
# Let's try splitting by the pattern "letter )" or "letter ."
pattern = r'\s*([a-eA-E]\s*[\)\.])\s*'
# Use re.split but keep the delimiter to know where choices start
parts = re.split(f'({pattern})', options_str)
choices = []
current_choice = ""
# Iterate through parts: ['', 'a ) ', 'choice text , ', 'b ) ', 'choice text ...']
for i, part in enumerate(parts):
part = part.strip()
if not part: continue
is_delimiter = re.fullmatch(pattern, part)
if is_delimiter and current_choice:
# We hit a new delimiter, so the previous current_choice is complete
choices.append(current_choice.strip().rstrip(',')) # Remove trailing comma if any
current_choice = "" # Reset for next choice
elif not is_delimiter:
# This is part of the choice text
current_choice += part + " " # Add space in case split removed it
# Add the last choice
if current_choice:
choices.append(current_choice.strip().rstrip(','))
# Basic validation
if not choices or len(choices) < 2:
# Fallback: try simple comma split if regex fails badly
choices = [c.strip() for c in options_str.split(',') if c.strip()]
# Clean prefixes again after comma split
choices = [re.sub(r"^\s*([a-eA-E][\)\.]|[0-9]+\.?)\s*", "", c).strip() for c in choices]
if not choices or len(choices) < 2:
log_warning(dataset_name, f"Could not parse MathQA options string into multiple choices: '{options_str}'")
return None # Return None if parsing fails
return choices
def canon_from_map(ex: Dict[str, Any], cmap: Dict[str, Union[str, List[str]]], dataset_name: str) -> Optional[Dict[str, Any]]:
"""Canonicalises an example based on a column map. Returns None on failure."""
try:
out = {
"context": str(_ci_get(ex, cmap.get("context", ""), "")),
"question": str(_ci_get(ex, cmap.get("question", ""), "")),
}
# --- Choices ---
choices = []
if "choice_keys" in cmap: # Explicit list of keys for choices
keys = cmap["choice_keys"]
if not isinstance(keys, list):
log_warning(dataset_name, f"'choice_keys' must be a list in COLMAP. Skipping.")
return None
for key in keys:
choice = _ci_get(ex, key)
if choice is None:
log_info(dataset_name, f"Skipping example due to missing choice key '{key}'.")
return None
choices.append(choice)
elif "choices" in cmap: # Single key containing choices
choices_key = cmap["choices"]
raw_choices = _ci_get(ex, choices_key)
if raw_choices is None:
log_info(dataset_name, f"Skipping example due to missing choices key '{choices_key}'.")
return None
if isinstance(raw_choices, list): # Already a list (like race)
choices = raw_choices
elif isinstance(raw_choices, str): # String needing split
if dataset_name == "math_qa":
parsed = _parse_mathqa_options(raw_choices, dataset_name)
if parsed is None: return None # Skip if parsing failed
choices = parsed
else: # Generic string splitting (e.g., using '|')
split_delimiters = r'\s*\|\s*|\s*;\s*|\s*\n\s*'
parts = re.split(split_delimiters, raw_choices.strip())
# Basic prefix cleaning for generic case
choices = [re.sub(r"^\s*([a-zA-Z0-9][\)\.]|[0-9]+\.?)\s*", "", p).strip() for p in parts if p.strip()]
else:
log_warning(dataset_name, f"Unexpected type for choices field '{choices_key}': {type(raw_choices)}. Skipping.")
return None
else:
log_warning(dataset_name, f"Column map must define 'choice_keys' or 'choices'. Skipping.")
return None
if not choices or len(choices) < 2:
log_warning(dataset_name, f"No valid choices extracted (need >= 2). Skipping.")
return None
out["choices"] = _ensure_string_choices(choices) # Ensure strings
# --- Label ---
label_key = cmap.get("label")
if not label_key:
log_warning(dataset_name, f"Column map missing 'label' key. Skipping.")
return None
label_val = _ci_get(ex, label_key)
if label_val is None:
log_info(dataset_name, f"Skipping example due to missing label key '{label_key}'.")
return None
label_int = -1 # Default invalid label
is_one_based_known = dataset_name in ["social_i_qa", "winogrande"]
if isinstance(label_val, str):
label_str = label_val.strip()
if label_str.isdigit():
val_int = int(label_str)
if is_one_based_known and val_int >= 1:
label_int = val_int - 1
else: # Assume 0-based if numeric string otherwise
label_int = val_int
elif len(label_str) == 1 and 'A' <= label_str.upper() <= 'Z':
# Handle 'a'/'A' = 0, 'b'/'B' = 1 etc.
label_int = ord(label_str.upper()) - ord('A')
else:
log_warning(dataset_name, f"Unhandled string label format: '{label_val}'. Skipping.")
return None
elif isinstance(label_val, (int, float)): # Already numeric
val_int = int(label_val)
if is_one_based_known and val_int >= 1:
label_int = val_int - 1
else: # Assume 0-based if integer otherwise
label_int = val_int
else:
log_warning(dataset_name, f"Unexpected label type: {type(label_val)} for value '{label_val}'. Skipping.")
return None
# Validate label index is within bounds
if not (0 <= label_int < len(out["choices"])):
log_warning(dataset_name, f"Label index {label_int} out of bounds for choices length {len(out['choices'])}. Original label: '{label_val}'. Skipping.")
return None
out["label"] = label_int
out["source_dataset"] = dataset_name # Add source dataset here
return out
except Exception as e:
log_warning(dataset_name, f"Unhandled exception processing example. Error: {e}. Data: {ex}")
# traceback.print_exc(file=sys.stderr) # Optionally print full trace for debugging
return None # Skip on any unexpected error
# ─────────────────────────────────────────────────────────────────────────────
# 4) MAIN PROCESSING LOOP
# ─────────────────────────────────────────────────────────────────────────────
aggregated: Dict[int, List[Dataset]] = defaultdict(list)
skipped_datasets: Dict[str, str] = {}
processed_counts: Dict[str, int] = defaultdict(int)
skipped_examples: Dict[str, int] = defaultdict(int)
print("\n▶ loading + canonicalising\n")
warning_counters.clear() # Reset counters for the run
for name, load_args in tqdm(REGISTRY.items(), desc="Datasets"):
print(f"\nProcessing: {name}")
warning_counters[name] = 0 # Reset counter for this dataset
original_count = 0
processed_ds = None
try:
ds = datasets.load_dataset(**load_args, split="train", trust_remote_code=True)
original_count = len(ds)
original_columns = ds.column_names # Get original columns *before* mapping
print(f" Loaded {name} ({original_count:,} examples)")
except Exception as e:
print(f" Skipping {name} - load_error: {e}", file=sys.stderr)
skipped_datasets[name] = f"load_error: {e}"
continue
try:
map_function = None
if name in SPECIAL:
print(f" Applying special canonicaliser for {name}...")
map_function = lambda ex: SPECIAL[name](ex, name)
elif name in COLMAP:
print(f" Applying generic canonicaliser for {name}...")
map_function = lambda ex: canon_from_map(ex, COLMAP[name], name)
else:
print(f" Skipping {name} - No canonicaliser found (add to SPECIAL or COLMAP)", file=sys.stderr)
skipped_datasets[name] = "no_canonicaliser"
continue
# --- Perform mapping ---
# remove_columns ensures only the dict returned by map_function remains
processed_ds = ds.map(
map_function,
remove_columns=original_columns, # Remove all original columns
desc=f"canon {name}",
load_from_cache_file=False, # Force re-processing
batched=False # Process example by example for skipping
)
# --- Filter out None examples (skipped by canonicalizers) ---
initial_processed_count = len(processed_ds)
processed_ds = processed_ds.filter(lambda ex: ex is not None)
final_processed_count = len(processed_ds)
num_skipped = initial_processed_count - final_processed_count
if num_skipped > 0:
print(f" Filtered out {num_skipped:,} invalid/skipped examples for {name}.")
skipped_examples[name] = num_skipped
# Check for empty datasets after filtering
if final_processed_count == 0:
print(f" Skipping {name} - resulted in empty dataset after processing and filtering.", file=sys.stderr)
skipped_datasets[name] = "empty_after_processing"
continue
# --- Cast features explicitly ---
# Now the dataset should *only* have the columns returned by the canonicalizer
# which should match TARGET_FEATURES
print(f" Casting features for {name}...")
processed_ds = processed_ds.cast(TARGET_FEATURES)
print(f" Cast successful for {name}.")
# Determine number of choices
# Use features if possible, fallback to first example
n_choices = -1
if isinstance(processed_ds.features['choices'], Sequence):
# If length is fixed and known in features, use that
# seq_len = processed_ds.features['choices'].length
# if seq_len != -1: n_choices = seq_len
# else: n_choices = len(processed_ds[0]["choices"]) # Fallback if variable length
n_choices = len(processed_ds[0]["choices"]) # Simpler: just check first example after cast
else:
# This shouldn't happen after successful cast
print(f" Skipping {name} - 'choices' feature is not a Sequence after casting.", file=sys.stderr)
skipped_datasets[name] = "choices_not_sequence"
continue
if n_choices <= 1:
print(f" Skipping {name} - only {n_choices} choice(s) found after processing. Requires >= 2.", file=sys.stderr)
skipped_datasets[name] = f"too_few_choices ({n_choices})"
continue
aggregated[n_choices].append(processed_ds)
processed_counts[name] = final_processed_count
print(f"✓ {name:<14} → {n_choices}-choice, {final_processed_count:,} rows added (out of {original_count:,} original).")
except Exception as e:
print(f" Skipping {name} - unhandled map/filter/cast error: {e}", file=sys.stderr)
traceback.print_exc(file=sys.stderr) # Print full traceback
skipped_datasets[name] = f"map_filter_cast_error: {e}"
continue
# ─────────────────────────────────────────────────────────────────────────────
# 5) CONCATENATE + REPORT
# ─────────────────────────────────────────────────────────────────────────────
print("\n▶ concatenating datasets by choice count\n")
mega: Dict[int, Dataset] = {}
# Reset counters before final report
warning_counters.clear()
for k, v_list in aggregated.items():
if v_list:
print(f" Concatenating {len(v_list)} dataset(s) for {k}-choice group...")
try:
# Ensure features are identical before concatenating (should be due to cast)
# first_features = v_list[0].features
# for ds in v_list[1:]:
# if ds.features != first_features:
# print(f"ERROR: Feature mismatch in {k}-choice group before concat!", file=sys.stderr)
# print(f" Expected: {first_features}", file=sys.stderr)
# print(f" Found: {ds.features} in dataset from {ds[0]['source_dataset']}", file=sys.stderr)
# raise ValueError("Feature mismatch detected before concatenation")
mega[k] = concatenate_datasets(v_list)
print(f" Concatenated {k}-choice group ({len(mega[k]):,} rows) successfully.")
except Exception as e:
print(f" Failed to concatenate {k}-choice group: {e}", file=sys.stderr)
skipped_datasets[f"concat_{k}_choice"] = f"concat_error: {e}"
traceback.print_exc(file=sys.stderr)
else:
print(f" Skipping empty group for {k}-choice.")
print("\n★ SUMMARY ★")
total_rows = 0
print("\n--- Aggregated Datasets ---")
for k in sorted(mega.keys()):
ds = mega[k]
count = len(ds)
# Get unique sources contributing to this group
sources = set()
if count > 0:
try:
# Sample sources - might be slow for large datasets
sample_size = min(1000, count)
sources = set(ds.select(range(sample_size))['source_dataset'])
# If needed, get all sources (can be slow):
# sources = set(ds['source_dataset'])
except Exception as e:
print(f" (Could not retrieve sources for {k}-choice group: {e})")
num_datasets = len(sources) if sources else len(aggregated.get(k,[])) # Approx if sampling fails
print(f" {k}-choice : {count:>10,} rows (from ~{num_datasets} dataset(s))")
total_rows += count
print(f" ──────────────────────────")
print(f" Total : {total_rows:>10,} rows")
print("\n--- Processed Examples per Source ---")
for name in sorted(REGISTRY.keys()):
if name in processed_counts:
original = REGISTRY.get(name, {}).get('_original_count', 0) # Get original count if stored
skipped = skipped_examples.get(name, 0)
added = processed_counts[name]
print(f" {name:<20}: {added:>10,} rows added ({skipped:,} skipped)")
elif name not in skipped_datasets:
print(f" {name:<20}: {'0 rows (unexpected)':>10}")
if skipped_examples:
print("\n--- Skipped Examples Summary ---")
total_skipped = sum(skipped_examples.values())
print(f" Total examples skipped across all datasets: {total_skipped:,}")
if skipped_datasets:
print("\n⚠ SKIPPED DATASETS / FATAL ERRORS ⚠")
for n, r in sorted(skipped_datasets.items()):
error_msg = r.split('\n')[0]
if len(error_msg) > 100: error_msg = error_msg[:97] + "..."
print(f" {n:<25}: {error_msg}")
# Optional: Save the datasets
# SAVE_PATH = "./aggregated_mcqa_v7"
# print(f"\n💾 Saving aggregated datasets to {SAVE_PATH}...")
# for k, ds in mega.items():
# save_dir = f"{SAVE_PATH}/{k}_choice"
# print(f" Saving {k}-choice dataset ({len(ds):,} rows) to {save_dir}")
# ds.save_to_disk(save_dir)
print("\n📊 Demo samples")
for k in sorted(mega.keys()):
print(f"\n--- {k}-choice example ---")
try:
if len(mega[k]) > 0:
print(mega[k][0])
else:
print(" (Dataset group is empty or failed concatenation)")
except Exception as e:
print(f" Error displaying sample for {k}-choice: {e}")
print("\n✅ Aggregation process finished.")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment