Skip to content

Instantly share code, notes, and snippets.

@ovuruska
Created September 27, 2024 10:58
Show Gist options
  • Save ovuruska/7b588313bd282ea95b7e86c50446aaf3 to your computer and use it in GitHub Desktop.
Save ovuruska/7b588313bd282ea95b7e86c50446aaf3 to your computer and use it in GitHub Desktop.
Quantized Image Generation with FLUX Model
import torch
from datasets import load_dataset
from torch.utils.data import Dataset
from pathlib import Path
from diffusers import DiffusionPipeline, FluxTransformer2DModel, AutoencoderKL
from transformers import T5EncoderModel, CLIPTextModel
from torchao.quantization import quantize_, int8_weight_only
import time
import csv
from tqdm import tqdm
# Dataset loading
dataset = load_dataset("pszemraj/text2image-multi-prompt")
class PromptDataset(Dataset):
def __init__(self, label: str):
self.data = dataset[label]["text"]
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
class TrainDataset(PromptDataset):
def __init__(self):
super().__init__("train")
train_dataset = TrainDataset()
def load_pipeline(ckpt_id: str):
start = time.time()
transformer = FluxTransformer2DModel.from_pretrained(
ckpt_id, subfolder="transformer", torch_dtype=torch.bfloat16
)
quantize_(transformer, int8_weight_only())
text_encoder = CLIPTextModel.from_pretrained(
ckpt_id, subfolder="text_encoder", torch_dtype=torch.bfloat16
)
quantize_(text_encoder, int8_weight_only())
text_encoder_2 = T5EncoderModel.from_pretrained(
ckpt_id, subfolder="text_encoder_2", torch_dtype=torch.bfloat16
)
quantize_(text_encoder_2, int8_weight_only())
vae = AutoencoderKL.from_pretrained(
ckpt_id, subfolder="vae", torch_dtype=torch.bfloat16
)
quantize_(vae, int8_weight_only())
pipeline = DiffusionPipeline.from_pretrained(
ckpt_id,
transformer=transformer,
vae=vae,
text_encoder=text_encoder,
text_encoder_2=text_encoder_2,
torch_dtype=torch.bfloat16
).to("cuda")
print(f"Loaded pipeline in {time.time() - start:.2f}s")
return pipeline
if __name__ == "__main__":
ckpt_id = "black-forest-labs/FLUX.1-schnell"
schnell_pipeline = load_pipeline(ckpt_id)
quantized_out_dir = Path.home() / "quantized_out" / ckpt_id.replace("/", "_").replace("-", "_").replace(".", "_").replace(" ", "_")
quantized_out_dir.mkdir(exist_ok=True, parents=True)
seeds = [42, 1337, 7, 13, 666, 420, 69, 314, 271, 888]
csv_file = Path.home() / "quantized_image_generation_log.csv"
num_inference_steps = 8
height = 1024
width = 1024
guidance_scale = 7.5
csv_headers = ["model_name", "prompt", "num_inference_steps", "seed", "height", "width", "guidance_scale", "total_memory_used", "total_duration", "image_id", "image_path"]
with open(csv_file, mode='a', newline='') as file:
writer = csv.DictWriter(file, fieldnames=csv_headers)
if file.tell() == 0:
writer.writeheader()
for prompt in tqdm(train_dataset[:150], desc="Generating images"):
for seed in seeds:
start_time = time.time()
torch.cuda.reset_peak_memory_stats()
image = schnell_pipeline(
prompt=prompt,
num_inference_steps=num_inference_steps,
height=height,
width=width,
guidance_scale=guidance_scale,
generator=torch.Generator("cuda").manual_seed(seed)
).images[0]
total_memory_used = torch.cuda.max_memory_allocated() / (1024 ** 3)
total_duration = time.time() - start_time
timestamp = str(time.time())
image_filename = f"{timestamp.replace('.', '')}.png"
image_path = quantized_out_dir / image_filename
image.save(image_path)
writer.writerow({
"model_name": ckpt_id,
"prompt": prompt,
"num_inference_steps": num_inference_steps,
"seed": seed,
"height": height,
"width": width,
"guidance_scale": guidance_scale,
"total_memory_used": f"{total_memory_used:.2f}",
"total_duration": f"{total_duration:.2f}",
"image_id": image_filename,
"image_path": str(image_path)
})
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment