Bug fixes and little improvements here and there.

This commit is contained in:
Jaret Burkett
2024-06-08 06:24:20 -06:00
parent 833c833f28
commit 3f3636b788
12 changed files with 358 additions and 117 deletions

View File

@@ -4,7 +4,7 @@ from collections import OrderedDict
from typing import Union, Literal, List, Optional
import numpy as np
from diffusers import T2IAdapter, AutoencoderTiny
from diffusers import T2IAdapter, AutoencoderTiny, ControlNetModel
import torch.functional as F
from safetensors.torch import load_file
@@ -824,6 +824,10 @@ class SDTrainer(BaseSDTrainProcess):
# remove the residuals as we wont use them on prediction when matching control
if match_adapter_assist and 'down_intrablock_additional_residuals' in pred_kwargs:
del pred_kwargs['down_intrablock_additional_residuals']
if match_adapter_assist and 'down_block_additional_residuals' in pred_kwargs:
del pred_kwargs['down_block_additional_residuals']
if match_adapter_assist and 'mid_block_additional_residual' in pred_kwargs:
del pred_kwargs['mid_block_additional_residual']
if can_disable_adapter:
self.adapter.is_active = was_adapter_active
@@ -1065,7 +1069,7 @@ class SDTrainer(BaseSDTrainProcess):
# if prompt_2 is not None:
# prompt_2 = prompt_2 + [self.train_config.negative_prompt for x in range(len(prompt_2))]
with network:
with (network):
# encode clip adapter here so embeds are active for tokenizer
if self.adapter and isinstance(self.adapter, ClipVisionAdapter):
with self.timer('encode_clip_vision_embeds'):
@@ -1162,26 +1166,27 @@ class SDTrainer(BaseSDTrainProcess):
# flush()
pred_kwargs = {}
if has_adapter_img and (
(self.adapter and isinstance(self.adapter, T2IAdapter)) or self.assistant_adapter):
with torch.set_grad_enabled(self.adapter is not None):
adapter = self.assistant_adapter if self.assistant_adapter is not None else self.adapter
adapter_multiplier = get_adapter_multiplier()
with self.timer('encode_adapter'):
down_block_additional_residuals = adapter(adapter_images)
if self.assistant_adapter:
# not training. detach
down_block_additional_residuals = [
sample.to(dtype=dtype).detach() * adapter_multiplier for sample in
down_block_additional_residuals
]
else:
down_block_additional_residuals = [
sample.to(dtype=dtype) * adapter_multiplier for sample in
down_block_additional_residuals
]
pred_kwargs['down_intrablock_additional_residuals'] = down_block_additional_residuals
if has_adapter_img:
if (self.adapter and isinstance(self.adapter, T2IAdapter)) or (self.assistant_adapter and isinstance(self.assistant_adapter, T2IAdapter)):
with torch.set_grad_enabled(self.adapter is not None):
adapter = self.assistant_adapter if self.assistant_adapter is not None else self.adapter
adapter_multiplier = get_adapter_multiplier()
with self.timer('encode_adapter'):
down_block_additional_residuals = adapter(adapter_images)
if self.assistant_adapter:
# not training. detach
down_block_additional_residuals = [
sample.to(dtype=dtype).detach() * adapter_multiplier for sample in
down_block_additional_residuals
]
else:
down_block_additional_residuals = [
sample.to(dtype=dtype) * adapter_multiplier for sample in
down_block_additional_residuals
]
pred_kwargs['down_intrablock_additional_residuals'] = down_block_additional_residuals
if self.adapter and isinstance(self.adapter, IPAdapter):
with self.timer('encode_adapter_embeds'):
@@ -1362,6 +1367,32 @@ class SDTrainer(BaseSDTrainProcess):
if self.train_config.do_cfg:
self.adapter.add_extra_values(torch.zeros_like(batch.extra_values.detach()), is_unconditional=True)
if has_adapter_img:
if (self.adapter and isinstance(self.adapter, ControlNetModel)) or (self.assistant_adapter and isinstance(self.assistant_adapter, ControlNetModel)):
if self.train_config.do_cfg:
raise ValueError("ControlNetModel is not supported with CFG")
with torch.set_grad_enabled(self.adapter is not None):
adapter: ControlNetModel = self.assistant_adapter if self.assistant_adapter is not None else self.adapter
adapter_multiplier = get_adapter_multiplier()
with self.timer('encode_adapter'):
# add_text_embeds is pooled_prompt_embeds for sdxl
added_cond_kwargs = {}
if self.sd.is_xl:
added_cond_kwargs["text_embeds"] = conditional_embeds.pooled_embeds
added_cond_kwargs['time_ids'] = self.sd.get_time_ids_from_latents(noisy_latents)
down_block_res_samples, mid_block_res_sample = adapter(
noisy_latents,
timesteps,
encoder_hidden_states=conditional_embeds.text_embeds,
controlnet_cond=adapter_images,
conditioning_scale=1.0,
guess_mode=False,
added_cond_kwargs=added_cond_kwargs,
return_dict=False,
)
pred_kwargs['down_block_additional_residuals'] = down_block_res_samples
pred_kwargs['mid_block_additional_residual'] = mid_block_res_sample
self.before_unet_predict()
# do a prior pred if we have an unconditional image, we will swap out the giadance later
@@ -1423,10 +1454,10 @@ class SDTrainer(BaseSDTrainProcess):
# 0.0 for the backward pass and the gradients will be 0.0
# I spent weeks on fighting this. DON'T DO IT
# with fsdp_overlap_step_with_backward():
if self.is_bfloat:
loss.backward()
else:
self.scaler.scale(loss).backward()
# if self.is_bfloat:
loss.backward()
# else:
# self.scaler.scale(loss).backward()
# flush()
if not self.is_grad_accumulation_step:
@@ -1443,8 +1474,8 @@ class SDTrainer(BaseSDTrainProcess):
self.optimizer.step()
else:
# apply gradients
self.scaler.step(self.optimizer)
self.scaler.update()
self.optimizer.step()
# self.scaler.update()
# self.optimizer.step()
self.optimizer.zero_grad(set_to_none=True)
else: