Allow short and long caption combinations like form the new captioning system. Merge the network into the model before inference and reextract when done. Doubles inference speed on locon models during inference. allow splitting a batch into individual components and run them through alone. Basicallt gradient accumulation with single batch size.

This commit is contained in:
Jaret Burkett
2023-10-24 16:02:07 -06:00
parent 73c8b50975
commit 002279cec3
9 changed files with 315 additions and 115 deletions

View File

@@ -172,7 +172,9 @@ class SDTrainer(BaseSDTrainProcess):
with torch.no_grad():
dtype = get_torch_dtype(self.train_config.dtype)
# dont use network on this
self.network.multiplier = 0.0
# self.network.multiplier = 0.0
was_network_active = self.network.is_active
self.network.is_active = False
self.sd.unet.eval()
prior_pred = self.sd.predict_noise(
latents=noisy_latents.to(self.device_torch, dtype=dtype).detach(),
@@ -187,7 +189,8 @@ class SDTrainer(BaseSDTrainProcess):
if match_adapter_assist and 'down_block_additional_residuals' in pred_kwargs:
del pred_kwargs['down_block_additional_residuals']
# restore network
self.network.multiplier = network_weight_list
# self.network.multiplier = network_weight_list
self.network.is_active = was_network_active
return prior_pred
def hook_train_loop(self, batch: 'DataLoaderBatchDTO'):
@@ -197,6 +200,8 @@ class SDTrainer(BaseSDTrainProcess):
dtype = get_torch_dtype(self.train_config.dtype)
noisy_latents, noise, timesteps, conditioned_prompts, imgs = self.process_general_training_batch(batch)
network_weight_list = batch.get_network_weight_list()
if self.train_config.single_item_batching:
network_weight_list = network_weight_list + network_weight_list
has_adapter_img = batch.control_tensor is not None
@@ -234,7 +239,7 @@ class SDTrainer(BaseSDTrainProcess):
# sigmas = self.get_sigmas(timesteps, len(noisy_latents.shape), noisy_latents.dtype)
# noisy_latents = noisy_latents / ((sigmas ** 2 + 1) ** 0.5)
mask_multiplier = 1.0
mask_multiplier = torch.ones((noisy_latents.shape[0], 1, 1, 1), device=self.device_torch, dtype=dtype)
if batch.mask_tensor is not None:
with self.timer('get_mask_multiplier'):
# upsampling no supported for bfloat16
@@ -297,107 +302,152 @@ class SDTrainer(BaseSDTrainProcess):
self.optimizer.zero_grad(set_to_none=True)
# activate network if it exits
with network:
with self.timer('encode_prompt'):
if grad_on_text_encoder:
with torch.set_grad_enabled(True):
conditional_embeds = self.sd.encode_prompt(conditioned_prompts, long_prompts=True).to(
self.device_torch,
dtype=dtype)
else:
with torch.set_grad_enabled(False):
# make sure it is in eval mode
if isinstance(self.sd.text_encoder, list):
for te in self.sd.text_encoder:
te.eval()
else:
self.sd.text_encoder.eval()
conditional_embeds = self.sd.encode_prompt(conditioned_prompts, long_prompts=True).to(
self.device_torch,
dtype=dtype)
# detach the embeddings
conditional_embeds = conditional_embeds.detach()
# make the batch splits
if self.train_config.single_item_batching:
batch_size = noisy_latents.shape[0]
# chunk/split everything
noisy_latents_list = torch.chunk(noisy_latents, batch_size, dim=0)
noise_list = torch.chunk(noise, batch_size, dim=0)
timesteps_list = torch.chunk(timesteps, batch_size, dim=0)
conditioned_prompts_list = [[prompt] for prompt in conditioned_prompts]
if imgs is not None:
imgs_list = torch.chunk(imgs, batch_size, dim=0)
else:
imgs_list = [None for _ in range(batch_size)]
if adapter_images is not None:
adapter_images_list = torch.chunk(adapter_images, batch_size, dim=0)
else:
adapter_images_list = [None for _ in range(batch_size)]
mask_multiplier_list = torch.chunk(mask_multiplier, batch_size, dim=0)
# flush()
pred_kwargs = {}
if has_adapter_img and ((self.adapter and isinstance(self.adapter, T2IAdapter)) or self.assistant_adapter):
with torch.set_grad_enabled(self.adapter is not None):
adapter = self.adapter if self.adapter else self.assistant_adapter
adapter_multiplier = get_adapter_multiplier()
else:
# but it all in an array
noisy_latents_list = [noisy_latents]
noise_list = [noise]
timesteps_list = [timesteps]
conditioned_prompts_list = [conditioned_prompts]
imgs_list = [imgs]
adapter_images_list = [adapter_images]
mask_multiplier_list = [mask_multiplier]
for noisy_latents, noise, timesteps, conditioned_prompts, imgs, adapter_images, mask_multiplier in zip(
noisy_latents_list,
noise_list,
timesteps_list,
conditioned_prompts_list,
imgs_list,
adapter_images_list,
mask_multiplier_list
):
with network:
with self.timer('encode_prompt'):
if grad_on_text_encoder:
with torch.set_grad_enabled(True):
conditional_embeds = self.sd.encode_prompt(conditioned_prompts, long_prompts=True).to(
self.device_torch,
dtype=dtype)
else:
with torch.set_grad_enabled(False):
# make sure it is in eval mode
if isinstance(self.sd.text_encoder, list):
for te in self.sd.text_encoder:
te.eval()
else:
self.sd.text_encoder.eval()
conditional_embeds = self.sd.encode_prompt(conditioned_prompts, long_prompts=True).to(
self.device_torch,
dtype=dtype)
# detach the embeddings
conditional_embeds = conditional_embeds.detach()
# flush()
pred_kwargs = {}
if has_adapter_img and ((self.adapter and isinstance(self.adapter, T2IAdapter)) or self.assistant_adapter):
with torch.set_grad_enabled(self.adapter is not None):
adapter = self.adapter if self.adapter else self.assistant_adapter
adapter_multiplier = get_adapter_multiplier()
with self.timer('encode_adapter'):
down_block_additional_residuals = adapter(adapter_images)
if self.assistant_adapter:
# not training. detach
down_block_additional_residuals = [
sample.to(dtype=dtype).detach() * adapter_multiplier for sample in
down_block_additional_residuals
]
else:
down_block_additional_residuals = [
sample.to(dtype=dtype) * adapter_multiplier for sample in
down_block_additional_residuals
]
pred_kwargs['down_block_additional_residuals'] = down_block_additional_residuals
prior_pred = None
if (has_adapter_img and self.assistant_adapter and match_adapter_assist) or self.do_prior_prediction:
with self.timer('prior predict'):
prior_pred = self.get_prior_prediction(
noisy_latents=noisy_latents,
conditional_embeds=conditional_embeds,
match_adapter_assist=match_adapter_assist,
network_weight_list=network_weight_list,
timesteps=timesteps,
pred_kwargs=pred_kwargs,
noise=noise,
batch=batch,
)
if has_adapter_img and self.adapter and isinstance(self.adapter, IPAdapter):
with self.timer('encode_adapter'):
down_block_additional_residuals = adapter(adapter_images)
if self.assistant_adapter:
# not training. detach
down_block_additional_residuals = [
sample.to(dtype=dtype).detach() * adapter_multiplier for sample in
down_block_additional_residuals
]
else:
down_block_additional_residuals = [
sample.to(dtype=dtype) * adapter_multiplier for sample in
down_block_additional_residuals
]
with torch.no_grad():
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(adapter_images)
conditional_embeds = self.adapter(conditional_embeds, conditional_clip_embeds)
pred_kwargs['down_block_additional_residuals'] = down_block_additional_residuals
prior_pred = None
if (has_adapter_img and self.assistant_adapter and match_adapter_assist) or self.do_prior_prediction:
with self.timer('prior predict'):
prior_pred = self.get_prior_prediction(
noisy_latents=noisy_latents,
conditional_embeds=conditional_embeds,
match_adapter_assist=match_adapter_assist,
network_weight_list=network_weight_list,
timesteps=timesteps,
pred_kwargs=pred_kwargs,
noise=noise,
batch=batch,
with self.timer('predict_unet'):
noise_pred = self.sd.predict_noise(
latents=noisy_latents.to(self.device_torch, dtype=dtype),
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype),
timestep=timesteps,
guidance_scale=1.0,
**pred_kwargs
)
if has_adapter_img and self.adapter and isinstance(self.adapter, IPAdapter):
with self.timer('encode_adapter'):
with torch.no_grad():
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(adapter_images)
conditional_embeds = self.adapter(conditional_embeds, conditional_clip_embeds)
with self.timer('calculate_loss'):
noise = noise.to(self.device_torch, dtype=dtype).detach()
loss = self.calculate_loss(
noise_pred=noise_pred,
noise=noise,
noisy_latents=noisy_latents,
timesteps=timesteps,
batch=batch,
mask_multiplier=mask_multiplier,
prior_pred=prior_pred,
)
# check if nan
if torch.isnan(loss):
raise ValueError("loss is nan")
with self.timer('predict_unet'):
noise_pred = self.sd.predict_noise(
latents=noisy_latents.to(self.device_torch, dtype=dtype),
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype),
timestep=timesteps,
guidance_scale=1.0,
**pred_kwargs
)
with self.timer('backward'):
# IMPORTANT if gradient checkpointing do not leave with network when doing backward
# it will destroy the gradients. This is because the network is a context manager
# and will change the multipliers back to 0.0 when exiting. They will be
# 0.0 for the backward pass and the gradients will be 0.0
# I spent weeks on fighting this. DON'T DO IT
# with fsdp_overlap_step_with_backward():
loss.backward()
with self.timer('calculate_loss'):
noise = noise.to(self.device_torch, dtype=dtype).detach()
loss = self.calculate_loss(
noise_pred=noise_pred,
noise=noise,
noisy_latents=noisy_latents,
timesteps=timesteps,
batch=batch,
mask_multiplier=mask_multiplier,
prior_pred=prior_pred,
)
# check if nan
if torch.isnan(loss):
raise ValueError("loss is nan")
with self.timer('backward'):
# IMPORTANT if gradient checkpointing do not leave with network when doing backward
# it will destroy the gradients. This is because the network is a context manager
# and will change the multipliers back to 0.0 when exiting. They will be
# 0.0 for the backward pass and the gradients will be 0.0
# I spent weeks on fighting this. DON'T DO IT
loss.backward()
torch.nn.utils.clip_grad_norm_(self.params, self.train_config.max_grad_norm)
torch.nn.utils.clip_grad_norm_(self.params, self.train_config.max_grad_norm)
# flush()
with self.timer('optimizer_step'):
# apply gradients
self.optimizer.step()
self.optimizer.zero_grad(set_to_none=True)
with self.timer('scheduler_step'):
self.lr_scheduler.step()

