Big refactor of SD runner and added image generator

This commit is contained in:
Jaret Burkett
2023-08-03 14:51:25 -06:00
parent 75ec5d9292
commit 66c6f0f6f7
16 changed files with 923 additions and 430 deletions

View File

@@ -1,4 +1,7 @@
from typing import List
import os
import time
from typing import List, Optional
import random
class SaveConfig:
@@ -27,6 +30,7 @@ class SampleConfig:
self.guidance_scale = kwargs.get('guidance_scale', 7)
self.sample_steps = kwargs.get('sample_steps', 20)
self.network_multiplier = kwargs.get('network_multiplier', 1)
self.guidance_rescale = kwargs.get('guidance_rescale', 0.0)
class NetworkConfig:
@@ -35,7 +39,7 @@ class NetworkConfig:
rank = kwargs.get('rank', None)
linear = kwargs.get('linear', None)
if rank is not None:
self.rank: int = rank # rank for backward compatibility
self.rank: int = rank # rank for backward compatibility
self.linear: int = rank
elif linear is not None:
self.rank: int = linear
@@ -71,6 +75,7 @@ class ModelConfig:
self.is_v2: bool = kwargs.get('is_v2', False)
self.is_xl: bool = kwargs.get('is_xl', False)
self.is_v_pred: bool = kwargs.get('is_v_pred', False)
self.dtype: str = kwargs.get('dtype', 'float16')
if self.name_or_path is None:
raise ValueError('name_or_path must be specified')
@@ -103,3 +108,197 @@ class SliderConfig:
self.resolutions: List[List[int]] = kwargs.get('resolutions', [[512, 512]])
self.prompt_file: str = kwargs.get('prompt_file', None)
self.prompt_tensors: str = kwargs.get('prompt_tensors', None)
class GenerateImageConfig:
def __init__(
self,
prompt: str = '',
prompt_2: Optional[str] = None,
width: int = 512,
height: int = 512,
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
negative_prompt: str = '',
negative_prompt_2: Optional[str] = None,
seed: int = -1,
network_multiplier: float = 1.0,
guidance_rescale: float = 0.0,
# the tag [time] will be replaced with milliseconds since epoch
output_path: str = None, # full image path
output_folder: str = None, # folder to save image in if output_path is not specified
output_ext: str = 'png', # extension to save image as if output_path is not specified
output_tail: str = '', # tail to add to output filename
add_prompt_file: bool = False, # add a prompt file with generated image
):
self.width: int = width
self.height: int = height
self.num_inference_steps: int = num_inference_steps
self.guidance_scale: float = guidance_scale
self.guidance_rescale: float = guidance_rescale
self.prompt: str = prompt
self.prompt_2: str = prompt_2
self.negative_prompt: str = negative_prompt
self.negative_prompt_2: str = negative_prompt_2
self.output_path: str = output_path
self.seed: int = seed
if self.seed == -1:
# generate random one
self.seed = random.randint(0, 2 ** 32 - 1)
self.network_multiplier: float = network_multiplier
self.output_folder: str = output_folder
self.output_ext: str = output_ext
self.add_prompt_file: bool = add_prompt_file
self.output_tail: str = output_tail
self.gen_time: int = int(time.time() * 1000)
# prompt string will override any settings above
self._process_prompt_string()
# handle dual text encoder prompts if nothing passed
if negative_prompt_2 is None:
self.negative_prompt_2 = negative_prompt
if prompt_2 is None:
self.prompt_2 = prompt
# parse prompt paths
if self.output_path is None and self.output_folder is None:
raise ValueError('output_path or output_folder must be specified')
elif self.output_path is not None:
self.output_folder = os.path.dirname(self.output_path)
self.output_ext = os.path.splitext(self.output_path)[1][1:]
self.output_filename_no_ext = os.path.splitext(os.path.basename(self.output_path))[0]
else:
self.output_filename_no_ext = '[time]_[count]'
if len(self.output_tail) > 0:
self.output_filename_no_ext += '_' + self.output_tail
self.output_path = os.path.join(self.output_folder, self.output_filename_no_ext + '.' + self.output_ext)
# adjust height
self.height = max(64, self.height - self.height % 8) # round to divisible by 8
self.width = max(64, self.width - self.width % 8) # round to divisible by 8
def set_gen_time(self, gen_time: int = None):
if gen_time is not None:
self.gen_time = gen_time
else:
self.gen_time = int(time.time() * 1000)
def _get_path_no_ext(self, count: int = 0, max_count=0):
# zero pad count
count_str = str(count).zfill(len(str(max_count)))
# replace [time] with gen time
filename = self.output_filename_no_ext.replace('[time]', str(self.gen_time))
# replace [count] with count
filename = filename.replace('[count]', count_str)
return filename
def get_image_path(self, count: int = 0, max_count=0):
filename = self._get_path_no_ext(count, max_count)
filename += '.' + self.output_ext
# join with folder
return os.path.join(self.output_folder, filename)
def get_prompt_path(self, count: int = 0, max_count=0):
filename = self._get_path_no_ext(count, max_count)
filename += '.txt'
# join with folder
return os.path.join(self.output_folder, filename)
def save_image(self, image, count: int = 0, max_count=0):
# make parent dirs
os.makedirs(self.output_folder, exist_ok=True)
self.set_gen_time()
# TODO save image gen header info for A1111 and us, our seeds probably wont match
image.save(self.get_image_path(count, max_count))
# do prompt file
if self.add_prompt_file:
self.save_prompt_file(count, max_count)
def save_prompt_file(self, count: int = 0, max_count=0):
# save prompt file
with open(self.get_prompt_path(count, max_count), 'w') as f:
prompt = self.prompt
if self.prompt_2 is not None:
prompt += ' --p2 ' + self.prompt_2
if self.negative_prompt is not None:
prompt += ' --n ' + self.negative_prompt
if self.negative_prompt_2 is not None:
prompt += ' --n2 ' + self.negative_prompt_2
prompt += ' --w ' + str(self.width)
prompt += ' --h ' + str(self.height)
prompt += ' --seed ' + str(self.seed)
prompt += ' --cfg ' + str(self.guidance_scale)
prompt += ' --steps ' + str(self.num_inference_steps)
prompt += ' --m ' + str(self.network_multiplier)
prompt += ' --gr ' + str(self.guidance_rescale)
# get gen info
f.write(self.prompt)
def _process_prompt_string(self):
# we will try to support all sd-scripts where we can
# FROM SD-SCRIPTS
# --n Treat everything until the next option as a negative prompt.
# --w Specify the width of the generated image.
# --h Specify the height of the generated image.
# --d Specify the seed for the generated image.
# --l Specify the CFG scale for the generated image.
# --s Specify the number of steps during generation.
# OURS and some QOL additions
# --m Specify the network multiplier for the generated image.
# --p2 Prompt for the second text encoder (SDXL only)
# --n2 Negative prompt for the second text encoder (SDXL only)
# --gr Specify the guidance rescale for the generated image (SDXL only)
# --seed Specify the seed for the generated image same as --d
# --cfg Specify the CFG scale for the generated image same as --l
# --steps Specify the number of steps during generation same as --s
# --network_multiplier Specify the network multiplier for the generated image same as --m
# process prompt string and update values if it has some
if self.prompt is not None and len(self.prompt) > 0:
# process prompt string
prompt = self.prompt
prompt = prompt.strip()
p_split = prompt.split('--')
self.prompt = p_split[0].strip()
if len(p_split) > 1:
for split in p_split[1:]:
# allows multi char flags
flag = split.split(' ')[0].strip()
content = split[len(flag):].strip()
if flag == 'p2':
self.prompt_2 = content
elif flag == 'n':
self.negative_prompt = content
elif flag == 'n2':
self.negative_prompt_2 = content
elif flag == 'w':
self.width = int(content)
elif flag == 'h':
self.height = int(content)
elif flag == 'd':
self.seed = int(content)
elif flag == 'seed':
self.seed = int(content)
elif flag == 'l':
self.guidance_scale = float(content)
elif flag == 'cfg':
self.guidance_scale = float(content)
elif flag == 's':
self.num_inference_steps = int(content)
elif flag == 'steps':
self.num_inference_steps = int(content)
elif flag == 'm':
self.network_multiplier = float(content)
elif flag == 'network_multiplier':
self.network_multiplier = float(content)
elif flag == 'gr':
self.guidance_rescale = float(content)