mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-29 18:51:37 +00:00
t2i training working from what I can tell at least
This commit is contained in:
@@ -54,16 +54,20 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
if os.path.exists(os.path.join(adapter_folder_path, file_name_no_ext + ext)):
|
||||
adapter_images.append(os.path.join(adapter_folder_path, file_name_no_ext + ext))
|
||||
break
|
||||
|
||||
width, height = batch.file_items[0].crop_width, batch.file_items[0].crop_height
|
||||
adapter_tensors = []
|
||||
# load images with torch transforms
|
||||
for adapter_image in adapter_images:
|
||||
for idx, adapter_image in enumerate(adapter_images):
|
||||
img = Image.open(adapter_image)
|
||||
# resize to match batch shape
|
||||
img = img.resize((width, height))
|
||||
img = adapter_transforms(img)
|
||||
adapter_tensors.append(img)
|
||||
|
||||
# stack them
|
||||
adapter_tensors = torch.stack(adapter_tensors)
|
||||
adapter_tensors = torch.stack(adapter_tensors).to(
|
||||
self.device_torch, dtype=get_torch_dtype(self.train_config.dtype)
|
||||
)
|
||||
return adapter_tensors
|
||||
|
||||
def hook_train_loop(self, batch):
|
||||
@@ -79,8 +83,8 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
adapter_images = self.get_adapter_images(batch)
|
||||
# not 100% sure what this does. But they do it here
|
||||
# https://github.com/huggingface/diffusers/blob/38a664a3d61e27ab18cd698231422b3c38d6eebf/examples/t2i_adapter/train_t2i_adapter_sdxl.py#L1170
|
||||
sigmas = self.get_sigmas(timesteps, len(noisy_latents.shape), noisy_latents.dtype)
|
||||
noisy_latents = noisy_latents / ((sigmas ** 2 + 1) ** 0.5)
|
||||
# sigmas = self.get_sigmas(timesteps, len(noisy_latents.shape), noisy_latents.dtype)
|
||||
# noisy_latents = noisy_latents / ((sigmas ** 2 + 1) ** 0.5)
|
||||
|
||||
# flush()
|
||||
self.optimizer.zero_grad()
|
||||
@@ -126,39 +130,38 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
**pred_kwargs
|
||||
)
|
||||
|
||||
if self.adapter:
|
||||
# todo, diffusers does this on t2i training, is it better approach?
|
||||
# Denoise the latents
|
||||
denoised_latents = noise_pred * (-sigmas) + noisy_latents
|
||||
weighing = sigmas ** -2.0
|
||||
|
||||
# Get the target for loss depending on the prediction type
|
||||
if self.sd.noise_scheduler.config.prediction_type == "epsilon":
|
||||
target = batch.latents # we are computing loss against denoise latents
|
||||
elif self.sd.noise_scheduler.config.prediction_type == "v_prediction":
|
||||
target = self.sd.noise_scheduler.get_velocity(batch.latents, noise, timesteps)
|
||||
else:
|
||||
raise ValueError(f"Unknown prediction type {self.sd.noise_scheduler.config.prediction_type}")
|
||||
|
||||
# MSE loss
|
||||
loss = torch.mean(
|
||||
(weighing.float() * (denoised_latents.float() - target.float()) ** 2).reshape(target.shape[0], -1),
|
||||
dim=1,
|
||||
)
|
||||
# if self.adapter:
|
||||
# # todo, diffusers does this on t2i training, is it better approach?
|
||||
# # Denoise the latents
|
||||
# denoised_latents = noise_pred * (-sigmas) + noisy_latents
|
||||
# weighing = sigmas ** -2.0
|
||||
#
|
||||
# # Get the target for loss depending on the prediction type
|
||||
# if self.sd.noise_scheduler.config.prediction_type == "epsilon":
|
||||
# target = batch.latents # we are computing loss against denoise latents
|
||||
# elif self.sd.noise_scheduler.config.prediction_type == "v_prediction":
|
||||
# target = self.sd.noise_scheduler.get_velocity(batch.latents, noise, timesteps)
|
||||
# else:
|
||||
# raise ValueError(f"Unknown prediction type {self.sd.noise_scheduler.config.prediction_type}")
|
||||
#
|
||||
# # MSE loss
|
||||
# loss = torch.mean(
|
||||
# (weighing.float() * (denoised_latents.float() - target.float()) ** 2).reshape(target.shape[0], -1),
|
||||
# dim=1,
|
||||
# )
|
||||
# else:
|
||||
noise = noise.to(self.device_torch, dtype=dtype).detach()
|
||||
if self.sd.prediction_type == 'v_prediction':
|
||||
# v-parameterization training
|
||||
target = self.sd.noise_scheduler.get_velocity(noisy_latents, noise, timesteps)
|
||||
else:
|
||||
noise = noise.to(self.device_torch, dtype=dtype).detach()
|
||||
if self.sd.prediction_type == 'v_prediction':
|
||||
# v-parameterization training
|
||||
target = self.sd.noise_scheduler.get_velocity(noisy_latents, noise, timesteps)
|
||||
else:
|
||||
target = noise
|
||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
|
||||
loss = loss.mean([1, 2, 3])
|
||||
target = noise
|
||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
|
||||
loss = loss.mean([1, 2, 3])
|
||||
|
||||
# TODO: I think the sigma method does not need this. Check
|
||||
if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001:
|
||||
# add min_snr_gamma
|
||||
loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma)
|
||||
if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001:
|
||||
# add min_snr_gamma
|
||||
loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma)
|
||||
|
||||
loss = loss.mean()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user