diff --git a/extensions_built_in/diffusion_models/qwen_image/qwen_image.py b/extensions_built_in/diffusion_models/qwen_image/qwen_image.py index edfe0c13..1828f835 100644 --- a/extensions_built_in/diffusion_models/qwen_image/qwen_image.py +++ b/extensions_built_in/diffusion_models/qwen_image/qwen_image.py @@ -9,63 +9,69 @@ from PIL import Image from toolkit.models.base_model import BaseModel from toolkit.basic import flush from toolkit.prompt_utils import PromptEmbeds -from toolkit.samplers.custom_flowmatch_sampler import CustomFlowMatchEulerDiscreteScheduler +from toolkit.samplers.custom_flowmatch_sampler import ( + CustomFlowMatchEulerDiscreteScheduler, +) from toolkit.accelerator import get_accelerator, unwrap_model from optimum.quanto import freeze, QTensor from toolkit.util.quantize import quantize, get_qtype, quantize_model import torch.nn.functional as F +from toolkit.memory_management import MemoryManager +from safetensors.torch import load_file -from diffusers import QwenImagePipeline, QwenImageTransformer2DModel, AutoencoderKLQwenImage -from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, Qwen2VLProcessor +from diffusers import ( + QwenImagePipeline, + QwenImageTransformer2DModel, + AutoencoderKLQwenImage, +) +from transformers import ( + Qwen2_5_VLForConditionalGeneration, + Qwen2Tokenizer, + Qwen2VLProcessor, +) from tqdm import tqdm if TYPE_CHECKING: from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO scheduler_config = { - "base_image_seq_len": 256, - "base_shift": 0.5, - "invert_sigmas": False, - "max_image_seq_len": 8192, - "max_shift": 0.9, - "num_train_timesteps": 1000, - "shift": 1.0, - "shift_terminal": 0.02, - "stochastic_sampling": False, - "time_shift_type": "exponential", - "use_beta_sigmas": False, - "use_dynamic_shifting": True, - "use_exponential_sigmas": False, - "use_karras_sigmas": False + "base_image_seq_len": 256, + "base_shift": 0.5, + "invert_sigmas": False, + "max_image_seq_len": 8192, + "max_shift": 0.9, + "num_train_timesteps": 1000, + "shift": 1.0, + "shift_terminal": 0.02, + "stochastic_sampling": False, + "time_shift_type": "exponential", + "use_beta_sigmas": False, + "use_dynamic_shifting": True, + "use_exponential_sigmas": False, + "use_karras_sigmas": False, } - class QwenImageModel(BaseModel): arch = "qwen_image" _qwen_image_keep_visual = False _qwen_pipeline = QwenImagePipeline def __init__( - self, - device, - model_config: ModelConfig, - dtype='bf16', - custom_pipeline=None, - noise_scheduler=None, - **kwargs + self, + device, + model_config: ModelConfig, + dtype="bf16", + custom_pipeline=None, + noise_scheduler=None, + **kwargs, ): super().__init__( - device, - model_config, - dtype, - custom_pipeline, - noise_scheduler, - **kwargs + device, model_config, dtype, custom_pipeline, noise_scheduler, **kwargs ) self.is_flow_matching = True self.is_transformer = True - self.target_lora_modules = ['QwenImageTransformer2DModel'] + self.target_lora_modules = ["QwenImageTransformer2DModel"] # static method to get the noise scheduler @staticmethod @@ -73,40 +79,58 @@ class QwenImageModel(BaseModel): return CustomFlowMatchEulerDiscreteScheduler(**scheduler_config) def get_bucket_divisibility(self): - return 16 * 2 # 16 for the VAE, 2 for patch size + return 16 * 2 # 16 for the VAE, 2 for patch size def load_model(self): dtype = self.torch_dtype self.print_and_status_update("Loading Qwen Image model") model_path = self.model_config.name_or_path base_model_path = self.model_config.extras_name_or_path + model_dtype = dtype - transformer_path = model_path - transformer_subfolder = 'transformer' - if os.path.exists(transformer_path): - transformer_subfolder = None - transformer_path = os.path.join(transformer_path, 'transformer') - # check if the path is a full checkpoint. - te_folder_path = os.path.join(model_path, 'text_encoder') - # if we have the te, this folder is a full checkpoint, use it as the base - if os.path.exists(te_folder_path): - base_model_path = model_path + if base_model_path.endswith(".safetensors"): + # use the repo for extras + base_model_path = "Qwen/Qwen-Image" self.print_and_status_update("Loading transformer") - transformer = QwenImageTransformer2DModel.from_pretrained( - transformer_path, - subfolder=transformer_subfolder, - torch_dtype=dtype - ) + + if model_path.endswith(".safetensors"): + # load the safetensors file + transformer = QwenImageTransformer2DModel.from_single_file( + model_path, + config="Qwen/Qwen-Image", + subfolder="transformer", + torch_dtype=model_dtype, + ) + transformer.to(model_dtype) + + else: + transformer_path = model_path + transformer_subfolder = "transformer" + if os.path.exists(transformer_path): + transformer_subfolder = None + transformer_path = os.path.join(transformer_path, "transformer") + # check if the path is a full checkpoint. + te_folder_path = os.path.join(model_path, "text_encoder") + # if we have the te, this folder is a full checkpoint, use it as the base + if os.path.exists(te_folder_path): + base_model_path = model_path + + transformer = QwenImageTransformer2DModel.from_pretrained( + transformer_path, subfolder=transformer_subfolder, torch_dtype=dtype + ) if self.model_config.quantize: self.print_and_status_update("Quantizing Transformer") quantize_model(self, transformer) flush() - + + if self.model_config.auto_memory: + MemoryManager.attach(transformer, self.device_torch) + if self.model_config.low_vram: self.print_and_status_update("Moving transformer to CPU") - transformer.to('cpu') + transformer.to("cpu") flush() @@ -117,32 +141,35 @@ class QwenImageModel(BaseModel): text_encoder = Qwen2_5_VLForConditionalGeneration.from_pretrained( base_model_path, subfolder="text_encoder", torch_dtype=dtype ) - + # remove the visual model as it is not needed for image generation self.processor = None if not self._qwen_image_keep_visual: text_encoder.model.visual = None + if self.model_config.auto_memory: + MemoryManager.attach(text_encoder, self.device_torch) + text_encoder.to(self.device_torch, dtype=dtype) flush() if self.model_config.quantize_te: self.print_and_status_update("Quantizing Text Encoder") - quantize(text_encoder, weights=get_qtype( - self.model_config.qtype_te)) + quantize(text_encoder, weights=get_qtype(self.model_config.qtype_te)) freeze(text_encoder) flush() self.print_and_status_update("Loading VAE") vae = AutoencoderKLQwenImage.from_pretrained( - base_model_path, subfolder="vae", torch_dtype=dtype) + base_model_path, subfolder="vae", torch_dtype=dtype + ) self.noise_scheduler = QwenImageModel.get_train_scheduler() self.print_and_status_update("Making pipe") - + kwargs = {} - + if self._qwen_image_keep_visual: try: self.processor = Qwen2VLProcessor.from_pretrained( @@ -152,7 +179,7 @@ class QwenImageModel(BaseModel): self.processor = Qwen2VLProcessor.from_pretrained( base_model_path, subfolder="processor" ) - kwargs['processor'] = self.processor + kwargs["processor"] = self.processor pipe: QwenImagePipeline = self._qwen_pipeline( scheduler=self.noise_scheduler, @@ -160,7 +187,7 @@ class QwenImageModel(BaseModel): tokenizer=tokenizer, vae=vae, transformer=None, - **kwargs + **kwargs, ) # for quantization, it works best to do these after making the pipe pipe.text_encoder = text_encoder @@ -198,7 +225,7 @@ class QwenImageModel(BaseModel): text_encoder=unwrap_model(self.text_encoder[0]), tokenizer=self.tokenizer[0], vae=unwrap_model(self.vae), - transformer=unwrap_model(self.transformer) + transformer=unwrap_model(self.transformer), ) pipeline = pipeline.to(self.device_torch) @@ -231,22 +258,27 @@ class QwenImageModel(BaseModel): # flush for low vram if we are doing that flush_between_steps = self.model_config.low_vram - # Fix a bug in diffusers/torch + + # Fix a bug in diffusers/torch def callback_on_step_end(pipe, i, t, callback_kwargs): if flush_between_steps: flush() latents = callback_kwargs["latents"] - + return {"latents": latents} - + sc = self.get_bucket_divisibility() - gen_config.width = int(gen_config.width // sc * sc) + gen_config.width = int(gen_config.width // sc * sc) gen_config.height = int(gen_config.height // sc * sc) img = pipeline( prompt_embeds=conditional_embeds.text_embeds, - prompt_embeds_mask=conditional_embeds.attention_mask.to(self.device_torch, dtype=torch.int64), + prompt_embeds_mask=conditional_embeds.attention_mask.to( + self.device_torch, dtype=torch.int64 + ), negative_prompt_embeds=unconditional_embeds.text_embeds, - negative_prompt_embeds_mask=unconditional_embeds.attention_mask.to(self.device_torch, dtype=torch.int64), + negative_prompt_embeds_mask=unconditional_embeds.attention_mask.to( + self.device_torch, dtype=torch.int64 + ), height=gen_config.height, width=gen_config.width, num_inference_steps=gen_config.num_inference_steps, @@ -254,7 +286,7 @@ class QwenImageModel(BaseModel): latents=gen_config.latents, generator=generator, callback_on_step_end=callback_on_step_end, - **extra + **extra, ).images[0] return img @@ -263,28 +295,36 @@ class QwenImageModel(BaseModel): latent_model_input: torch.Tensor, timestep: torch.Tensor, # 0 to 1000 scale text_embeddings: PromptEmbeds, - **kwargs + **kwargs, ): self.model.to(self.device_torch) batch_size, num_channels_latents, height, width = latent_model_input.shape - + ps = self.transformer.config.patch_size # pack image tokens - latent_model_input = latent_model_input.view(batch_size, num_channels_latents, height // ps, ps, width // ps, ps) + latent_model_input = latent_model_input.view( + batch_size, num_channels_latents, height // ps, ps, width // ps, ps + ) latent_model_input = latent_model_input.permute(0, 2, 4, 1, 3, 5) - latent_model_input = latent_model_input.reshape(batch_size, (height // ps) * (width // ps), num_channels_latents * (ps * ps)) + latent_model_input = latent_model_input.reshape( + batch_size, (height // ps) * (width // ps), num_channels_latents * (ps * ps) + ) # img_shapes passed to the model img_h2, img_w2 = height // ps, width // ps img_shapes = [[(1, img_h2, img_w2)]] * batch_size enc_hs = text_embeddings.text_embeds.to(self.device_torch, self.torch_dtype) - prompt_embeds_mask = text_embeddings.attention_mask.to(self.device_torch, dtype=torch.int64) + prompt_embeds_mask = text_embeddings.attention_mask.to( + self.device_torch, dtype=torch.int64 + ) txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() noise_pred = self.transformer( - hidden_states=latent_model_input.to(self.device_torch, self.torch_dtype).detach(), + hidden_states=latent_model_input.to( + self.device_torch, self.torch_dtype + ).detach(), timestep=(timestep / 1000).detach(), guidance=None, encoder_hidden_states=enc_hs.detach(), @@ -296,56 +336,55 @@ class QwenImageModel(BaseModel): )[0] # unpack - noise_pred = noise_pred.view(batch_size, height // ps, width // ps, num_channels_latents, ps, ps) + noise_pred = noise_pred.view( + batch_size, height // ps, width // ps, num_channels_latents, ps, ps + ) noise_pred = noise_pred.permute(0, 3, 1, 4, 2, 5) noise_pred = noise_pred.reshape(batch_size, num_channels_latents, height, width) return noise_pred - + def get_prompt_embeds(self, prompt: str) -> PromptEmbeds: if self.pipeline.text_encoder.device != self.device_torch: self.pipeline.text_encoder.to(self.device_torch) - + prompt_embeds, prompt_embeds_mask = self.pipeline.encode_prompt( prompt, device=self.device_torch, num_images_per_prompt=1, ) - pe = PromptEmbeds( - prompt_embeds - ) + pe = PromptEmbeds(prompt_embeds) pe.attention_mask = prompt_embeds_mask return pe - + def get_model_has_grad(self): return False def get_te_has_grad(self): return False - + def save_model(self, output_path, meta, save_dtype): # only save the unet transformer: QwenImageTransformer2DModel = unwrap_model(self.model) transformer.save_pretrained( - save_directory=os.path.join(output_path, 'transformer'), + save_directory=os.path.join(output_path, "transformer"), safe_serialization=True, ) - meta_path = os.path.join(output_path, 'aitk_meta.yaml') - with open(meta_path, 'w') as f: + meta_path = os.path.join(output_path, "aitk_meta.yaml") + with open(meta_path, "w") as f: yaml.dump(meta, f) def get_loss_target(self, *args, **kwargs): - noise = kwargs.get('noise') - batch = kwargs.get('batch') + noise = kwargs.get("noise") + batch = kwargs.get("batch") return (noise - batch.latents).detach() - def get_base_model_version(self): return "qwen_image" - + def get_transformer_block_names(self) -> Optional[List[str]]: - return ['transformer_blocks'] - + return ["transformer_blocks"] + def convert_lora_weights_before_save(self, state_dict): new_sd = {} for key, value in state_dict.items(): @@ -359,20 +398,15 @@ class QwenImageModel(BaseModel): new_key = key.replace("diffusion_model.", "transformer.") new_sd[new_key] = value return new_sd - - def encode_images( - self, - image_list: List[torch.Tensor], - device=None, - dtype=None - ): + + def encode_images(self, image_list: List[torch.Tensor], device=None, dtype=None): if device is None: device = self.vae_device_torch if dtype is None: dtype = self.vae_torch_dtype # Move to vae to device if on cpu - if self.vae.device == 'cpu': + if self.vae.device == "cpu": self.vae.to(device) self.vae.eval() self.vae.requires_grad_(False) @@ -383,20 +417,19 @@ class QwenImageModel(BaseModel): images = images.unsqueeze(2) latents = self.vae.encode(images).latent_dist.sample() - + latents_mean = ( torch.tensor(self.vae.config.latents_mean) .view(1, self.vae.config.z_dim, 1, 1, 1) .to(latents.device, latents.dtype) ) - latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( - latents.device, latents.dtype - ) - + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view( + 1, self.vae.config.z_dim, 1, 1, 1 + ).to(latents.device, latents.dtype) + latents = (latents - latents_mean) * latents_std latents = latents.to(device, dtype=dtype) - - + latents = latents.squeeze(2) # remove the frame count dimension - return latents \ No newline at end of file + return latents diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index a75eb3b2..6a696d6c 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -1759,7 +1759,7 @@ class BaseSDTrainProcess(BaseTrainProcess): ) # we cannot merge in if quantized - if self.model_config.quantize: + if self.model_config.quantize or self.model_config.auto_memory: # todo find a way around this self.network.can_merge_in = False diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 71871617..cb3c6c80 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -624,6 +624,15 @@ class ModelConfig: self.arch: ModelArch = kwargs.get("arch", None) + # auto memory management, only for some models + self.auto_memory = kwargs.get("auto_memory", False) + if self.auto_memory and self.qtype == "qfloat8": + print(f"Auto memory is not compatible with qfloat8, switching to float8 for model") + self.qtype = "float8" + if self.auto_memory and not self.qtype_te == "qfloat8": + print(f"Auto memory is not compatible with qfloat8, switching to float8 for te") + self.qtype_te = "float8" + # can be used to load the extras like text encoder or vae from here # only setup for some models but will prevent having to download the te for # 20 different model variants @@ -650,6 +659,7 @@ class ModelConfig: if self.arch == "flex1": self.arch = "flux" + # handle migrating to new model arch if self.arch is not None: diff --git a/toolkit/memory_management/manager.py b/toolkit/memory_management/manager.py index 4480ea94..fa7a2d07 100644 --- a/toolkit/memory_management/manager.py +++ b/toolkit/memory_management/manager.py @@ -1,12 +1,92 @@ -from typing import TYPE_CHECKING +import torch +from .manager_modules import LinearLayerMemoryManager, ConvLayerMemoryManager -if TYPE_CHECKING: - from toolkit.models.base_model import BaseModel +LINEAR_MODULES = [ + "Linear", + "LoRACompatibleLinear", + "QLinear", +] +CONV_MODULES = [ + "Conv2d", + "LoRACompatibleConv", + "QConv2d", +] + +UNMANAGED_MODULES = [ + "LayerNorm", + "BatchNorm1d", + "BatchNorm2d", + "BatchNorm3d", + "GroupNorm", + "InstanceNorm1d", + "InstanceNorm2d", + "InstanceNorm3d", + "Embedding", + "EmbeddingBag", + "RNNBase", + "LSTM", + "GRU", + "RNN", +] + +UNMANAGED_MODULES_INCLUDES = ["RotaryEmbedding", "Norm"] class MemoryManager: def __init__( self, - model: "BaseModel", + module: torch.nn.Module, + process_device: torch.device = torch.device("cpu"), ): - self.model: "BaseModel" = model + self.module: torch.nn.Module = module + self.process_device: torch.device = process_device + self.unmanaged_modules: list[torch.nn.Module] = [] + + def memory_managed_to(self, *args, **kwargs): + # first move all the unmanaged modules + for module in self.unmanaged_modules: + module.to(*args, **kwargs) + # check for a dtype argument + dtype = None + if "dtype" in kwargs: + dtype = kwargs["dtype"] + elif len(args) > 0: + for i, arg in enumerate(args): + if isinstance(arg, torch.dtype): + dtype = arg + break + if dtype is not None: + return self.module._mm_to(dtype=dtype) + return self.module + + @classmethod + def attach(cls, module: torch.nn.Module, device: torch.device): + if hasattr(module, "_memory_manager"): + # already attached + return + + module._memory_manager = cls(module, device) + + # override the to method to handle memory management + module._mm_to = module.to + module.to = module._memory_manager.memory_managed_to + + # attach to all modules + for name, sub_module in module.named_modules(): + for child_name, child_module in sub_module.named_modules(): + if child_module.__class__.__name__ in LINEAR_MODULES: + # linear + LinearLayerMemoryManager.attach( + child_module, module._memory_manager + ) + elif child_module.__class__.__name__ in CONV_MODULES: + # conv + ConvLayerMemoryManager.attach(child_module, module._memory_manager) + elif child_module.__class__.__name__ in UNMANAGED_MODULES or any( + inc in child_module.__class__.__name__ + for inc in UNMANAGED_MODULES_INCLUDES + ): + # unmanaged + module._memory_manager.unmanaged_modules.append(child_module) + else: + continue diff --git a/toolkit/memory_management/manager_modules.py b/toolkit/memory_management/manager_modules.py new file mode 100644 index 00000000..4898dfc3 --- /dev/null +++ b/toolkit/memory_management/manager_modules.py @@ -0,0 +1,450 @@ +""" +This code was heavily inspired by the work of Lodestone-Rock, pretty much all credit goes +to them. The original code can be found here: +https://github.com/lodestone-rock/RamTorch/blob/main/ramtorch/modules/linear.py + +I simply modified it to work with a memory management model and with AI Toolkit's models +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +from typing import TYPE_CHECKING, Optional, Tuple + +if TYPE_CHECKING: + from .manager import MemoryManager + +# --- Per-device global state registry --- +_DEVICE_STATE = {} + + +def _get_device_state(device: torch.device): + """Get or initialize per-device state.""" + if isinstance(device, str): + device = torch.device(device) + + # CPU path needs no CUDA state + if device.type != "cuda": + if device not in _DEVICE_STATE: + _DEVICE_STATE[device] = {} + return _DEVICE_STATE[device] + + if device not in _DEVICE_STATE: + with torch.cuda.device(device): + _DEVICE_STATE[device] = { + # streams & events + "transfer_stream": torch.cuda.Stream(device=device), + "transfer_grad_stream": torch.cuda.Stream(device=device), + "transfer_forward_finished_event": torch.cuda.Event(), + "compute_forward_start_event": torch.cuda.Event(), + "transfer_backward_finished_event": torch.cuda.Event(), + "transfer_weight_backward_finished_event": torch.cuda.Event(), + "compute_backward_start_event": torch.cuda.Event(), + "compute_backward_finished_event": torch.cuda.Event(), + # ping-pong buffers + "w_buffers": [None, None], + "b_buffers": [None, None], + "w_bwd_buffers": [None, None], + # device-side staging for grads to be sent to CPU + "w_grad_buffers": [None, None], + "b_grad_buffers": [None, None], + # clocks + "forward_clk": 0, + "backward_clk": 0, + } + return _DEVICE_STATE[device] + + +def _ensure_cpu_pinned(t: Optional[torch.Tensor]) -> Optional[torch.Tensor]: + if t is None: + return None + if t.device.type != "cpu": + t = t.to("cpu", copy=True) + if torch.cuda.is_available(): + try: + t = t.pin_memory() + except RuntimeError: + pass + return t + + +def _move_params_to_cpu_and_pin(module: nn.Module): + """Force parameters to CPU (+pinned) so we can 'bounce' them per forward/backward.""" + with torch.no_grad(): + if hasattr(module, "weight") and isinstance(module.weight, nn.Parameter): + module.weight.data = _ensure_cpu_pinned(module.weight.data).detach() + if hasattr(module, "bias") and isinstance(module.bias, nn.Parameter): + if module.bias is not None: + module.bias.data = _ensure_cpu_pinned(module.bias.data).detach() + + +# ========================== +# Autograd functions (CUDA) +# ========================== + + +class _BouncingLinearFn(torch.autograd.Function): + @staticmethod + def forward(ctx, x, weight_cpu, bias_cpu, device: torch.device): + if device.type != "cuda": + out = F.linear(x.to("cpu"), weight_cpu, bias_cpu) + ctx.save_for_backward(x.to("cpu"), weight_cpu, bias_cpu) + ctx.device = torch.device("cpu") + return out.to(x.device) + + state = _get_device_state(device) + ts = state["transfer_stream"] + w_bufs, b_bufs = state["w_buffers"], state["b_buffers"] + ev_tx_f = state["transfer_forward_finished_event"] + ev_cu_s = state["compute_forward_start_event"] + idx = state["forward_clk"] + + with torch.cuda.stream(ts): + ts.wait_event(ev_cu_s) + w_bufs[idx] = weight_cpu.to(device, non_blocking=True) + b_bufs[idx] = ( + bias_cpu.to(device, non_blocking=True) if bias_cpu is not None else None + ) + state["forward_clk"] ^= 1 + ev_tx_f.record() + + torch.cuda.current_stream().wait_event(ev_tx_f) + ev_cu_s.record() + out = F.linear(x, w_bufs[idx], b_bufs[idx]) + + ctx.save_for_backward(x, weight_cpu, bias_cpu) + ctx.device = device + return out + + @staticmethod + def backward(ctx, grad_out): + x, weight_cpu, bias_cpu = ctx.saved_tensors + device = ctx.device + + if device.type != "cuda": + go_cpu = grad_out.to("cpu") + x_cpu = x.to("cpu") + grad_input = go_cpu @ weight_cpu + grad_weight = go_cpu.flatten(0, -2).T @ x_cpu.flatten(0, -2) + grad_bias = ( + go_cpu.sum(dim=tuple(range(go_cpu.ndim - 1))) + if bias_cpu is not None + else None + ) + return grad_input.to(grad_out.device), grad_weight, grad_bias, None + + state = _get_device_state(device) + transfer_stream = state["transfer_stream"] + transfer_grad_stream = state["transfer_grad_stream"] + + w_bwd_buffers = state["w_bwd_buffers"] + w_grad_buffers = state["w_grad_buffers"] + b_grad_buffers = state["b_grad_buffers"] + + ev_tx_b = state["transfer_backward_finished_event"] + ev_tx_w_bwd_done = state["transfer_weight_backward_finished_event"] + ev_cu_b_start = state["compute_backward_start_event"] + ev_cu_b_finish = state["compute_backward_finished_event"] + + idx = state["backward_clk"] + + # Stage weights onto device (transfer stream), ping-pong to avoid races + with torch.cuda.stream(transfer_stream): + transfer_stream.wait_event(ev_cu_b_start) + w_bwd_buffers[idx] = weight_cpu.to(device, non_blocking=True) + state["backward_clk"] ^= 1 + ev_tx_b.record() + + # Compute stream waits for weights to arrive, then start compute + torch.cuda.current_stream().wait_event(ev_tx_b) + ev_cu_b_start.record() + + # 1) Compute grad_input using the freshly transferred weights + grad_input = grad_out @ w_bwd_buffers[idx] + + # 2) Ensure previous grad-to-CPU transfer that used this slot finished + torch.cuda.current_stream().wait_event(ev_tx_w_bwd_done) + + # 3) Compute weight/bias grads on GPU into staging buffers + w_grad_buffers[idx] = grad_out.flatten(0, -2).T @ x.flatten(0, -2) + if bias_cpu is not None: + reduce_dims = tuple(range(grad_out.ndim - 1)) + b_grad_buffers[idx] = grad_out.sum(dim=reduce_dims) + + # Mark end of GPU compute + ev_cu_b_finish.record() + + # 4) Launch non-blocking H2D->CPU transfers on a separate grad stream (full-duplex) + with torch.cuda.stream(transfer_grad_stream): + transfer_grad_stream.wait_event(ev_cu_b_finish) + grad_weight = w_grad_buffers[idx].to("cpu", non_blocking=True) + grad_bias = ( + b_grad_buffers[idx].to("cpu", non_blocking=True) + if bias_cpu is not None + else None + ) + # signal that this slot's CPU transfer is complete (safe for next reuse) + state["transfer_weight_backward_finished_event"].record() + + return grad_input, grad_weight, grad_bias, None + + +class _BouncingConv2dFn(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x, + weight_cpu, + bias_cpu, + device: torch.device, + stride: Tuple[int, int], + padding: Tuple[int, int], + dilation: Tuple[int, int], + groups: int, + ): + if device.type != "cuda": + out = F.conv2d( + x.to("cpu"), weight_cpu, bias_cpu, stride, padding, dilation, groups + ) + ctx.save_for_backward(x.to("cpu"), weight_cpu, bias_cpu) + ctx.meta = ("cpu", stride, padding, dilation, groups) + return out.to(x.device) + + state = _get_device_state(device) + ts = state["transfer_stream"] + w_bufs, b_bufs = state["w_buffers"], state["b_buffers"] + ev_tx_f = state["transfer_forward_finished_event"] + ev_cu_s = state["compute_forward_start_event"] + idx = state["forward_clk"] + + with torch.cuda.stream(ts): + ts.wait_event(ev_cu_s) + w_bufs[idx] = weight_cpu.to(device, non_blocking=True) + b_bufs[idx] = ( + bias_cpu.to(device, non_blocking=True) if bias_cpu is not None else None + ) + state["forward_clk"] ^= 1 + ev_tx_f.record() + + torch.cuda.current_stream().wait_event(ev_tx_f) + ev_cu_s.record() + out = F.conv2d(x, w_bufs[idx], b_bufs[idx], stride, padding, dilation, groups) + + ctx.save_for_backward(x, weight_cpu, bias_cpu) + ctx.meta = (device, stride, padding, dilation, groups) + return out + + @staticmethod + def backward(ctx, grad_out): + x, weight_cpu, bias_cpu = ctx.saved_tensors + meta = ctx.meta + device, stride, padding, dilation, groups = meta + + if ( + isinstance(device, torch.device) and device.type != "cuda" + ) or device == "cpu": + # CPU grads + go = grad_out.to("cpu") + x_cpu = x.to("cpu") + w_cpu = weight_cpu + from torch.nn.grad import conv2d_input, conv2d_weight # type: ignore + + grad_input = conv2d_input( + x_cpu.shape, + w_cpu, + go, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + ) + grad_weight = conv2d_weight( + x_cpu, + w_cpu.shape, + go, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + ) + grad_bias = go.sum(dim=(0, 2, 3)) if bias_cpu is not None else None + return ( + grad_input.to(grad_out.device), + grad_weight, + grad_bias, + None, + None, + None, + None, + None, + ) + + # CUDA path (full-duplex) + state = _get_device_state(device) + transfer_stream = state["transfer_stream"] + transfer_grad_stream = state["transfer_grad_stream"] + + # device-side buffers + w_bwd_buffers = state["w_bwd_buffers"] + w_grad_buffers = state["w_grad_buffers"] + b_grad_buffers = state["b_grad_buffers"] + + ev_tx_b = state["transfer_backward_finished_event"] + ev_tx_w_bwd_done = state["transfer_weight_backward_finished_event"] + ev_cu_b_start = state["compute_backward_start_event"] + ev_cu_b_finish = state["compute_backward_finished_event"] + + idx = state["backward_clk"] + + # Stage weights for input-grad compute + with torch.cuda.stream(transfer_stream): + transfer_stream.wait_event(ev_cu_b_start) + w_bwd_buffers[idx] = weight_cpu.to(device, non_blocking=True) + state["backward_clk"] ^= 1 + ev_tx_b.record() + + torch.cuda.current_stream().wait_event(ev_tx_b) + ev_cu_b_start.record() + + # grad wrt input on GPU with streamed weights + from torch.nn.grad import conv2d_input, conv2d_weight # type: ignore + + grad_input = conv2d_input( + x.shape, + w_bwd_buffers[idx], + grad_out, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + ) + + # Ensure previous grad transfer that used this slot is done + torch.cuda.current_stream().wait_event(ev_tx_w_bwd_done) + + # Compute heavy grads on GPU into staging buffers + w_grad_buffers[idx] = conv2d_weight( + x, + weight_cpu.shape, + grad_out, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + ) + if bias_cpu is not None: + b_grad_buffers[idx] = grad_out.sum(dim=(0, 2, 3)) + + # Mark end of GPU math + ev_cu_b_finish.record() + + # Launch CPU copies on the dedicated grad stream (overlaps with next H2D) + with torch.cuda.stream(transfer_grad_stream): + transfer_grad_stream.wait_event(ev_cu_b_finish) + grad_weight = w_grad_buffers[idx].to("cpu", non_blocking=True) + grad_bias = ( + b_grad_buffers[idx].to("cpu", non_blocking=True) + if bias_cpu is not None + else None + ) + state["transfer_weight_backward_finished_event"].record() + + return grad_input, grad_weight, grad_bias, None, None, None, None, None + + +class BaseLayerMemoryManager: + def __init__( + self, + module: nn.Module, + manager: "MemoryManager", + ): + self.module: nn.Module = module + self.manager: "MemoryManager" = manager + + @classmethod + def attach(cls, module: nn.Module, manager: "MemoryManager"): + if hasattr(module, "_layer_memory_manager"): + return + module._layer_memory_manager = cls(module, manager) + + # mark parameters as memory managed + for param in module.parameters(recurse=False): + param._is_memory_managed = True + + +class LinearLayerMemoryManager(BaseLayerMemoryManager): + def __init__( + self, + module: nn.Module, + manager: "MemoryManager", + ): + super().__init__(module, manager) + + # 1) Move params to CPU + pin memory for fast H2D + _move_params_to_cpu_and_pin(self.module) + + # 2) Hijack forward + self._original_forward = getattr(self.module, "forward") + + def _mm_forward(x, *args, **kwargs): + # ensure we only use expected signature (Linear: x) + if args or kwargs: + # fall back to original if a custom signature is used + return self._original_forward(x, *args, **kwargs) + + weight_cpu = self.module.weight + bias_cpu = getattr(self.module, "bias", None) + device = self.manager.process_device + + # NOTE: do NOT move params to device here; autograd fn streams & bounces them + return _BouncingLinearFn.apply(x, weight_cpu, bias_cpu, device) + + self.module.forward = _mm_forward + + +class ConvLayerMemoryManager(BaseLayerMemoryManager): + def __init__( + self, + module: nn.Module, + manager: "MemoryManager", + ): + super().__init__(module, manager) + + # 1) Move params to CPU + pin memory for fast H2D + _move_params_to_cpu_and_pin(self.module) + + # Cache static conv attributes from the module + stride = ( + self.module.stride + if isinstance(self.module.stride, tuple) + else (self.module.stride, self.module.stride) + ) + padding = ( + self.module.padding + if isinstance(self.module.padding, tuple) + else (self.module.padding, self.module.padding) + ) + dilation = ( + self.module.dilation + if isinstance(self.module.dilation, tuple) + else (self.module.dilation, self.module.dilation) + ) + groups = self.module.groups + + # 2) Hijack forward + self._original_forward = getattr(self.module, "forward") + + def _mm_forward(x, *args, **kwargs): + # Support the typical Conv2d(x) call; if user passes uncommon extras, fallback. + if args or kwargs: + return self._original_forward(x, *args, **kwargs) + + weight_cpu = self.module.weight + bias_cpu = getattr(self.module, "bias", None) + device = self.manager.process_device + + return _BouncingConv2dFn.apply( + x, weight_cpu, bias_cpu, device, stride, padding, dilation, groups + ) + + self.module.forward = _mm_forward diff --git a/toolkit/models/base_model.py b/toolkit/models/base_model.py index dd070249..58d48b4f 100644 --- a/toolkit/models/base_model.py +++ b/toolkit/models/base_model.py @@ -41,7 +41,6 @@ from torchvision.transforms import functional as TF from toolkit.accelerator import get_accelerator, unwrap_model from typing import TYPE_CHECKING from toolkit.print import print_acc -from toolkit.memory_management import MemoryManager if TYPE_CHECKING: from toolkit.lora_special import LoRASpecialNetwork @@ -186,8 +185,6 @@ class BaseModel: self.has_multiple_control_images = False # do not resize control images self.use_raw_control_images = False - - self.memory_manager = MemoryManager(self) # properties for old arch for backwards compatibility @property diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index 8b72a1c5..78960ed1 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -70,7 +70,6 @@ from typing import TYPE_CHECKING from toolkit.print import print_acc from diffusers import FluxFillPipeline from transformers import AutoModel, AutoTokenizer, Gemma2Model, Qwen2Model, LlamaModel -from toolkit.memory_management import MemoryManager if TYPE_CHECKING: from toolkit.lora_special import LoRASpecialNetwork @@ -225,8 +224,6 @@ class StableDiffusion: # do not resize control images self.use_raw_control_images = False - self.memory_manager = MemoryManager(self) - # properties for old arch for backwards compatibility @property def is_xl(self): diff --git a/toolkit/util/quantize.py b/toolkit/util/quantize.py index 31a96bd1..e3e81940 100644 --- a/toolkit/util/quantize.py +++ b/toolkit/util/quantize.py @@ -301,14 +301,14 @@ def quantize_model( f" - quantizing {len(all_blocks)} transformer blocks" ) for block in tqdm(all_blocks): - block.to(base_model.device_torch, dtype=base_model.torch_dtype) + block.to(base_model.device_torch, dtype=base_model.torch_dtype, non_blocking=True) quantize(block, weights=quantization_type) freeze(block) - block.to("cpu") + block.to("cpu", non_blocking=True) # todo, on extras find a universal way to quantize them on device and move them back to their original # device without having to move the transformer blocks to the device first base_model.print_and_status_update(" - quantizing extras") - model_to_quantize.to(base_model.device_torch, dtype=base_model.torch_dtype) + # model_to_quantize.to(base_model.device_torch, dtype=base_model.torch_dtype) quantize(model_to_quantize, weights=quantization_type) freeze(model_to_quantize)