Created
April 30, 2023 17:23
-
-
Save Dref360/acd4630a58361e28a50d034ad10220b6 to your computer and use it in GitHub Desktop.
Overwrite HugginFace Dataset
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os.path | |
import shutil | |
import tempfile | |
from datasets import Dataset, load_from_disk | |
PATH = '/tmp/b.arrow' | |
def overwrite_dataset(ds: Dataset, path) -> Dataset: | |
""" | |
Completely overwrite `ds` allowing you to save it in the same folder. | |
Note: | |
The "old" copy will be unusable so be sure that no reference to it exists. | |
""" | |
pt = tempfile.mkdtemp() | |
ds.save_to_disk(pt) | |
if os.path.exists(path): | |
shutil.rmtree(path) | |
shutil.move(pt, path) | |
return load_from_disk(path) | |
ds = Dataset.from_dict({'a': list(range(100))}) | |
ds.save_to_disk(PATH) | |
ds = load_from_disk(PATH) | |
ds = ds.map(lambda u: {'a': u['a'] + 1}) | |
# Will raise | |
try: | |
ds.save_to_disk(PATH) | |
except PermissionError as e: | |
print("Caught", e) | |
# Will not raise! | |
ds = overwrite_dataset(ds, PATH) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment