Improved lorm extraction and training

This commit is contained in:
Jaret Burkett
2023-10-28 08:21:59 -06:00
parent 0a79ac9604
commit 6f3e0d5af2
10 changed files with 559 additions and 196 deletions

View File

@@ -0,0 +1,102 @@
import os
from collections import OrderedDict
from toolkit.config_modules import ModelConfig, GenerateImageConfig, SampleConfig, LoRMConfig
from toolkit.lorm import ExtractMode, convert_diffusers_unet_to_lorm
from toolkit.sd_device_states_presets import get_train_sd_device_state_preset
from toolkit.stable_diffusion_model import StableDiffusion
import gc
import torch
from jobs.process import BaseExtensionProcess
from toolkit.train_tools import get_torch_dtype
def flush():
torch.cuda.empty_cache()
gc.collect()
class PureLoraGenerator(BaseExtensionProcess):
def __init__(self, process_id: int, job, config: OrderedDict):
super().__init__(process_id, job, config)
self.output_folder = self.get_conf('output_folder', required=True)
self.device = self.get_conf('device', 'cuda')
self.device_torch = torch.device(self.device)
self.model_config = ModelConfig(**self.get_conf('model', required=True))
self.generate_config = SampleConfig(**self.get_conf('sample', required=True))
self.dtype = self.get_conf('dtype', 'float16')
self.torch_dtype = get_torch_dtype(self.dtype)
lorm_config = self.get_conf('lorm', None)
self.lorm_config = LoRMConfig(**lorm_config) if lorm_config is not None else None
self.device_state_preset = get_train_sd_device_state_preset(
device=torch.device(self.device),
)
self.progress_bar = None
self.sd = StableDiffusion(
device=self.device,
model_config=self.model_config,
dtype=self.dtype,
)
def run(self):
super().run()
print("Loading model...")
with torch.no_grad():
self.sd.load_model()
self.sd.unet.eval()
self.sd.unet.to(self.device_torch)
if isinstance(self.sd.text_encoder, list):
for te in self.sd.text_encoder:
te.eval()
te.to(self.device_torch)
else:
self.sd.text_encoder.eval()
self.sd.to(self.device_torch)
print(f"Converting to LoRM UNet")
# replace the unet with LoRMUnet
convert_diffusers_unet_to_lorm(
self.sd.unet,
config=self.lorm_config,
)
sample_folder = os.path.join(self.output_folder)
gen_img_config_list = []
sample_config = self.generate_config
start_seed = sample_config.seed
current_seed = start_seed
for i in range(len(sample_config.prompts)):
if sample_config.walk_seed:
current_seed = start_seed + i
filename = f"[time]_[count].{self.generate_config.ext}"
output_path = os.path.join(sample_folder, filename)
prompt = sample_config.prompts[i]
extra_args = {}
gen_img_config_list.append(GenerateImageConfig(
prompt=prompt, # it will autoparse the prompt
width=sample_config.width,
height=sample_config.height,
negative_prompt=sample_config.neg,
seed=current_seed,
guidance_scale=sample_config.guidance_scale,
guidance_rescale=sample_config.guidance_rescale,
num_inference_steps=sample_config.sample_steps,
network_multiplier=sample_config.network_multiplier,
output_path=output_path,
output_ext=sample_config.ext,
adapter_conditioning_scale=sample_config.adapter_conditioning_scale,
**extra_args
))
# send to be generated
self.sd.generate_images(gen_img_config_list, sampler=sample_config.sampler)
print("Done generating images")
# cleanup
del self.sd
gc.collect()
torch.cuda.empty_cache()

View File

@@ -19,7 +19,24 @@ class AdvancedReferenceGeneratorExtension(Extension):
return ReferenceGenerator
# This is for generic training (LoRA, Dreambooth, FineTuning)
class PureLoraGenerator(Extension):
# uid must be unique, it is how the extension is identified
uid = "pure_lora_generator"
# name is the name of the extension for printing
name = "Pure LoRA Generator"
# This is where your process class is loaded
# keep your imports in here so they don't slow down the rest of the program
@classmethod
def get_process(cls):
# import your process class here so it is only loaded when needed and return it
from .PureLoraGenerator import PureLoraGenerator
return PureLoraGenerator
AI_TOOLKIT_EXTENSIONS = [
# you can put a list of extensions here
AdvancedReferenceGeneratorExtension,
AdvancedReferenceGeneratorExtension, PureLoraGenerator
]

View File

@@ -32,7 +32,6 @@ class SDTrainer(BaseSDTrainProcess):
if self.train_config.inverted_mask_prior:
self.do_prior_prediction = True
def before_model_load(self):
pass
@@ -193,6 +192,15 @@ class SDTrainer(BaseSDTrainProcess):
self.network.is_active = was_network_active
return prior_pred
def before_unet_predict(self):
pass
def after_unet_predict(self):
pass
def end_of_training_loop(self):
pass
def hook_train_loop(self, batch: 'DataLoaderBatchDTO'):
self.timer.start('preprocess_batch')
@@ -331,7 +339,6 @@ class SDTrainer(BaseSDTrainProcess):
adapter_images_list = [adapter_images]
mask_multiplier_list = [mask_multiplier]
for noisy_latents, noise, timesteps, conditioned_prompts, imgs, adapter_images, mask_multiplier in zip(
noisy_latents_list,
noise_list,
@@ -366,7 +373,8 @@ class SDTrainer(BaseSDTrainProcess):
# flush()
pred_kwargs = {}
if has_adapter_img and ((self.adapter and isinstance(self.adapter, T2IAdapter)) or self.assistant_adapter):
if has_adapter_img and (
(self.adapter and isinstance(self.adapter, T2IAdapter)) or self.assistant_adapter):
with torch.set_grad_enabled(self.adapter is not None):
adapter = self.adapter if self.adapter else self.assistant_adapter
adapter_multiplier = get_adapter_multiplier()
@@ -406,8 +414,7 @@ class SDTrainer(BaseSDTrainProcess):
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(adapter_images)
conditional_embeds = self.adapter(conditional_embeds, conditional_clip_embeds)
self.before_unet_predict()
with self.timer('predict_unet'):
noise_pred = self.sd.predict_noise(
latents=noisy_latents.to(self.device_torch, dtype=dtype),
@@ -416,6 +423,7 @@ class SDTrainer(BaseSDTrainProcess):
guidance_scale=1.0,
**pred_kwargs
)
self.after_unet_predict()
with self.timer('calculate_loss'):
noise = noise.to(self.device_torch, dtype=dtype).detach()
@@ -442,7 +450,7 @@ class SDTrainer(BaseSDTrainProcess):
loss.backward()
torch.nn.utils.clip_grad_norm_(self.params, self.train_config.max_grad_norm)
# flush()
# flush()
with self.timer('optimizer_step'):
# apply gradients
@@ -460,4 +468,6 @@ class SDTrainer(BaseSDTrainProcess):
{'loss': loss.item()}
)
self.end_of_training_loop()
return loss_dict

View File

@@ -7,6 +7,7 @@ from typing import Union, List
import numpy as np
from diffusers import T2IAdapter
from safetensors.torch import save_file, load_file
# from lycoris.config import PRESET
from torch.utils.data import DataLoader
import torch
@@ -18,7 +19,8 @@ from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatc
from toolkit.embedding import Embedding
from toolkit.ip_adapter import IPAdapter
from toolkit.lora_special import LoRASpecialNetwork
from toolkit.lorm import convert_diffusers_unet_to_lorm
from toolkit.lorm import convert_diffusers_unet_to_lorm, count_parameters, print_lorm_extract_details, \
lorm_ignore_if_contains, lorm_parameter_threshold, LORM_TARGET_REPLACE_MODULE
from toolkit.lycoris_special import LycorisSpecialNetwork
from toolkit.network_mixins import Network
from toolkit.optimizer import get_optimizer
@@ -128,6 +130,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
is_training_adapter = self.adapter_config is not None and self.adapter_config.train
self.do_lorm = self.get_conf('do_lorm', False)
self.lorm_extract_mode = self.get_conf('lorm_extract_mode', 'ratio')
self.lorm_extract_mode_param = self.get_conf('lorm_extract_mode_param', 0.25)
# 'ratio', 0.25)
# get the device state preset based on what we are training
self.train_device_state_preset = get_train_sd_device_state_preset(
@@ -300,9 +305,6 @@ class BaseSDTrainProcess(BaseTrainProcess):
file_path = os.path.join(self.save_root, filename)
prev_multiplier = self.network.multiplier
self.network.multiplier = 1.0
if self.network_config.normalize:
# apply the normalization
self.network.apply_stored_normalizer()
# if we are doing embedding training as well, add that
embedding_dict = self.embedding.state_dict() if self.embedding else None
@@ -427,6 +429,21 @@ class BaseSDTrainProcess(BaseTrainProcess):
print("load_weights not implemented for non-network models")
return None
def load_lorm(self):
latest_save_path = self.get_latest_save_path()
if latest_save_path is not None:
# hacky way to reload weights for now
# todo, do this
state_dict = load_file(latest_save_path, device=self.device)
self.sd.unet.load_state_dict(state_dict)
meta = load_metadata_from_safetensors(latest_save_path)
# if 'training_info' in Orderdict keys
if 'training_info' in meta and 'step' in meta['training_info']:
self.step_num = meta['training_info']['step']
self.start_step = self.step_num
print(f"Found step {self.step_num} in metadata, starting from there")
# def get_sigmas(self, timesteps, n_dim=4, dtype=torch.float32):
# self.sd.noise_scheduler.set_timesteps(1000, device=self.device_torch)
# sigmas = self.sd.noise_scheduler.sigmas.to(device=self.device_torch, dtype=dtype)
@@ -610,7 +627,6 @@ class BaseSDTrainProcess(BaseTrainProcess):
batch.mask_tensor = double_up_tensor(batch.mask_tensor)
batch.control_tensor = double_up_tensor(batch.control_tensor)
# remove grads for these
noisy_latents.requires_grad = False
noisy_latents = noisy_latents.detach()
@@ -712,15 +728,6 @@ class BaseSDTrainProcess(BaseTrainProcess):
# run base sd process run
self.sd.load_model()
if self.do_lorm:
train_modules = convert_diffusers_unet_to_lorm(self.sd.unet, 'ratio', 0.27)
for module in train_modules:
p = module.parameters()
for param in p:
param.requires_grad_(True)
params.append(param)
dtype = get_torch_dtype(self.train_config.dtype)
# model is loaded from BaseSDProcess
@@ -783,14 +790,20 @@ class BaseSDTrainProcess(BaseTrainProcess):
if not self.is_fine_tuning:
if self.network_config is not None:
# TODO should we completely switch to LycorisSpecialNetwork?
network_kwargs = {}
is_lycoris = False
is_lorm = self.network_config.type.lower() == 'lorm'
# default to LoCON if there are any conv layers or if it is named
NetworkClass = LoRASpecialNetwork
if self.network_config.type.lower() == 'locon' or self.network_config.type.lower() == 'lycoris':
NetworkClass = LycorisSpecialNetwork
is_lycoris = True
if is_lorm:
network_kwargs['ignore_if_contains'] = lorm_ignore_if_contains
network_kwargs['parameter_threshold'] = lorm_parameter_threshold
network_kwargs['target_lin_modules'] = LORM_TARGET_REPLACE_MODULE
# if is_lycoris:
# preset = PRESET['full']
# NetworkClass.apply_preset(preset)
@@ -810,6 +823,10 @@ class BaseSDTrainProcess(BaseTrainProcess):
dropout=self.network_config.dropout,
use_text_encoder_1=self.model_config.use_text_encoder_1,
use_text_encoder_2=self.model_config.use_text_encoder_2,
use_bias=is_lorm,
is_lorm=is_lorm,
network_config=self.network_config,
**network_kwargs
)
self.network.force_to(self.device_torch, dtype=dtype)
@@ -824,6 +841,20 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.train_config.train_unet
)
if is_lorm:
self.network.is_lorm = True
# make sure it is on the right device
self.sd.unet.to(self.sd.device, dtype=dtype)
original_unet_param_count = count_parameters(self.sd.unet)
self.network.setup_lorm()
new_unet_param_count = original_unet_param_count - self.network.calculate_lorem_parameter_reduction()
print_lorm_extract_details(
start_num_params=original_unet_param_count,
end_num_params=new_unet_param_count,
num_replaced=len(self.network.get_all_modules()),
)
self.network.prepare_grad_etc(text_encoder, unet)
flush()
@@ -846,9 +877,6 @@ class BaseSDTrainProcess(BaseTrainProcess):
if self.train_config.gradient_checkpointing:
self.network.enable_gradient_checkpointing()
# set the network to normalize if we are
self.network.is_normalizing = self.network_config.normalize
lora_name = self.name
# need to adapt name so they are not mixed up
if self.named_lora:
@@ -915,7 +943,6 @@ class BaseSDTrainProcess(BaseTrainProcess):
# set the device state preset before getting params
self.sd.set_device_state(self.train_device_state_preset)
# params = self.get_params()
if len(params) == 0:
# will only return savable weights and ones with grad
@@ -1050,9 +1077,6 @@ class BaseSDTrainProcess(BaseTrainProcess):
else:
batch = None
# turn on normalization if we are using it and it is not on
if self.network is not None and self.network_config.normalize and not self.network.is_normalizing:
self.network.is_normalizing = True
# flush()
### HOOK ###
self.timer.start('train_loop')
@@ -1078,11 +1102,6 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.progress_bar.set_postfix_str(prog_bar_string)
# apply network normalizer if we are using it, not on regularization steps
if self.network is not None and self.network.is_normalizing and not is_reg_step:
with self.timer('apply_normalizer'):
self.network.apply_stored_normalizer()
# if the batch is a DataLoaderBatchDTO, then we need to clean it up
if isinstance(batch, DataLoaderBatchDTO):
with self.timer('batch_cleanup'):

View File

@@ -40,7 +40,48 @@ class SampleConfig:
self.adapter_conditioning_scale = kwargs.get('adapter_conditioning_scale', 1.0)
NetworkType = Literal['lora', 'locon']
class LormModuleSettingsConfig:
def __init__(self, **kwargs):
self.contains: str = kwargs.get('contains', '4nt$3')
self.extract_mode: str = kwargs.get('extract_mode', 'ratio')
# min num parameters to attach to
self.parameter_threshold: int = kwargs.get('parameter_threshold', 0)
self.extract_mode_param: dict = kwargs.get('extract_mode_param', 0.25)
class LoRMConfig:
def __init__(self, **kwargs):
self.extract_mode: str = kwargs.get('extract_mode', 'ratio')
self.extract_mode_param: dict = kwargs.get('extract_mode_param', 0.25)
self.parameter_threshold: int = kwargs.get('parameter_threshold', 0)
module_settings = kwargs.get('module_settings', [])
default_module_settings = {
'extract_mode': self.extract_mode,
'extract_mode_param': self.extract_mode_param,
'parameter_threshold': self.parameter_threshold,
}
module_settings = [{**default_module_settings, **module_setting, } for module_setting in module_settings]
self.module_settings: List[LormModuleSettingsConfig] = [LormModuleSettingsConfig(**module_setting) for
module_setting in module_settings]
def get_config_for_module(self, block_name):
for setting in self.module_settings:
contain_pieces = setting.contains.split('|')
if all(contain_piece in block_name for contain_piece in contain_pieces):
return setting
# try replacing the . with _
contain_pieces = setting.contains.replace('.', '_').split('|')
if all(contain_piece in block_name for contain_piece in contain_pieces):
return setting
# do default
return LormModuleSettingsConfig(**{
'extract_mode': self.extract_mode,
'extract_mode_param': self.extract_mode_param,
'parameter_threshold': self.parameter_threshold,
})
NetworkType = Literal['lora', 'locon', 'lorm']
class NetworkConfig:
@@ -58,12 +99,22 @@ class NetworkConfig:
self.alpha: float = kwargs.get('alpha', 1.0)
self.linear_alpha: float = kwargs.get('linear_alpha', self.alpha)
self.conv_alpha: float = kwargs.get('conv_alpha', self.conv)
self.normalize = kwargs.get('normalize', False)
self.dropout: Union[float, None] = kwargs.get('dropout', None)
self.lorm_config: Union[LoRMConfig, None] = None
lorm = kwargs.get('lorm', None)
if lorm is not None:
self.lorm_config: LoRMConfig = LoRMConfig(**lorm)
if self.type == 'lorm':
# set linear to arbitrary values so it makes them
self.linear = 4
self.rank = 4
AdapterTypes = Literal['t2i', 'ip', 'ip+']
class AdapterConfig:
def __init__(self, **kwargs):
self.type: AdapterTypes = kwargs.get('type', 't2i') # t2i, ip
@@ -90,6 +141,7 @@ class EmbeddingConfig:
ContentOrStyleType = Literal['balanced', 'style', 'content']
LossTarget = Literal['noise', 'source', 'unaugmented', 'differential_noise']
class TrainConfig:
def __init__(self, **kwargs):
self.noise_scheduler = kwargs.get('noise_scheduler', 'ddpm')
@@ -138,7 +190,8 @@ class TrainConfig:
match_adapter_assist = kwargs.get('match_adapter_assist', False)
self.match_adapter_chance = kwargs.get('match_adapter_chance', 0.0)
self.loss_target: LossTarget = kwargs.get('loss_target', 'noise') # noise, source, unaugmented, differential_noise
self.loss_target: LossTarget = kwargs.get('loss_target',
'noise') # noise, source, unaugmented, differential_noise
# When a mask is passed in a dataset, and this is true,
# we will predict noise without a the LoRa network and use the prediction as a target for
@@ -151,7 +204,6 @@ class TrainConfig:
self.match_adapter_chance = 1.0
class ModelConfig:
def __init__(self, **kwargs):
self.name_or_path: str = kwargs.get('name_or_path', None)
@@ -216,7 +268,7 @@ class SliderConfig:
self.prompt_file: str = kwargs.get('prompt_file', None)
self.prompt_tensors: str = kwargs.get('prompt_tensors', None)
self.batch_full_slide: bool = kwargs.get('batch_full_slide', True)
self.use_adapter: bool = kwargs.get('use_adapter', None) # depth
self.use_adapter: bool = kwargs.get('use_adapter', None) # depth
self.adapter_img_dir = kwargs.get('adapter_img_dir', None)
self.high_ram = kwargs.get('high_ram', False)
@@ -267,9 +319,11 @@ class DatasetConfig:
self.augments: List[str] = kwargs.get('augments', [])
self.control_path: str = kwargs.get('control_path', None) # depth maps, etc
self.alpha_mask: bool = kwargs.get('alpha_mask', False) # if true, will use alpha channel as mask
self.mask_path: str = kwargs.get('mask_path', None) # focus mask (black and white. White has higher loss than black)
self.mask_path: str = kwargs.get('mask_path',
None) # focus mask (black and white. White has higher loss than black)
self.mask_min_value: float = kwargs.get('mask_min_value', 0.01) # min value for . 0 - 1
self.poi: Union[str, None] = kwargs.get('poi', None) # if one is set and in json data, will be used as auto crop scale point of interes
self.poi: Union[str, None] = kwargs.get('poi',
None) # if one is set and in json data, will be used as auto crop scale point of interes
self.num_repeats: int = kwargs.get('num_repeats', 1) # number of times to repeat dataset
# cache latents will store them in memory
self.cache_latents: bool = kwargs.get('cache_latents', False)
@@ -525,4 +579,4 @@ class GenerateImageConfig:
unconditional_prompt_embeds: Optional[PromptEmbeds] = None,
):
# this is called after prompt embeds are encoded. We can override them in the future here
pass
pass

