Created
August 16, 2024 03:03
-
-
Save ezyang/f466d9c23452d968c0a3b17eb845d07e to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import json | |
import sys | |
import re | |
pattern = re.compile(r'^(\s*)if TEMPLATE[^:]*:\s*\n((?:\1\s{4}.*\n|\n)*?)\1else:\s*\n(?:\1\s{4}.*\n|\n)*?(\1(?=\S|$))', re.MULTILINE) | |
def replace(match): | |
indent = match.group(1) | |
true_branch = match.group(2) | |
true_branch = re.sub(r'^\s{' + str(len(indent) + 4) + '}', indent, true_branch, flags=re.MULTILINE) | |
return true_branch + match.group(3) | |
test = """ | |
@torch.compile(backend="eager", fullgraph=True) | |
def cf_check(x): | |
u0, u1 = x.tolist() | |
if TEMPLATE and False: | |
pass | |
else: | |
torch._check(u0 * 2 == u1 * 3) | |
# Do not modify the code below here (imagine it's in framework code you can't edit) | |
# NB: In future exercises, we'll use force_guard as a shorthand for this pattern. | |
if u0 * 2 == u1 * 3: | |
return torch.tensor(True) | |
else: | |
return torch.tensor(False) | |
@run_test | |
def test_check(): | |
assert cf_check(torch.tensor([12, 8])).item() | |
""" | |
#print(pattern.sub(replace, test)) | |
#sys.exit(0) | |
def process_notebook(notebook_path): | |
with open(notebook_path, 'r') as file: | |
notebook = json.load(file) | |
for cell in notebook['cells']: | |
if cell['cell_type'] == 'code': | |
source = cell['source'] | |
if isinstance(source, list): | |
source = ''.join(source) | |
cell['source'] = pattern.sub(replace, source) | |
#cell['source'] = source | |
with open(f"new_{notebook_path}", 'w') as file: | |
json.dump(notebook, file, indent=1) | |
# Usage | |
file_path = 'puzzlers-aug15.ipynb' | |
process_notebook(file_path) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment