mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-05-01 03:31:35 +00:00
added a method to apply multipliers to noise and latents prior to combining
This commit is contained in:
@@ -5,6 +5,7 @@ from collections import OrderedDict
|
|||||||
import os
|
import os
|
||||||
from typing import Union, List
|
from typing import Union, List
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
from diffusers import T2IAdapter
|
from diffusers import T2IAdapter
|
||||||
# from lycoris.config import PRESET
|
# from lycoris.config import PRESET
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
@@ -538,6 +539,14 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
noise_offset=self.train_config.noise_offset
|
noise_offset=self.train_config.noise_offset
|
||||||
).to(self.device_torch, dtype=dtype)
|
).to(self.device_torch, dtype=dtype)
|
||||||
|
|
||||||
|
noise_multiplier = self.train_config.noise_multiplier
|
||||||
|
|
||||||
|
noise = noise * noise_multiplier
|
||||||
|
|
||||||
|
img_multiplier = self.train_config.img_multiplier
|
||||||
|
|
||||||
|
latents = latents * img_multiplier
|
||||||
|
|
||||||
noisy_latents = self.sd.noise_scheduler.add_noise(latents, noise, timesteps)
|
noisy_latents = self.sd.noise_scheduler.add_noise(latents, noise, timesteps)
|
||||||
|
|
||||||
# remove grads for these
|
# remove grads for these
|
||||||
@@ -997,10 +1006,21 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
|
|
||||||
self.progress_bar.set_postfix_str(prog_bar_string)
|
self.progress_bar.set_postfix_str(prog_bar_string)
|
||||||
|
|
||||||
|
# apply network normalizer if we are using it, not on regularization steps
|
||||||
|
if self.network is not None and self.network.is_normalizing and not is_reg_step:
|
||||||
|
with self.timer('apply_normalizer'):
|
||||||
|
self.network.apply_stored_normalizer()
|
||||||
|
|
||||||
|
# if the batch is a DataLoaderBatchDTO, then we need to clean it up
|
||||||
|
if isinstance(batch, DataLoaderBatchDTO):
|
||||||
|
with self.timer('batch_cleanup'):
|
||||||
|
batch.cleanup()
|
||||||
|
|
||||||
# don't do on first step
|
# don't do on first step
|
||||||
if self.step_num != self.start_step:
|
if self.step_num != self.start_step:
|
||||||
if is_sample_step:
|
if is_sample_step:
|
||||||
self.progress_bar.pause()
|
self.progress_bar.pause()
|
||||||
|
flush()
|
||||||
# print above the progress bar
|
# print above the progress bar
|
||||||
if self.train_config.free_u:
|
if self.train_config.free_u:
|
||||||
self.sd.pipeline.disable_freeu()
|
self.sd.pipeline.disable_freeu()
|
||||||
@@ -1036,16 +1056,6 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
# end of step
|
# end of step
|
||||||
self.step_num = step
|
self.step_num = step
|
||||||
|
|
||||||
# apply network normalizer if we are using it, not on regularization steps
|
|
||||||
if self.network is not None and self.network.is_normalizing and not is_reg_step:
|
|
||||||
with self.timer('apply_normalizer'):
|
|
||||||
self.network.apply_stored_normalizer()
|
|
||||||
|
|
||||||
# if the batch is a DataLoaderBatchDTO, then we need to clean it up
|
|
||||||
if isinstance(batch, DataLoaderBatchDTO):
|
|
||||||
with self.timer('batch_cleanup'):
|
|
||||||
batch.cleanup()
|
|
||||||
|
|
||||||
# flush every 10 steps
|
# flush every 10 steps
|
||||||
# if self.step_num % 10 == 0:
|
# if self.step_num % 10 == 0:
|
||||||
# flush()
|
# flush()
|
||||||
|
|||||||
@@ -122,6 +122,8 @@ class TrainConfig:
|
|||||||
self.start_step = kwargs.get('start_step', None)
|
self.start_step = kwargs.get('start_step', None)
|
||||||
self.free_u = kwargs.get('free_u', False)
|
self.free_u = kwargs.get('free_u', False)
|
||||||
self.adapter_assist_name_or_path: Optional[str] = kwargs.get('adapter_assist_name_or_path', None)
|
self.adapter_assist_name_or_path: Optional[str] = kwargs.get('adapter_assist_name_or_path', None)
|
||||||
|
self.noise_multiplier = kwargs.get('noise_multiplier', 1.0)
|
||||||
|
self.img_multiplier = kwargs.get('img_multiplier', 1.0)
|
||||||
|
|
||||||
match_adapter_assist = kwargs.get('match_adapter_assist', False)
|
match_adapter_assist = kwargs.get('match_adapter_assist', False)
|
||||||
self.match_adapter_chance = kwargs.get('match_adapter_chance', 0.0)
|
self.match_adapter_chance = kwargs.get('match_adapter_chance', 0.0)
|
||||||
|
|||||||
Reference in New Issue
Block a user