View File

@@ -7,7 +7,9 @@ from typing import List, Optional, Dict, Type, Union
import torch
from transformers import CLIPTextModel
from .network_mixins import ToolkitNetworkMixin, ToolkitModuleMixin
from .config_modules import NetworkConfig
from .lorm import count_parameters
from .network_mixins import ToolkitNetworkMixin, ToolkitModuleMixin, ExtractableModuleMixin
from .paths import SD_SCRIPTS_ROOT
sys.path.append(SD_SCRIPTS_ROOT)
@@ -30,7 +32,7 @@ CONV_MODULES = [
'LoRACompatibleConv'
]
class LoRAModule(ToolkitModuleMixin, torch.nn.Module):
class LoRAModule(ToolkitModuleMixin, ExtractableModuleMixin, torch.nn.Module):
"""
replaces forward method of the original Linear, instead of replacing the original Linear module.
"""
@@ -46,13 +48,17 @@ class LoRAModule(ToolkitModuleMixin, torch.nn.Module):
rank_dropout=None,
module_dropout=None,
network: 'LoRASpecialNetwork' = None,
parent=None,
use_bias: bool = False,
**kwargs
):
"""if alpha == 0 or None, alpha is rank (no scaling)."""
super().__init__(network=network)
ToolkitModuleMixin.__init__(self, network=network)
torch.nn.Module.__init__(self)
self.lora_name = lora_name
self.scalar = torch.tensor(1.0)
# check if parent has bias. if not force use_bias to False
if org_module.bias is None:
use_bias = False
if org_module.__class__.__name__ in CONV_MODULES:
in_dim = org_module.in_channels
@@ -73,10 +79,10 @@ class LoRAModule(ToolkitModuleMixin, torch.nn.Module):
stride = org_module.stride
padding = org_module.padding
self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False)
self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False)
self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=use_bias)
else:
self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False)
self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False)
self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=use_bias)
if type(alpha) == torch.Tensor:
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
@@ -95,8 +101,6 @@ class LoRAModule(ToolkitModuleMixin, torch.nn.Module):
self.rank_dropout = rank_dropout
self.module_dropout = module_dropout
self.is_checkpointing = False
self.is_normalizing = False
self.normalize_scaler = 1.0
def apply_to(self):
self.org_forward = self.org_module[0].forward
@@ -143,6 +147,13 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
train_unet: Optional[bool] = True,
is_sdxl=False,
is_v2=False,
use_bias: bool = False,
is_lorm: bool = False,
ignore_if_contains = None,
parameter_threshold: float = 0.0,
target_lin_modules=LoRANetwork.UNET_TARGET_REPLACE_MODULE,
target_conv_modules=LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3,
**kwargs
) -> None:
"""
LoRA network: すごく引数が多いが、パターンは以下の通り
@@ -154,7 +165,18 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
"""
# call the parent of the parent we are replacing (LoRANetwork) init
torch.nn.Module.__init__(self)
ToolkitNetworkMixin.__init__(
self,
train_text_encoder=train_text_encoder,
train_unet=train_unet,
is_sdxl=is_sdxl,
is_v2=is_v2,
is_lorm=is_lorm,
**kwargs
)
if ignore_if_contains is None:
ignore_if_contains = []
self.ignore_if_contains = ignore_if_contains
self.lora_dim = lora_dim
self.alpha = alpha
self.conv_lora_dim = conv_lora_dim
@@ -165,13 +187,11 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
self.is_checkpointing = False
self._multiplier: float = 1.0
self.is_active: bool = False
self._is_normalizing: bool = False
self.torch_multiplier = None
# triggers the state updates
self.multiplier = multiplier
self.is_sdxl = is_sdxl
self.is_v2 = is_v2
self.is_merged_in = False
if modules_dim is not None:
print(f"create LoRA network from weights")
@@ -217,7 +237,15 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
is_conv2d = child_module.__class__.__name__ in CONV_MODULES
is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
if is_linear or is_conv2d:
skip = False
if any([word in child_name for word in self.ignore_if_contains]):
skip = True
# see if it is over threshold
if count_parameters(child_module) < parameter_threshold:
skip = True
if (is_linear or is_conv2d) and not skip:
lora_name = prefix + "." + name + "." + child_name
lora_name = lora_name.replace(".", "_")
@@ -265,6 +293,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
module_dropout=module_dropout,
network=self,
parent=module,
use_bias=use_bias,
)
loras.append(lora)
return loras, skipped
@@ -295,9 +324,9 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
# extend U-Net target modules if conv2d 3x3 is enabled, or load from weights
target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE
target_modules = target_lin_modules
if modules_dim is not None or self.conv_lora_dim is not None or conv_block_dims is not None:
target_modules += LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
target_modules += target_conv_modules
if train_unet:
self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules)

