Last active
April 27, 2025 21:25
-
-
Save pszemraj/564c895273c8ffb13a3dfea34a862fcc to your computer and use it in GitHub Desktop.
load zyda 2 with streaming
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Dict, List, Optional | |
import datasets | |
# Optional: Keep the version print outside the function if desired | |
# print(f"Using datasets library version: {datasets.__version__}") | |
def create_interleaved_streaming_dataset( | |
dataset_path: str = "Zyphra/Zyda-2", | |
component_configs: Optional[Dict[str, str]] = None, | |
component_order: Optional[List[str]] = None, | |
interleave_weights: Optional[List[float]] = None, | |
common_columns: List[str] = ["nemo_id", "text"], | |
split: str = "train", | |
stopping_strategy: str = "all_exhausted", | |
verbose: bool = True, | |
) -> datasets.IterableDataset: | |
""" | |
Loads multiple dataset components in streaming mode, selects common columns, | |
and interleaves them according to specified weights. | |
Args: | |
dataset_path: The Hugging Face Hub path for the dataset repository | |
(e.g., "Zyphra/Zyda-2"). | |
component_configs: A dictionary mapping a short identifier (key) to the | |
specific dataset configuration 'name' (value). | |
Defaults to Zyda-2 standard components if None. | |
component_order: A list of the short identifiers (keys from | |
component_configs) specifying the exact order for loading | |
and applying interleave_weights. Must match the order | |
of weights. Defaults to Zyda-2 standard order if None. | |
interleave_weights: A list of probabilities for interleaving. Must sum | |
close to 1.0 and match the order specified in | |
`component_order`. Defaults to Zyda-2 standard weights | |
if None. | |
common_columns: A list of column names to keep from each component dataset. | |
split: The dataset split to load (e.g., "train"). | |
stopping_strategy: The strategy for stopping interleaving | |
(e.g., "all_exhausted", "first_exhausted"). | |
verbose: If True, print progress messages. | |
Returns: | |
An interleaved datasets.IterableDataset stream. | |
Raises: | |
ValueError: If component_order and interleave_weights have different lengths | |
or if component_configs/component_order/interleave_weights | |
are not provided together when defaults are not used. | |
KeyError: If a key in component_order is not found in component_configs. | |
""" | |
# --- Default Configuration (if not provided) --- | |
if component_configs is None: | |
component_configs = { | |
"dclm": "dclm_crossdeduped", | |
"zyda": "zyda_crossdeduped-filtered", | |
"dolma": "dolma-cc_crossdeduped-filtered", | |
"fwe": "fwe3", | |
} | |
if component_order is None: | |
# This order MUST match the default interleave_weights | |
component_order = ["dclm", "zyda", "dolma", "fwe"] | |
if interleave_weights is None: | |
# Default weights corresponding to the default component_order | |
# DCLM: 4.0, Zyda: 0.16, Dolma-CC: 0.24, FWE3: 4.0 -> Normalized | |
interleave_weights = [0.4038, 0.0316, 0.0585, 0.5061] # Sum ~ 1.0 | |
# --- Validation --- | |
if len(component_order) != len(interleave_weights): | |
raise ValueError( | |
f"Mismatch between component_order (len={len(component_order)}) " | |
f"and interleave_weights (len={len(interleave_weights)})." | |
) | |
if not all(key in component_configs for key in component_order): | |
missing_keys = [key for key in component_order if key not in component_configs] | |
raise KeyError( | |
f"Keys from component_order not found in component_configs: {missing_keys}" | |
) | |
# --- Load and Prepare Components --- | |
if verbose: | |
print(f"Loading {len(component_order)} components with streaming enabled...") | |
streamed_components = [] | |
for i, component_key in enumerate(component_order): | |
config_name = component_configs[component_key] | |
if verbose: | |
print(f" Loading component '{component_key}' (config: '{config_name}')...") | |
# 1. Load individual component with streaming | |
ds = datasets.load_dataset( | |
dataset_path, name=config_name, split=split, streaming=True | |
) | |
# 2. Select only the common columns | |
try: | |
# Attempt select_columns first (more explicit) | |
prepared_ds = ds.select_columns(common_columns) | |
except (AttributeError, TypeError, ValueError): | |
# Fallback or alternative if select_columns is problematic | |
if verbose: | |
print( | |
f" Warning: select_columns failed for component '{component_key}'. " | |
"Attempting remove_columns." | |
) | |
columns_to_remove = [ | |
col for col in ds.column_names if col not in common_columns | |
] | |
if columns_to_remove: | |
prepared_ds = ds.remove_columns(columns_to_remove) | |
else: | |
# If all columns are already the common ones, no change needed | |
prepared_ds = ds | |
streamed_components.append(prepared_ds) | |
if verbose: | |
print( | |
f"\nSelecting common columns: {common_columns}" | |
) # Already done above, just confirming | |
print("Interleaving datasets with specified probabilities...") | |
print(f" Order: {component_order}") | |
print(f" Weights: {interleave_weights}") | |
# --- Interleave Datasets --- | |
final_streamed_ds = datasets.interleave_datasets( | |
streamed_components, | |
probabilities=interleave_weights, | |
stopping_strategy=stopping_strategy, | |
) | |
if verbose: | |
print("\nSuccessfully created the streamed interleaved dataset.") | |
return final_streamed_ds | |
# --- Example Usage --- | |
if __name__ == "__main__": | |
print(f"Using datasets library version: {datasets.__version__}") | |
# Create the dataset using the function with default parameters | |
streamed_ds = create_interleaved_streaming_dataset(verbose=True) | |
print("\n--- Example Usage ---") | |
print("Taking one example entry from the final stream:") | |
# Get the first item from the iterator | |
try: | |
first_entry = next(iter(streamed_ds)) | |
print(first_entry) | |
except StopIteration: | |
print("The dataset stream is empty.") | |
except Exception as e: | |
print(f"An error occurred while fetching the first entry: {e}") | |
# Example of how you might use it (conceptual) | |
# print("\nSimulating iteration (first 5 entries):") | |
# count = 0 | |
# for entry in streamed_ds: | |
# print(f"Entry {count+1}: { {k: str(v)[:50] + '...' if isinstance(v, str) and len(v) > 50 else v for k,v in entry.items()} }") # Print truncated text | |
# count += 1 | |
# if count >= 5: | |
# break | |
# You can customize parameters: | |
# custom_stream = create_interleaved_streaming_dataset( | |
# component_order=["dclm", "fwe"], # Only use DCLM and FWE | |
# interleave_weights=[0.5, 0.5], # Equal weighting (adjust component_configs if needed) | |
# common_columns=["text"], # Only keep text | |
# verbose=False # Less output | |
# ) | |
# print("\nCreated custom stream (first entry):") | |
# print(next(iter(custom_stream))) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import math | |
from typing import Any, Dict, List, Optional | |
import datasets | |
def load_zyda2_with_optional_dataset( | |
additional_dataset_config: Optional[Dict[str, Any]] = None, | |
required_columns: List[str] = ["nemo_id", "text"], | |
) -> datasets.IterableDataset: | |
""" | |
Loads the standard Zyda-2 dataset components using streaming and interleaving. | |
Optionally loads and adds one additional dataset to the mix. | |
The standard Zyda-2 components (dclm, zyda, dolma, fwe) are always loaded. | |
Args: | |
additional_dataset_config: An optional dictionary defining the additional | |
dataset. If provided, must contain: | |
- 'hf_name': Hugging Face dataset name. | |
- 'config_name': Dataset configuration name. | |
- 'split': Dataset split (default 'train'). | |
- 'weight': Positive relative weight for this dataset. | |
- 'id': A unique identifier string (optional, for logging). | |
required_columns: A list of column names that MUST be present in the final | |
output dataset. All loaded datasets (including the optional | |
one) must contain these columns. Defaults to ['nemo_id', 'text']. | |
Returns: | |
An IterableDataset containing the interleaved data streams. | |
Raises: | |
ValueError: If additional_dataset_config is invalid, weights are non-positive, | |
or if any dataset (including optional) lacks the required_columns. | |
ImportError: If 'datasets' library isn't installed. | |
Exception: Can re-raise exceptions from datasets.load_dataset. | |
""" | |
if not required_columns: | |
raise ValueError("required_columns list cannot be empty.") | |
# --- Standard Zyda-2 Configuration --- | |
zyda2_base_configs = [ | |
{ | |
"id": "dclm", | |
"hf_name": "Zyphra/Zyda-2", | |
"config_name": "dclm_crossdeduped", | |
"split": "train", | |
"weight": 0.4038, | |
}, # Using pre-normalized doc weights directly | |
{ | |
"id": "zyda", | |
"hf_name": "Zyphra/Zyda-2", | |
"config_name": "zyda_crossdeduped-filtered", | |
"split": "train", | |
"weight": 0.0316, | |
}, | |
{ | |
"id": "dolma", | |
"hf_name": "Zyphra/Zyda-2", | |
"config_name": "dolma-cc_crossdeduped-filtered", | |
"split": "train", | |
"weight": 0.0585, | |
}, | |
{ | |
"id": "fwe", | |
"hf_name": "Zyphra/Zyda-2", | |
"config_name": "fwe3", | |
"split": "train", | |
"weight": 0.5061, | |
}, | |
] | |
# Verify base weights sum to 1 initially (they are probabilities here) | |
assert math.isclose( | |
sum(c["weight"] for c in zyda2_base_configs), 1.0 | |
), "Internal Zyda-2 weights don't sum to 1." | |
all_configs_to_load = list(zyda2_base_configs) # Start with base Zyda-2 | |
final_dataset_ids = [c["id"] for c in zyda2_base_configs] | |
final_weights = [ | |
c["weight"] for c in zyda2_base_configs | |
] # These are initially probabilities | |
# --- Handle Optional Additional Dataset --- | |
if additional_dataset_config: | |
print("\n--- Processing Additional Dataset ---") | |
# Validate config | |
req_keys = ["hf_name", "config_name", "weight"] | |
if not all(key in additional_dataset_config for key in req_keys): | |
raise ValueError( | |
f"additional_dataset_config missing required keys: {req_keys}. Got: {additional_dataset_config}" | |
) | |
add_weight = additional_dataset_config["weight"] | |
if not isinstance(add_weight, (int, float)) or add_weight <= 0: | |
raise ValueError( | |
f"Weight for additional dataset must be positive. Got: {add_weight}" | |
) | |
# Add config to the list | |
add_config = { | |
"id": additional_dataset_config.get( | |
"id", additional_dataset_config["hf_name"] | |
), # Use hf_name if id missing | |
"hf_name": additional_dataset_config["hf_name"], | |
"config_name": additional_dataset_config["config_name"], | |
"split": additional_dataset_config.get( | |
"split", "train" | |
), # Default split to 'train' | |
"weight": add_weight, # Store the relative weight | |
} | |
all_configs_to_load.append(add_config) | |
final_dataset_ids.append(add_config["id"]) | |
# We will normalize weights later | |
# Adjust weights: Treat original Zyda-2 weights as relative parts summing to 1, | |
# and the new weight as relative to that sum. | |
# Example: If new weight is 0.1, total relative weight is 1.0 (Zyda) + 0.1 (New) = 1.1 | |
# New probabilities will be [original_prob / 1.1, ..., new_weight / 1.1] | |
# Let's use the raw relative weights for calculation before normalization | |
zyda_relative_sum = 1.0 # The base Zyda probabilities sum to 1, treat this as the relative weight sum | |
new_relative_weight = add_config["weight"] | |
total_relative_weight = zyda_relative_sum + new_relative_weight | |
# Recalculate weights (now storing relative weights before final normalization) | |
final_weights = [ | |
c["weight"] for c in zyda2_base_configs | |
] # Start with original probabilities as relative weights | |
final_weights.append(new_relative_weight) # Add the new relative weight | |
print( | |
f"Added dataset '{add_config['id']}' with relative weight {new_relative_weight}." | |
) | |
print(f"Total relative weight: {total_relative_weight}") | |
else: | |
print("\n--- Loading Standard Zyda-2 Components Only ---") | |
total_relative_weight = sum(final_weights) # Should be 1.0 | |
# --- Load, Verify, and Prepare All Datasets --- | |
loaded_datasets = [] | |
print("\n--- Loading and Preparing Datasets ---") | |
for config in all_configs_to_load: | |
dataset_id = config["id"] | |
hf_name = config["hf_name"] | |
config_name = config["config_name"] | |
split = config["split"] | |
print(f"Loading '{dataset_id}' ({hf_name} / {config_name} / {split})...") | |
try: | |
# Load dataset | |
ds = datasets.load_dataset( | |
hf_name, name=config_name, split=split, streaming=True | |
) | |
# Verify and Select Columns | |
print( | |
f" Verifying and selecting columns {required_columns} for '{dataset_id}'..." | |
) | |
try: | |
existing_columns = list(ds.features.keys()) | |
except Exception as e: | |
print( | |
f" Warning: Could not reliably get features for {dataset_id}. Error: {e}" | |
) | |
existing_columns = required_columns # Optimistic assumption | |
missing_columns = [ | |
col for col in required_columns if col not in existing_columns | |
] | |
if missing_columns: | |
raise ValueError( | |
f"Dataset '{dataset_id}' ({hf_name}/{config_name}) is missing required columns: {missing_columns}. " | |
f"Available columns: {existing_columns}" | |
) | |
# Select/Remove columns | |
cols_to_remove = [ | |
col for col in existing_columns if col not in required_columns | |
] | |
if cols_to_remove: | |
print(f" Removing columns: {cols_to_remove}") | |
processed_ds = ds.remove_columns(cols_to_remove) | |
else: | |
processed_ds = ds | |
loaded_datasets.append(processed_ds) | |
final_cols = ( | |
list(processed_ds.features.keys()) | |
if processed_ds.features | |
else required_columns | |
) | |
print(f" Successfully prepared '{dataset_id}'. Kept columns: {final_cols}") | |
except Exception as e: | |
print( | |
f"Error loading or processing dataset '{dataset_id}' ({hf_name}/{config_name}): {e}" | |
) | |
raise | |
# --- Normalize Final Weights to Probabilities --- | |
if total_relative_weight <= 0: | |
raise ValueError("Total relative weight must be positive.") | |
probabilities = [w / total_relative_weight for w in final_weights] | |
if not math.isclose(sum(probabilities), 1.0): | |
print( | |
f"Warning: Final probabilities sum to {sum(probabilities):.4f}, not 1.0. Check weights." | |
) | |
# --- Interleave Datasets --- | |
print("\n--- Interleaving Datasets ---") | |
print(f"Dataset order for interleaving: {final_dataset_ids}") | |
print(f"Final probabilities: {[f'{p:.4f}' for p in probabilities]}") | |
interleaved_ds = datasets.interleave_datasets( | |
loaded_datasets, probabilities=probabilities, stopping_strategy="all_exhausted" | |
) | |
print("\nSuccessfully created interleaved streaming dataset.") | |
return interleaved_ds | |
# --- Example Usage --- | |
if __name__ == "__main__": | |
# --- Example 1: Load only standard Zyda-2 --- | |
print("\n--- EXAMPLE 1: Loading Standard Zyda-2 ---") | |
try: | |
# Specify the columns needed from the standard Zyda-2 mix | |
std_columns = ["nemo_id", "text"] | |
zyda2_standard_stream = load_zyda2_with_optional_dataset( | |
required_columns=std_columns | |
) | |
print("\nStandard Zyda-2 stream created. Fetching first item:") | |
first_item_std = next(iter(zyda2_standard_stream)) | |
print(first_item_std) | |
except Exception as e: | |
print(f"\nError in Example 1: {e}") | |
# --- Example 2: Load Zyda-2 + smollm-corpus-python --- | |
print("\n\n--- EXAMPLE 2: Loading Zyda-2 + SmolLM Python ---") | |
try: | |
# Define the additional dataset configuration | |
smollm_config = { | |
"id": "smollm_py", | |
"hf_name": "BEE-spoke-data/smollm-corpus-python", | |
"config_name": "default", | |
"split": "train", | |
"weight": 0.2, # Give it a relative weight (e.g., 20% compared to Zyda-2's total weight of 1.0) | |
} | |
# Since smollm only has 'text' in common with Zyda-2's 'nemo_id' and 'text', | |
# we MUST specify 'text' as the required column for the combined dataset. | |
combined_columns = ["text"] | |
zyda2_plus_smollm_stream = load_zyda2_with_optional_dataset( | |
additional_dataset_config=smollm_config, required_columns=combined_columns | |
) | |
print("\nZyda-2 + SmolLM stream created. Fetching first item:") | |
first_item_combined = next(iter(zyda2_plus_smollm_stream)) | |
print(first_item_combined) | |
except Exception as e: | |
print(f"\nError in Example 2: {e}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment