Fixes and longer prompts

This commit is contained in:
Jaret Burkett
2023-10-22 08:57:37 -06:00
parent 0e9fc42816
commit 9905a1e205
4 changed files with 177 additions and 54 deletions

View File

@@ -4,6 +4,7 @@ from diffusers import T2IAdapter
from toolkit.basic import value_map
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
from toolkit.ip_adapter import IPAdapter
from toolkit.prompt_utils import PromptEmbeds
from toolkit.stable_diffusion_model import StableDiffusion, BlankNetwork
from toolkit.train_tools import get_torch_dtype, apply_snr_weight
import gc
@@ -27,6 +28,10 @@ class SDTrainer(BaseSDTrainProcess):
def __init__(self, process_id: int, job, config: OrderedDict, **kwargs):
super().__init__(process_id, job, config, **kwargs)
self.assistant_adapter: Union['T2IAdapter', None]
self.do_prior_prediction = False
if self.train_config.inverted_mask_prior:
self.do_prior_prediction = True
def before_model_load(self):
pass
@@ -135,6 +140,40 @@ class SDTrainer(BaseSDTrainProcess):
def preprocess_batch(self, batch: 'DataLoaderBatchDTO'):
return batch
def get_prior_prediction(
self,
noisy_latents: torch.Tensor,
conditional_embeds: PromptEmbeds,
match_adapter_assist: bool,
network_weight_list: list,
timesteps: torch.Tensor,
pred_kwargs: dict,
batch: 'DataLoaderBatchDTO',
noise: torch.Tensor,
**kwargs
):
# do a prediction here so we can match its output with network multiplier set to 0.0
with torch.no_grad():
dtype = get_torch_dtype(self.train_config.dtype)
# dont use network on this
self.network.multiplier = 0.0
self.sd.unet.eval()
prior_pred = self.sd.predict_noise(
latents=noisy_latents.to(self.device_torch, dtype=dtype).detach(),
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype).detach(),
timestep=timesteps,
guidance_scale=1.0,
**pred_kwargs # adapter residuals in here
)
self.sd.unet.train()
prior_pred = prior_pred.detach()
# remove the residuals as we wont use them on prediction when matching control
if match_adapter_assist and 'down_block_additional_residuals' in pred_kwargs:
del pred_kwargs['down_block_additional_residuals']
# restore network
self.network.multiplier = network_weight_list
return prior_pred
def hook_train_loop(self, batch: 'DataLoaderBatchDTO'):
self.timer.start('preprocess_batch')
@@ -287,28 +326,18 @@ class SDTrainer(BaseSDTrainProcess):
pred_kwargs['down_block_additional_residuals'] = down_block_additional_residuals
prior_pred = None
if (has_adapter_img and self.assistant_adapter and match_adapter_assist) or self.train_config.inverted_mask_prior:
if (has_adapter_img and self.assistant_adapter and match_adapter_assist) or self.do_prior_prediction:
with self.timer('prior predict'):
# do a prediction here so we can match its output with network multiplier set to 0.0
with torch.no_grad():
# dont use network on this
network.multiplier = 0.0
self.sd.unet.eval()
prior_pred = self.sd.predict_noise(
latents=noisy_latents.to(self.device_torch, dtype=dtype).detach(),
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype).detach(),
timestep=timesteps,
guidance_scale=1.0,
**pred_kwargs # adapter residuals in here
)
self.sd.unet.train()
prior_pred = prior_pred.detach()
# remove the residuals as we wont use them on prediction when matching control
if match_adapter_assist and 'down_block_additional_residuals' in pred_kwargs:
del pred_kwargs['down_block_additional_residuals']
# restore network
network.multiplier = network_weight_list
prior_pred = self.get_prior_prediction(
noisy_latents=noisy_latents,
conditional_embeds=conditional_embeds,
match_adapter_assist=match_adapter_assist,
network_weight_list=network_weight_list,
timesteps=timesteps,
pred_kwargs=pred_kwargs,
noise=noise,
batch=batch,
)
if has_adapter_img and self.adapter and isinstance(self.adapter, IPAdapter):
with self.timer('encode_adapter'):