View File

@@ -460,6 +460,17 @@ class BaseSDTrainProcess(BaseTrainProcess):
prompts = batch.get_caption_list()
is_reg_list = batch.get_is_reg_list()
is_any_reg = any([is_reg for is_reg in is_reg_list])
do_double = self.train_config.short_and_long_captions and not is_any_reg
if self.train_config.short_and_long_captions and do_double:
# dont do this with regs. No point
# double batch and add short captions to the end
prompts = prompts + batch.get_caption_short_list()
is_reg_list = is_reg_list + is_reg_list
conditioned_prompts = []
for prompt, is_reg in zip(prompts, is_reg_list):
@@ -500,7 +511,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
# we determine noise from the differential of the latents
unaugmented_latents = self.sd.encode_images(batch.unaugmented_tensor)
batch_size = latents.shape[0]
batch_size = len(batch.file_items)
with self.timer('prepare_noise'):
@@ -582,6 +593,21 @@ class BaseSDTrainProcess(BaseTrainProcess):
# todo is this for sdxl? find out where this came from originally
# noisy_latents = noisy_latents / ((sigmas ** 2 + 1) ** 0.5)
def double_up_tensor(tensor: torch.Tensor):
if tensor is None:
return None
return torch.cat([tensor, tensor], dim=0)
if do_double:
noisy_latents = double_up_tensor(noisy_latents)
noise = double_up_tensor(noise)
timesteps = double_up_tensor(timesteps)
# prompts are already updated above
imgs = double_up_tensor(imgs)
batch.mask_tensor = double_up_tensor(batch.mask_tensor)
batch.control_tensor = double_up_tensor(batch.control_tensor)
# remove grads for these
noisy_latents.requires_grad = False
noisy_latents = noisy_latents.detach()
@@ -927,16 +953,16 @@ class BaseSDTrainProcess(BaseTrainProcess):
### HOOK ###
self.hook_before_train_loop()
if self.has_first_sample_requested:
if self.has_first_sample_requested and self.step_num <= 1:
self.print("Generating first sample from first sample config")
self.sample(0, is_first=True)
# sample first
if self.train_config.skip_first_sample:
self.print("Skipping first sample due to config setting")
else:
elif self.step_num <= 1:
self.print("Generating baseline samples before training")
self.sample(0)
self.sample(self.step_num)
self.progress_bar = ToolkitProgressBar(
total=self.train_config.steps,

View File

@@ -125,6 +125,17 @@ class TrainConfig:
self.noise_multiplier = kwargs.get('noise_multiplier', 1.0)
self.img_multiplier = kwargs.get('img_multiplier', 1.0)
# short long captions will double your batch size. This only works when a dataset is
# prepared with a json caption file that has both short and long captions in it. It will
# Double up every image and run it through with both short and long captions. The idea
# is that the network will learn how to generate good images with both short and long captions
self.short_and_long_captions = kwargs.get('short_and_long_captions', False)
# basically gradient accumulation but we run just 1 item through the network
# and accumulate gradients. This can be used as basic gradient accumulation but is very helpful
# for training tricks that increase batch size but need a single gradient step
self.single_item_batching = kwargs.get('single_item_batching', False)
match_adapter_assist = kwargs.get('match_adapter_assist', False)
self.match_adapter_chance = kwargs.get('match_adapter_chance', 0.0)
self.loss_target: LossTarget = kwargs.get('loss_target', 'noise') # noise, source, unaugmented, differential_noise

View File

@@ -82,7 +82,7 @@ class DataLoaderBatchDTO:
self.control_tensor: Union[torch.Tensor, None] = None
self.mask_tensor: Union[torch.Tensor, None] = None
self.unaugmented_tensor: Union[torch.Tensor, None] = None
self.sigmas: Union[torch.Tensor, None] = None # can be added elseware and passed along training code
self.sigmas: Union[torch.Tensor, None] = None # can be added elseware and passed along training code
if not is_latents_cached:
# only return a tensor if latents are not cached
self.tensor: torch.Tensor = torch.cat([x.tensor.unsqueeze(0) for x in self.file_items])
@@ -160,6 +160,19 @@ class DataLoaderBatchDTO:
add_if_not_present=add_if_not_present
) for x in self.file_items]
def get_caption_short_list(
self,
trigger=None,
to_replace_list=None,
add_if_not_present=True
):
return [x.get_caption(
trigger=trigger,
to_replace_list=to_replace_list,
add_if_not_present=add_if_not_present,
short_caption=False
) for x in self.file_items]
def cleanup(self):
del self.latents
del self.tensor

