Skip to content

Instantly share code, notes, and snippets.

@pszemraj
Last active April 27, 2025 21:25
Show Gist options
  • Save pszemraj/564c895273c8ffb13a3dfea34a862fcc to your computer and use it in GitHub Desktop.
Save pszemraj/564c895273c8ffb13a3dfea34a862fcc to your computer and use it in GitHub Desktop.
load zyda 2 with streaming
from typing import Dict, List, Optional
import datasets
# Optional: Keep the version print outside the function if desired
# print(f"Using datasets library version: {datasets.__version__}")
def create_interleaved_streaming_dataset(
dataset_path: str = "Zyphra/Zyda-2",
component_configs: Optional[Dict[str, str]] = None,
component_order: Optional[List[str]] = None,
interleave_weights: Optional[List[float]] = None,
common_columns: List[str] = ["nemo_id", "text"],
split: str = "train",
stopping_strategy: str = "all_exhausted",
verbose: bool = True,
) -> datasets.IterableDataset:
"""
Loads multiple dataset components in streaming mode, selects common columns,
and interleaves them according to specified weights.
Args:
dataset_path: The Hugging Face Hub path for the dataset repository
(e.g., "Zyphra/Zyda-2").
component_configs: A dictionary mapping a short identifier (key) to the
specific dataset configuration 'name' (value).
Defaults to Zyda-2 standard components if None.
component_order: A list of the short identifiers (keys from
component_configs) specifying the exact order for loading
and applying interleave_weights. Must match the order
of weights. Defaults to Zyda-2 standard order if None.
interleave_weights: A list of probabilities for interleaving. Must sum
close to 1.0 and match the order specified in
`component_order`. Defaults to Zyda-2 standard weights
if None.
common_columns: A list of column names to keep from each component dataset.
split: The dataset split to load (e.g., "train").
stopping_strategy: The strategy for stopping interleaving
(e.g., "all_exhausted", "first_exhausted").
verbose: If True, print progress messages.
Returns:
An interleaved datasets.IterableDataset stream.
Raises:
ValueError: If component_order and interleave_weights have different lengths
or if component_configs/component_order/interleave_weights
are not provided together when defaults are not used.
KeyError: If a key in component_order is not found in component_configs.
"""
# --- Default Configuration (if not provided) ---
if component_configs is None:
component_configs = {
"dclm": "dclm_crossdeduped",
"zyda": "zyda_crossdeduped-filtered",
"dolma": "dolma-cc_crossdeduped-filtered",
"fwe": "fwe3",
}
if component_order is None:
# This order MUST match the default interleave_weights
component_order = ["dclm", "zyda", "dolma", "fwe"]
if interleave_weights is None:
# Default weights corresponding to the default component_order
# DCLM: 4.0, Zyda: 0.16, Dolma-CC: 0.24, FWE3: 4.0 -> Normalized
interleave_weights = [0.4038, 0.0316, 0.0585, 0.5061] # Sum ~ 1.0
# --- Validation ---
if len(component_order) != len(interleave_weights):
raise ValueError(
f"Mismatch between component_order (len={len(component_order)}) "
f"and interleave_weights (len={len(interleave_weights)})."
)
if not all(key in component_configs for key in component_order):
missing_keys = [key for key in component_order if key not in component_configs]
raise KeyError(
f"Keys from component_order not found in component_configs: {missing_keys}"
)
# --- Load and Prepare Components ---
if verbose:
print(f"Loading {len(component_order)} components with streaming enabled...")
streamed_components = []
for i, component_key in enumerate(component_order):
config_name = component_configs[component_key]
if verbose:
print(f" Loading component '{component_key}' (config: '{config_name}')...")
# 1. Load individual component with streaming
ds = datasets.load_dataset(
dataset_path, name=config_name, split=split, streaming=True
)
# 2. Select only the common columns
try:
# Attempt select_columns first (more explicit)
prepared_ds = ds.select_columns(common_columns)
except (AttributeError, TypeError, ValueError):
# Fallback or alternative if select_columns is problematic
if verbose:
print(
f" Warning: select_columns failed for component '{component_key}'. "
"Attempting remove_columns."
)
columns_to_remove = [
col for col in ds.column_names if col not in common_columns
]
if columns_to_remove:
prepared_ds = ds.remove_columns(columns_to_remove)
else:
# If all columns are already the common ones, no change needed
prepared_ds = ds
streamed_components.append(prepared_ds)
if verbose:
print(
f"\nSelecting common columns: {common_columns}"
) # Already done above, just confirming
print("Interleaving datasets with specified probabilities...")
print(f" Order: {component_order}")
print(f" Weights: {interleave_weights}")
# --- Interleave Datasets ---
final_streamed_ds = datasets.interleave_datasets(
streamed_components,
probabilities=interleave_weights,
stopping_strategy=stopping_strategy,
)
if verbose:
print("\nSuccessfully created the streamed interleaved dataset.")
return final_streamed_ds
# --- Example Usage ---
if __name__ == "__main__":
print(f"Using datasets library version: {datasets.__version__}")
# Create the dataset using the function with default parameters
streamed_ds = create_interleaved_streaming_dataset(verbose=True)
print("\n--- Example Usage ---")
print("Taking one example entry from the final stream:")
# Get the first item from the iterator
try:
first_entry = next(iter(streamed_ds))
print(first_entry)
except StopIteration:
print("The dataset stream is empty.")
except Exception as e:
print(f"An error occurred while fetching the first entry: {e}")
# Example of how you might use it (conceptual)
# print("\nSimulating iteration (first 5 entries):")
# count = 0
# for entry in streamed_ds:
# print(f"Entry {count+1}: { {k: str(v)[:50] + '...' if isinstance(v, str) and len(v) > 50 else v for k,v in entry.items()} }") # Print truncated text
# count += 1
# if count >= 5:
# break
# You can customize parameters:
# custom_stream = create_interleaved_streaming_dataset(
# component_order=["dclm", "fwe"], # Only use DCLM and FWE
# interleave_weights=[0.5, 0.5], # Equal weighting (adjust component_configs if needed)
# common_columns=["text"], # Only keep text
# verbose=False # Less output
# )
# print("\nCreated custom stream (first entry):")
# print(next(iter(custom_stream)))
import math
from typing import Any, Dict, List, Optional
import datasets
def load_zyda2_with_optional_dataset(
additional_dataset_config: Optional[Dict[str, Any]] = None,
required_columns: List[str] = ["nemo_id", "text"],
) -> datasets.IterableDataset:
"""
Loads the standard Zyda-2 dataset components using streaming and interleaving.
Optionally loads and adds one additional dataset to the mix.
The standard Zyda-2 components (dclm, zyda, dolma, fwe) are always loaded.
Args:
additional_dataset_config: An optional dictionary defining the additional
dataset. If provided, must contain:
- 'hf_name': Hugging Face dataset name.
- 'config_name': Dataset configuration name.
- 'split': Dataset split (default 'train').
- 'weight': Positive relative weight for this dataset.
- 'id': A unique identifier string (optional, for logging).
required_columns: A list of column names that MUST be present in the final
output dataset. All loaded datasets (including the optional
one) must contain these columns. Defaults to ['nemo_id', 'text'].
Returns:
An IterableDataset containing the interleaved data streams.
Raises:
ValueError: If additional_dataset_config is invalid, weights are non-positive,
or if any dataset (including optional) lacks the required_columns.
ImportError: If 'datasets' library isn't installed.
Exception: Can re-raise exceptions from datasets.load_dataset.
"""
if not required_columns:
raise ValueError("required_columns list cannot be empty.")
# --- Standard Zyda-2 Configuration ---
zyda2_base_configs = [
{
"id": "dclm",
"hf_name": "Zyphra/Zyda-2",
"config_name": "dclm_crossdeduped",
"split": "train",
"weight": 0.4038,
}, # Using pre-normalized doc weights directly
{
"id": "zyda",
"hf_name": "Zyphra/Zyda-2",
"config_name": "zyda_crossdeduped-filtered",
"split": "train",
"weight": 0.0316,
},
{
"id": "dolma",
"hf_name": "Zyphra/Zyda-2",
"config_name": "dolma-cc_crossdeduped-filtered",
"split": "train",
"weight": 0.0585,
},
{
"id": "fwe",
"hf_name": "Zyphra/Zyda-2",
"config_name": "fwe3",
"split": "train",
"weight": 0.5061,
},
]
# Verify base weights sum to 1 initially (they are probabilities here)
assert math.isclose(
sum(c["weight"] for c in zyda2_base_configs), 1.0
), "Internal Zyda-2 weights don't sum to 1."
all_configs_to_load = list(zyda2_base_configs) # Start with base Zyda-2
final_dataset_ids = [c["id"] for c in zyda2_base_configs]
final_weights = [
c["weight"] for c in zyda2_base_configs
] # These are initially probabilities
# --- Handle Optional Additional Dataset ---
if additional_dataset_config:
print("\n--- Processing Additional Dataset ---")
# Validate config
req_keys = ["hf_name", "config_name", "weight"]
if not all(key in additional_dataset_config for key in req_keys):
raise ValueError(
f"additional_dataset_config missing required keys: {req_keys}. Got: {additional_dataset_config}"
)
add_weight = additional_dataset_config["weight"]
if not isinstance(add_weight, (int, float)) or add_weight <= 0:
raise ValueError(
f"Weight for additional dataset must be positive. Got: {add_weight}"
)
# Add config to the list
add_config = {
"id": additional_dataset_config.get(
"id", additional_dataset_config["hf_name"]
), # Use hf_name if id missing
"hf_name": additional_dataset_config["hf_name"],
"config_name": additional_dataset_config["config_name"],
"split": additional_dataset_config.get(
"split", "train"
), # Default split to 'train'
"weight": add_weight, # Store the relative weight
}
all_configs_to_load.append(add_config)
final_dataset_ids.append(add_config["id"])
# We will normalize weights later
# Adjust weights: Treat original Zyda-2 weights as relative parts summing to 1,
# and the new weight as relative to that sum.
# Example: If new weight is 0.1, total relative weight is 1.0 (Zyda) + 0.1 (New) = 1.1
# New probabilities will be [original_prob / 1.1, ..., new_weight / 1.1]
# Let's use the raw relative weights for calculation before normalization
zyda_relative_sum = 1.0 # The base Zyda probabilities sum to 1, treat this as the relative weight sum
new_relative_weight = add_config["weight"]
total_relative_weight = zyda_relative_sum + new_relative_weight
# Recalculate weights (now storing relative weights before final normalization)
final_weights = [
c["weight"] for c in zyda2_base_configs
] # Start with original probabilities as relative weights
final_weights.append(new_relative_weight) # Add the new relative weight
print(
f"Added dataset '{add_config['id']}' with relative weight {new_relative_weight}."
)
print(f"Total relative weight: {total_relative_weight}")
else:
print("\n--- Loading Standard Zyda-2 Components Only ---")
total_relative_weight = sum(final_weights) # Should be 1.0
# --- Load, Verify, and Prepare All Datasets ---
loaded_datasets = []
print("\n--- Loading and Preparing Datasets ---")
for config in all_configs_to_load:
dataset_id = config["id"]
hf_name = config["hf_name"]
config_name = config["config_name"]
split = config["split"]
print(f"Loading '{dataset_id}' ({hf_name} / {config_name} / {split})...")
try:
# Load dataset
ds = datasets.load_dataset(
hf_name, name=config_name, split=split, streaming=True
)
# Verify and Select Columns
print(
f" Verifying and selecting columns {required_columns} for '{dataset_id}'..."
)
try:
existing_columns = list(ds.features.keys())
except Exception as e:
print(
f" Warning: Could not reliably get features for {dataset_id}. Error: {e}"
)
existing_columns = required_columns # Optimistic assumption
missing_columns = [
col for col in required_columns if col not in existing_columns
]
if missing_columns:
raise ValueError(
f"Dataset '{dataset_id}' ({hf_name}/{config_name}) is missing required columns: {missing_columns}. "
f"Available columns: {existing_columns}"
)
# Select/Remove columns
cols_to_remove = [
col for col in existing_columns if col not in required_columns
]
if cols_to_remove:
print(f" Removing columns: {cols_to_remove}")
processed_ds = ds.remove_columns(cols_to_remove)
else:
processed_ds = ds
loaded_datasets.append(processed_ds)
final_cols = (
list(processed_ds.features.keys())
if processed_ds.features
else required_columns
)
print(f" Successfully prepared '{dataset_id}'. Kept columns: {final_cols}")
except Exception as e:
print(
f"Error loading or processing dataset '{dataset_id}' ({hf_name}/{config_name}): {e}"
)
raise
# --- Normalize Final Weights to Probabilities ---
if total_relative_weight <= 0:
raise ValueError("Total relative weight must be positive.")
probabilities = [w / total_relative_weight for w in final_weights]
if not math.isclose(sum(probabilities), 1.0):
print(
f"Warning: Final probabilities sum to {sum(probabilities):.4f}, not 1.0. Check weights."
)
# --- Interleave Datasets ---
print("\n--- Interleaving Datasets ---")
print(f"Dataset order for interleaving: {final_dataset_ids}")
print(f"Final probabilities: {[f'{p:.4f}' for p in probabilities]}")
interleaved_ds = datasets.interleave_datasets(
loaded_datasets, probabilities=probabilities, stopping_strategy="all_exhausted"
)
print("\nSuccessfully created interleaved streaming dataset.")
return interleaved_ds
# --- Example Usage ---
if __name__ == "__main__":
# --- Example 1: Load only standard Zyda-2 ---
print("\n--- EXAMPLE 1: Loading Standard Zyda-2 ---")
try:
# Specify the columns needed from the standard Zyda-2 mix
std_columns = ["nemo_id", "text"]
zyda2_standard_stream = load_zyda2_with_optional_dataset(
required_columns=std_columns
)
print("\nStandard Zyda-2 stream created. Fetching first item:")
first_item_std = next(iter(zyda2_standard_stream))
print(first_item_std)
except Exception as e:
print(f"\nError in Example 1: {e}")
# --- Example 2: Load Zyda-2 + smollm-corpus-python ---
print("\n\n--- EXAMPLE 2: Loading Zyda-2 + SmolLM Python ---")
try:
# Define the additional dataset configuration
smollm_config = {
"id": "smollm_py",
"hf_name": "BEE-spoke-data/smollm-corpus-python",
"config_name": "default",
"split": "train",
"weight": 0.2, # Give it a relative weight (e.g., 20% compared to Zyda-2's total weight of 1.0)
}
# Since smollm only has 'text' in common with Zyda-2's 'nemo_id' and 'text',
# we MUST specify 'text' as the required column for the combined dataset.
combined_columns = ["text"]
zyda2_plus_smollm_stream = load_zyda2_with_optional_dataset(
additional_dataset_config=smollm_config, required_columns=combined_columns
)
print("\nZyda-2 + SmolLM stream created. Fetching first item:")
first_item_combined = next(iter(zyda2_plus_smollm_stream))
print(first_item_combined)
except Exception as e:
print(f"\nError in Example 2: {e}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment