Work on ipadapters and custom adapters

This commit is contained in:
Jaret Burkett
2024-05-13 06:37:54 -06:00
parent 10e1ecf1e8
commit 5a45c709cd
10 changed files with 150 additions and 67 deletions

View File

@@ -1,7 +1,7 @@
import gc
import os
from collections import OrderedDict
from typing import ForwardRef, List
from typing import ForwardRef, List, Optional, Union
import torch
from safetensors.torch import save_file, load_file
@@ -22,6 +22,7 @@ class GenerateConfig:
self.sampler = kwargs.get('sampler', 'ddpm')
self.width = kwargs.get('width', 512)
self.height = kwargs.get('height', 512)
self.size_list: Union[List[int], None] = kwargs.get('size_list', None)
self.neg = kwargs.get('neg', '')
self.seed = kwargs.get('seed', -1)
self.guidance_scale = kwargs.get('guidance_scale', 7)
@@ -30,6 +31,7 @@ class GenerateConfig:
self.neg_2 = kwargs.get('neg_2', None)
self.prompts = kwargs.get('prompts', None)
self.guidance_rescale = kwargs.get('guidance_rescale', 0.0)
self.compile = kwargs.get('compile', False)
self.ext = kwargs.get('ext', 'png')
self.prompt_file = kwargs.get('prompt_file', False)
self.prompts_in_file = self.prompts
@@ -93,17 +95,26 @@ class GenerateProcess(BaseProcess):
self.sd.load_model()
print("Compiling model...")
self.sd.unet = torch.compile(self.sd.unet, mode="reduce-overhead", fullgraph=True)
# self.sd.unet = torch.compile(self.sd.unet, mode="reduce-overhead", fullgraph=True)
if self.generate_config.compile:
self.sd.unet = torch.compile(self.sd.unet, mode="reduce-overhead")
print(f"Generating {len(self.generate_config.prompts)} images")
# build prompt image configs
prompt_image_configs = []
for prompt in self.generate_config.prompts:
width = self.generate_config.width
height = self.generate_config.height
if self.generate_config.size_list is not None:
# randomly select a size
width, height = random.choice(self.generate_config.size_list)
prompt_image_configs.append(GenerateImageConfig(
prompt=prompt,
prompt_2=self.generate_config.prompt_2,
width=self.generate_config.width,
height=self.generate_config.height,
width=width,
height=height,
num_inference_steps=self.generate_config.sample_steps,
guidance_scale=self.generate_config.guidance_scale,
negative_prompt=self.generate_config.neg,