mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-28 10:11:14 +00:00
Work on ipadapters and custom adapters
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
import gc
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
from typing import ForwardRef, List
|
||||
from typing import ForwardRef, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from safetensors.torch import save_file, load_file
|
||||
@@ -22,6 +22,7 @@ class GenerateConfig:
|
||||
self.sampler = kwargs.get('sampler', 'ddpm')
|
||||
self.width = kwargs.get('width', 512)
|
||||
self.height = kwargs.get('height', 512)
|
||||
self.size_list: Union[List[int], None] = kwargs.get('size_list', None)
|
||||
self.neg = kwargs.get('neg', '')
|
||||
self.seed = kwargs.get('seed', -1)
|
||||
self.guidance_scale = kwargs.get('guidance_scale', 7)
|
||||
@@ -30,6 +31,7 @@ class GenerateConfig:
|
||||
self.neg_2 = kwargs.get('neg_2', None)
|
||||
self.prompts = kwargs.get('prompts', None)
|
||||
self.guidance_rescale = kwargs.get('guidance_rescale', 0.0)
|
||||
self.compile = kwargs.get('compile', False)
|
||||
self.ext = kwargs.get('ext', 'png')
|
||||
self.prompt_file = kwargs.get('prompt_file', False)
|
||||
self.prompts_in_file = self.prompts
|
||||
@@ -93,17 +95,26 @@ class GenerateProcess(BaseProcess):
|
||||
self.sd.load_model()
|
||||
|
||||
print("Compiling model...")
|
||||
self.sd.unet = torch.compile(self.sd.unet, mode="reduce-overhead", fullgraph=True)
|
||||
# self.sd.unet = torch.compile(self.sd.unet, mode="reduce-overhead", fullgraph=True)
|
||||
if self.generate_config.compile:
|
||||
self.sd.unet = torch.compile(self.sd.unet, mode="reduce-overhead")
|
||||
|
||||
print(f"Generating {len(self.generate_config.prompts)} images")
|
||||
# build prompt image configs
|
||||
prompt_image_configs = []
|
||||
for prompt in self.generate_config.prompts:
|
||||
width = self.generate_config.width
|
||||
height = self.generate_config.height
|
||||
|
||||
if self.generate_config.size_list is not None:
|
||||
# randomly select a size
|
||||
width, height = random.choice(self.generate_config.size_list)
|
||||
|
||||
prompt_image_configs.append(GenerateImageConfig(
|
||||
prompt=prompt,
|
||||
prompt_2=self.generate_config.prompt_2,
|
||||
width=self.generate_config.width,
|
||||
height=self.generate_config.height,
|
||||
width=width,
|
||||
height=height,
|
||||
num_inference_steps=self.generate_config.sample_steps,
|
||||
guidance_scale=self.generate_config.guidance_scale,
|
||||
negative_prompt=self.generate_config.neg,
|
||||
|
||||
Reference in New Issue
Block a user