mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 11:11:37 +00:00
Added ability to add masks to dataloader and sd trainer to adjust weight of image
This commit is contained in:
@@ -98,18 +98,31 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
noisy_latents, noise, timesteps, conditioned_prompts, imgs = self.process_general_training_batch(batch)
|
||||
network_weight_list = batch.get_network_weight_list()
|
||||
|
||||
adapter_images = None
|
||||
sigmas = None
|
||||
if self.adapter:
|
||||
# todo move this to data loader
|
||||
if batch.control_tensor is not None:
|
||||
adapter_images = batch.control_tensor.to(self.device_torch, dtype=dtype).detach()
|
||||
else:
|
||||
adapter_images = self.get_adapter_images(batch)
|
||||
# not 100% sure what this does. But they do it here
|
||||
# https://github.com/huggingface/diffusers/blob/38a664a3d61e27ab18cd698231422b3c38d6eebf/examples/t2i_adapter/train_t2i_adapter_sdxl.py#L1170
|
||||
# sigmas = self.get_sigmas(timesteps, len(noisy_latents.shape), noisy_latents.dtype)
|
||||
# noisy_latents = noisy_latents / ((sigmas ** 2 + 1) ** 0.5)
|
||||
with torch.no_grad():
|
||||
adapter_images = None
|
||||
sigmas = None
|
||||
if self.adapter:
|
||||
# todo move this to data loader
|
||||
if batch.control_tensor is not None:
|
||||
adapter_images = batch.control_tensor.to(self.device_torch, dtype=dtype).detach()
|
||||
else:
|
||||
adapter_images = self.get_adapter_images(batch)
|
||||
# not 100% sure what this does. But they do it here
|
||||
# https://github.com/huggingface/diffusers/blob/38a664a3d61e27ab18cd698231422b3c38d6eebf/examples/t2i_adapter/train_t2i_adapter_sdxl.py#L1170
|
||||
# sigmas = self.get_sigmas(timesteps, len(noisy_latents.shape), noisy_latents.dtype)
|
||||
# noisy_latents = noisy_latents / ((sigmas ** 2 + 1) ** 0.5)
|
||||
|
||||
mask_multiplier = 1.0
|
||||
if batch.mask_tensor is not None:
|
||||
# upsampling no supported for bfloat16
|
||||
mask_multiplier = batch.mask_tensor.to(self.device_torch, dtype=torch.float16).detach()
|
||||
# scale down to the size of the latents, mask multiplier shape(bs, 1, width, height), noisy_latents shape(bs, channels, width, height)
|
||||
mask_multiplier = torch.nn.functional.interpolate(
|
||||
mask_multiplier, size=(noisy_latents.shape[2], noisy_latents.shape[3])
|
||||
)
|
||||
# expand to match latents
|
||||
mask_multiplier = mask_multiplier.expand(-1, noisy_latents.shape[1], -1, -1)
|
||||
mask_multiplier = mask_multiplier.to(self.device_torch, dtype=dtype).detach()
|
||||
|
||||
# flush()
|
||||
self.optimizer.zero_grad()
|
||||
@@ -188,6 +201,9 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
else:
|
||||
target = noise
|
||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
|
||||
# multiply by our mask
|
||||
loss = loss * mask_multiplier
|
||||
|
||||
loss = loss.mean([1, 2, 3])
|
||||
|
||||
if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001:
|
||||
|
||||
Reference in New Issue
Block a user