diff --git a/extensions_built_in/diffusion_models/__init__.py b/extensions_built_in/diffusion_models/__init__.py index 04ff0ff3..a2d5df69 100644 --- a/extensions_built_in/diffusion_models/__init__.py +++ b/extensions_built_in/diffusion_models/__init__.py @@ -3,7 +3,7 @@ from .hidream import HidreamModel, HidreamE1Model from .f_light import FLiteModel from .omnigen2 import OmniGen2Model from .flux_kontext import FluxKontextModel -from .wan22 import Wan225bModel +from .wan22 import Wan225bModel, Wan2214bModel from .qwen_image import QwenImageModel AI_TOOLKIT_MODELS = [ @@ -15,5 +15,6 @@ AI_TOOLKIT_MODELS = [ OmniGen2Model, FluxKontextModel, Wan225bModel, + Wan2214bModel, QwenImageModel, ] diff --git a/extensions_built_in/diffusion_models/wan22/__init__.py b/extensions_built_in/diffusion_models/wan22/__init__.py index 993acbfb..5c88152b 100644 --- a/extensions_built_in/diffusion_models/wan22/__init__.py +++ b/extensions_built_in/diffusion_models/wan22/__init__.py @@ -1 +1,2 @@ -from .wan22_5b_model import Wan225bModel \ No newline at end of file +from .wan22_5b_model import Wan225bModel +from .wan22_14b_model import Wan2214bModel \ No newline at end of file diff --git a/extensions_built_in/diffusion_models/wan22/wan22_14b_model.py b/extensions_built_in/diffusion_models/wan22/wan22_14b_model.py new file mode 100644 index 00000000..3266d320 --- /dev/null +++ b/extensions_built_in/diffusion_models/wan22/wan22_14b_model.py @@ -0,0 +1,470 @@ +from functools import partial +import os +from typing import Any, Dict, Optional, Union +from typing_extensions import Self +import torch +import yaml +from toolkit.accelerator import unwrap_model +from toolkit.basic import flush +from toolkit.prompt_utils import PromptEmbeds +from PIL import Image +from diffusers import UniPCMultistepScheduler +import torch +from toolkit.config_modules import GenerateImageConfig, ModelConfig +from toolkit.samplers.custom_flowmatch_sampler import ( + CustomFlowMatchEulerDiscreteScheduler, +) +from toolkit.util.quantize import quantize_model +from .wan22_pipeline import Wan22Pipeline +from diffusers import WanTransformer3DModel + +from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO +from torchvision.transforms import functional as TF + +from toolkit.models.wan21.wan21 import AggressiveWanUnloadPipeline +from .wan22_5b_model import ( + scheduler_config, + time_text_monkeypatch, + Wan225bModel, +) +from safetensors.torch import load_file, save_file + + +boundary_ratio_t2v = 0.875 +boundary_ratio_i2v = 0.9 + +scheduler_configUniPC = { + "_class_name": "UniPCMultistepScheduler", + "_diffusers_version": "0.35.0.dev0", + "beta_end": 0.02, + "beta_schedule": "linear", + "beta_start": 0.0001, + "disable_corrector": [], + "dynamic_thresholding_ratio": 0.995, + "final_sigmas_type": "zero", + "flow_shift": 3.0, + "lower_order_final": True, + "num_train_timesteps": 1000, + "predict_x0": True, + "prediction_type": "flow_prediction", + "rescale_betas_zero_snr": False, + "sample_max_value": 1.0, + "solver_order": 2, + "solver_p": None, + "solver_type": "bh2", + "steps_offset": 0, + "thresholding": False, + "time_shift_type": "exponential", + "timestep_spacing": "linspace", + "trained_betas": None, + "use_beta_sigmas": False, + "use_dynamic_shifting": False, + "use_exponential_sigmas": False, + "use_flow_sigmas": True, + "use_karras_sigmas": False, +} + + +class DualWanTransformer3DModel(torch.nn.Module): + def __init__( + self, + transformer_1: WanTransformer3DModel, + transformer_2: WanTransformer3DModel, + torch_dtype: Optional[Union[str, torch.dtype]] = None, + device: Optional[Union[str, torch.device]] = None, + boundary_ratio: float = boundary_ratio_t2v, + low_vram: bool = False, + ) -> None: + super().__init__() + self.transformer_1: WanTransformer3DModel = transformer_1 + self.transformer_2: WanTransformer3DModel = transformer_2 + self.torch_dtype: torch.dtype = torch_dtype + self.device_torch: torch.device = device + self.boundary_ratio: float = boundary_ratio + self.boundary: float = self.boundary_ratio * 1000 + self.low_vram: bool = low_vram + self._active_transformer_name = "transformer_1" # default to transformer_1 + + @property + def device(self) -> torch.device: + return self.device_torch + + @property + def dtype(self) -> torch.dtype: + return self.torch_dtype + + @property + def config(self): + return self.transformer_1.config + + @property + def transformer(self) -> WanTransformer3DModel: + return getattr(self, self._active_transformer_name) + + def enable_gradient_checkpointing(self): + """ + Enable gradient checkpointing for both transformers. + """ + self.transformer_1.enable_gradient_checkpointing() + self.transformer_2.enable_gradient_checkpointing() + + def forward( + self, + hidden_states: torch.Tensor, + timestep: torch.LongTensor, + encoder_hidden_states: torch.Tensor, + encoder_hidden_states_image: Optional[torch.Tensor] = None, + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: + # determine if doing high noise or low noise by meaning the timestep. + # timesteps are in the range of 0 to 1000, so we can use a threshold + with torch.no_grad(): + if timestep.float().mean().item() >= self.boundary: + t_name = "transformer_1" + else: + t_name = "transformer_2" + + # check if we are changing the active transformer, if so, we need to swap the one in + # vram if low_vram is enabled + # todo swap the loras as well + if t_name != self._active_transformer_name: + if self.low_vram: + getattr(self, self._active_transformer_name).to("cpu") + getattr(self, t_name).to(self.device_torch) + torch.cuda.empty_cache() + self._active_transformer_name = t_name + if self.transformer.device != hidden_states.device: + raise ValueError( + f"Transformer device {self.transformer.device} does not match hidden states device {hidden_states.device}" + ) + + return self.transformer( + hidden_states=hidden_states, + timestep=timestep, + encoder_hidden_states=encoder_hidden_states, + encoder_hidden_states_image=encoder_hidden_states_image, + return_dict=return_dict, + attention_kwargs=attention_kwargs, + ) + + def to(self, *args, **kwargs) -> Self: + # do not do to, this will be handled separately + return self + + +class Wan2214bModel(Wan225bModel): + arch = "wan22_14b" + _wan_generation_scheduler_config = scheduler_configUniPC + _wan_expand_timesteps = True + _wan_vae_path = "ai-toolkit/wan2.1-vae" + + def __init__( + self, + device, + model_config: ModelConfig, + dtype="bf16", + custom_pipeline=None, + noise_scheduler=None, + **kwargs, + ): + super().__init__( + device=device, + model_config=model_config, + dtype=dtype, + custom_pipeline=custom_pipeline, + noise_scheduler=noise_scheduler, + **kwargs, + ) + # target it so we can target both transformers + self.target_lora_modules = ["DualWanTransformer3DModel"] + self._wan_cache = None + + @property + def max_step_saves_to_keep_multiplier(self): + # the cleanup mechanism checks this to see how many saves to keep + # if we are training a LoRA, we need to set this to 2 so we keep both the high noise and low noise LoRAs at saves to keep + if self.network is not None: + return 2 + return 1 + + def load_model(self): + # load model from patent parent. Wan21 not immediate parent + # super().load_model() + super(Wan225bModel, self).load_model() + + # we have to split up the model on the pipeline + self.pipeline.transformer = self.model.transformer_1 + self.pipeline.transformer_2 = self.model.transformer_2 + + # patch the condition embedder + self.model.transformer_1.condition_embedder.forward = partial( + time_text_monkeypatch, self.model.transformer_1.condition_embedder + ) + self.model.transformer_2.condition_embedder.forward = partial( + time_text_monkeypatch, self.model.transformer_2.condition_embedder + ) + + def get_bucket_divisibility(self): + # 16x compression and 2x2 patch size + return 32 + + def load_wan_transformer(self, transformer_path, subfolder=None): + if self.model_config.split_model_over_gpus: + raise ValueError( + "Splitting model over gpus is not supported for Wan2.2 models" + ) + + if ( + self.model_config.assistant_lora_path is not None + or self.model_config.inference_lora_path is not None + ): + raise ValueError( + "Assistant LoRA is not supported for Wan2.2 models currently" + ) + + if self.model_config.lora_path is not None: + raise ValueError( + "Loading LoRA is not supported for Wan2.2 models currently" + ) + + # transformer path can be a directory that ends with /transformer or a hf path. + + transformer_path_1 = transformer_path + subfolder_1 = subfolder + + transformer_path_2 = transformer_path + subfolder_2 = subfolder + + if subfolder_2 is None: + # we have a local path, replace it with transformer_2 folder + transformer_path_2 = os.path.join( + os.path.dirname(transformer_path_1), "transformer_2" + ) + else: + # we have a hf path, replace it with transformer_2 subfolder + subfolder_2 = "transformer_2" + + self.print_and_status_update("Loading transformer 1") + dtype = self.torch_dtype + transformer_1 = WanTransformer3DModel.from_pretrained( + transformer_path_1, + subfolder=subfolder_1, + torch_dtype=dtype, + ).to(dtype=dtype) + + flush() + + if not self.model_config.low_vram: + # quantize on the device + transformer_1.to(self.quantize_device, dtype=dtype) + flush() + + if self.model_config.quantize: + # todo handle two ARAs + self.print_and_status_update("Quantizing Transformer 1") + quantize_model(self, transformer_1) + flush() + + if self.model_config.low_vram: + self.print_and_status_update("Moving transformer 1 to CPU") + transformer_1.to("cpu") + + self.print_and_status_update("Loading transformer 2") + dtype = self.torch_dtype + transformer_2 = WanTransformer3DModel.from_pretrained( + transformer_path_2, + subfolder=subfolder_2, + torch_dtype=dtype, + ).to(dtype=dtype) + + flush() + + if not self.model_config.low_vram: + # quantize on the device + transformer_2.to(self.quantize_device, dtype=dtype) + flush() + + if self.model_config.quantize: + # todo handle two ARAs + self.print_and_status_update("Quantizing Transformer 2") + quantize_model(self, transformer_2) + flush() + + if self.model_config.low_vram: + self.print_and_status_update("Moving transformer 2 to CPU") + transformer_2.to("cpu") + + # make the combined model + self.print_and_status_update("Creating DualWanTransformer3DModel") + transformer = DualWanTransformer3DModel( + transformer_1=transformer_1, + transformer_2=transformer_2, + torch_dtype=self.torch_dtype, + device=self.device_torch, + boundary_ratio=boundary_ratio_t2v, + low_vram=self.model_config.low_vram, + ) + + return transformer + + def get_generation_pipeline(self): + scheduler = UniPCMultistepScheduler(**self._wan_generation_scheduler_config) + pipeline = Wan22Pipeline( + vae=self.vae, + transformer=self.model.transformer_1, + transformer_2=self.model.transformer_2, + text_encoder=self.text_encoder, + tokenizer=self.tokenizer, + scheduler=scheduler, + expand_timesteps=self._wan_expand_timesteps, + device=self.device_torch, + aggressive_offload=self.model_config.low_vram, + # todo detect if it is i2v or t2v + boundary_ratio=boundary_ratio_t2v, + ) + + # pipeline = pipeline.to(self.device_torch) + + return pipeline + + # static method to get the scheduler + @staticmethod + def get_train_scheduler(): + scheduler = CustomFlowMatchEulerDiscreteScheduler(**scheduler_config) + return scheduler + + def get_base_model_version(self): + return "wan_2.2_14b" + + def generate_single_image( + self, + pipeline: AggressiveWanUnloadPipeline, + gen_config: GenerateImageConfig, + conditional_embeds: PromptEmbeds, + unconditional_embeds: PromptEmbeds, + generator: torch.Generator, + extra: dict, + ): + return super().generate_single_image( + pipeline=pipeline, + gen_config=gen_config, + conditional_embeds=conditional_embeds, + unconditional_embeds=unconditional_embeds, + generator=generator, + extra=extra, + ) + + def get_noise_prediction( + self, + latent_model_input: torch.Tensor, + timestep: torch.Tensor, # 0 to 1000 scale + text_embeddings: PromptEmbeds, + batch: DataLoaderBatchDTO, + **kwargs, + ): + # todo do we need to override this? Adjust timesteps? + return super().get_noise_prediction( + latent_model_input=latent_model_input, + timestep=timestep, + text_embeddings=text_embeddings, + batch=batch, + **kwargs, + ) + + def get_model_has_grad(self): + return False + + def get_te_has_grad(self): + return False + + def save_model(self, output_path, meta, save_dtype): + transformer_combo: DualWanTransformer3DModel = unwrap_model(self.model) + transformer_combo.transformer_1.save_pretrained( + save_directory=os.path.join(output_path, "transformer"), + safe_serialization=True, + ) + transformer_combo.transformer_2.save_pretrained( + save_directory=os.path.join(output_path, "transformer_2"), + safe_serialization=True, + ) + + meta_path = os.path.join(output_path, "aitk_meta.yaml") + with open(meta_path, "w") as f: + yaml.dump(meta, f) + + def save_lora( + self, + state_dict: Dict[str, torch.Tensor], + output_path: str, + metadata: Optional[Dict[str, Any]] = None, + ): + if not self.network.network_config.split_multistage_loras: + # just save as a combo lora + save_file(state_dict, output_path, metadata=metadata) + return + + # we need to build out both dictionaries for high and low noise LoRAs + high_noise_lora = {} + low_noise_lora = {} + + for key in state_dict: + if ".transformer_1." in key: + # this is a high noise LoRA + new_key = key.replace(".transformer_1.", ".") + high_noise_lora[new_key] = state_dict[key] + elif ".transformer_2." in key: + # this is a low noise LoRA + new_key = key.replace(".transformer_2.", ".") + low_noise_lora[new_key] = state_dict[key] + + # loras have either LORA_MODEL_NAME_000005000.safetensors or LORA_MODEL_NAME.safetensors + if len(high_noise_lora.keys()) > 0: + # save the high noise LoRA + high_noise_lora_path = output_path.replace( + ".safetensors", "_high_noise.safetensors" + ) + save_file(high_noise_lora, high_noise_lora_path, metadata=metadata) + + if len(low_noise_lora.keys()) > 0: + # save the low noise LoRA + low_noise_lora_path = output_path.replace( + ".safetensors", "_low_noise.safetensors" + ) + save_file(low_noise_lora, low_noise_lora_path, metadata=metadata) + + def load_lora(self, file: str): + # if it doesnt have high_noise or low_noise, it is a combo LoRA + if "_high_noise.safetensors" not in file and "_low_noise.safetensors" not in file: + # this is a combined LoRA, we need to split it up + sd = load_file(file) + return sd + + # we may have been passed the high_noise or the low_noise LoRA path, but we need to load both + high_noise_lora_path = file.replace( + "_low_noise.safetensors", "_high_noise.safetensors" + ) + low_noise_lora_path = file.replace( + "_high_noise.safetensors", "_low_noise.safetensors" + ) + + combined_dict = {} + + if os.path.exists(high_noise_lora_path): + # load the high noise LoRA + high_noise_lora = load_file(high_noise_lora_path) + for key in high_noise_lora: + new_key = key.replace( + "diffusion_model.", "diffusion_model.transformer_1." + ) + combined_dict[new_key] = high_noise_lora[key] + if os.path.exists(low_noise_lora_path): + # load the low noise LoRA + low_noise_lora = load_file(low_noise_lora_path) + for key in low_noise_lora: + new_key = key.replace( + "diffusion_model.", "diffusion_model.transformer_2." + ) + combined_dict[new_key] = low_noise_lora[key] + + return combined_dict diff --git a/extensions_built_in/diffusion_models/wan22/wan22_pipeline.py b/extensions_built_in/diffusion_models/wan22/wan22_pipeline.py index 6c7b93b7..dafa2012 100644 --- a/extensions_built_in/diffusion_models/wan22/wan22_pipeline.py +++ b/extensions_built_in/diffusion_models/wan22/wan22_pipeline.py @@ -12,7 +12,6 @@ from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback from typing import Any, Callable, Dict, List, Optional, Union - class Wan22Pipeline(WanPipeline): def __init__( self, @@ -52,6 +51,7 @@ class Wan22Pipeline(WanPipeline): num_frames: int = 81, num_inference_steps: int = 50, guidance_scale: float = 5.0, + guidance_scale_2: Optional[float] = None, num_videos_per_prompt: Optional[int] = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, @@ -77,11 +77,15 @@ class Wan22Pipeline(WanPipeline): vae_device = self.vae.device transformer_device = self.transformer.device text_encoder_device = self.text_encoder.device - device = self.transformer.device + device = self._exec_device if self._aggressive_offload: print("Unloading vae") self.vae.to("cpu") + print("Unloading transformer") + self.transformer.to("cpu") + if self.transformer_2 is not None: + self.transformer_2.to("cpu") self.text_encoder.to(device) flush() @@ -95,9 +99,14 @@ class Wan22Pipeline(WanPipeline): prompt_embeds, negative_prompt_embeds, callback_on_step_end_tensor_inputs, + guidance_scale_2 ) + + if self.config.boundary_ratio is not None and guidance_scale_2 is None: + guidance_scale_2 = guidance_scale self._guidance_scale = guidance_scale + self._guidance_scale_2 = guidance_scale_2 self._attention_kwargs = attention_kwargs self._current_timestep = None self._interrupt = False @@ -160,6 +169,13 @@ class Wan22Pipeline(WanPipeline): num_warmup_steps = len(timesteps) - \ num_inference_steps * self.scheduler.order self._num_timesteps = len(timesteps) + + if self.config.boundary_ratio is not None: + boundary_timestep = self.config.boundary_ratio * self.scheduler.config.num_train_timesteps + else: + boundary_timestep = None + + current_model = self.transformer with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): @@ -167,6 +183,25 @@ class Wan22Pipeline(WanPipeline): continue self._current_timestep = t + + if boundary_timestep is None or t >= boundary_timestep: + if self._aggressive_offload and current_model != self.transformer: + if self.transformer_2 is not None: + self.transformer_2.to("cpu") + self.transformer.to(device) + # wan2.1 or high-noise stage in wan2.2 + current_model = self.transformer + current_guidance_scale = guidance_scale + else: + if self._aggressive_offload and current_model != self.transformer_2: + if self.transformer is not None: + self.transformer.to("cpu") + if self.transformer_2 is not None: + self.transformer_2.to(device) + # low-noise stage in wan2.2 + current_model = self.transformer_2 + current_guidance_scale = guidance_scale_2 + latent_model_input = latents.to(device, transformer_dtype) if self.config.expand_timesteps: # seq_len: num_latent_frames * latent_height//2 * latent_width//2 @@ -176,7 +211,7 @@ class Wan22Pipeline(WanPipeline): else: timestep = t.expand(latents.shape[0]) - noise_pred = self.transformer( + noise_pred = current_model( hidden_states=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds, @@ -185,14 +220,14 @@ class Wan22Pipeline(WanPipeline): )[0] if self.do_classifier_free_guidance: - noise_uncond = self.transformer( + noise_uncond = current_model( hidden_states=latent_model_input, timestep=timestep, encoder_hidden_states=negative_prompt_embeds, attention_kwargs=attention_kwargs, return_dict=False, )[0] - noise_pred = noise_uncond + guidance_scale * \ + noise_pred = noise_uncond + current_guidance_scale * \ (noise_pred - noise_uncond) # compute the previous noisy sample x_t -> x_t-1 @@ -259,10 +294,8 @@ class Wan22Pipeline(WanPipeline): # move transformer back to device if self._aggressive_offload: - print("Moving transformer back to device") - self.transformer.to(self._execution_device) - if self.transformer_2 is not None: - self.transformer_2.to(self._execution_device) + # print("Moving transformer back to device") + # self.transformer.to(self._execution_device) flush() if not return_dict: diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 851237cb..daf3d33c 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -437,19 +437,24 @@ class BaseSDTrainProcess(BaseTrainProcess): # Combine and sort the lists combined_items = safetensors_files + directories + pt_files combined_items.sort(key=os.path.getctime) + + num_saves_to_keep = self.save_config.max_step_saves_to_keep + + if hasattr(self.sd, 'max_step_saves_to_keep_multiplier'): + num_saves_to_keep *= self.sd.max_step_saves_to_keep_multiplier # Use slicing with a check to avoid 'NoneType' error safetensors_to_remove = safetensors_files[ - :-self.save_config.max_step_saves_to_keep] if safetensors_files else [] - pt_files_to_remove = pt_files[:-self.save_config.max_step_saves_to_keep] if pt_files else [] - directories_to_remove = directories[:-self.save_config.max_step_saves_to_keep] if directories else [] - embeddings_to_remove = embed_files[:-self.save_config.max_step_saves_to_keep] if embed_files else [] - critic_to_remove = critic_items[:-self.save_config.max_step_saves_to_keep] if critic_items else [] + :-num_saves_to_keep] if safetensors_files else [] + pt_files_to_remove = pt_files[:-num_saves_to_keep] if pt_files else [] + directories_to_remove = directories[:-num_saves_to_keep] if directories else [] + embeddings_to_remove = embed_files[:-num_saves_to_keep] if embed_files else [] + critic_to_remove = critic_items[:-num_saves_to_keep] if critic_items else [] items_to_remove = safetensors_to_remove + pt_files_to_remove + directories_to_remove + embeddings_to_remove + critic_to_remove # remove all but the latest max_step_saves_to_keep - # items_to_remove = combined_items[:-self.save_config.max_step_saves_to_keep] + # items_to_remove = combined_items[:-num_saves_to_keep] # remove duplicates items_to_remove = list(dict.fromkeys(items_to_remove)) diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index a825830b..81415a10 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -185,6 +185,9 @@ class NetworkConfig: self.conv_alpha = 9999999999 # -1 automatically finds the largest factor self.lokr_factor = kwargs.get('lokr_factor', -1) + + # for multi stage models + self.split_multistage_loras = kwargs.get('split_multistage_loras', True) AdapterTypes = Literal['t2i', 'ip', 'ip+', 'clip', 'ilora', 'photo_maker', 'control_net', 'control_lora', 'i2v'] diff --git a/toolkit/models/wan21/wan21.py b/toolkit/models/wan21/wan21.py index bdc2f601..ecdc8f3f 100644 --- a/toolkit/models/wan21/wan21.py +++ b/toolkit/models/wan21/wan21.py @@ -310,6 +310,7 @@ class Wan21(BaseModel): arch = 'wan21' _wan_generation_scheduler_config = scheduler_configUniPC _wan_expand_timesteps = False + _wan_vae_path = None _comfy_te_file = ['text_encoders/umt5_xxl_fp16.safetensors', 'text_encoders/umt5_xxl_fp8_e4m3fn_scaled.safetensors'] def __init__( @@ -431,8 +432,14 @@ class Wan21(BaseModel): scheduler = Wan21.get_train_scheduler() self.print_and_status_update("Loading VAE") # todo, example does float 32? check if quality suffers - vae = AutoencoderKLWan.from_pretrained( - vae_path, subfolder="vae", torch_dtype=dtype).to(dtype=dtype) + + if self._wan_vae_path is not None: + # load the vae from individual repo + vae = AutoencoderKLWan.from_pretrained( + self._wan_vae_path, torch_dtype=dtype).to(dtype=dtype) + else: + vae = AutoencoderKLWan.from_pretrained( + vae_path, subfolder="vae", torch_dtype=dtype).to(dtype=dtype) flush() self.print_and_status_update("Making pipe") diff --git a/toolkit/network_mixins.py b/toolkit/network_mixins.py index 796e4042..59c15d3d 100644 --- a/toolkit/network_mixins.py +++ b/toolkit/network_mixins.py @@ -565,6 +565,13 @@ class ToolkitNetworkMixin: if metadata is None: metadata = OrderedDict() metadata = add_model_hash_to_meta(save_dict, metadata) + # let the model handle the saving + + if self.base_model_ref is not None and hasattr(self.base_model_ref(), 'save_lora'): + # call the base model save lora method + self.base_model_ref().save_lora(save_dict, file, metadata) + return + if os.path.splitext(file)[1] == ".safetensors": from safetensors.torch import save_file save_file(save_dict, file, metadata) @@ -577,12 +584,15 @@ class ToolkitNetworkMixin: keymap = {} if keymap is None else keymap if isinstance(file, str): - if os.path.splitext(file)[1] == ".safetensors": - from safetensors.torch import load_file - - weights_sd = load_file(file) + if self.base_model_ref is not None and hasattr(self.base_model_ref(), 'load_lora'): + # call the base model load lora method + weights_sd = self.base_model_ref().load_lora(file) else: - weights_sd = torch.load(file, map_location="cpu") + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import load_file + weights_sd = load_file(file) + else: + weights_sd = torch.load(file, map_location="cpu") else: # probably a state dict weights_sd = file