View File

@@ -6,6 +6,8 @@ from diffusers import UNet2DConditionModel
from torch import Tensor
from tqdm import tqdm
from toolkit.config_modules import LoRMConfig
conv = nn.Conv2d
lin = nn.Linear
_size_2_t = Union[int, Tuple[int, int]]
@@ -29,12 +31,13 @@ CONV_MODULES = [
UNET_TARGET_REPLACE_MODULE = [
"Transformer2DModel",
# "BasicTransformerBlock",
# "ResnetBlock2D",
"Downsample2D",
"Upsample2D",
]
LORM_TARGET_REPLACE_MODULE = UNET_TARGET_REPLACE_MODULE
UNET_TARGET_REPLACE_NAME = [
"conv_in",
"conv_out",
@@ -279,13 +282,38 @@ def compute_optimal_bias(original_module, linear_down, linear_up, X):
return optimal_bias
def format_with_commas(n):
return f"{n:,}"
def print_lorm_extract_details(
start_num_params: int,
end_num_params: int,
num_replaced: int,
):
start_formatted = format_with_commas(start_num_params)
end_formatted = format_with_commas(end_num_params)
num_replaced_formatted = format_with_commas(num_replaced)
width = max(len(start_formatted), len(end_formatted), len(num_replaced_formatted))
print(f"Convert UNet result:")
print(f" - converted: {num_replaced:>{width},} modules")
print(f" - start: {start_num_params:>{width},} params")
print(f" - end: {end_num_params:>{width},} params")
lorm_ignore_if_contains = [
'proj_out', 'proj_in',
]
lorm_parameter_threshold = 1000000
@torch.no_grad()
def convert_diffusers_unet_to_lorm(
unet: UNet2DConditionModel,
extract_mode: ExtractMode = "percentile",
mode_param: Union[int, float] = 0.5,
parameter_threshold: int = 500000,
# parameter_threshold: int = 1500000
config: LoRMConfig,
):
print('Converting UNet to LoRM UNet')
start_num_params = count_parameters(unet)
@@ -299,8 +327,6 @@ def convert_diffusers_unet_to_lorm(
ignore_if_contains = [
'proj_out', 'proj_in',
]
def format_with_commas(n):
return f"{n:,}"
for name, module in named_modules:
module_name = module.__class__.__name__
@@ -311,6 +337,13 @@ def convert_diffusers_unet_to_lorm(
combined_name = combined_name = f"{name}.{child_name}"
# if child_module.__class__.__name__ in LINEAR_MODULES and child_module.bias is None:
# pass
lorm_config = config.get_config_for_module(combined_name)
extract_mode = lorm_config.extract_mode
extract_mode_param = lorm_config.extract_mode_param
parameter_threshold = lorm_config.parameter_threshold
if any([word in child_name for word in ignore_if_contains]):
pass
@@ -322,7 +355,7 @@ def convert_diffusers_unet_to_lorm(
down_weight, up_weight, lora_dim, diff = extract_linear(
weight=child_module.weight.clone().detach().float(),
mode=extract_mode,
mode_param=mode_param,
mode_param=extract_mode_param,
device=child_module.weight.device,
)
down_weight = down_weight.to(dtype=dtype)
@@ -362,7 +395,7 @@ def convert_diffusers_unet_to_lorm(
down_weight, up_weight, lora_dim, diff = extract_conv(
weight=child_module.weight.clone().detach().float(),
mode=extract_mode,
mode_param=mode_param,
mode_param=extract_mode_param,
device=child_module.weight.device,
)
down_weight = down_weight.to(dtype=dtype)
@@ -395,30 +428,25 @@ def convert_diffusers_unet_to_lorm(
replace_module_by_path(unet, combined_name, new_module)
converted_modules.append(new_module)
num_replaced += 1
layer_names_replaced.append(f"{combined_name} - {format_with_commas(count_parameters(child_module))}")
layer_names_replaced.append(
f"{combined_name} - {format_with_commas(count_parameters(child_module))}")
pbar.update(1)
pbar.close()
end_num_params = count_parameters(unet)
start_formatted = format_with_commas(start_num_params)
end_formatted = format_with_commas(end_num_params)
num_replaced_formatted = format_with_commas(num_replaced)
width = max(len(start_formatted), len(end_formatted), len(num_replaced_formatted))
def sorting_key(s):
# Extract the number part, remove commas, and convert to integer
return int(s.split("-")[1].strip().replace(",", ""))
sorted_layer_names_replaced = sorted(layer_names_replaced, key=sorting_key, reverse=True)
for layer_name in sorted_layer_names_replaced:
print(layer_name)
print(f"Convert UNet result:")
print(f" - converted: {num_replaced:>{width},} modules")
print(f" - start: {start_num_params:>{width},} params")
print(f" - end: {end_num_params:>{width},} params")
print_lorm_extract_details(
start_num_params=start_num_params,
end_num_params=end_num_params,
num_replaced=num_replaced,
)
return converted_modules

View File

@@ -8,20 +8,19 @@ from lycoris.modules.glora import GLoRAModule
from torch import nn
from transformers import CLIPTextModel
from torch.nn import functional as F
from toolkit.network_mixins import ToolkitNetworkMixin, ToolkitModuleMixin
from toolkit.network_mixins import ToolkitNetworkMixin, ToolkitModuleMixin, ExtractableModuleMixin
# diffusers specific stuff
LINEAR_MODULES = [
'Linear',
'LoRACompatibleLinear'
# 'GroupNorm',
]
CONV_MODULES = [
'Conv2d',
'LoRACompatibleConv'
]
class LoConSpecialModule(ToolkitModuleMixin, LoConModule):
class LoConSpecialModule(ToolkitModuleMixin, LoConModule, ExtractableModuleMixin):
def __init__(
self,
lora_name, org_module: nn.Module,
@@ -30,18 +29,20 @@ class LoConSpecialModule(ToolkitModuleMixin, LoConModule):
dropout=0., rank_dropout=0., module_dropout=0.,
use_cp=False,
network: 'LycorisSpecialNetwork' = None,
parent=None,
use_bias=False,
**kwargs,
):
""" if alpha == 0 or None, alpha is rank (no scaling). """
# call super of super
ToolkitModuleMixin.__init__(self, network=network)
torch.nn.Module.__init__(self)
# call super of
super().__init__(call_super_init=False, network=network)
self.lora_name = lora_name
self.lora_dim = lora_dim
self.cp = False
# check if parent has bias. if not force use_bias to False
if org_module.bias is None:
use_bias = False
self.scalar = nn.Parameter(torch.tensor(0.0))
orig_module_name = org_module.__class__.__name__
@@ -61,7 +62,7 @@ class LoConSpecialModule(ToolkitModuleMixin, LoConModule):
self.cp = True
else:
self.lora_down = nn.Conv2d(in_dim, lora_dim, k_size, stride, padding, bias=False)
self.lora_up = nn.Conv2d(lora_dim, out_dim, (1, 1), bias=False)
self.lora_up = nn.Conv2d(lora_dim, out_dim, (1, 1), bias=use_bias)
elif orig_module_name in LINEAR_MODULES:
self.isconv = False
self.down_op = F.linear
@@ -74,7 +75,7 @@ class LoConSpecialModule(ToolkitModuleMixin, LoConModule):
in_dim = org_module.in_features
out_dim = org_module.out_features
self.lora_down = nn.Linear(in_dim, lora_dim, bias=False)
self.lora_up = nn.Linear(lora_dim, out_dim, bias=False)
self.lora_up = nn.Linear(lora_dim, out_dim, bias=use_bias)
else:
raise NotImplementedError
self.shape = org_module.weight.shape
@@ -159,10 +160,16 @@ class LycorisSpecialNetwork(ToolkitNetworkMixin, LycorisNetwork):
train_text_encoder: bool = True,
use_text_encoder_1: bool = True,
use_text_encoder_2: bool = True,
use_bias: bool = False,
is_lorm: bool = False,
**kwargs,
) -> None:
# call ToolkitNetworkMixin super
super().__init__(
ToolkitNetworkMixin.__init__(
self,
train_text_encoder=train_text_encoder,
train_unet=train_unet,
is_lorm=is_lorm,
**kwargs
)
# call the parent of the parent LycorisNetwork
@@ -217,7 +224,6 @@ class LycorisSpecialNetwork(ToolkitNetworkMixin, LycorisNetwork):
loras = []
# remove this
named_modules = root_module.named_modules()
modules = root_module.modules()
# add a few to tthe generator
for name, module in named_modules:
@@ -241,6 +247,7 @@ class LycorisSpecialNetwork(ToolkitNetworkMixin, LycorisNetwork):
use_cp,
network=self,
parent=module,
use_bias=use_bias,
**kwargs
)
elif child_module.__class__.__name__ in CONV_MODULES:
@@ -253,6 +260,7 @@ class LycorisSpecialNetwork(ToolkitNetworkMixin, LycorisNetwork):
use_cp,
network=self,
parent=module,
use_bias=use_bias,
**kwargs
)
elif conv_lora_dim > 0:
@@ -263,6 +271,7 @@ class LycorisSpecialNetwork(ToolkitNetworkMixin, LycorisNetwork):
use_cp,
network=self,
parent=module,
use_bias=use_bias,
**kwargs
)
else:
@@ -285,6 +294,7 @@ class LycorisSpecialNetwork(ToolkitNetworkMixin, LycorisNetwork):
use_cp,
parent=module,
network=self,
use_bias=use_bias,
**kwargs
)
elif module.__class__.__name__ == 'Conv2d':
@@ -297,6 +307,7 @@ class LycorisSpecialNetwork(ToolkitNetworkMixin, LycorisNetwork):
use_cp,
network=self,
parent=module,
use_bias=use_bias,
**kwargs
)
elif conv_lora_dim > 0:
@@ -307,6 +318,7 @@ class LycorisSpecialNetwork(ToolkitNetworkMixin, LycorisNetwork):
use_cp,
network=self,
parent=module,
use_bias=use_bias,
**kwargs
)
else:

View File

@@ -1,17 +1,23 @@
import json
import os
from collections import OrderedDict
from typing import Optional, Union, List, Type, TYPE_CHECKING, Dict, Any
from typing import Optional, Union, List, Type, TYPE_CHECKING, Dict, Any, Literal
import torch
from torch import nn
import weakref
from tqdm import tqdm
from toolkit.config_modules import NetworkConfig
from toolkit.lorm import extract_conv, extract_linear, count_parameters
from toolkit.metadata import add_model_hash_to_meta
from toolkit.paths import KEYMAPS_ROOT
if TYPE_CHECKING:
from toolkit.lycoris_special import LycorisSpecialNetwork, LoConSpecialModule
from toolkit.lora_special import LoRASpecialNetwork, LoRAModule
from toolkit.stable_diffusion_model import StableDiffusion
Network = Union['LycorisSpecialNetwork', 'LoRASpecialNetwork']
Module = Union['LoConSpecialModule', 'LoRAModule']
@@ -26,6 +32,15 @@ CONV_MODULES = [
'LoRACompatibleConv'
]
ExtractMode = Union[
'existing'
'fixed',
'threshold',
'ratio',
'quantile',
'percentage'
]
def broadcast_and_multiply(tensor, multiplier):
# Determine the number of dimensions required
@@ -41,20 +56,101 @@ def broadcast_and_multiply(tensor, multiplier):
return result
def add_bias(tensor, bias):
if bias is None:
return tensor
# add batch dim
bias = bias.unsqueeze(0)
bias = torch.cat([bias] * tensor.size(0), dim=0)
# Determine the number of dimensions required
num_extra_dims = tensor.dim() - bias.dim()
# Unsqueezing the tensor to match the dimensionality
for _ in range(num_extra_dims):
bias = bias.unsqueeze(-1)
# we may need to swap -1 for -2
if bias.size(1) != tensor.size(1):
if len(bias.size()) == 3:
bias = bias.permute(0, 2, 1)
elif len(bias.size()) == 4:
bias = bias.permute(0, 3, 1, 2)
# Multiplying the broadcasted tensor with the output tensor
try:
result = tensor + bias
except RuntimeError as e:
print(e)
print(tensor.size())
print(bias.size())
raise e
return result
class ExtractableModuleMixin:
def extract_weight(
self: Module,
extract_mode: ExtractMode = "existing",
extract_mode_param: Union[int, float] = None,
):
device = self.lora_down.weight.device
weight_to_extract = self.org_module[0].weight
if extract_mode == "existing":
extract_mode = 'fixed'
extract_mode_param = self.lora_dim
if self.org_module[0].__class__.__name__ in CONV_MODULES:
# do conv extraction
down_weight, up_weight, new_dim, diff = extract_conv(
weight=weight_to_extract.clone().detach().float(),
mode=extract_mode,
mode_param=extract_mode_param,
device=device
)
elif self.org_module[0].__class__.__name__ in LINEAR_MODULES:
# do linear extraction
down_weight, up_weight, new_dim, diff = extract_linear(
weight=weight_to_extract.clone().detach().float(),
mode=extract_mode,
mode_param=extract_mode_param,
device=device,
)
else:
raise ValueError(f"Unknown module type: {self.org_module[0].__class__.__name__}")
self.lora_dim = new_dim
# inject weights into the param
self.lora_down.weight.data = down_weight.to(self.lora_down.weight.dtype).clone().detach()
self.lora_up.weight.data = up_weight.to(self.lora_up.weight.dtype).clone().detach()
# copy bias if we have one and are using them
if self.org_module[0].bias is not None and self.lora_up.bias is not None:
self.lora_up.bias.data = self.org_module[0].bias.data.clone().detach()
# set up alphas
self.alpha = (self.alpha * 0) + down_weight.shape[0]
self.scale = self.alpha / self.lora_dim
# assign them
# handle trainable scaler method locon does
if hasattr(self, 'scalar'):
# scaler is a parameter update the value with 1.0
self.scalar.data = torch.tensor(1.0).to(self.scalar.device, self.scalar.dtype)
class ToolkitModuleMixin:
def __init__(
self: Module,
*args,
network: Network,
call_super_init: bool = True,
**kwargs
):
if call_super_init:
super().__init__(*args, **kwargs)
self.network_ref: weakref.ref = weakref.ref(network)
self.is_checkpointing = False
# self.is_normalizing = False
self.normalize_scaler = 1.0
self._multiplier: Union[float, list, torch.Tensor] = None
def _call_forward(self: Module, x):
@@ -100,11 +196,40 @@ class ToolkitModuleMixin:
return lx * scale
# this may get an additional positional arg or not
def lorm_forward(self: Network, x, *args, **kwargs):
network: Network = self.network_ref()
if not network.is_active:
return self.org_forward(x, *args, **kwargs)
if network.lorm_train_mode == 'local':
# we are going to predict input with both and do a loss on them
inputs = x.detach()
with torch.no_grad():
# get the local prediction
target_pred = self.org_forward(inputs, *args, **kwargs).detach()
with torch.set_grad_enabled(True):
# make a prediction with the lorm
lorm_pred = self.lora_up(self.lora_down(inputs.requires_grad_(True)))
local_loss = torch.nn.functional.mse_loss(target_pred.float(), lorm_pred.float())
# backpropr
local_loss.backward()
network.module_losses.append(local_loss.detach())
# return the original as we dont want our trainer to affect ones down the line
return target_pred
else:
return self.lora_up(self.lora_down(x))
def forward(self: Module, x, *args, **kwargs):
skip = False
network = self.network_ref()
network: Network = self.network_ref()
if network.is_lorm:
# we are doing lorm
return self.lorm_forward(x, *args, **kwargs)
# skip if not active
if not network.is_active:
skip = True
@@ -130,40 +255,9 @@ class ToolkitModuleMixin:
if lora_output_batch_size != multiplier_batch_size:
num_interleaves = lora_output_batch_size // multiplier_batch_size
multiplier = multiplier.repeat_interleave(num_interleaves)
# multiplier = 1.0
if self.network_ref().is_normalizing:
with torch.no_grad():
# do this calculation without set multiplier and instead use same polarity, but with 1.0 multiplier
if isinstance(multiplier, torch.Tensor):
norm_multiplier = multiplier.clone().detach() * 10
norm_multiplier = norm_multiplier.clamp(min=-1.0, max=1.0)
else:
norm_multiplier = multiplier
# get a dim array from orig forward that had index of all dimensions except the batch and channel
# Calculate the target magnitude for the combined output
orig_max = torch.max(torch.abs(org_forwarded))
# Calculate the additional increase in magnitude that lora_output would introduce
potential_max_increase = torch.max(
torch.abs(org_forwarded + lora_output * norm_multiplier) - torch.abs(org_forwarded))
epsilon = 1e-6 # Small constant to avoid division by zero
# Calculate the scaling factor for the lora_output
# to ensure that the potential increase in magnitude doesn't change the original max
normalize_scaler = orig_max / (orig_max + potential_max_increase + epsilon)
normalize_scaler = normalize_scaler.detach()
# save the scaler so it can be applied later
self.normalize_scaler = normalize_scaler.clone().detach()
lora_output = lora_output * normalize_scaler
return org_forwarded + broadcast_and_multiply(lora_output, multiplier)
x = org_forwarded + broadcast_and_multiply(lora_output, multiplier)
return x
def enable_gradient_checkpointing(self: Module):
self.is_checkpointing = True
@@ -171,40 +265,6 @@ class ToolkitModuleMixin:
def disable_gradient_checkpointing(self: Module):
self.is_checkpointing = False
@torch.no_grad()
def apply_stored_normalizer(self: Module, target_normalize_scaler: float = 1.0):
"""
Applied the previous normalization calculation to the module.
This must be called before saving or normalization will be lost.
It is probably best to call after each batch as well.
We just scale the up down weights to match this vector
:return:
"""
# get state dict
state_dict = self.state_dict()
dtype = state_dict['lora_up.weight'].dtype
device = state_dict['lora_up.weight'].device
# todo should we do this at fp32?
if isinstance(self.normalize_scaler, torch.Tensor):
scaler = self.normalize_scaler.clone().detach()
else:
scaler = torch.tensor(self.normalize_scaler).to(device, dtype=dtype)
total_module_scale = scaler / target_normalize_scaler
num_modules_layers = 2 # up and down
up_down_scale = torch.pow(total_module_scale, 1.0 / num_modules_layers) \
.to(device, dtype=dtype)
# apply the scaler to the up and down weights
for key in state_dict.keys():
if key.endswith('.lora_up.weight') or key.endswith('.lora_down.weight'):
# do it inplace do params are updated
state_dict[key] *= up_down_scale
# reset the normalization scaler
self.normalize_scaler = target_normalize_scaler
@torch.no_grad()
def merge_out(self: Module, merge_out_weight=1.0):
# make sure it is positive
@@ -251,6 +311,23 @@ class ToolkitModuleMixin:
org_sd["weight"] = weight.to(orig_dtype)
self.org_module[0].load_state_dict(org_sd)
def setup_lorm(self: Module, state_dict: Optional[Dict[str, Any]] = None):
# LoRM (Low Rank Middle) is a method reduce the number of parameters in a module while keeping the inputs and
# outputs the same. It is basically a LoRA but with the original module removed
# if a state dict is passed, use those weights instead of extracting
# todo load from state dict
network: Network = self.network_ref()
lorm_config = network.network_config.lorm_config.get_config_for_module(self.lora_name)
extract_mode = lorm_config.extract_mode
extract_mode_param = lorm_config.extract_mode_param
parameter_threshold = lorm_config.parameter_threshold
self.extract_weight(
extract_mode=extract_mode,
extract_mode_param=extract_mode_param
)
class ToolkitNetworkMixin:
def __init__(
@@ -260,6 +337,8 @@ class ToolkitNetworkMixin:
train_unet: Optional[bool] = True,
is_sdxl=False,
is_v2=False,
network_config: Optional[NetworkConfig] = None,
is_lorm=False,
**kwargs
):
self.train_text_encoder = train_text_encoder
@@ -267,11 +346,14 @@ class ToolkitNetworkMixin:
self.is_checkpointing = False
self._multiplier: float = 1.0
self.is_active: bool = False
self._is_normalizing: bool = False
self.is_sdxl = is_sdxl
self.is_v2 = is_v2
self.is_merged_in = False
# super().__init__(*args, **kwargs)
self.is_lorm = is_lorm
self.network_config: NetworkConfig = network_config
self.module_losses: List[torch.Tensor] = []
self.lorm_train_mode: Literal['local', None] = None
self.can_merge_in = not is_lorm
def get_keymap(self: Network):
if self.is_sdxl:
@@ -443,28 +525,41 @@ class ToolkitNetworkMixin:
self.is_checkpointing = False
self._update_checkpointing()
@property
def is_normalizing(self: Network) -> bool:
return self._is_normalizing
@is_normalizing.setter
def is_normalizing(self: Network, value: bool):
self._is_normalizing = value
# for module in self.get_all_modules():
# module.is_normalizing = self._is_normalizing
def apply_stored_normalizer(self: Network, target_normalize_scaler: float = 1.0):
for module in self.get_all_modules():
module.apply_stored_normalizer(target_normalize_scaler)
def merge_in(self, merge_weight=1.0):
self.is_merged_in = True
for module in self.get_all_modules():
module.merge_in(merge_weight)
def merge_out(self, merge_weight=1.0):
def merge_out(self: Network, merge_weight=1.0):
if not self.is_merged_in:
return
self.is_merged_in = False
for module in self.get_all_modules():
module.merge_out(merge_weight)
def extract_weight(
self: Network,
extract_mode: ExtractMode = "existing",
extract_mode_param: Union[int, float] = None,
):
if extract_mode_param is None:
raise ValueError("extract_mode_param must be set")
for module in tqdm(self.get_all_modules(), desc="Extracting weights"):
module.extract_weight(
extract_mode=extract_mode,
extract_mode_param=extract_mode_param
)
def setup_lorm(self: Network, state_dict: Optional[Dict[str, Any]] = None):
for module in tqdm(self.get_all_modules(), desc="Extracting LoRM"):
module.setup_lorm(state_dict=state_dict)
def calculate_lorem_parameter_reduction(self):
params_reduced = 0
for module in self.get_all_modules():
num_orig_module_params = count_parameters(module.org_module[0])
num_lorem_params = count_parameters(module.lora_down) + count_parameters(module.lora_up)
params_reduced += (num_orig_module_params - num_lorem_params)
return params_reduced

View File

@@ -61,12 +61,8 @@ class BlankNetwork:
def __init__(self):
self.multiplier = 1.0
self.is_active = True
self.is_normalizing = False
self.is_merged_in = False
def apply_stored_normalizer(self, target_normalize_scaler: float = 1.0):
pass
def __enter__(self):
self.is_active = True
@@ -180,11 +176,19 @@ class StableDiffusion:
**load_args
)
else:
pipe = pipln.from_single_file(
model_path,
device=self.device_torch,
torch_dtype=self.torch_dtype,
)
try:
pipe = pipln.from_single_file(
model_path,
device=self.device_torch,
torch_dtype=self.torch_dtype,
)
except Exception as e:
print("Error loading model from single file. Trying to load from pretrained")
pipe = pipln.from_pretrained(
model_path,
device=self.device_torch,
torch_dtype=self.torch_dtype,
)
flush()
text_encoders = [pipe.text_encoder, pipe.text_encoder_2]
@@ -277,19 +281,13 @@ class StableDiffusion:
# check if we have the same network weight for all samples. If we do, we can merge in th
# the network to drastically speed up inference
unique_network_weights = set([x.network_multiplier for x in image_configs])
if len(unique_network_weights) == 1:
if len(unique_network_weights) == 1 and self.network.can_merge_in:
can_merge_in = True
merge_multiplier = unique_network_weights.pop()
network.merge_in(merge_weight=merge_multiplier)
else:
network = BlankNetwork()
was_network_normalizing = network.is_normalizing
# apply the normalizer if it is normalizing before inference and disable it
if network.is_normalizing:
network.apply_stored_normalizer()
network.is_normalizing = False
self.save_device_state()
self.set_device_state_preset('generate')
@@ -471,7 +469,6 @@ class StableDiffusion:
if self.network is not None:
self.network.train()
self.network.multiplier = start_multiplier
self.network.is_normalizing = was_network_normalizing
if network.is_merged_in:
network.merge_out(merge_multiplier)