Various bug fixes, wip stuff, and tweaks

This commit is contained in:
Jaret Burkett
2023-11-02 18:19:20 -06:00
parent 7d707b2fe6
commit ceaf1d9454
5 changed files with 472 additions and 11 deletions

View File

@@ -3,6 +3,8 @@ import time
from typing import List, Optional, Literal, Union
import random
import torch
from toolkit.prompt_utils import PromptEmbeds
ImgExt = Literal['jpg', 'png', 'webp']
@@ -184,6 +186,11 @@ class TrainConfig:
self.noise_multiplier = kwargs.get('noise_multiplier', 1.0)
self.img_multiplier = kwargs.get('img_multiplier', 1.0)
# match the norm of the noise before computing loss. This will help the model maintain its
#current understandin of the brightness of images.
self.match_noise_norm = kwargs.get('match_noise_norm', False)
# set to -1 to accumulate gradients for entire epoch
# warning, only do this with a small dataset or you will run out of memory
self.gradient_accumulation_steps = kwargs.get('gradient_accumulation_steps', 1)
@@ -406,6 +413,8 @@ class GenerateImageConfig:
add_prompt_file: bool = False, # add a prompt file with generated image
adapter_image_path: str = None, # path to adapter image
adapter_conditioning_scale: float = 1.0, # scale for adapter conditioning
latents: Union[torch.Tensor | None] = None, # input latent to start with,
extra_kwargs: dict = None, # extra data to save with prompt file
):
self.width: int = width
self.height: int = height
@@ -416,6 +425,7 @@ class GenerateImageConfig:
self.prompt_2: str = prompt_2
self.negative_prompt: str = negative_prompt
self.negative_prompt_2: str = negative_prompt_2
self.latents: Union[torch.Tensor | None] = latents
self.output_path: str = output_path
self.seed: int = seed
@@ -430,6 +440,7 @@ class GenerateImageConfig:
self.gen_time: int = int(time.time() * 1000)
self.adapter_image_path: str = adapter_image_path
self.adapter_conditioning_scale: float = adapter_conditioning_scale
self.extra_kwargs = extra_kwargs if extra_kwargs is not None else {}
# prompt string will override any settings above
self._process_prompt_string()