t2i training working from what I can tell at least

This commit is contained in:
Jaret Burkett
2023-09-17 15:56:43 -06:00
parent 181f237a7b
commit 61badf85a7
5 changed files with 214 additions and 174 deletions

View File

@@ -54,16 +54,20 @@ class SDTrainer(BaseSDTrainProcess):
if os.path.exists(os.path.join(adapter_folder_path, file_name_no_ext + ext)):
adapter_images.append(os.path.join(adapter_folder_path, file_name_no_ext + ext))
break
width, height = batch.file_items[0].crop_width, batch.file_items[0].crop_height
adapter_tensors = []
# load images with torch transforms
for adapter_image in adapter_images:
for idx, adapter_image in enumerate(adapter_images):
img = Image.open(adapter_image)
# resize to match batch shape
img = img.resize((width, height))
img = adapter_transforms(img)
adapter_tensors.append(img)
# stack them
adapter_tensors = torch.stack(adapter_tensors)
adapter_tensors = torch.stack(adapter_tensors).to(
self.device_torch, dtype=get_torch_dtype(self.train_config.dtype)
)
return adapter_tensors
def hook_train_loop(self, batch):
@@ -79,8 +83,8 @@ class SDTrainer(BaseSDTrainProcess):
adapter_images = self.get_adapter_images(batch)
# not 100% sure what this does. But they do it here
# https://github.com/huggingface/diffusers/blob/38a664a3d61e27ab18cd698231422b3c38d6eebf/examples/t2i_adapter/train_t2i_adapter_sdxl.py#L1170
sigmas = self.get_sigmas(timesteps, len(noisy_latents.shape), noisy_latents.dtype)
noisy_latents = noisy_latents / ((sigmas ** 2 + 1) ** 0.5)
# sigmas = self.get_sigmas(timesteps, len(noisy_latents.shape), noisy_latents.dtype)
# noisy_latents = noisy_latents / ((sigmas ** 2 + 1) ** 0.5)
# flush()
self.optimizer.zero_grad()
@@ -126,39 +130,38 @@ class SDTrainer(BaseSDTrainProcess):
**pred_kwargs
)
if self.adapter:
# todo, diffusers does this on t2i training, is it better approach?
# Denoise the latents
denoised_latents = noise_pred * (-sigmas) + noisy_latents
weighing = sigmas ** -2.0
# Get the target for loss depending on the prediction type
if self.sd.noise_scheduler.config.prediction_type == "epsilon":
target = batch.latents # we are computing loss against denoise latents
elif self.sd.noise_scheduler.config.prediction_type == "v_prediction":
target = self.sd.noise_scheduler.get_velocity(batch.latents, noise, timesteps)
else:
raise ValueError(f"Unknown prediction type {self.sd.noise_scheduler.config.prediction_type}")
# MSE loss
loss = torch.mean(
(weighing.float() * (denoised_latents.float() - target.float()) ** 2).reshape(target.shape[0], -1),
dim=1,
)
# if self.adapter:
# # todo, diffusers does this on t2i training, is it better approach?
# # Denoise the latents
# denoised_latents = noise_pred * (-sigmas) + noisy_latents
# weighing = sigmas ** -2.0
#
# # Get the target for loss depending on the prediction type
# if self.sd.noise_scheduler.config.prediction_type == "epsilon":
# target = batch.latents # we are computing loss against denoise latents
# elif self.sd.noise_scheduler.config.prediction_type == "v_prediction":
# target = self.sd.noise_scheduler.get_velocity(batch.latents, noise, timesteps)
# else:
# raise ValueError(f"Unknown prediction type {self.sd.noise_scheduler.config.prediction_type}")
#
# # MSE loss
# loss = torch.mean(
# (weighing.float() * (denoised_latents.float() - target.float()) ** 2).reshape(target.shape[0], -1),
# dim=1,
# )
# else:
noise = noise.to(self.device_torch, dtype=dtype).detach()
if self.sd.prediction_type == 'v_prediction':
# v-parameterization training
target = self.sd.noise_scheduler.get_velocity(noisy_latents, noise, timesteps)
else:
noise = noise.to(self.device_torch, dtype=dtype).detach()
if self.sd.prediction_type == 'v_prediction':
# v-parameterization training
target = self.sd.noise_scheduler.get_velocity(noisy_latents, noise, timesteps)
else:
target = noise
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = loss.mean([1, 2, 3])
target = noise
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = loss.mean([1, 2, 3])
# TODO: I think the sigma method does not need this. Check
if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001:
# add min_snr_gamma
loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma)
if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001:
# add min_snr_gamma
loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma)
loss = loss.mean()