mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-29 02:31:17 +00:00
Bug fixes and little improvements here and there.
This commit is contained in:
@@ -4,7 +4,7 @@ from collections import OrderedDict
|
||||
from typing import Union, Literal, List, Optional
|
||||
|
||||
import numpy as np
|
||||
from diffusers import T2IAdapter, AutoencoderTiny
|
||||
from diffusers import T2IAdapter, AutoencoderTiny, ControlNetModel
|
||||
|
||||
import torch.functional as F
|
||||
from safetensors.torch import load_file
|
||||
@@ -824,6 +824,10 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
# remove the residuals as we wont use them on prediction when matching control
|
||||
if match_adapter_assist and 'down_intrablock_additional_residuals' in pred_kwargs:
|
||||
del pred_kwargs['down_intrablock_additional_residuals']
|
||||
if match_adapter_assist and 'down_block_additional_residuals' in pred_kwargs:
|
||||
del pred_kwargs['down_block_additional_residuals']
|
||||
if match_adapter_assist and 'mid_block_additional_residual' in pred_kwargs:
|
||||
del pred_kwargs['mid_block_additional_residual']
|
||||
|
||||
if can_disable_adapter:
|
||||
self.adapter.is_active = was_adapter_active
|
||||
@@ -1065,7 +1069,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
# if prompt_2 is not None:
|
||||
# prompt_2 = prompt_2 + [self.train_config.negative_prompt for x in range(len(prompt_2))]
|
||||
|
||||
with network:
|
||||
with (network):
|
||||
# encode clip adapter here so embeds are active for tokenizer
|
||||
if self.adapter and isinstance(self.adapter, ClipVisionAdapter):
|
||||
with self.timer('encode_clip_vision_embeds'):
|
||||
@@ -1162,26 +1166,27 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
|
||||
# flush()
|
||||
pred_kwargs = {}
|
||||
if has_adapter_img and (
|
||||
(self.adapter and isinstance(self.adapter, T2IAdapter)) or self.assistant_adapter):
|
||||
with torch.set_grad_enabled(self.adapter is not None):
|
||||
adapter = self.assistant_adapter if self.assistant_adapter is not None else self.adapter
|
||||
adapter_multiplier = get_adapter_multiplier()
|
||||
with self.timer('encode_adapter'):
|
||||
down_block_additional_residuals = adapter(adapter_images)
|
||||
if self.assistant_adapter:
|
||||
# not training. detach
|
||||
down_block_additional_residuals = [
|
||||
sample.to(dtype=dtype).detach() * adapter_multiplier for sample in
|
||||
down_block_additional_residuals
|
||||
]
|
||||
else:
|
||||
down_block_additional_residuals = [
|
||||
sample.to(dtype=dtype) * adapter_multiplier for sample in
|
||||
down_block_additional_residuals
|
||||
]
|
||||
|
||||
pred_kwargs['down_intrablock_additional_residuals'] = down_block_additional_residuals
|
||||
if has_adapter_img:
|
||||
if (self.adapter and isinstance(self.adapter, T2IAdapter)) or (self.assistant_adapter and isinstance(self.assistant_adapter, T2IAdapter)):
|
||||
with torch.set_grad_enabled(self.adapter is not None):
|
||||
adapter = self.assistant_adapter if self.assistant_adapter is not None else self.adapter
|
||||
adapter_multiplier = get_adapter_multiplier()
|
||||
with self.timer('encode_adapter'):
|
||||
down_block_additional_residuals = adapter(adapter_images)
|
||||
if self.assistant_adapter:
|
||||
# not training. detach
|
||||
down_block_additional_residuals = [
|
||||
sample.to(dtype=dtype).detach() * adapter_multiplier for sample in
|
||||
down_block_additional_residuals
|
||||
]
|
||||
else:
|
||||
down_block_additional_residuals = [
|
||||
sample.to(dtype=dtype) * adapter_multiplier for sample in
|
||||
down_block_additional_residuals
|
||||
]
|
||||
|
||||
pred_kwargs['down_intrablock_additional_residuals'] = down_block_additional_residuals
|
||||
|
||||
if self.adapter and isinstance(self.adapter, IPAdapter):
|
||||
with self.timer('encode_adapter_embeds'):
|
||||
@@ -1362,6 +1367,32 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
if self.train_config.do_cfg:
|
||||
self.adapter.add_extra_values(torch.zeros_like(batch.extra_values.detach()), is_unconditional=True)
|
||||
|
||||
if has_adapter_img:
|
||||
if (self.adapter and isinstance(self.adapter, ControlNetModel)) or (self.assistant_adapter and isinstance(self.assistant_adapter, ControlNetModel)):
|
||||
if self.train_config.do_cfg:
|
||||
raise ValueError("ControlNetModel is not supported with CFG")
|
||||
with torch.set_grad_enabled(self.adapter is not None):
|
||||
adapter: ControlNetModel = self.assistant_adapter if self.assistant_adapter is not None else self.adapter
|
||||
adapter_multiplier = get_adapter_multiplier()
|
||||
with self.timer('encode_adapter'):
|
||||
# add_text_embeds is pooled_prompt_embeds for sdxl
|
||||
added_cond_kwargs = {}
|
||||
if self.sd.is_xl:
|
||||
added_cond_kwargs["text_embeds"] = conditional_embeds.pooled_embeds
|
||||
added_cond_kwargs['time_ids'] = self.sd.get_time_ids_from_latents(noisy_latents)
|
||||
down_block_res_samples, mid_block_res_sample = adapter(
|
||||
noisy_latents,
|
||||
timesteps,
|
||||
encoder_hidden_states=conditional_embeds.text_embeds,
|
||||
controlnet_cond=adapter_images,
|
||||
conditioning_scale=1.0,
|
||||
guess_mode=False,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
return_dict=False,
|
||||
)
|
||||
pred_kwargs['down_block_additional_residuals'] = down_block_res_samples
|
||||
pred_kwargs['mid_block_additional_residual'] = mid_block_res_sample
|
||||
|
||||
|
||||
self.before_unet_predict()
|
||||
# do a prior pred if we have an unconditional image, we will swap out the giadance later
|
||||
@@ -1423,10 +1454,10 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
# 0.0 for the backward pass and the gradients will be 0.0
|
||||
# I spent weeks on fighting this. DON'T DO IT
|
||||
# with fsdp_overlap_step_with_backward():
|
||||
if self.is_bfloat:
|
||||
loss.backward()
|
||||
else:
|
||||
self.scaler.scale(loss).backward()
|
||||
# if self.is_bfloat:
|
||||
loss.backward()
|
||||
# else:
|
||||
# self.scaler.scale(loss).backward()
|
||||
# flush()
|
||||
|
||||
if not self.is_grad_accumulation_step:
|
||||
@@ -1443,8 +1474,8 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
self.optimizer.step()
|
||||
else:
|
||||
# apply gradients
|
||||
self.scaler.step(self.optimizer)
|
||||
self.scaler.update()
|
||||
self.optimizer.step()
|
||||
# self.scaler.update()
|
||||
# self.optimizer.step()
|
||||
self.optimizer.zero_grad(set_to_none=True)
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user