-
-
Save wkpark/f70bc55f23c479e302dc4a5ebd5ae1c2 to your computer and use it in GitHub Desktop.
Replace the VAE in a Stable Diffusion model with a new VAE. Tested on v1.4 & v1.5 SD models
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# | |
# Script by https://github.com/ProGamerGov | |
# | |
# ChangeLog: | |
# - support safetensors, save float16 if needed, check filename 2023/06/22 by wkpark | |
# | |
import os | |
import sys | |
import copy | |
import torch | |
from pathlib import Path | |
from safetensors.torch import load_file, save_file | |
def load_model(path): | |
if path.suffix == ".safetensors": | |
return load_file(path, device="cpu") | |
else: | |
ckpt = torch.load(path, map_location="cpu") | |
return ckpt["state_dict"] if "state_dict" in ckpt else ckpt | |
# Path to model and VAE files that you want to merge | |
if len(sys.argv) == 1: | |
print("Usage: replace_vae.py model_file vae_file") | |
exit(1) | |
model_file_path = Path(sys.argv[1]) | |
if len(sys.argv) > 2: | |
vae_file_path = Path(sys.argv[2]) | |
else: | |
vae_file_path = Path("vae-ft-mse-840000-ema-pruned.safetensors") | |
if not vae_file_path.exists(): | |
for dir in ".", "../VAE": | |
default_vae = "vae-ft-mse-840000-ema-pruned" | |
for ext in "safetensors", "ckpt": | |
vae_file_path = Path(os.path.join(dir, default_vae + "." + ext)) | |
if not vae_file_path.exists(): | |
continue | |
break | |
if vae_file_path.exists(): | |
print(f"- vae file {str(vae_file_path)} found!") | |
break | |
if not vae_file_path.exists(): | |
print(f"no default vae file {default_vae} found!") | |
exit(1) | |
if not model_file_path.exists() and model_file_path.suffix == "": | |
model_file = sys.argv[1] | |
for ext in "safetensors", "ckpt": | |
model_file_path = Path(model_file + "." + ext) | |
if not model_file_path.exists(): | |
continue | |
break | |
if not model_file_path.exists(): | |
print(f"no model file {model_file} found!") | |
exit(1) | |
print(f"- vae file = {str(vae_file_path)}") | |
print(f"- model file = {str(model_file_path)}") | |
# Name to use for new model file | |
new_model_path = model_file_path.parent / (model_file_path.stem + "-vae" + model_file_path.suffix) | |
# Load files | |
vae_model = load_model(vae_file_path) | |
full_model = load_model(model_file_path) | |
# check original dtype | |
if full_model["cond_stage_model.transformer.text_model.embeddings.position_embedding.weight"].dtype == torch.float32: | |
half = False | |
else: | |
half = True | |
# Replace VAE in model file with new VAE | |
vae_dict = {k: v for k, v in vae_model.items() if k[0:4] not in ["loss", "mode"]} | |
for k, _ in vae_dict.items(): | |
key_name = "first_stage_model." + k | |
full_model[key_name] = copy.deepcopy(vae_model[k]) | |
if half and type(full_model[key_name]) == torch.Tensor and full_model[key_name].dtype == torch.float32: | |
full_model[key_name] = full_model[key_name].half() | |
# Save model with new VAE | |
if new_model_path.suffix == ".safetensors": | |
save_file(full_model, str(new_model_path)) | |
else: | |
torch.save({"state_dict": full_model}, str(new_model_path)) | |
print(f"new file {str(new_model_path)} saved!") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment