mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-03-10 21:19:49 +00:00
Added some experimental training techniques. Ignore for now. Still in testing.
This commit is contained in:
@@ -325,6 +325,8 @@ class TrainConfig:
|
||||
self.adapter_assist_type: Optional[str] = kwargs.get('adapter_assist_type', 't2i') # t2i, control_net
|
||||
self.noise_multiplier = kwargs.get('noise_multiplier', 1.0)
|
||||
self.target_noise_multiplier = kwargs.get('target_noise_multiplier', 1.0)
|
||||
self.random_noise_multiplier = kwargs.get('random_noise_multiplier', 0.0)
|
||||
self.random_noise_shift = kwargs.get('random_noise_shift', 0.0)
|
||||
self.img_multiplier = kwargs.get('img_multiplier', 1.0)
|
||||
self.noisy_latent_multiplier = kwargs.get('noisy_latent_multiplier', 1.0)
|
||||
self.latent_multiplier = kwargs.get('latent_multiplier', 1.0)
|
||||
@@ -333,7 +335,6 @@ class TrainConfig:
|
||||
# multiplier applied to loos on regularization images
|
||||
self.reg_weight = kwargs.get('reg_weight', 1.0)
|
||||
self.num_train_timesteps = kwargs.get('num_train_timesteps', 1000)
|
||||
self.random_noise_shift = kwargs.get('random_noise_shift', 0.0)
|
||||
# automatically adapte the vae scaling based on the image norm
|
||||
self.adaptive_scaling_factor = kwargs.get('adaptive_scaling_factor', False)
|
||||
|
||||
@@ -412,7 +413,7 @@ class TrainConfig:
|
||||
self.correct_pred_norm = kwargs.get('correct_pred_norm', False)
|
||||
self.correct_pred_norm_multiplier = kwargs.get('correct_pred_norm_multiplier', 1.0)
|
||||
|
||||
self.loss_type = kwargs.get('loss_type', 'mse') # mse, mae, wavelet, pixelspace
|
||||
self.loss_type = kwargs.get('loss_type', 'mse') # mse, mae, wavelet, pixelspace, cfm
|
||||
|
||||
# scale the prediction by this. Increase for more detail, decrease for less
|
||||
self.pred_scaler = kwargs.get('pred_scaler', 1.0)
|
||||
@@ -436,7 +437,8 @@ class TrainConfig:
|
||||
# adds an additional loss to the network to encourage it output a normalized standard deviation
|
||||
self.target_norm_std = kwargs.get('target_norm_std', None)
|
||||
self.target_norm_std_value = kwargs.get('target_norm_std_value', 1.0)
|
||||
self.timestep_type = kwargs.get('timestep_type', 'sigmoid') # sigmoid, linear, lognorm_blend
|
||||
self.timestep_type = kwargs.get('timestep_type', 'sigmoid') # sigmoid, linear, lognorm_blend, next_sample
|
||||
self.next_sample_timesteps = kwargs.get('next_sample_timesteps', 8)
|
||||
self.linear_timesteps = kwargs.get('linear_timesteps', False)
|
||||
self.linear_timesteps2 = kwargs.get('linear_timesteps2', False)
|
||||
self.disable_sampling = kwargs.get('disable_sampling', False)
|
||||
|
||||
@@ -142,7 +142,9 @@ class StableDiffusion:
|
||||
):
|
||||
self.accelerator = get_accelerator()
|
||||
self.custom_pipeline = custom_pipeline
|
||||
self.device = device
|
||||
self.device = str(device)
|
||||
if "cuda" in self.device and ":" not in self.device:
|
||||
self.device = f"{self.device}:0"
|
||||
self.device_torch = torch.device(device)
|
||||
self.dtype = dtype
|
||||
self.torch_dtype = get_torch_dtype(dtype)
|
||||
@@ -2086,7 +2088,10 @@ class StableDiffusion:
|
||||
noise_pred = noise_pred
|
||||
else:
|
||||
if self.unet.device != self.device_torch:
|
||||
self.unet.to(self.device_torch)
|
||||
try:
|
||||
self.unet.to(self.device_torch)
|
||||
except Exception as e:
|
||||
pass
|
||||
if self.unet.dtype != self.torch_dtype:
|
||||
self.unet = self.unet.to(dtype=self.torch_dtype)
|
||||
if self.is_flux:
|
||||
|
||||
Reference in New Issue
Block a user