mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-28 10:11:14 +00:00
Bug fixes, speed improvements, compatability adjustments withdiffusers updates
This commit is contained in:
@@ -77,7 +77,10 @@ class TrainConfig:
|
||||
self.optimizer_params = kwargs.get('optimizer_params', {})
|
||||
self.lr_scheduler = kwargs.get('lr_scheduler', 'constant')
|
||||
self.lr_scheduler_params = kwargs.get('lr_scheduler_params', {})
|
||||
self.min_denoising_steps: int = kwargs.get('min_denoising_steps', 0)
|
||||
self.max_denoising_steps: int = kwargs.get('max_denoising_steps', 1000)
|
||||
self.use_linear_denoising: int = kwargs.get('use_linear_denoising', False)
|
||||
self.use_progressive_denoising: int = kwargs.get('use_progressive_denoising', False)
|
||||
self.batch_size: int = kwargs.get('batch_size', 1)
|
||||
self.dtype: str = kwargs.get('dtype', 'fp32')
|
||||
self.xformers = kwargs.get('xformers', False)
|
||||
|
||||
@@ -418,7 +418,6 @@ class AiToolkitDataset(LatentCachingMixin, BucketsMixin, CaptionMixin, Dataset):
|
||||
file_item.load_caption(self.caption_dict)
|
||||
return file_item
|
||||
|
||||
@lru_cache(maxsize=300)
|
||||
def __getitem__(self, item):
|
||||
if self.dataset_config.buckets:
|
||||
# for buckets we collate ourselves for now
|
||||
|
||||
@@ -53,7 +53,7 @@ class FileItemDTO(LatentCachingFileItemDTOMixin, CaptionProcessingDTOMixin, Imag
|
||||
self.tensor: Union[torch.Tensor, None] = None
|
||||
|
||||
def cleanup(self):
|
||||
del self.tensor
|
||||
self.tensor = None
|
||||
self.cleanup_latent()
|
||||
|
||||
|
||||
|
||||
@@ -334,6 +334,9 @@ class ToolkitNetworkMixin:
|
||||
|
||||
@multiplier.setter
|
||||
def multiplier(self, value: Union[float, List[float]]):
|
||||
# only update if the value has changed
|
||||
if self._multiplier == value:
|
||||
return
|
||||
self._multiplier = value
|
||||
self._update_lora_multiplier()
|
||||
|
||||
|
||||
@@ -138,9 +138,9 @@ class StableDiffusion:
|
||||
self.noise_scheduler = scheduler
|
||||
|
||||
# move the betas alphas and alphas_cumprod to device. Sometimed they get stuck on cpu, not sure why
|
||||
self.noise_scheduler.betas = self.noise_scheduler.betas.to(self.device_torch)
|
||||
self.noise_scheduler.alphas = self.noise_scheduler.alphas.to(self.device_torch)
|
||||
self.noise_scheduler.alphas_cumprod = self.noise_scheduler.alphas_cumprod.to(self.device_torch)
|
||||
# self.noise_scheduler.betas = self.noise_scheduler.betas.to(self.device_torch)
|
||||
# self.noise_scheduler.alphas = self.noise_scheduler.alphas.to(self.device_torch)
|
||||
# self.noise_scheduler.alphas_cumprod = self.noise_scheduler.alphas_cumprod.to(self.device_torch)
|
||||
|
||||
model_path = self.model_config.name_or_path
|
||||
if 'civitai.com' in self.model_config.name_or_path:
|
||||
|
||||
Reference in New Issue
Block a user