Created
September 27, 2024 10:58
-
-
Save ovuruska/7b588313bd282ea95b7e86c50446aaf3 to your computer and use it in GitHub Desktop.
Quantized Image Generation with FLUX Model
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from datasets import load_dataset | |
from torch.utils.data import Dataset | |
from pathlib import Path | |
from diffusers import DiffusionPipeline, FluxTransformer2DModel, AutoencoderKL | |
from transformers import T5EncoderModel, CLIPTextModel | |
from torchao.quantization import quantize_, int8_weight_only | |
import time | |
import csv | |
from tqdm import tqdm | |
# Dataset loading | |
dataset = load_dataset("pszemraj/text2image-multi-prompt") | |
class PromptDataset(Dataset): | |
def __init__(self, label: str): | |
self.data = dataset[label]["text"] | |
def __len__(self): | |
return len(self.data) | |
def __getitem__(self, idx): | |
return self.data[idx] | |
class TrainDataset(PromptDataset): | |
def __init__(self): | |
super().__init__("train") | |
train_dataset = TrainDataset() | |
def load_pipeline(ckpt_id: str): | |
start = time.time() | |
transformer = FluxTransformer2DModel.from_pretrained( | |
ckpt_id, subfolder="transformer", torch_dtype=torch.bfloat16 | |
) | |
quantize_(transformer, int8_weight_only()) | |
text_encoder = CLIPTextModel.from_pretrained( | |
ckpt_id, subfolder="text_encoder", torch_dtype=torch.bfloat16 | |
) | |
quantize_(text_encoder, int8_weight_only()) | |
text_encoder_2 = T5EncoderModel.from_pretrained( | |
ckpt_id, subfolder="text_encoder_2", torch_dtype=torch.bfloat16 | |
) | |
quantize_(text_encoder_2, int8_weight_only()) | |
vae = AutoencoderKL.from_pretrained( | |
ckpt_id, subfolder="vae", torch_dtype=torch.bfloat16 | |
) | |
quantize_(vae, int8_weight_only()) | |
pipeline = DiffusionPipeline.from_pretrained( | |
ckpt_id, | |
transformer=transformer, | |
vae=vae, | |
text_encoder=text_encoder, | |
text_encoder_2=text_encoder_2, | |
torch_dtype=torch.bfloat16 | |
).to("cuda") | |
print(f"Loaded pipeline in {time.time() - start:.2f}s") | |
return pipeline | |
if __name__ == "__main__": | |
ckpt_id = "black-forest-labs/FLUX.1-schnell" | |
schnell_pipeline = load_pipeline(ckpt_id) | |
quantized_out_dir = Path.home() / "quantized_out" / ckpt_id.replace("/", "_").replace("-", "_").replace(".", "_").replace(" ", "_") | |
quantized_out_dir.mkdir(exist_ok=True, parents=True) | |
seeds = [42, 1337, 7, 13, 666, 420, 69, 314, 271, 888] | |
csv_file = Path.home() / "quantized_image_generation_log.csv" | |
num_inference_steps = 8 | |
height = 1024 | |
width = 1024 | |
guidance_scale = 7.5 | |
csv_headers = ["model_name", "prompt", "num_inference_steps", "seed", "height", "width", "guidance_scale", "total_memory_used", "total_duration", "image_id", "image_path"] | |
with open(csv_file, mode='a', newline='') as file: | |
writer = csv.DictWriter(file, fieldnames=csv_headers) | |
if file.tell() == 0: | |
writer.writeheader() | |
for prompt in tqdm(train_dataset[:150], desc="Generating images"): | |
for seed in seeds: | |
start_time = time.time() | |
torch.cuda.reset_peak_memory_stats() | |
image = schnell_pipeline( | |
prompt=prompt, | |
num_inference_steps=num_inference_steps, | |
height=height, | |
width=width, | |
guidance_scale=guidance_scale, | |
generator=torch.Generator("cuda").manual_seed(seed) | |
).images[0] | |
total_memory_used = torch.cuda.max_memory_allocated() / (1024 ** 3) | |
total_duration = time.time() - start_time | |
timestamp = str(time.time()) | |
image_filename = f"{timestamp.replace('.', '')}.png" | |
image_path = quantized_out_dir / image_filename | |
image.save(image_path) | |
writer.writerow({ | |
"model_name": ckpt_id, | |
"prompt": prompt, | |
"num_inference_steps": num_inference_steps, | |
"seed": seed, | |
"height": height, | |
"width": width, | |
"guidance_scale": guidance_scale, | |
"total_memory_used": f"{total_memory_used:.2f}", | |
"total_duration": f"{total_duration:.2f}", | |
"image_id": image_filename, | |
"image_path": str(image_path) | |
}) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment