mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 11:11:37 +00:00
New image generation img2img. various tweaks and fixes
This commit is contained in:
188
extensions_built_in/advanced_generator/Img2ImgGenerator.py
Normal file
188
extensions_built_in/advanced_generator/Img2ImgGenerator.py
Normal file
@@ -0,0 +1,188 @@
|
|||||||
|
import os
|
||||||
|
import random
|
||||||
|
from collections import OrderedDict
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
from diffusers import T2IAdapter
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from diffusers import StableDiffusionXLImg2ImgPipeline
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from toolkit.config_modules import ModelConfig, GenerateImageConfig, preprocess_dataset_raw_config, DatasetConfig
|
||||||
|
from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO
|
||||||
|
from toolkit.sampler import get_sampler
|
||||||
|
from toolkit.stable_diffusion_model import StableDiffusion
|
||||||
|
import gc
|
||||||
|
import torch
|
||||||
|
from jobs.process import BaseExtensionProcess
|
||||||
|
from toolkit.data_loader import get_dataloader_from_datasets
|
||||||
|
from toolkit.train_tools import get_torch_dtype
|
||||||
|
from controlnet_aux.midas import MidasDetector
|
||||||
|
from diffusers.utils import load_image
|
||||||
|
|
||||||
|
|
||||||
|
def flush():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
|
||||||
|
class GenerateConfig:
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
self.prompts: List[str]
|
||||||
|
self.sampler = kwargs.get('sampler', 'ddpm')
|
||||||
|
self.neg = kwargs.get('neg', '')
|
||||||
|
self.seed = kwargs.get('seed', -1)
|
||||||
|
self.walk_seed = kwargs.get('walk_seed', False)
|
||||||
|
self.guidance_scale = kwargs.get('guidance_scale', 7)
|
||||||
|
self.sample_steps = kwargs.get('sample_steps', 20)
|
||||||
|
self.guidance_rescale = kwargs.get('guidance_rescale', 0.0)
|
||||||
|
self.ext = kwargs.get('ext', 'png')
|
||||||
|
self.denoise_strength = kwargs.get('denoise_strength', 0.5)
|
||||||
|
self.trigger_word = kwargs.get('trigger_word', None)
|
||||||
|
|
||||||
|
|
||||||
|
class Img2ImgGenerator(BaseExtensionProcess):
|
||||||
|
|
||||||
|
def __init__(self, process_id: int, job, config: OrderedDict):
|
||||||
|
super().__init__(process_id, job, config)
|
||||||
|
self.output_folder = self.get_conf('output_folder', required=True)
|
||||||
|
self.copy_inputs_to = self.get_conf('copy_inputs_to', None)
|
||||||
|
self.device = self.get_conf('device', 'cuda')
|
||||||
|
self.model_config = ModelConfig(**self.get_conf('model', required=True))
|
||||||
|
self.generate_config = GenerateConfig(**self.get_conf('generate', required=True))
|
||||||
|
self.is_latents_cached = True
|
||||||
|
raw_datasets = self.get_conf('datasets', None)
|
||||||
|
if raw_datasets is not None and len(raw_datasets) > 0:
|
||||||
|
raw_datasets = preprocess_dataset_raw_config(raw_datasets)
|
||||||
|
self.datasets = None
|
||||||
|
self.datasets_reg = None
|
||||||
|
self.dtype = self.get_conf('dtype', 'float16')
|
||||||
|
self.torch_dtype = get_torch_dtype(self.dtype)
|
||||||
|
self.params = []
|
||||||
|
if raw_datasets is not None and len(raw_datasets) > 0:
|
||||||
|
for raw_dataset in raw_datasets:
|
||||||
|
dataset = DatasetConfig(**raw_dataset)
|
||||||
|
is_caching = dataset.cache_latents or dataset.cache_latents_to_disk
|
||||||
|
if not is_caching:
|
||||||
|
self.is_latents_cached = False
|
||||||
|
if dataset.is_reg:
|
||||||
|
if self.datasets_reg is None:
|
||||||
|
self.datasets_reg = []
|
||||||
|
self.datasets_reg.append(dataset)
|
||||||
|
else:
|
||||||
|
if self.datasets is None:
|
||||||
|
self.datasets = []
|
||||||
|
self.datasets.append(dataset)
|
||||||
|
|
||||||
|
self.progress_bar = None
|
||||||
|
self.sd = StableDiffusion(
|
||||||
|
device=self.device,
|
||||||
|
model_config=self.model_config,
|
||||||
|
dtype=self.dtype,
|
||||||
|
)
|
||||||
|
print(f"Using device {self.device}")
|
||||||
|
self.data_loader: DataLoader = None
|
||||||
|
self.adapter: T2IAdapter = None
|
||||||
|
|
||||||
|
def to_pil(self, img):
|
||||||
|
# image comes in -1 to 1. convert to a PIL RGB image
|
||||||
|
img = (img + 1) / 2
|
||||||
|
img = img.clamp(0, 1)
|
||||||
|
img = img[0].permute(1, 2, 0).cpu().numpy()
|
||||||
|
img = (img * 255).astype(np.uint8)
|
||||||
|
image = Image.fromarray(img)
|
||||||
|
return image
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
with torch.no_grad():
|
||||||
|
super().run()
|
||||||
|
print("Loading model...")
|
||||||
|
self.sd.load_model()
|
||||||
|
device = torch.device(self.device)
|
||||||
|
|
||||||
|
|
||||||
|
if self.model_config.is_xl:
|
||||||
|
pipe = StableDiffusionXLImg2ImgPipeline(
|
||||||
|
vae=self.sd.vae,
|
||||||
|
unet=self.sd.unet,
|
||||||
|
text_encoder=self.sd.text_encoder[0],
|
||||||
|
text_encoder_2=self.sd.text_encoder[1],
|
||||||
|
tokenizer=self.sd.tokenizer[0],
|
||||||
|
tokenizer_2=self.sd.tokenizer[1],
|
||||||
|
scheduler=get_sampler(self.generate_config.sampler),
|
||||||
|
).to(device, dtype=self.torch_dtype)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("Only XL models are supported")
|
||||||
|
pipe.set_progress_bar_config(disable=True)
|
||||||
|
|
||||||
|
# pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
|
||||||
|
# midas_depth = torch.compile(midas_depth, mode="reduce-overhead", fullgraph=True)
|
||||||
|
|
||||||
|
self.data_loader = get_dataloader_from_datasets(self.datasets, 1, self.sd)
|
||||||
|
|
||||||
|
num_batches = len(self.data_loader)
|
||||||
|
pbar = tqdm(total=num_batches, desc="Generating images")
|
||||||
|
seed = self.generate_config.seed
|
||||||
|
# load images from datasets, use tqdm
|
||||||
|
for i, batch in enumerate(self.data_loader):
|
||||||
|
batch: DataLoaderBatchDTO = batch
|
||||||
|
|
||||||
|
file_item: FileItemDTO = batch.file_items[0]
|
||||||
|
img_path = file_item.path
|
||||||
|
img_filename = os.path.basename(img_path)
|
||||||
|
img_filename_no_ext = os.path.splitext(img_filename)[0]
|
||||||
|
img_filename = img_filename_no_ext + '.' + self.generate_config.ext
|
||||||
|
output_path = os.path.join(self.output_folder, img_filename)
|
||||||
|
output_caption_path = os.path.join(self.output_folder, img_filename_no_ext + '.txt')
|
||||||
|
|
||||||
|
if self.copy_inputs_to is not None:
|
||||||
|
output_inputs_path = os.path.join(self.copy_inputs_to, img_filename)
|
||||||
|
output_inputs_caption_path = os.path.join(self.copy_inputs_to, img_filename_no_ext + '.txt')
|
||||||
|
else:
|
||||||
|
output_inputs_path = None
|
||||||
|
output_inputs_caption_path = None
|
||||||
|
|
||||||
|
caption = batch.get_caption_list()[0]
|
||||||
|
if self.generate_config.trigger_word is not None:
|
||||||
|
caption = caption.replace('[trigger]', self.generate_config.trigger_word)
|
||||||
|
|
||||||
|
img: torch.Tensor = batch.tensor.clone()
|
||||||
|
image = self.to_pil(img)
|
||||||
|
|
||||||
|
|
||||||
|
# image.save(output_depth_path)
|
||||||
|
pipe: StableDiffusionXLImg2ImgPipeline = pipe
|
||||||
|
|
||||||
|
gen_images = pipe.__call__(
|
||||||
|
prompt=caption,
|
||||||
|
negative_prompt=self.generate_config.neg,
|
||||||
|
image=image,
|
||||||
|
num_inference_steps=self.generate_config.sample_steps,
|
||||||
|
guidance_scale=self.generate_config.guidance_scale,
|
||||||
|
strength=self.generate_config.denoise_strength,
|
||||||
|
).images[0]
|
||||||
|
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
||||||
|
gen_images.save(output_path)
|
||||||
|
|
||||||
|
# save caption
|
||||||
|
with open(output_caption_path, 'w') as f:
|
||||||
|
f.write(caption)
|
||||||
|
|
||||||
|
if output_inputs_path is not None:
|
||||||
|
os.makedirs(os.path.dirname(output_inputs_path), exist_ok=True)
|
||||||
|
image.save(output_inputs_path)
|
||||||
|
with open(output_inputs_caption_path, 'w') as f:
|
||||||
|
f.write(caption)
|
||||||
|
|
||||||
|
pbar.update(1)
|
||||||
|
batch.cleanup()
|
||||||
|
|
||||||
|
pbar.close()
|
||||||
|
print("Done generating images")
|
||||||
|
# cleanup
|
||||||
|
del self.sd
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
@@ -36,7 +36,24 @@ class PureLoraGenerator(Extension):
|
|||||||
return PureLoraGenerator
|
return PureLoraGenerator
|
||||||
|
|
||||||
|
|
||||||
|
# This is for generic training (LoRA, Dreambooth, FineTuning)
|
||||||
|
class Img2ImgGeneratorExtension(Extension):
|
||||||
|
# uid must be unique, it is how the extension is identified
|
||||||
|
uid = "batch_img2img"
|
||||||
|
|
||||||
|
# name is the name of the extension for printing
|
||||||
|
name = "Img2ImgGeneratorExtension"
|
||||||
|
|
||||||
|
# This is where your process class is loaded
|
||||||
|
# keep your imports in here so they don't slow down the rest of the program
|
||||||
|
@classmethod
|
||||||
|
def get_process(cls):
|
||||||
|
# import your process class here so it is only loaded when needed and return it
|
||||||
|
from .Img2ImgGenerator import Img2ImgGenerator
|
||||||
|
return Img2ImgGenerator
|
||||||
|
|
||||||
|
|
||||||
AI_TOOLKIT_EXTENSIONS = [
|
AI_TOOLKIT_EXTENSIONS = [
|
||||||
# you can put a list of extensions here
|
# you can put a list of extensions here
|
||||||
AdvancedReferenceGeneratorExtension, PureLoraGenerator
|
AdvancedReferenceGeneratorExtension, PureLoraGenerator, Img2ImgGeneratorExtension
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -483,6 +483,7 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
noise=noise,
|
noise=noise,
|
||||||
sd=self.sd,
|
sd=self.sd,
|
||||||
unconditional_embeds=unconditional_embeds,
|
unconditional_embeds=unconditional_embeds,
|
||||||
|
scaler=self.scaler,
|
||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -739,6 +739,11 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
# add to noise
|
# add to noise
|
||||||
noise += noise_shift
|
noise += noise_shift
|
||||||
|
|
||||||
|
# standardize the noise
|
||||||
|
std = noise.std(dim=(2, 3), keepdim=True)
|
||||||
|
normalizer = 1 / (std + 1e-6)
|
||||||
|
noise = noise * normalizer
|
||||||
|
|
||||||
return noise
|
return noise
|
||||||
|
|
||||||
def process_general_training_batch(self, batch: 'DataLoaderBatchDTO'):
|
def process_general_training_batch(self, batch: 'DataLoaderBatchDTO'):
|
||||||
@@ -975,14 +980,21 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
|
|
||||||
noise = noise * noise_multiplier
|
noise = noise * noise_multiplier
|
||||||
|
|
||||||
latents = latents * self.train_config.latent_multiplier
|
latent_multiplier = self.train_config.latent_multiplier
|
||||||
|
|
||||||
|
# handle adaptive scaling mased on std
|
||||||
|
if self.train_config.adaptive_scaling_factor:
|
||||||
|
std = latents.std(dim=(2, 3), keepdim=True)
|
||||||
|
normalizer = 1 / (std + 1e-6)
|
||||||
|
latent_multiplier = normalizer
|
||||||
|
|
||||||
|
latents = latents * latent_multiplier
|
||||||
|
batch.latents = latents
|
||||||
|
|
||||||
# normalize latents to a mean of 0 and an std of 1
|
# normalize latents to a mean of 0 and an std of 1
|
||||||
# mean_zero_latents = latents - latents.mean()
|
# mean_zero_latents = latents - latents.mean()
|
||||||
# latents = mean_zero_latents / mean_zero_latents.std()
|
# latents = mean_zero_latents / mean_zero_latents.std()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if batch.unconditional_latents is not None:
|
if batch.unconditional_latents is not None:
|
||||||
batch.unconditional_latents = batch.unconditional_latents * self.train_config.latent_multiplier
|
batch.unconditional_latents = batch.unconditional_latents * self.train_config.latent_multiplier
|
||||||
|
|
||||||
|
|||||||
@@ -80,6 +80,7 @@ class GenerateProcess(BaseProcess):
|
|||||||
self.model_config = ModelConfig(**self.get_conf('model', required=True))
|
self.model_config = ModelConfig(**self.get_conf('model', required=True))
|
||||||
self.device = self.get_conf('device', self.job.device)
|
self.device = self.get_conf('device', self.job.device)
|
||||||
self.generate_config = GenerateConfig(**self.get_conf('generate', required=True))
|
self.generate_config = GenerateConfig(**self.get_conf('generate', required=True))
|
||||||
|
self.torch_dtype = get_torch_dtype(self.get_conf('dtype', 'float16'))
|
||||||
|
|
||||||
self.progress_bar = None
|
self.progress_bar = None
|
||||||
self.sd = StableDiffusion(
|
self.sd = StableDiffusion(
|
||||||
@@ -87,49 +88,57 @@ class GenerateProcess(BaseProcess):
|
|||||||
model_config=self.model_config,
|
model_config=self.model_config,
|
||||||
dtype=self.model_config.dtype,
|
dtype=self.model_config.dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f"Using device {self.device}")
|
print(f"Using device {self.device}")
|
||||||
|
|
||||||
|
def clean_prompt(self, prompt: str):
|
||||||
|
# remove any non alpha numeric characters or ,'" from prompt
|
||||||
|
return ''.join(e for e in prompt if e.isalnum() or e in ", '\"")
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
super().run()
|
with torch.no_grad():
|
||||||
print("Loading model...")
|
super().run()
|
||||||
self.sd.load_model()
|
print("Loading model...")
|
||||||
|
self.sd.load_model()
|
||||||
|
self.sd.pipeline.to(self.device, self.torch_dtype)
|
||||||
|
|
||||||
print("Compiling model...")
|
print("Compiling model...")
|
||||||
# self.sd.unet = torch.compile(self.sd.unet, mode="reduce-overhead", fullgraph=True)
|
# self.sd.unet = torch.compile(self.sd.unet, mode="reduce-overhead", fullgraph=True)
|
||||||
if self.generate_config.compile:
|
if self.generate_config.compile:
|
||||||
self.sd.unet = torch.compile(self.sd.unet, mode="reduce-overhead")
|
self.sd.unet = torch.compile(self.sd.unet, mode="reduce-overhead")
|
||||||
|
|
||||||
print(f"Generating {len(self.generate_config.prompts)} images")
|
print(f"Generating {len(self.generate_config.prompts)} images")
|
||||||
# build prompt image configs
|
# build prompt image configs
|
||||||
prompt_image_configs = []
|
prompt_image_configs = []
|
||||||
for prompt in self.generate_config.prompts:
|
for prompt in self.generate_config.prompts:
|
||||||
width = self.generate_config.width
|
width = self.generate_config.width
|
||||||
height = self.generate_config.height
|
height = self.generate_config.height
|
||||||
|
prompt = self.clean_prompt(prompt)
|
||||||
|
|
||||||
if self.generate_config.size_list is not None:
|
if self.generate_config.size_list is not None:
|
||||||
# randomly select a size
|
# randomly select a size
|
||||||
width, height = random.choice(self.generate_config.size_list)
|
width, height = random.choice(self.generate_config.size_list)
|
||||||
|
|
||||||
prompt_image_configs.append(GenerateImageConfig(
|
prompt_image_configs.append(GenerateImageConfig(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
prompt_2=self.generate_config.prompt_2,
|
prompt_2=self.generate_config.prompt_2,
|
||||||
width=width,
|
width=width,
|
||||||
height=height,
|
height=height,
|
||||||
num_inference_steps=self.generate_config.sample_steps,
|
num_inference_steps=self.generate_config.sample_steps,
|
||||||
guidance_scale=self.generate_config.guidance_scale,
|
guidance_scale=self.generate_config.guidance_scale,
|
||||||
negative_prompt=self.generate_config.neg,
|
negative_prompt=self.generate_config.neg,
|
||||||
negative_prompt_2=self.generate_config.neg_2,
|
negative_prompt_2=self.generate_config.neg_2,
|
||||||
seed=self.generate_config.seed,
|
seed=self.generate_config.seed,
|
||||||
guidance_rescale=self.generate_config.guidance_rescale,
|
guidance_rescale=self.generate_config.guidance_rescale,
|
||||||
output_ext=self.generate_config.ext,
|
output_ext=self.generate_config.ext,
|
||||||
output_folder=self.output_folder,
|
output_folder=self.output_folder,
|
||||||
add_prompt_file=self.generate_config.prompt_file
|
add_prompt_file=self.generate_config.prompt_file
|
||||||
))
|
))
|
||||||
# generate images
|
# generate images
|
||||||
self.sd.generate_images(prompt_image_configs, sampler=self.generate_config.sampler)
|
self.sd.generate_images(prompt_image_configs, sampler=self.generate_config.sampler)
|
||||||
|
|
||||||
print("Done generating images")
|
print("Done generating images")
|
||||||
# cleanup
|
# cleanup
|
||||||
del self.sd
|
del self.sd
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|||||||
@@ -266,6 +266,8 @@ class TrainConfig:
|
|||||||
self.reg_weight = kwargs.get('reg_weight', 1.0)
|
self.reg_weight = kwargs.get('reg_weight', 1.0)
|
||||||
self.num_train_timesteps = kwargs.get('num_train_timesteps', 1000)
|
self.num_train_timesteps = kwargs.get('num_train_timesteps', 1000)
|
||||||
self.random_noise_shift = kwargs.get('random_noise_shift', 0.0)
|
self.random_noise_shift = kwargs.get('random_noise_shift', 0.0)
|
||||||
|
# automatically adapte the vae scaling based on the image norm
|
||||||
|
self.adaptive_scaling_factor = kwargs.get('adaptive_scaling_factor', False)
|
||||||
|
|
||||||
# dropout that happens before encoding. It functions independently per text encoder
|
# dropout that happens before encoding. It functions independently per text encoder
|
||||||
self.prompt_dropout_prob = kwargs.get('prompt_dropout_prob', 0.0)
|
self.prompt_dropout_prob = kwargs.get('prompt_dropout_prob', 0.0)
|
||||||
|
|||||||
@@ -387,7 +387,7 @@ class CaptionProcessingDTOMixin:
|
|||||||
|
|
||||||
# join back together
|
# join back together
|
||||||
caption = ', '.join(token_list)
|
caption = ', '.join(token_list)
|
||||||
caption = inject_trigger_into_prompt(caption, trigger, to_replace_list, add_if_not_present)
|
# caption = inject_trigger_into_prompt(caption, trigger, to_replace_list, add_if_not_present)
|
||||||
|
|
||||||
if self.dataset_config.random_triggers:
|
if self.dataset_config.random_triggers:
|
||||||
num_triggers = self.dataset_config.random_triggers_max
|
num_triggers = self.dataset_config.random_triggers_max
|
||||||
|
|||||||
@@ -407,6 +407,7 @@ def get_guided_loss_polarity(
|
|||||||
batch: 'DataLoaderBatchDTO',
|
batch: 'DataLoaderBatchDTO',
|
||||||
noise: torch.Tensor,
|
noise: torch.Tensor,
|
||||||
sd: 'StableDiffusion',
|
sd: 'StableDiffusion',
|
||||||
|
scaler=None,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
dtype = get_torch_dtype(sd.torch_dtype)
|
dtype = get_torch_dtype(sd.torch_dtype)
|
||||||
@@ -473,7 +474,10 @@ def get_guided_loss_polarity(
|
|||||||
|
|
||||||
loss = loss.mean([1, 2, 3])
|
loss = loss.mean([1, 2, 3])
|
||||||
loss = loss.mean()
|
loss = loss.mean()
|
||||||
loss.backward()
|
if scaler is not None:
|
||||||
|
scaler.scale(loss).backward()
|
||||||
|
else:
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
# detach it so parent class can run backward on no grads without throwing error
|
# detach it so parent class can run backward on no grads without throwing error
|
||||||
loss = loss.detach()
|
loss = loss.detach()
|
||||||
@@ -590,6 +594,7 @@ def get_guidance_loss(
|
|||||||
unconditional_embeds: Optional[PromptEmbeds] = None,
|
unconditional_embeds: Optional[PromptEmbeds] = None,
|
||||||
mask_multiplier=None,
|
mask_multiplier=None,
|
||||||
prior_pred=None,
|
prior_pred=None,
|
||||||
|
scaler=None,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
# TODO add others and process individual batch items separately
|
# TODO add others and process individual batch items separately
|
||||||
@@ -621,6 +626,7 @@ def get_guidance_loss(
|
|||||||
batch,
|
batch,
|
||||||
noise,
|
noise,
|
||||||
sd,
|
sd,
|
||||||
|
scaler=scaler,
|
||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
elif guidance_type == "tnt":
|
elif guidance_type == "tnt":
|
||||||
|
|||||||
@@ -41,9 +41,12 @@ sd_config = {
|
|||||||
"prediction_type": "epsilon",
|
"prediction_type": "epsilon",
|
||||||
"sample_max_value": 1.0,
|
"sample_max_value": 1.0,
|
||||||
"set_alpha_to_one": False,
|
"set_alpha_to_one": False,
|
||||||
"skip_prk_steps": False,
|
# "skip_prk_steps": False, # for training
|
||||||
"steps_offset": 1,
|
"skip_prk_steps": True,
|
||||||
"timestep_spacing": "trailing",
|
# "steps_offset": 1,
|
||||||
|
"steps_offset": 0,
|
||||||
|
# "timestep_spacing": "trailing", # for training
|
||||||
|
"timestep_spacing": "leading",
|
||||||
"trained_betas": None
|
"trained_betas": None
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user