mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-03-11 13:39:50 +00:00
Allow short and long caption combinations like form the new captioning system. Merge the network into the model before inference and reextract when done. Doubles inference speed on locon models during inference. allow splitting a batch into individual components and run them through alone. Basicallt gradient accumulation with single batch size.
This commit is contained in:
@@ -172,7 +172,9 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
with torch.no_grad():
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
# dont use network on this
|
||||
self.network.multiplier = 0.0
|
||||
# self.network.multiplier = 0.0
|
||||
was_network_active = self.network.is_active
|
||||
self.network.is_active = False
|
||||
self.sd.unet.eval()
|
||||
prior_pred = self.sd.predict_noise(
|
||||
latents=noisy_latents.to(self.device_torch, dtype=dtype).detach(),
|
||||
@@ -187,7 +189,8 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
if match_adapter_assist and 'down_block_additional_residuals' in pred_kwargs:
|
||||
del pred_kwargs['down_block_additional_residuals']
|
||||
# restore network
|
||||
self.network.multiplier = network_weight_list
|
||||
# self.network.multiplier = network_weight_list
|
||||
self.network.is_active = was_network_active
|
||||
return prior_pred
|
||||
|
||||
def hook_train_loop(self, batch: 'DataLoaderBatchDTO'):
|
||||
@@ -197,6 +200,8 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
noisy_latents, noise, timesteps, conditioned_prompts, imgs = self.process_general_training_batch(batch)
|
||||
network_weight_list = batch.get_network_weight_list()
|
||||
if self.train_config.single_item_batching:
|
||||
network_weight_list = network_weight_list + network_weight_list
|
||||
|
||||
has_adapter_img = batch.control_tensor is not None
|
||||
|
||||
@@ -234,7 +239,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
# sigmas = self.get_sigmas(timesteps, len(noisy_latents.shape), noisy_latents.dtype)
|
||||
# noisy_latents = noisy_latents / ((sigmas ** 2 + 1) ** 0.5)
|
||||
|
||||
mask_multiplier = 1.0
|
||||
mask_multiplier = torch.ones((noisy_latents.shape[0], 1, 1, 1), device=self.device_torch, dtype=dtype)
|
||||
if batch.mask_tensor is not None:
|
||||
with self.timer('get_mask_multiplier'):
|
||||
# upsampling no supported for bfloat16
|
||||
@@ -297,107 +302,152 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
self.optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
# activate network if it exits
|
||||
with network:
|
||||
with self.timer('encode_prompt'):
|
||||
if grad_on_text_encoder:
|
||||
with torch.set_grad_enabled(True):
|
||||
conditional_embeds = self.sd.encode_prompt(conditioned_prompts, long_prompts=True).to(
|
||||
self.device_torch,
|
||||
dtype=dtype)
|
||||
else:
|
||||
with torch.set_grad_enabled(False):
|
||||
# make sure it is in eval mode
|
||||
if isinstance(self.sd.text_encoder, list):
|
||||
for te in self.sd.text_encoder:
|
||||
te.eval()
|
||||
else:
|
||||
self.sd.text_encoder.eval()
|
||||
conditional_embeds = self.sd.encode_prompt(conditioned_prompts, long_prompts=True).to(
|
||||
self.device_torch,
|
||||
dtype=dtype)
|
||||
|
||||
# detach the embeddings
|
||||
conditional_embeds = conditional_embeds.detach()
|
||||
# make the batch splits
|
||||
if self.train_config.single_item_batching:
|
||||
batch_size = noisy_latents.shape[0]
|
||||
# chunk/split everything
|
||||
noisy_latents_list = torch.chunk(noisy_latents, batch_size, dim=0)
|
||||
noise_list = torch.chunk(noise, batch_size, dim=0)
|
||||
timesteps_list = torch.chunk(timesteps, batch_size, dim=0)
|
||||
conditioned_prompts_list = [[prompt] for prompt in conditioned_prompts]
|
||||
if imgs is not None:
|
||||
imgs_list = torch.chunk(imgs, batch_size, dim=0)
|
||||
else:
|
||||
imgs_list = [None for _ in range(batch_size)]
|
||||
if adapter_images is not None:
|
||||
adapter_images_list = torch.chunk(adapter_images, batch_size, dim=0)
|
||||
else:
|
||||
adapter_images_list = [None for _ in range(batch_size)]
|
||||
mask_multiplier_list = torch.chunk(mask_multiplier, batch_size, dim=0)
|
||||
|
||||
# flush()
|
||||
pred_kwargs = {}
|
||||
if has_adapter_img and ((self.adapter and isinstance(self.adapter, T2IAdapter)) or self.assistant_adapter):
|
||||
with torch.set_grad_enabled(self.adapter is not None):
|
||||
adapter = self.adapter if self.adapter else self.assistant_adapter
|
||||
adapter_multiplier = get_adapter_multiplier()
|
||||
else:
|
||||
# but it all in an array
|
||||
noisy_latents_list = [noisy_latents]
|
||||
noise_list = [noise]
|
||||
timesteps_list = [timesteps]
|
||||
conditioned_prompts_list = [conditioned_prompts]
|
||||
imgs_list = [imgs]
|
||||
adapter_images_list = [adapter_images]
|
||||
mask_multiplier_list = [mask_multiplier]
|
||||
|
||||
|
||||
for noisy_latents, noise, timesteps, conditioned_prompts, imgs, adapter_images, mask_multiplier in zip(
|
||||
noisy_latents_list,
|
||||
noise_list,
|
||||
timesteps_list,
|
||||
conditioned_prompts_list,
|
||||
imgs_list,
|
||||
adapter_images_list,
|
||||
mask_multiplier_list
|
||||
):
|
||||
|
||||
with network:
|
||||
with self.timer('encode_prompt'):
|
||||
if grad_on_text_encoder:
|
||||
with torch.set_grad_enabled(True):
|
||||
conditional_embeds = self.sd.encode_prompt(conditioned_prompts, long_prompts=True).to(
|
||||
self.device_torch,
|
||||
dtype=dtype)
|
||||
else:
|
||||
with torch.set_grad_enabled(False):
|
||||
# make sure it is in eval mode
|
||||
if isinstance(self.sd.text_encoder, list):
|
||||
for te in self.sd.text_encoder:
|
||||
te.eval()
|
||||
else:
|
||||
self.sd.text_encoder.eval()
|
||||
conditional_embeds = self.sd.encode_prompt(conditioned_prompts, long_prompts=True).to(
|
||||
self.device_torch,
|
||||
dtype=dtype)
|
||||
|
||||
# detach the embeddings
|
||||
conditional_embeds = conditional_embeds.detach()
|
||||
|
||||
# flush()
|
||||
pred_kwargs = {}
|
||||
if has_adapter_img and ((self.adapter and isinstance(self.adapter, T2IAdapter)) or self.assistant_adapter):
|
||||
with torch.set_grad_enabled(self.adapter is not None):
|
||||
adapter = self.adapter if self.adapter else self.assistant_adapter
|
||||
adapter_multiplier = get_adapter_multiplier()
|
||||
with self.timer('encode_adapter'):
|
||||
down_block_additional_residuals = adapter(adapter_images)
|
||||
if self.assistant_adapter:
|
||||
# not training. detach
|
||||
down_block_additional_residuals = [
|
||||
sample.to(dtype=dtype).detach() * adapter_multiplier for sample in
|
||||
down_block_additional_residuals
|
||||
]
|
||||
else:
|
||||
down_block_additional_residuals = [
|
||||
sample.to(dtype=dtype) * adapter_multiplier for sample in
|
||||
down_block_additional_residuals
|
||||
]
|
||||
|
||||
pred_kwargs['down_block_additional_residuals'] = down_block_additional_residuals
|
||||
|
||||
prior_pred = None
|
||||
if (has_adapter_img and self.assistant_adapter and match_adapter_assist) or self.do_prior_prediction:
|
||||
with self.timer('prior predict'):
|
||||
prior_pred = self.get_prior_prediction(
|
||||
noisy_latents=noisy_latents,
|
||||
conditional_embeds=conditional_embeds,
|
||||
match_adapter_assist=match_adapter_assist,
|
||||
network_weight_list=network_weight_list,
|
||||
timesteps=timesteps,
|
||||
pred_kwargs=pred_kwargs,
|
||||
noise=noise,
|
||||
batch=batch,
|
||||
)
|
||||
|
||||
if has_adapter_img and self.adapter and isinstance(self.adapter, IPAdapter):
|
||||
with self.timer('encode_adapter'):
|
||||
down_block_additional_residuals = adapter(adapter_images)
|
||||
if self.assistant_adapter:
|
||||
# not training. detach
|
||||
down_block_additional_residuals = [
|
||||
sample.to(dtype=dtype).detach() * adapter_multiplier for sample in
|
||||
down_block_additional_residuals
|
||||
]
|
||||
else:
|
||||
down_block_additional_residuals = [
|
||||
sample.to(dtype=dtype) * adapter_multiplier for sample in
|
||||
down_block_additional_residuals
|
||||
]
|
||||
with torch.no_grad():
|
||||
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(adapter_images)
|
||||
conditional_embeds = self.adapter(conditional_embeds, conditional_clip_embeds)
|
||||
|
||||
pred_kwargs['down_block_additional_residuals'] = down_block_additional_residuals
|
||||
|
||||
prior_pred = None
|
||||
if (has_adapter_img and self.assistant_adapter and match_adapter_assist) or self.do_prior_prediction:
|
||||
with self.timer('prior predict'):
|
||||
prior_pred = self.get_prior_prediction(
|
||||
noisy_latents=noisy_latents,
|
||||
conditional_embeds=conditional_embeds,
|
||||
match_adapter_assist=match_adapter_assist,
|
||||
network_weight_list=network_weight_list,
|
||||
timesteps=timesteps,
|
||||
pred_kwargs=pred_kwargs,
|
||||
noise=noise,
|
||||
batch=batch,
|
||||
|
||||
with self.timer('predict_unet'):
|
||||
noise_pred = self.sd.predict_noise(
|
||||
latents=noisy_latents.to(self.device_torch, dtype=dtype),
|
||||
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype),
|
||||
timestep=timesteps,
|
||||
guidance_scale=1.0,
|
||||
**pred_kwargs
|
||||
)
|
||||
|
||||
if has_adapter_img and self.adapter and isinstance(self.adapter, IPAdapter):
|
||||
with self.timer('encode_adapter'):
|
||||
with torch.no_grad():
|
||||
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(adapter_images)
|
||||
conditional_embeds = self.adapter(conditional_embeds, conditional_clip_embeds)
|
||||
with self.timer('calculate_loss'):
|
||||
noise = noise.to(self.device_torch, dtype=dtype).detach()
|
||||
loss = self.calculate_loss(
|
||||
noise_pred=noise_pred,
|
||||
noise=noise,
|
||||
noisy_latents=noisy_latents,
|
||||
timesteps=timesteps,
|
||||
batch=batch,
|
||||
mask_multiplier=mask_multiplier,
|
||||
prior_pred=prior_pred,
|
||||
)
|
||||
# check if nan
|
||||
if torch.isnan(loss):
|
||||
raise ValueError("loss is nan")
|
||||
|
||||
with self.timer('predict_unet'):
|
||||
noise_pred = self.sd.predict_noise(
|
||||
latents=noisy_latents.to(self.device_torch, dtype=dtype),
|
||||
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype),
|
||||
timestep=timesteps,
|
||||
guidance_scale=1.0,
|
||||
**pred_kwargs
|
||||
)
|
||||
with self.timer('backward'):
|
||||
# IMPORTANT if gradient checkpointing do not leave with network when doing backward
|
||||
# it will destroy the gradients. This is because the network is a context manager
|
||||
# and will change the multipliers back to 0.0 when exiting. They will be
|
||||
# 0.0 for the backward pass and the gradients will be 0.0
|
||||
# I spent weeks on fighting this. DON'T DO IT
|
||||
# with fsdp_overlap_step_with_backward():
|
||||
loss.backward()
|
||||
|
||||
with self.timer('calculate_loss'):
|
||||
noise = noise.to(self.device_torch, dtype=dtype).detach()
|
||||
loss = self.calculate_loss(
|
||||
noise_pred=noise_pred,
|
||||
noise=noise,
|
||||
noisy_latents=noisy_latents,
|
||||
timesteps=timesteps,
|
||||
batch=batch,
|
||||
mask_multiplier=mask_multiplier,
|
||||
prior_pred=prior_pred,
|
||||
)
|
||||
# check if nan
|
||||
if torch.isnan(loss):
|
||||
raise ValueError("loss is nan")
|
||||
|
||||
with self.timer('backward'):
|
||||
# IMPORTANT if gradient checkpointing do not leave with network when doing backward
|
||||
# it will destroy the gradients. This is because the network is a context manager
|
||||
# and will change the multipliers back to 0.0 when exiting. They will be
|
||||
# 0.0 for the backward pass and the gradients will be 0.0
|
||||
# I spent weeks on fighting this. DON'T DO IT
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(self.params, self.train_config.max_grad_norm)
|
||||
torch.nn.utils.clip_grad_norm_(self.params, self.train_config.max_grad_norm)
|
||||
# flush()
|
||||
|
||||
with self.timer('optimizer_step'):
|
||||
# apply gradients
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad(set_to_none=True)
|
||||
with self.timer('scheduler_step'):
|
||||
self.lr_scheduler.step()
|
||||
|
||||
|
||||
@@ -460,6 +460,17 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
prompts = batch.get_caption_list()
|
||||
is_reg_list = batch.get_is_reg_list()
|
||||
|
||||
is_any_reg = any([is_reg for is_reg in is_reg_list])
|
||||
|
||||
do_double = self.train_config.short_and_long_captions and not is_any_reg
|
||||
|
||||
if self.train_config.short_and_long_captions and do_double:
|
||||
# dont do this with regs. No point
|
||||
|
||||
# double batch and add short captions to the end
|
||||
prompts = prompts + batch.get_caption_short_list()
|
||||
is_reg_list = is_reg_list + is_reg_list
|
||||
|
||||
conditioned_prompts = []
|
||||
|
||||
for prompt, is_reg in zip(prompts, is_reg_list):
|
||||
@@ -500,7 +511,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
# we determine noise from the differential of the latents
|
||||
unaugmented_latents = self.sd.encode_images(batch.unaugmented_tensor)
|
||||
|
||||
batch_size = latents.shape[0]
|
||||
batch_size = len(batch.file_items)
|
||||
|
||||
with self.timer('prepare_noise'):
|
||||
|
||||
@@ -582,6 +593,21 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
# todo is this for sdxl? find out where this came from originally
|
||||
# noisy_latents = noisy_latents / ((sigmas ** 2 + 1) ** 0.5)
|
||||
|
||||
def double_up_tensor(tensor: torch.Tensor):
|
||||
if tensor is None:
|
||||
return None
|
||||
return torch.cat([tensor, tensor], dim=0)
|
||||
|
||||
if do_double:
|
||||
noisy_latents = double_up_tensor(noisy_latents)
|
||||
noise = double_up_tensor(noise)
|
||||
timesteps = double_up_tensor(timesteps)
|
||||
# prompts are already updated above
|
||||
imgs = double_up_tensor(imgs)
|
||||
batch.mask_tensor = double_up_tensor(batch.mask_tensor)
|
||||
batch.control_tensor = double_up_tensor(batch.control_tensor)
|
||||
|
||||
|
||||
# remove grads for these
|
||||
noisy_latents.requires_grad = False
|
||||
noisy_latents = noisy_latents.detach()
|
||||
@@ -927,16 +953,16 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
### HOOK ###
|
||||
self.hook_before_train_loop()
|
||||
|
||||
if self.has_first_sample_requested:
|
||||
if self.has_first_sample_requested and self.step_num <= 1:
|
||||
self.print("Generating first sample from first sample config")
|
||||
self.sample(0, is_first=True)
|
||||
|
||||
# sample first
|
||||
if self.train_config.skip_first_sample:
|
||||
self.print("Skipping first sample due to config setting")
|
||||
else:
|
||||
elif self.step_num <= 1:
|
||||
self.print("Generating baseline samples before training")
|
||||
self.sample(0)
|
||||
self.sample(self.step_num)
|
||||
|
||||
self.progress_bar = ToolkitProgressBar(
|
||||
total=self.train_config.steps,
|
||||
|
||||
@@ -125,6 +125,17 @@ class TrainConfig:
|
||||
self.noise_multiplier = kwargs.get('noise_multiplier', 1.0)
|
||||
self.img_multiplier = kwargs.get('img_multiplier', 1.0)
|
||||
|
||||
# short long captions will double your batch size. This only works when a dataset is
|
||||
# prepared with a json caption file that has both short and long captions in it. It will
|
||||
# Double up every image and run it through with both short and long captions. The idea
|
||||
# is that the network will learn how to generate good images with both short and long captions
|
||||
self.short_and_long_captions = kwargs.get('short_and_long_captions', False)
|
||||
|
||||
# basically gradient accumulation but we run just 1 item through the network
|
||||
# and accumulate gradients. This can be used as basic gradient accumulation but is very helpful
|
||||
# for training tricks that increase batch size but need a single gradient step
|
||||
self.single_item_batching = kwargs.get('single_item_batching', False)
|
||||
|
||||
match_adapter_assist = kwargs.get('match_adapter_assist', False)
|
||||
self.match_adapter_chance = kwargs.get('match_adapter_chance', 0.0)
|
||||
self.loss_target: LossTarget = kwargs.get('loss_target', 'noise') # noise, source, unaugmented, differential_noise
|
||||
|
||||
@@ -82,7 +82,7 @@ class DataLoaderBatchDTO:
|
||||
self.control_tensor: Union[torch.Tensor, None] = None
|
||||
self.mask_tensor: Union[torch.Tensor, None] = None
|
||||
self.unaugmented_tensor: Union[torch.Tensor, None] = None
|
||||
self.sigmas: Union[torch.Tensor, None] = None # can be added elseware and passed along training code
|
||||
self.sigmas: Union[torch.Tensor, None] = None # can be added elseware and passed along training code
|
||||
if not is_latents_cached:
|
||||
# only return a tensor if latents are not cached
|
||||
self.tensor: torch.Tensor = torch.cat([x.tensor.unsqueeze(0) for x in self.file_items])
|
||||
@@ -160,6 +160,19 @@ class DataLoaderBatchDTO:
|
||||
add_if_not_present=add_if_not_present
|
||||
) for x in self.file_items]
|
||||
|
||||
def get_caption_short_list(
|
||||
self,
|
||||
trigger=None,
|
||||
to_replace_list=None,
|
||||
add_if_not_present=True
|
||||
):
|
||||
return [x.get_caption(
|
||||
trigger=trigger,
|
||||
to_replace_list=to_replace_list,
|
||||
add_if_not_present=add_if_not_present,
|
||||
short_caption=False
|
||||
) for x in self.file_items]
|
||||
|
||||
def cleanup(self):
|
||||
del self.latents
|
||||
del self.tensor
|
||||
|
||||
@@ -55,6 +55,7 @@ transforms_dict = {
|
||||
|
||||
caption_ext_list = ['txt', 'json', 'caption']
|
||||
|
||||
|
||||
def clean_caption(caption):
|
||||
# remove any newlines
|
||||
caption = caption.replace('\n', ', ')
|
||||
@@ -227,6 +228,8 @@ class CaptionProcessingDTOMixin:
|
||||
def __init__(self: 'FileItemDTO', *args, **kwargs):
|
||||
if hasattr(super(), '__init__'):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.raw_caption: str = None
|
||||
self.raw_caption_short: str = None
|
||||
|
||||
# todo allow for loading from sd-scripts style dict
|
||||
def load_caption(self: 'FileItemDTO', caption_dict: Union[dict, None]):
|
||||
@@ -235,15 +238,19 @@ class CaptionProcessingDTOMixin:
|
||||
pass
|
||||
elif caption_dict is not None and self.path in caption_dict and "caption" in caption_dict[self.path]:
|
||||
self.raw_caption = caption_dict[self.path]["caption"]
|
||||
if 'caption_short' in caption_dict[self.path]:
|
||||
self.raw_caption_short = caption_dict[self.path]["caption_short"]
|
||||
else:
|
||||
# see if prompt file exists
|
||||
path_no_ext = os.path.splitext(self.path)[0]
|
||||
prompt_ext = self.dataset_config.caption_ext
|
||||
prompt_path = f"{path_no_ext}.{prompt_ext}"
|
||||
short_caption = None
|
||||
|
||||
if os.path.exists(prompt_path):
|
||||
with open(prompt_path, 'r', encoding='utf-8') as f:
|
||||
prompt = f.read()
|
||||
short_caption = None
|
||||
if prompt_path.endswith('.json'):
|
||||
# replace any line endings with commas for \n \r \r\n
|
||||
prompt = prompt.replace('\r\n', ' ')
|
||||
@@ -253,32 +260,36 @@ class CaptionProcessingDTOMixin:
|
||||
prompt = json.loads(prompt)
|
||||
if 'caption' in prompt:
|
||||
prompt = prompt['caption']
|
||||
# remove any newlines
|
||||
prompt = prompt.replace('\n', ', ')
|
||||
# remove new lines for all operating systems
|
||||
prompt = prompt.replace('\r', ', ')
|
||||
prompt_split = prompt.split(',')
|
||||
# remove empty strings
|
||||
prompt_split = [p.strip() for p in prompt_split if p.strip()]
|
||||
# join back together
|
||||
prompt = ', '.join(prompt_split)
|
||||
if 'caption_short' in prompt:
|
||||
short_caption = prompt['caption_short']
|
||||
prompt = clean_caption(prompt)
|
||||
if short_caption is not None:
|
||||
short_caption = clean_caption(short_caption)
|
||||
else:
|
||||
prompt = ''
|
||||
if self.dataset_config.default_caption is not None:
|
||||
prompt = self.dataset_config.default_caption
|
||||
|
||||
if short_caption is None:
|
||||
short_caption = self.dataset_config.default_caption
|
||||
self.raw_caption = prompt
|
||||
self.raw_caption_short = short_caption
|
||||
|
||||
def get_caption(
|
||||
self: 'FileItemDTO',
|
||||
trigger=None,
|
||||
to_replace_list=None,
|
||||
add_if_not_present=False
|
||||
add_if_not_present=False,
|
||||
short_caption=False
|
||||
):
|
||||
raw_caption = self.raw_caption
|
||||
if short_caption:
|
||||
raw_caption = self.raw_caption_short
|
||||
else:
|
||||
raw_caption = self.raw_caption
|
||||
if raw_caption is None:
|
||||
raw_caption = ''
|
||||
# handle dropout
|
||||
if self.dataset_config.caption_dropout_rate > 0:
|
||||
if self.dataset_config.caption_dropout_rate > 0 and not short_caption:
|
||||
# get a random float form 0 to 1
|
||||
rand = random.random()
|
||||
if rand < self.dataset_config.caption_dropout_rate:
|
||||
@@ -296,7 +307,7 @@ class CaptionProcessingDTOMixin:
|
||||
random.shuffle(token_list)
|
||||
|
||||
# handle token dropout
|
||||
if self.dataset_config.token_dropout_rate > 0:
|
||||
if self.dataset_config.token_dropout_rate > 0 and not short_caption:
|
||||
new_token_list = []
|
||||
for token in token_list:
|
||||
# get a random float form 0 to 1
|
||||
@@ -845,7 +856,8 @@ class LatentCachingMixin:
|
||||
self.sd.set_device_state_preset('cache_latents')
|
||||
|
||||
# use tqdm to show progress
|
||||
for i, file_item in tqdm(enumerate(self.file_list), desc=f'Caching latents{" to disk" if to_disk else ""}'):
|
||||
i = 0
|
||||
for file_item in tqdm(self.file_list, desc=f'Caching latents{" to disk" if to_disk else ""}'):
|
||||
# set latent space version
|
||||
if self.sd.is_xl:
|
||||
file_item.latent_space_version = 'sdxl'
|
||||
@@ -891,6 +903,7 @@ class LatentCachingMixin:
|
||||
|
||||
flush(garbage_collect=False)
|
||||
file_item.is_latent_cached = True
|
||||
i += 1
|
||||
# flush every 100
|
||||
# if i % 100 == 0:
|
||||
# flush()
|
||||
|
||||
@@ -89,7 +89,8 @@ class LoRAModule(ToolkitModuleMixin, torch.nn.Module):
|
||||
torch.nn.init.zeros_(self.lora_up.weight)
|
||||
|
||||
self.multiplier: Union[float, List[float]] = multiplier
|
||||
self.org_module = org_module # remove in applying
|
||||
# wrap the original module so it doesn't get weights updated
|
||||
self.org_module = [org_module]
|
||||
self.dropout = dropout
|
||||
self.rank_dropout = rank_dropout
|
||||
self.module_dropout = module_dropout
|
||||
@@ -98,9 +99,9 @@ class LoRAModule(ToolkitModuleMixin, torch.nn.Module):
|
||||
self.normalize_scaler = 1.0
|
||||
|
||||
def apply_to(self):
|
||||
self.org_forward = self.org_module.forward
|
||||
self.org_module.forward = self.forward
|
||||
del self.org_module
|
||||
self.org_forward = self.org_module[0].forward
|
||||
self.org_module[0].forward = self.forward
|
||||
# del self.org_module
|
||||
|
||||
|
||||
class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
||||
@@ -170,6 +171,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
||||
self.multiplier = multiplier
|
||||
self.is_sdxl = is_sdxl
|
||||
self.is_v2 = is_v2
|
||||
self.is_merged_in = False
|
||||
|
||||
if modules_dim is not None:
|
||||
print(f"create LoRA network from weights")
|
||||
|
||||
@@ -42,6 +42,7 @@ class LoConSpecialModule(ToolkitModuleMixin, LoConModule):
|
||||
self.lora_dim = lora_dim
|
||||
self.cp = False
|
||||
|
||||
|
||||
self.scalar = nn.Parameter(torch.tensor(0.0))
|
||||
orig_module_name = org_module.__class__.__name__
|
||||
if orig_module_name in CONV_MODULES:
|
||||
|
||||
@@ -103,7 +103,21 @@ class ToolkitModuleMixin:
|
||||
# this may get an additional positional arg or not
|
||||
|
||||
def forward(self: Module, x, *args, **kwargs):
|
||||
if not self.network_ref().is_active:
|
||||
skip = False
|
||||
network = self.network_ref()
|
||||
# skip if not active
|
||||
if not network.is_active:
|
||||
skip = True
|
||||
|
||||
# skip if is merged in
|
||||
if network.is_merged_in:
|
||||
skip = True
|
||||
|
||||
# skip if multiplier is 0
|
||||
if network._multiplier == 0:
|
||||
skip = True
|
||||
|
||||
if skip:
|
||||
# network is not active, avoid doing anything
|
||||
return self.org_forward(x, *args, **kwargs)
|
||||
|
||||
@@ -191,6 +205,52 @@ class ToolkitModuleMixin:
|
||||
# reset the normalization scaler
|
||||
self.normalize_scaler = target_normalize_scaler
|
||||
|
||||
@torch.no_grad()
|
||||
def merge_out(self: Module, merge_out_weight=1.0):
|
||||
# make sure it is positive
|
||||
merge_out_weight = abs(merge_out_weight)
|
||||
# merging out is just merging in the negative of the weight
|
||||
self.merge_in(merge_weight=-merge_out_weight)
|
||||
|
||||
@torch.no_grad()
|
||||
def merge_in(self: Module, merge_weight=1.0):
|
||||
# get up/down weight
|
||||
up_weight = self.lora_up.weight.clone().float()
|
||||
down_weight = self.lora_down.weight.clone().float()
|
||||
|
||||
# extract weight from org_module
|
||||
org_sd = self.org_module[0].state_dict()
|
||||
orig_dtype = org_sd["weight"].dtype
|
||||
weight = org_sd["weight"].float()
|
||||
|
||||
multiplier = merge_weight
|
||||
scale = self.scale
|
||||
# handle trainable scaler method locon does
|
||||
if hasattr(self, 'scalar'):
|
||||
scale = scale * self.scalar
|
||||
|
||||
# merge weight
|
||||
if len(weight.size()) == 2:
|
||||
# linear
|
||||
weight = weight + multiplier * (up_weight @ down_weight) * scale
|
||||
elif down_weight.size()[2:4] == (1, 1):
|
||||
# conv2d 1x1
|
||||
weight = (
|
||||
weight
|
||||
+ multiplier
|
||||
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
|
||||
* scale
|
||||
)
|
||||
else:
|
||||
# conv2d 3x3
|
||||
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
|
||||
# print(conved.size(), weight.size(), module.stride, module.padding)
|
||||
weight = weight + multiplier * conved * scale
|
||||
|
||||
# set weight to org_module
|
||||
org_sd["weight"] = weight.to(orig_dtype)
|
||||
self.org_module[0].load_state_dict(org_sd)
|
||||
|
||||
|
||||
class ToolkitNetworkMixin:
|
||||
def __init__(
|
||||
@@ -210,6 +270,7 @@ class ToolkitNetworkMixin:
|
||||
self._is_normalizing: bool = False
|
||||
self.is_sdxl = is_sdxl
|
||||
self.is_v2 = is_v2
|
||||
self.is_merged_in = False
|
||||
# super().__init__(*args, **kwargs)
|
||||
|
||||
def get_keymap(self: Network):
|
||||
@@ -326,7 +387,6 @@ class ToolkitNetworkMixin:
|
||||
|
||||
self.torch_multiplier = tensor_multiplier.clone().detach()
|
||||
|
||||
|
||||
@property
|
||||
def multiplier(self) -> Union[float, List[float]]:
|
||||
return self._multiplier
|
||||
@@ -396,3 +456,15 @@ class ToolkitNetworkMixin:
|
||||
def apply_stored_normalizer(self: Network, target_normalize_scaler: float = 1.0):
|
||||
for module in self.get_all_modules():
|
||||
module.apply_stored_normalizer(target_normalize_scaler)
|
||||
|
||||
def merge_in(self, merge_weight=1.0):
|
||||
self.is_merged_in = True
|
||||
for module in self.get_all_modules():
|
||||
module.merge_in(merge_weight)
|
||||
|
||||
def merge_out(self, merge_weight=1.0):
|
||||
if not self.is_merged_in:
|
||||
return
|
||||
self.is_merged_in = False
|
||||
for module in self.get_all_modules():
|
||||
module.merge_out(merge_weight)
|
||||
|
||||
@@ -62,6 +62,7 @@ class BlankNetwork:
|
||||
self.multiplier = 1.0
|
||||
self.is_active = True
|
||||
self.is_normalizing = False
|
||||
self.is_merged_in = False
|
||||
|
||||
def apply_stored_normalizer(self, target_normalize_scaler: float = 1.0):
|
||||
pass
|
||||
@@ -267,10 +268,18 @@ class StableDiffusion:
|
||||
|
||||
@torch.no_grad()
|
||||
def generate_images(self, image_configs: List[GenerateImageConfig], sampler=None):
|
||||
merge_multiplier = 1.0
|
||||
# sample_folder = os.path.join(self.save_root, 'samples')
|
||||
if self.network is not None:
|
||||
self.network.eval()
|
||||
network = self.network
|
||||
# check if we have the same network weight for all samples. If we do, we can merge in th
|
||||
# the network to drastically speed up inference
|
||||
unique_network_weights = set([x.network_multiplier for x in image_configs])
|
||||
if len(unique_network_weights) == 1:
|
||||
can_merge_in = True
|
||||
merge_multiplier = unique_network_weights.pop()
|
||||
network.merge_in(merge_weight=merge_multiplier)
|
||||
else:
|
||||
network = BlankNetwork()
|
||||
|
||||
@@ -462,6 +471,9 @@ class StableDiffusion:
|
||||
self.network.train()
|
||||
self.network.multiplier = start_multiplier
|
||||
self.network.is_normalizing = was_network_normalizing
|
||||
|
||||
if network.is_merged_in:
|
||||
network.merge_out(merge_multiplier)
|
||||
# self.tokenizer.to(original_device_dict['tokenizer'])
|
||||
|
||||
def get_latent_noise(
|
||||
|
||||
Reference in New Issue
Block a user