View File

@@ -55,6 +55,7 @@ transforms_dict = {
caption_ext_list = ['txt', 'json', 'caption']
def clean_caption(caption):
# remove any newlines
caption = caption.replace('\n', ', ')
@@ -227,6 +228,8 @@ class CaptionProcessingDTOMixin:
def __init__(self: 'FileItemDTO', *args, **kwargs):
if hasattr(super(), '__init__'):
super().__init__(*args, **kwargs)
self.raw_caption: str = None
self.raw_caption_short: str = None
# todo allow for loading from sd-scripts style dict
def load_caption(self: 'FileItemDTO', caption_dict: Union[dict, None]):
@@ -235,15 +238,19 @@ class CaptionProcessingDTOMixin:
pass
elif caption_dict is not None and self.path in caption_dict and "caption" in caption_dict[self.path]:
self.raw_caption = caption_dict[self.path]["caption"]
if 'caption_short' in caption_dict[self.path]:
self.raw_caption_short = caption_dict[self.path]["caption_short"]
else:
# see if prompt file exists
path_no_ext = os.path.splitext(self.path)[0]
prompt_ext = self.dataset_config.caption_ext
prompt_path = f"{path_no_ext}.{prompt_ext}"
short_caption = None
if os.path.exists(prompt_path):
with open(prompt_path, 'r', encoding='utf-8') as f:
prompt = f.read()
short_caption = None
if prompt_path.endswith('.json'):
# replace any line endings with commas for \n \r \r\n
prompt = prompt.replace('\r\n', ' ')
@@ -253,32 +260,36 @@ class CaptionProcessingDTOMixin:
prompt = json.loads(prompt)
if 'caption' in prompt:
prompt = prompt['caption']
# remove any newlines
prompt = prompt.replace('\n', ', ')
# remove new lines for all operating systems
prompt = prompt.replace('\r', ', ')
prompt_split = prompt.split(',')
# remove empty strings
prompt_split = [p.strip() for p in prompt_split if p.strip()]
# join back together
prompt = ', '.join(prompt_split)
if 'caption_short' in prompt:
short_caption = prompt['caption_short']
prompt = clean_caption(prompt)
if short_caption is not None:
short_caption = clean_caption(short_caption)
else:
prompt = ''
if self.dataset_config.default_caption is not None:
prompt = self.dataset_config.default_caption
if short_caption is None:
short_caption = self.dataset_config.default_caption
self.raw_caption = prompt
self.raw_caption_short = short_caption
def get_caption(
self: 'FileItemDTO',
trigger=None,
to_replace_list=None,
add_if_not_present=False
add_if_not_present=False,
short_caption=False
):
raw_caption = self.raw_caption
if short_caption:
raw_caption = self.raw_caption_short
else:
raw_caption = self.raw_caption
if raw_caption is None:
raw_caption = ''
# handle dropout
if self.dataset_config.caption_dropout_rate > 0:
if self.dataset_config.caption_dropout_rate > 0 and not short_caption:
# get a random float form 0 to 1
rand = random.random()
if rand < self.dataset_config.caption_dropout_rate:
@@ -296,7 +307,7 @@ class CaptionProcessingDTOMixin:
random.shuffle(token_list)
# handle token dropout
if self.dataset_config.token_dropout_rate > 0:
if self.dataset_config.token_dropout_rate > 0 and not short_caption:
new_token_list = []
for token in token_list:
# get a random float form 0 to 1
@@ -845,7 +856,8 @@ class LatentCachingMixin:
self.sd.set_device_state_preset('cache_latents')
# use tqdm to show progress
for i, file_item in tqdm(enumerate(self.file_list), desc=f'Caching latents{" to disk" if to_disk else ""}'):
i = 0
for file_item in tqdm(self.file_list, desc=f'Caching latents{" to disk" if to_disk else ""}'):
# set latent space version
if self.sd.is_xl:
file_item.latent_space_version = 'sdxl'
@@ -891,6 +903,7 @@ class LatentCachingMixin:
flush(garbage_collect=False)
file_item.is_latent_cached = True
i += 1
# flush every 100
# if i % 100 == 0:
# flush()

View File

@@ -89,7 +89,8 @@ class LoRAModule(ToolkitModuleMixin, torch.nn.Module):
torch.nn.init.zeros_(self.lora_up.weight)
self.multiplier: Union[float, List[float]] = multiplier
self.org_module = org_module # remove in applying
# wrap the original module so it doesn't get weights updated
self.org_module = [org_module]
self.dropout = dropout
self.rank_dropout = rank_dropout
self.module_dropout = module_dropout
@@ -98,9 +99,9 @@ class LoRAModule(ToolkitModuleMixin, torch.nn.Module):
self.normalize_scaler = 1.0
def apply_to(self):
self.org_forward = self.org_module.forward
self.org_module.forward = self.forward
del self.org_module
self.org_forward = self.org_module[0].forward
self.org_module[0].forward = self.forward
# del self.org_module
class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
@@ -170,6 +171,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
self.multiplier = multiplier
self.is_sdxl = is_sdxl
self.is_v2 = is_v2
self.is_merged_in = False
if modules_dim is not None:
print(f"create LoRA network from weights")

View File

@@ -42,6 +42,7 @@ class LoConSpecialModule(ToolkitModuleMixin, LoConModule):
self.lora_dim = lora_dim
self.cp = False
self.scalar = nn.Parameter(torch.tensor(0.0))
orig_module_name = org_module.__class__.__name__
if orig_module_name in CONV_MODULES:

View File

@@ -103,7 +103,21 @@ class ToolkitModuleMixin:
# this may get an additional positional arg or not
def forward(self: Module, x, *args, **kwargs):
if not self.network_ref().is_active:
skip = False
network = self.network_ref()
# skip if not active
if not network.is_active:
skip = True
# skip if is merged in
if network.is_merged_in:
skip = True
# skip if multiplier is 0
if network._multiplier == 0:
skip = True
if skip:
# network is not active, avoid doing anything
return self.org_forward(x, *args, **kwargs)
@@ -191,6 +205,52 @@ class ToolkitModuleMixin:
# reset the normalization scaler
self.normalize_scaler = target_normalize_scaler
@torch.no_grad()
def merge_out(self: Module, merge_out_weight=1.0):
# make sure it is positive
merge_out_weight = abs(merge_out_weight)
# merging out is just merging in the negative of the weight
self.merge_in(merge_weight=-merge_out_weight)
@torch.no_grad()
def merge_in(self: Module, merge_weight=1.0):
# get up/down weight
up_weight = self.lora_up.weight.clone().float()
down_weight = self.lora_down.weight.clone().float()
# extract weight from org_module
org_sd = self.org_module[0].state_dict()
orig_dtype = org_sd["weight"].dtype
weight = org_sd["weight"].float()
multiplier = merge_weight
scale = self.scale
# handle trainable scaler method locon does
if hasattr(self, 'scalar'):
scale = scale * self.scalar
# merge weight
if len(weight.size()) == 2:
# linear
weight = weight + multiplier * (up_weight @ down_weight) * scale
elif down_weight.size()[2:4] == (1, 1):
# conv2d 1x1
weight = (
weight
+ multiplier
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
* scale
)
else:
# conv2d 3x3
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
# print(conved.size(), weight.size(), module.stride, module.padding)
weight = weight + multiplier * conved * scale
# set weight to org_module
org_sd["weight"] = weight.to(orig_dtype)
self.org_module[0].load_state_dict(org_sd)
class ToolkitNetworkMixin:
def __init__(
@@ -210,6 +270,7 @@ class ToolkitNetworkMixin:
self._is_normalizing: bool = False
self.is_sdxl = is_sdxl
self.is_v2 = is_v2
self.is_merged_in = False
# super().__init__(*args, **kwargs)
def get_keymap(self: Network):
@@ -326,7 +387,6 @@ class ToolkitNetworkMixin:
self.torch_multiplier = tensor_multiplier.clone().detach()
@property
def multiplier(self) -> Union[float, List[float]]:
return self._multiplier
@@ -396,3 +456,15 @@ class ToolkitNetworkMixin:
def apply_stored_normalizer(self: Network, target_normalize_scaler: float = 1.0):
for module in self.get_all_modules():
module.apply_stored_normalizer(target_normalize_scaler)
def merge_in(self, merge_weight=1.0):
self.is_merged_in = True
for module in self.get_all_modules():
module.merge_in(merge_weight)
def merge_out(self, merge_weight=1.0):
if not self.is_merged_in:
return
self.is_merged_in = False
for module in self.get_all_modules():
module.merge_out(merge_weight)

View File

@@ -62,6 +62,7 @@ class BlankNetwork:
self.multiplier = 1.0
self.is_active = True
self.is_normalizing = False
self.is_merged_in = False
def apply_stored_normalizer(self, target_normalize_scaler: float = 1.0):
pass
@@ -267,10 +268,18 @@ class StableDiffusion:
@torch.no_grad()
def generate_images(self, image_configs: List[GenerateImageConfig], sampler=None):
merge_multiplier = 1.0
# sample_folder = os.path.join(self.save_root, 'samples')
if self.network is not None:
self.network.eval()
network = self.network
# check if we have the same network weight for all samples. If we do, we can merge in th
# the network to drastically speed up inference
unique_network_weights = set([x.network_multiplier for x in image_configs])
if len(unique_network_weights) == 1:
can_merge_in = True
merge_multiplier = unique_network_weights.pop()
network.merge_in(merge_weight=merge_multiplier)
else:
network = BlankNetwork()
@@ -462,6 +471,9 @@ class StableDiffusion:
self.network.train()
self.network.multiplier = start_multiplier
self.network.is_normalizing = was_network_normalizing
if network.is_merged_in:
network.merge_out(merge_multiplier)
# self.tokenizer.to(original_device_dict['tokenizer'])
def get_latent_noise(