Bug fixes, speed improvements, compatability adjustments withdiffusers updates

This commit is contained in:
Jaret Burkett
2023-09-13 07:03:53 -06:00
parent d8d1e6fd1e
commit ae70200d3c
8 changed files with 52 additions and 11 deletions

View File

@@ -77,7 +77,10 @@ class TrainConfig:
self.optimizer_params = kwargs.get('optimizer_params', {})
self.lr_scheduler = kwargs.get('lr_scheduler', 'constant')
self.lr_scheduler_params = kwargs.get('lr_scheduler_params', {})
self.min_denoising_steps: int = kwargs.get('min_denoising_steps', 0)
self.max_denoising_steps: int = kwargs.get('max_denoising_steps', 1000)
self.use_linear_denoising: int = kwargs.get('use_linear_denoising', False)
self.use_progressive_denoising: int = kwargs.get('use_progressive_denoising', False)
self.batch_size: int = kwargs.get('batch_size', 1)
self.dtype: str = kwargs.get('dtype', 'fp32')
self.xformers = kwargs.get('xformers', False)

View File

@@ -418,7 +418,6 @@ class AiToolkitDataset(LatentCachingMixin, BucketsMixin, CaptionMixin, Dataset):
file_item.load_caption(self.caption_dict)
return file_item
@lru_cache(maxsize=300)
def __getitem__(self, item):
if self.dataset_config.buckets:
# for buckets we collate ourselves for now

View File

@@ -53,7 +53,7 @@ class FileItemDTO(LatentCachingFileItemDTOMixin, CaptionProcessingDTOMixin, Imag
self.tensor: Union[torch.Tensor, None] = None
def cleanup(self):
del self.tensor
self.tensor = None
self.cleanup_latent()

View File

@@ -334,6 +334,9 @@ class ToolkitNetworkMixin:
@multiplier.setter
def multiplier(self, value: Union[float, List[float]]):
# only update if the value has changed
if self._multiplier == value:
return
self._multiplier = value
self._update_lora_multiplier()

View File

@@ -138,9 +138,9 @@ class StableDiffusion:
self.noise_scheduler = scheduler
# move the betas alphas and alphas_cumprod to device. Sometimed they get stuck on cpu, not sure why
self.noise_scheduler.betas = self.noise_scheduler.betas.to(self.device_torch)
self.noise_scheduler.alphas = self.noise_scheduler.alphas.to(self.device_torch)
self.noise_scheduler.alphas_cumprod = self.noise_scheduler.alphas_cumprod.to(self.device_torch)
# self.noise_scheduler.betas = self.noise_scheduler.betas.to(self.device_torch)
# self.noise_scheduler.alphas = self.noise_scheduler.alphas.to(self.device_torch)
# self.noise_scheduler.alphas_cumprod = self.noise_scheduler.alphas_cumprod.to(self.device_torch)
model_path = self.model_config.name_or_path
if 'civitai.com' in self.model_config.name_or_path: