Created
April 30, 2024 02:28
-
-
Save clane9/6f12d2372ba00fb01adda1074e8c5a45 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import io\n", | |
"import psutil\n", | |
"from pathlib import Path\n", | |
"\n", | |
"import numpy as np\n", | |
"import webdataset as wds\n", | |
"from torch.utils.data import DataLoader" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def get_memory():\n", | |
" proc = psutil.Process()\n", | |
" children = proc.children(recursive=True)\n", | |
" mem = []\n", | |
" for p in [proc] + children:\n", | |
" mem.append(p.memory_info().rss / 1024**2)\n", | |
" return mem" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Generate a dummy dataset consisting of sequences of high-dimensional data." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# change directory as needed\n", | |
"root = Path(\"/local/slurm-23665773/local/data\")\n", | |
"root.mkdir(exist_ok=True)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def encode_numpy(data: np.ndarray) -> bytes:\n", | |
" with io.BytesIO() as f:\n", | |
" np.save(f, data)\n", | |
" buf = f.getvalue()\n", | |
" return buf" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"# writing /local/slurm-23665773/local/data/000000.tar 0 0.0 GB 0\n", | |
"# writing /local/slurm-23665773/local/data/000001.tar 50 0.4 GB 50\n", | |
"# writing /local/slurm-23665773/local/data/000002.tar 50 0.4 GB 100\n", | |
"# writing /local/slurm-23665773/local/data/000003.tar 50 0.4 GB 150\n", | |
"# writing /local/slurm-23665773/local/data/000004.tar 50 0.4 GB 200\n", | |
"# writing /local/slurm-23665773/local/data/000005.tar 50 0.4 GB 250\n", | |
"# writing /local/slurm-23665773/local/data/000006.tar 50 0.4 GB 300\n", | |
"# writing /local/slurm-23665773/local/data/000007.tar 50 0.4 GB 350\n", | |
"# writing /local/slurm-23665773/local/data/000008.tar 50 0.4 GB 400\n", | |
"# writing /local/slurm-23665773/local/data/000009.tar 50 0.4 GB 450\n", | |
"# writing /local/slurm-23665773/local/data/000010.tar 50 0.4 GB 500\n", | |
"# writing /local/slurm-23665773/local/data/000011.tar 50 0.4 GB 550\n", | |
"# writing /local/slurm-23665773/local/data/000012.tar 50 0.4 GB 600\n", | |
"# writing /local/slurm-23665773/local/data/000013.tar 50 0.4 GB 650\n", | |
"# writing /local/slurm-23665773/local/data/000014.tar 50 0.4 GB 700\n", | |
"# writing /local/slurm-23665773/local/data/000015.tar 50 0.4 GB 750\n", | |
"# writing /local/slurm-23665773/local/data/000016.tar 50 0.4 GB 800\n", | |
"# writing /local/slurm-23665773/local/data/000017.tar 50 0.4 GB 850\n", | |
"# writing /local/slurm-23665773/local/data/000018.tar 50 0.4 GB 900\n", | |
"# writing /local/slurm-23665773/local/data/000019.tar 50 0.4 GB 950\n" | |
] | |
} | |
], | |
"source": [ | |
"with wds.ShardWriter(str(root / \"%06d.tar\"), maxsize=400*1024*1024, encoder=False) as sink:\n", | |
" for ii in range(1000):\n", | |
" x = np.random.randint(0, 255, (256, 32768), dtype=np.uint8)\n", | |
" buf = encode_numpy(x)\n", | |
" sample = {\"__key__\": f\"{ii:05d}\", \"npy\": buf}\n", | |
" sink.write(sample)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Test 1: iterating over full sequences, shuffled and batched.\n", | |
"\n", | |
"Here the memory usage is roughly the buffer size plus one shard." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"dataset = (\n", | |
" wds.WebDataset(str(root / \"{000000..000019}.tar\"))\n", | |
" .decode()\n", | |
" .to_tuple(\"npy\")\n", | |
" .shuffle(200)\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"(256, 32768)\n" | |
] | |
} | |
], | |
"source": [ | |
"x, = next(iter(dataset))\n", | |
"print(x.shape)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"1600.0\n" | |
] | |
} | |
], | |
"source": [ | |
"buffer_size = 200 * 256 * 32768 / 1024 ** 2\n", | |
"print(buffer_size)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"loader = DataLoader(dataset.batched(8), num_workers=2, batch_size=None)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"[ 0] Shape: torch.Size([8, 256, 32768]) Mem: (970,1301,1373)\n", | |
"[ 10] Shape: torch.Size([8, 256, 32768]) Mem: (970,1565,1691)\n", | |
"[ 20] Shape: torch.Size([8, 256, 32768]) Mem: (970,1893,1976)\n", | |
"[ 30] Shape: torch.Size([8, 256, 32768]) Mem: (970,2053,2062)\n", | |
"[ 40] Shape: torch.Size([8, 256, 32768]) Mem: (970,2078,2055)\n", | |
"[ 50] Shape: torch.Size([8, 256, 32768]) Mem: (970,2094,2051)\n", | |
"[ 60] Shape: torch.Size([8, 256, 32768]) Mem: (970,2097,2055)\n", | |
"[ 70] Shape: torch.Size([8, 256, 32768]) Mem: (970,2113,2065)\n", | |
"[ 80] Shape: torch.Size([8, 256, 32768]) Mem: (970,2093,2077)\n", | |
"[ 90] Shape: torch.Size([8, 256, 32768]) Mem: (970,2079,2077)\n", | |
"[ 100] Shape: torch.Size([8, 256, 32768]) Mem: (970,2082,2097)\n", | |
"[ 110] Shape: torch.Size([8, 256, 32768]) Mem: (970,2002,2020)\n", | |
"[ 120] Shape: torch.Size([8, 256, 32768]) Mem: (970,1999,2017)\n" | |
] | |
} | |
], | |
"source": [ | |
"for ii, (x,) in enumerate(loader):\n", | |
" if ii % 10 == 0:\n", | |
" mem = get_memory()\n", | |
" mem_fmt = \",\".join(f\"{v:.0f}\" for v in mem)\n", | |
" print(f\"[{ii:>6d}] Shape: {x.shape} Mem: ({mem_fmt})\")" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Test 2: iterating over short clips sampled sequentially from each sequence, then shuffled. The buffer carries more samples, but equal size.\n", | |
"\n", | |
"Now, the memory usage of the data loader is much higher! Why??" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def to_clips(window: int = 16):\n", | |
" def _filter(source):\n", | |
" for x, in source:\n", | |
" for start in range(0, len(x) - window, window):\n", | |
" yield (x[start: start+window],)\n", | |
"\n", | |
" return _filter" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"dataset2 = (\n", | |
" wds.WebDataset(str(root / \"{000000..000019}.tar\"))\n", | |
" .decode()\n", | |
" .to_tuple(\"npy\")\n", | |
" .compose(to_clips(window=16))\n", | |
" .shuffle(3200)\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"(16, 32768)\n" | |
] | |
} | |
], | |
"source": [ | |
"x, = next(iter(dataset2))\n", | |
"print(x.shape)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"1600.0\n" | |
] | |
} | |
], | |
"source": [ | |
"buffer_size = 3200 * 16 * 32768 / 1024 ** 2\n", | |
"print(buffer_size)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"loader2 = DataLoader(dataset2.batched(128), num_workers=2, batch_size=None)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 16, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"[ 0] Shape: torch.Size([128, 16, 32768]) Mem: (970,933,893)\n", | |
"[ 10] Shape: torch.Size([128, 16, 32768]) Mem: (970,1261,1279)\n", | |
"[ 20] Shape: torch.Size([128, 16, 32768]) Mem: (970,1901,1917)\n", | |
"[ 30] Shape: torch.Size([128, 16, 32768]) Mem: (970,2541,2566)\n", | |
"[ 40] Shape: torch.Size([128, 16, 32768]) Mem: (970,3174,3247)\n", | |
"[ 50] Shape: torch.Size([128, 16, 32768]) Mem: (933,3660,3819)\n", | |
"[ 60] Shape: torch.Size([128, 16, 32768]) Mem: (933,3962,4089)\n", | |
"[ 70] Shape: torch.Size([128, 16, 32768]) Mem: (933,4237,4156)\n", | |
"[ 80] Shape: torch.Size([128, 16, 32768]) Mem: (933,4130,4208)\n", | |
"[ 90] Shape: torch.Size([128, 16, 32768]) Mem: (933,4138,4190)\n", | |
"[ 100] Shape: torch.Size([128, 16, 32768]) Mem: (933,4142,4188)\n", | |
"[ 110] Shape: torch.Size([128, 16, 32768]) Mem: (933,4098,4124)\n" | |
] | |
} | |
], | |
"source": [ | |
"for ii, (x,) in enumerate(loader2):\n", | |
" if ii % 10 == 0:\n", | |
" mem = get_memory()\n", | |
" mem_fmt = \",\".join(f\"{v:.0f}\" for v in mem)\n", | |
" print(f\"[{ii:>6d}] Shape: {x.shape} Mem: ({mem_fmt})\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": ".venv", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.10.14" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment