mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Initial support for RamTorch. Still a WIP
This commit is contained in:
@@ -9,63 +9,69 @@ from PIL import Image
|
||||
from toolkit.models.base_model import BaseModel
|
||||
from toolkit.basic import flush
|
||||
from toolkit.prompt_utils import PromptEmbeds
|
||||
from toolkit.samplers.custom_flowmatch_sampler import CustomFlowMatchEulerDiscreteScheduler
|
||||
from toolkit.samplers.custom_flowmatch_sampler import (
|
||||
CustomFlowMatchEulerDiscreteScheduler,
|
||||
)
|
||||
from toolkit.accelerator import get_accelerator, unwrap_model
|
||||
from optimum.quanto import freeze, QTensor
|
||||
from toolkit.util.quantize import quantize, get_qtype, quantize_model
|
||||
import torch.nn.functional as F
|
||||
from toolkit.memory_management import MemoryManager
|
||||
from safetensors.torch import load_file
|
||||
|
||||
from diffusers import QwenImagePipeline, QwenImageTransformer2DModel, AutoencoderKLQwenImage
|
||||
from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, Qwen2VLProcessor
|
||||
from diffusers import (
|
||||
QwenImagePipeline,
|
||||
QwenImageTransformer2DModel,
|
||||
AutoencoderKLQwenImage,
|
||||
)
|
||||
from transformers import (
|
||||
Qwen2_5_VLForConditionalGeneration,
|
||||
Qwen2Tokenizer,
|
||||
Qwen2VLProcessor,
|
||||
)
|
||||
from tqdm import tqdm
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
|
||||
|
||||
scheduler_config = {
|
||||
"base_image_seq_len": 256,
|
||||
"base_shift": 0.5,
|
||||
"invert_sigmas": False,
|
||||
"max_image_seq_len": 8192,
|
||||
"max_shift": 0.9,
|
||||
"num_train_timesteps": 1000,
|
||||
"shift": 1.0,
|
||||
"shift_terminal": 0.02,
|
||||
"stochastic_sampling": False,
|
||||
"time_shift_type": "exponential",
|
||||
"use_beta_sigmas": False,
|
||||
"use_dynamic_shifting": True,
|
||||
"use_exponential_sigmas": False,
|
||||
"use_karras_sigmas": False
|
||||
"base_image_seq_len": 256,
|
||||
"base_shift": 0.5,
|
||||
"invert_sigmas": False,
|
||||
"max_image_seq_len": 8192,
|
||||
"max_shift": 0.9,
|
||||
"num_train_timesteps": 1000,
|
||||
"shift": 1.0,
|
||||
"shift_terminal": 0.02,
|
||||
"stochastic_sampling": False,
|
||||
"time_shift_type": "exponential",
|
||||
"use_beta_sigmas": False,
|
||||
"use_dynamic_shifting": True,
|
||||
"use_exponential_sigmas": False,
|
||||
"use_karras_sigmas": False,
|
||||
}
|
||||
|
||||
|
||||
|
||||
class QwenImageModel(BaseModel):
|
||||
arch = "qwen_image"
|
||||
_qwen_image_keep_visual = False
|
||||
_qwen_pipeline = QwenImagePipeline
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
device,
|
||||
model_config: ModelConfig,
|
||||
dtype='bf16',
|
||||
custom_pipeline=None,
|
||||
noise_scheduler=None,
|
||||
**kwargs
|
||||
self,
|
||||
device,
|
||||
model_config: ModelConfig,
|
||||
dtype="bf16",
|
||||
custom_pipeline=None,
|
||||
noise_scheduler=None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
device,
|
||||
model_config,
|
||||
dtype,
|
||||
custom_pipeline,
|
||||
noise_scheduler,
|
||||
**kwargs
|
||||
device, model_config, dtype, custom_pipeline, noise_scheduler, **kwargs
|
||||
)
|
||||
self.is_flow_matching = True
|
||||
self.is_transformer = True
|
||||
self.target_lora_modules = ['QwenImageTransformer2DModel']
|
||||
self.target_lora_modules = ["QwenImageTransformer2DModel"]
|
||||
|
||||
# static method to get the noise scheduler
|
||||
@staticmethod
|
||||
@@ -73,40 +79,58 @@ class QwenImageModel(BaseModel):
|
||||
return CustomFlowMatchEulerDiscreteScheduler(**scheduler_config)
|
||||
|
||||
def get_bucket_divisibility(self):
|
||||
return 16 * 2 # 16 for the VAE, 2 for patch size
|
||||
return 16 * 2 # 16 for the VAE, 2 for patch size
|
||||
|
||||
def load_model(self):
|
||||
dtype = self.torch_dtype
|
||||
self.print_and_status_update("Loading Qwen Image model")
|
||||
model_path = self.model_config.name_or_path
|
||||
base_model_path = self.model_config.extras_name_or_path
|
||||
model_dtype = dtype
|
||||
|
||||
transformer_path = model_path
|
||||
transformer_subfolder = 'transformer'
|
||||
if os.path.exists(transformer_path):
|
||||
transformer_subfolder = None
|
||||
transformer_path = os.path.join(transformer_path, 'transformer')
|
||||
# check if the path is a full checkpoint.
|
||||
te_folder_path = os.path.join(model_path, 'text_encoder')
|
||||
# if we have the te, this folder is a full checkpoint, use it as the base
|
||||
if os.path.exists(te_folder_path):
|
||||
base_model_path = model_path
|
||||
if base_model_path.endswith(".safetensors"):
|
||||
# use the repo for extras
|
||||
base_model_path = "Qwen/Qwen-Image"
|
||||
|
||||
self.print_and_status_update("Loading transformer")
|
||||
transformer = QwenImageTransformer2DModel.from_pretrained(
|
||||
transformer_path,
|
||||
subfolder=transformer_subfolder,
|
||||
torch_dtype=dtype
|
||||
)
|
||||
|
||||
if model_path.endswith(".safetensors"):
|
||||
# load the safetensors file
|
||||
transformer = QwenImageTransformer2DModel.from_single_file(
|
||||
model_path,
|
||||
config="Qwen/Qwen-Image",
|
||||
subfolder="transformer",
|
||||
torch_dtype=model_dtype,
|
||||
)
|
||||
transformer.to(model_dtype)
|
||||
|
||||
else:
|
||||
transformer_path = model_path
|
||||
transformer_subfolder = "transformer"
|
||||
if os.path.exists(transformer_path):
|
||||
transformer_subfolder = None
|
||||
transformer_path = os.path.join(transformer_path, "transformer")
|
||||
# check if the path is a full checkpoint.
|
||||
te_folder_path = os.path.join(model_path, "text_encoder")
|
||||
# if we have the te, this folder is a full checkpoint, use it as the base
|
||||
if os.path.exists(te_folder_path):
|
||||
base_model_path = model_path
|
||||
|
||||
transformer = QwenImageTransformer2DModel.from_pretrained(
|
||||
transformer_path, subfolder=transformer_subfolder, torch_dtype=dtype
|
||||
)
|
||||
|
||||
if self.model_config.quantize:
|
||||
self.print_and_status_update("Quantizing Transformer")
|
||||
quantize_model(self, transformer)
|
||||
flush()
|
||||
|
||||
|
||||
if self.model_config.auto_memory:
|
||||
MemoryManager.attach(transformer, self.device_torch)
|
||||
|
||||
if self.model_config.low_vram:
|
||||
self.print_and_status_update("Moving transformer to CPU")
|
||||
transformer.to('cpu')
|
||||
transformer.to("cpu")
|
||||
|
||||
flush()
|
||||
|
||||
@@ -117,32 +141,35 @@ class QwenImageModel(BaseModel):
|
||||
text_encoder = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
||||
base_model_path, subfolder="text_encoder", torch_dtype=dtype
|
||||
)
|
||||
|
||||
|
||||
# remove the visual model as it is not needed for image generation
|
||||
self.processor = None
|
||||
if not self._qwen_image_keep_visual:
|
||||
text_encoder.model.visual = None
|
||||
|
||||
if self.model_config.auto_memory:
|
||||
MemoryManager.attach(text_encoder, self.device_torch)
|
||||
|
||||
text_encoder.to(self.device_torch, dtype=dtype)
|
||||
flush()
|
||||
|
||||
if self.model_config.quantize_te:
|
||||
self.print_and_status_update("Quantizing Text Encoder")
|
||||
quantize(text_encoder, weights=get_qtype(
|
||||
self.model_config.qtype_te))
|
||||
quantize(text_encoder, weights=get_qtype(self.model_config.qtype_te))
|
||||
freeze(text_encoder)
|
||||
flush()
|
||||
|
||||
self.print_and_status_update("Loading VAE")
|
||||
vae = AutoencoderKLQwenImage.from_pretrained(
|
||||
base_model_path, subfolder="vae", torch_dtype=dtype)
|
||||
base_model_path, subfolder="vae", torch_dtype=dtype
|
||||
)
|
||||
|
||||
self.noise_scheduler = QwenImageModel.get_train_scheduler()
|
||||
|
||||
self.print_and_status_update("Making pipe")
|
||||
|
||||
|
||||
kwargs = {}
|
||||
|
||||
|
||||
if self._qwen_image_keep_visual:
|
||||
try:
|
||||
self.processor = Qwen2VLProcessor.from_pretrained(
|
||||
@@ -152,7 +179,7 @@ class QwenImageModel(BaseModel):
|
||||
self.processor = Qwen2VLProcessor.from_pretrained(
|
||||
base_model_path, subfolder="processor"
|
||||
)
|
||||
kwargs['processor'] = self.processor
|
||||
kwargs["processor"] = self.processor
|
||||
|
||||
pipe: QwenImagePipeline = self._qwen_pipeline(
|
||||
scheduler=self.noise_scheduler,
|
||||
@@ -160,7 +187,7 @@ class QwenImageModel(BaseModel):
|
||||
tokenizer=tokenizer,
|
||||
vae=vae,
|
||||
transformer=None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
)
|
||||
# for quantization, it works best to do these after making the pipe
|
||||
pipe.text_encoder = text_encoder
|
||||
@@ -198,7 +225,7 @@ class QwenImageModel(BaseModel):
|
||||
text_encoder=unwrap_model(self.text_encoder[0]),
|
||||
tokenizer=self.tokenizer[0],
|
||||
vae=unwrap_model(self.vae),
|
||||
transformer=unwrap_model(self.transformer)
|
||||
transformer=unwrap_model(self.transformer),
|
||||
)
|
||||
|
||||
pipeline = pipeline.to(self.device_torch)
|
||||
@@ -231,22 +258,27 @@ class QwenImageModel(BaseModel):
|
||||
|
||||
# flush for low vram if we are doing that
|
||||
flush_between_steps = self.model_config.low_vram
|
||||
# Fix a bug in diffusers/torch
|
||||
|
||||
# Fix a bug in diffusers/torch
|
||||
def callback_on_step_end(pipe, i, t, callback_kwargs):
|
||||
if flush_between_steps:
|
||||
flush()
|
||||
latents = callback_kwargs["latents"]
|
||||
|
||||
|
||||
return {"latents": latents}
|
||||
|
||||
|
||||
sc = self.get_bucket_divisibility()
|
||||
gen_config.width = int(gen_config.width // sc * sc)
|
||||
gen_config.width = int(gen_config.width // sc * sc)
|
||||
gen_config.height = int(gen_config.height // sc * sc)
|
||||
img = pipeline(
|
||||
prompt_embeds=conditional_embeds.text_embeds,
|
||||
prompt_embeds_mask=conditional_embeds.attention_mask.to(self.device_torch, dtype=torch.int64),
|
||||
prompt_embeds_mask=conditional_embeds.attention_mask.to(
|
||||
self.device_torch, dtype=torch.int64
|
||||
),
|
||||
negative_prompt_embeds=unconditional_embeds.text_embeds,
|
||||
negative_prompt_embeds_mask=unconditional_embeds.attention_mask.to(self.device_torch, dtype=torch.int64),
|
||||
negative_prompt_embeds_mask=unconditional_embeds.attention_mask.to(
|
||||
self.device_torch, dtype=torch.int64
|
||||
),
|
||||
height=gen_config.height,
|
||||
width=gen_config.width,
|
||||
num_inference_steps=gen_config.num_inference_steps,
|
||||
@@ -254,7 +286,7 @@ class QwenImageModel(BaseModel):
|
||||
latents=gen_config.latents,
|
||||
generator=generator,
|
||||
callback_on_step_end=callback_on_step_end,
|
||||
**extra
|
||||
**extra,
|
||||
).images[0]
|
||||
return img
|
||||
|
||||
@@ -263,28 +295,36 @@ class QwenImageModel(BaseModel):
|
||||
latent_model_input: torch.Tensor,
|
||||
timestep: torch.Tensor, # 0 to 1000 scale
|
||||
text_embeddings: PromptEmbeds,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
self.model.to(self.device_torch)
|
||||
batch_size, num_channels_latents, height, width = latent_model_input.shape
|
||||
|
||||
|
||||
ps = self.transformer.config.patch_size
|
||||
|
||||
# pack image tokens
|
||||
latent_model_input = latent_model_input.view(batch_size, num_channels_latents, height // ps, ps, width // ps, ps)
|
||||
latent_model_input = latent_model_input.view(
|
||||
batch_size, num_channels_latents, height // ps, ps, width // ps, ps
|
||||
)
|
||||
latent_model_input = latent_model_input.permute(0, 2, 4, 1, 3, 5)
|
||||
latent_model_input = latent_model_input.reshape(batch_size, (height // ps) * (width // ps), num_channels_latents * (ps * ps))
|
||||
latent_model_input = latent_model_input.reshape(
|
||||
batch_size, (height // ps) * (width // ps), num_channels_latents * (ps * ps)
|
||||
)
|
||||
|
||||
# img_shapes passed to the model
|
||||
img_h2, img_w2 = height // ps, width // ps
|
||||
img_shapes = [[(1, img_h2, img_w2)]] * batch_size
|
||||
|
||||
enc_hs = text_embeddings.text_embeds.to(self.device_torch, self.torch_dtype)
|
||||
prompt_embeds_mask = text_embeddings.attention_mask.to(self.device_torch, dtype=torch.int64)
|
||||
prompt_embeds_mask = text_embeddings.attention_mask.to(
|
||||
self.device_torch, dtype=torch.int64
|
||||
)
|
||||
txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist()
|
||||
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input.to(self.device_torch, self.torch_dtype).detach(),
|
||||
hidden_states=latent_model_input.to(
|
||||
self.device_torch, self.torch_dtype
|
||||
).detach(),
|
||||
timestep=(timestep / 1000).detach(),
|
||||
guidance=None,
|
||||
encoder_hidden_states=enc_hs.detach(),
|
||||
@@ -296,56 +336,55 @@ class QwenImageModel(BaseModel):
|
||||
)[0]
|
||||
|
||||
# unpack
|
||||
noise_pred = noise_pred.view(batch_size, height // ps, width // ps, num_channels_latents, ps, ps)
|
||||
noise_pred = noise_pred.view(
|
||||
batch_size, height // ps, width // ps, num_channels_latents, ps, ps
|
||||
)
|
||||
noise_pred = noise_pred.permute(0, 3, 1, 4, 2, 5)
|
||||
noise_pred = noise_pred.reshape(batch_size, num_channels_latents, height, width)
|
||||
return noise_pred
|
||||
|
||||
|
||||
def get_prompt_embeds(self, prompt: str) -> PromptEmbeds:
|
||||
if self.pipeline.text_encoder.device != self.device_torch:
|
||||
self.pipeline.text_encoder.to(self.device_torch)
|
||||
|
||||
|
||||
prompt_embeds, prompt_embeds_mask = self.pipeline.encode_prompt(
|
||||
prompt,
|
||||
device=self.device_torch,
|
||||
num_images_per_prompt=1,
|
||||
)
|
||||
pe = PromptEmbeds(
|
||||
prompt_embeds
|
||||
)
|
||||
pe = PromptEmbeds(prompt_embeds)
|
||||
pe.attention_mask = prompt_embeds_mask
|
||||
return pe
|
||||
|
||||
|
||||
def get_model_has_grad(self):
|
||||
return False
|
||||
|
||||
def get_te_has_grad(self):
|
||||
return False
|
||||
|
||||
|
||||
def save_model(self, output_path, meta, save_dtype):
|
||||
# only save the unet
|
||||
transformer: QwenImageTransformer2DModel = unwrap_model(self.model)
|
||||
transformer.save_pretrained(
|
||||
save_directory=os.path.join(output_path, 'transformer'),
|
||||
save_directory=os.path.join(output_path, "transformer"),
|
||||
safe_serialization=True,
|
||||
)
|
||||
|
||||
meta_path = os.path.join(output_path, 'aitk_meta.yaml')
|
||||
with open(meta_path, 'w') as f:
|
||||
meta_path = os.path.join(output_path, "aitk_meta.yaml")
|
||||
with open(meta_path, "w") as f:
|
||||
yaml.dump(meta, f)
|
||||
|
||||
def get_loss_target(self, *args, **kwargs):
|
||||
noise = kwargs.get('noise')
|
||||
batch = kwargs.get('batch')
|
||||
noise = kwargs.get("noise")
|
||||
batch = kwargs.get("batch")
|
||||
return (noise - batch.latents).detach()
|
||||
|
||||
|
||||
def get_base_model_version(self):
|
||||
return "qwen_image"
|
||||
|
||||
|
||||
def get_transformer_block_names(self) -> Optional[List[str]]:
|
||||
return ['transformer_blocks']
|
||||
|
||||
return ["transformer_blocks"]
|
||||
|
||||
def convert_lora_weights_before_save(self, state_dict):
|
||||
new_sd = {}
|
||||
for key, value in state_dict.items():
|
||||
@@ -359,20 +398,15 @@ class QwenImageModel(BaseModel):
|
||||
new_key = key.replace("diffusion_model.", "transformer.")
|
||||
new_sd[new_key] = value
|
||||
return new_sd
|
||||
|
||||
def encode_images(
|
||||
self,
|
||||
image_list: List[torch.Tensor],
|
||||
device=None,
|
||||
dtype=None
|
||||
):
|
||||
|
||||
def encode_images(self, image_list: List[torch.Tensor], device=None, dtype=None):
|
||||
if device is None:
|
||||
device = self.vae_device_torch
|
||||
if dtype is None:
|
||||
dtype = self.vae_torch_dtype
|
||||
|
||||
# Move to vae to device if on cpu
|
||||
if self.vae.device == 'cpu':
|
||||
if self.vae.device == "cpu":
|
||||
self.vae.to(device)
|
||||
self.vae.eval()
|
||||
self.vae.requires_grad_(False)
|
||||
@@ -383,20 +417,19 @@ class QwenImageModel(BaseModel):
|
||||
|
||||
images = images.unsqueeze(2)
|
||||
latents = self.vae.encode(images).latent_dist.sample()
|
||||
|
||||
|
||||
latents_mean = (
|
||||
torch.tensor(self.vae.config.latents_mean)
|
||||
.view(1, self.vae.config.z_dim, 1, 1, 1)
|
||||
.to(latents.device, latents.dtype)
|
||||
)
|
||||
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
|
||||
latents.device, latents.dtype
|
||||
)
|
||||
|
||||
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(
|
||||
1, self.vae.config.z_dim, 1, 1, 1
|
||||
).to(latents.device, latents.dtype)
|
||||
|
||||
latents = (latents - latents_mean) * latents_std
|
||||
latents = latents.to(device, dtype=dtype)
|
||||
|
||||
|
||||
|
||||
latents = latents.squeeze(2) # remove the frame count dimension
|
||||
|
||||
return latents
|
||||
return latents
|
||||
|
||||
@@ -1759,7 +1759,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
)
|
||||
|
||||
# we cannot merge in if quantized
|
||||
if self.model_config.quantize:
|
||||
if self.model_config.quantize or self.model_config.auto_memory:
|
||||
# todo find a way around this
|
||||
self.network.can_merge_in = False
|
||||
|
||||
|
||||
@@ -624,6 +624,15 @@ class ModelConfig:
|
||||
|
||||
self.arch: ModelArch = kwargs.get("arch", None)
|
||||
|
||||
# auto memory management, only for some models
|
||||
self.auto_memory = kwargs.get("auto_memory", False)
|
||||
if self.auto_memory and self.qtype == "qfloat8":
|
||||
print(f"Auto memory is not compatible with qfloat8, switching to float8 for model")
|
||||
self.qtype = "float8"
|
||||
if self.auto_memory and not self.qtype_te == "qfloat8":
|
||||
print(f"Auto memory is not compatible with qfloat8, switching to float8 for te")
|
||||
self.qtype_te = "float8"
|
||||
|
||||
# can be used to load the extras like text encoder or vae from here
|
||||
# only setup for some models but will prevent having to download the te for
|
||||
# 20 different model variants
|
||||
@@ -650,6 +659,7 @@ class ModelConfig:
|
||||
|
||||
if self.arch == "flex1":
|
||||
self.arch = "flux"
|
||||
|
||||
|
||||
# handle migrating to new model arch
|
||||
if self.arch is not None:
|
||||
|
||||
@@ -1,12 +1,92 @@
|
||||
from typing import TYPE_CHECKING
|
||||
import torch
|
||||
from .manager_modules import LinearLayerMemoryManager, ConvLayerMemoryManager
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from toolkit.models.base_model import BaseModel
|
||||
LINEAR_MODULES = [
|
||||
"Linear",
|
||||
"LoRACompatibleLinear",
|
||||
"QLinear",
|
||||
]
|
||||
CONV_MODULES = [
|
||||
"Conv2d",
|
||||
"LoRACompatibleConv",
|
||||
"QConv2d",
|
||||
]
|
||||
|
||||
UNMANAGED_MODULES = [
|
||||
"LayerNorm",
|
||||
"BatchNorm1d",
|
||||
"BatchNorm2d",
|
||||
"BatchNorm3d",
|
||||
"GroupNorm",
|
||||
"InstanceNorm1d",
|
||||
"InstanceNorm2d",
|
||||
"InstanceNorm3d",
|
||||
"Embedding",
|
||||
"EmbeddingBag",
|
||||
"RNNBase",
|
||||
"LSTM",
|
||||
"GRU",
|
||||
"RNN",
|
||||
]
|
||||
|
||||
UNMANAGED_MODULES_INCLUDES = ["RotaryEmbedding", "Norm"]
|
||||
|
||||
|
||||
class MemoryManager:
|
||||
def __init__(
|
||||
self,
|
||||
model: "BaseModel",
|
||||
module: torch.nn.Module,
|
||||
process_device: torch.device = torch.device("cpu"),
|
||||
):
|
||||
self.model: "BaseModel" = model
|
||||
self.module: torch.nn.Module = module
|
||||
self.process_device: torch.device = process_device
|
||||
self.unmanaged_modules: list[torch.nn.Module] = []
|
||||
|
||||
def memory_managed_to(self, *args, **kwargs):
|
||||
# first move all the unmanaged modules
|
||||
for module in self.unmanaged_modules:
|
||||
module.to(*args, **kwargs)
|
||||
# check for a dtype argument
|
||||
dtype = None
|
||||
if "dtype" in kwargs:
|
||||
dtype = kwargs["dtype"]
|
||||
elif len(args) > 0:
|
||||
for i, arg in enumerate(args):
|
||||
if isinstance(arg, torch.dtype):
|
||||
dtype = arg
|
||||
break
|
||||
if dtype is not None:
|
||||
return self.module._mm_to(dtype=dtype)
|
||||
return self.module
|
||||
|
||||
@classmethod
|
||||
def attach(cls, module: torch.nn.Module, device: torch.device):
|
||||
if hasattr(module, "_memory_manager"):
|
||||
# already attached
|
||||
return
|
||||
|
||||
module._memory_manager = cls(module, device)
|
||||
|
||||
# override the to method to handle memory management
|
||||
module._mm_to = module.to
|
||||
module.to = module._memory_manager.memory_managed_to
|
||||
|
||||
# attach to all modules
|
||||
for name, sub_module in module.named_modules():
|
||||
for child_name, child_module in sub_module.named_modules():
|
||||
if child_module.__class__.__name__ in LINEAR_MODULES:
|
||||
# linear
|
||||
LinearLayerMemoryManager.attach(
|
||||
child_module, module._memory_manager
|
||||
)
|
||||
elif child_module.__class__.__name__ in CONV_MODULES:
|
||||
# conv
|
||||
ConvLayerMemoryManager.attach(child_module, module._memory_manager)
|
||||
elif child_module.__class__.__name__ in UNMANAGED_MODULES or any(
|
||||
inc in child_module.__class__.__name__
|
||||
for inc in UNMANAGED_MODULES_INCLUDES
|
||||
):
|
||||
# unmanaged
|
||||
module._memory_manager.unmanaged_modules.append(child_module)
|
||||
else:
|
||||
continue
|
||||
|
||||
450
toolkit/memory_management/manager_modules.py
Normal file
450
toolkit/memory_management/manager_modules.py
Normal file
@@ -0,0 +1,450 @@
|
||||
"""
|
||||
This code was heavily inspired by the work of Lodestone-Rock, pretty much all credit goes
|
||||
to them. The original code can be found here:
|
||||
https://github.com/lodestone-rock/RamTorch/blob/main/ramtorch/modules/linear.py
|
||||
|
||||
I simply modified it to work with a memory management model and with AI Toolkit's models
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from typing import TYPE_CHECKING, Optional, Tuple
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .manager import MemoryManager
|
||||
|
||||
# --- Per-device global state registry ---
|
||||
_DEVICE_STATE = {}
|
||||
|
||||
|
||||
def _get_device_state(device: torch.device):
|
||||
"""Get or initialize per-device state."""
|
||||
if isinstance(device, str):
|
||||
device = torch.device(device)
|
||||
|
||||
# CPU path needs no CUDA state
|
||||
if device.type != "cuda":
|
||||
if device not in _DEVICE_STATE:
|
||||
_DEVICE_STATE[device] = {}
|
||||
return _DEVICE_STATE[device]
|
||||
|
||||
if device not in _DEVICE_STATE:
|
||||
with torch.cuda.device(device):
|
||||
_DEVICE_STATE[device] = {
|
||||
# streams & events
|
||||
"transfer_stream": torch.cuda.Stream(device=device),
|
||||
"transfer_grad_stream": torch.cuda.Stream(device=device),
|
||||
"transfer_forward_finished_event": torch.cuda.Event(),
|
||||
"compute_forward_start_event": torch.cuda.Event(),
|
||||
"transfer_backward_finished_event": torch.cuda.Event(),
|
||||
"transfer_weight_backward_finished_event": torch.cuda.Event(),
|
||||
"compute_backward_start_event": torch.cuda.Event(),
|
||||
"compute_backward_finished_event": torch.cuda.Event(),
|
||||
# ping-pong buffers
|
||||
"w_buffers": [None, None],
|
||||
"b_buffers": [None, None],
|
||||
"w_bwd_buffers": [None, None],
|
||||
# device-side staging for grads to be sent to CPU
|
||||
"w_grad_buffers": [None, None],
|
||||
"b_grad_buffers": [None, None],
|
||||
# clocks
|
||||
"forward_clk": 0,
|
||||
"backward_clk": 0,
|
||||
}
|
||||
return _DEVICE_STATE[device]
|
||||
|
||||
|
||||
def _ensure_cpu_pinned(t: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
|
||||
if t is None:
|
||||
return None
|
||||
if t.device.type != "cpu":
|
||||
t = t.to("cpu", copy=True)
|
||||
if torch.cuda.is_available():
|
||||
try:
|
||||
t = t.pin_memory()
|
||||
except RuntimeError:
|
||||
pass
|
||||
return t
|
||||
|
||||
|
||||
def _move_params_to_cpu_and_pin(module: nn.Module):
|
||||
"""Force parameters to CPU (+pinned) so we can 'bounce' them per forward/backward."""
|
||||
with torch.no_grad():
|
||||
if hasattr(module, "weight") and isinstance(module.weight, nn.Parameter):
|
||||
module.weight.data = _ensure_cpu_pinned(module.weight.data).detach()
|
||||
if hasattr(module, "bias") and isinstance(module.bias, nn.Parameter):
|
||||
if module.bias is not None:
|
||||
module.bias.data = _ensure_cpu_pinned(module.bias.data).detach()
|
||||
|
||||
|
||||
# ==========================
|
||||
# Autograd functions (CUDA)
|
||||
# ==========================
|
||||
|
||||
|
||||
class _BouncingLinearFn(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, weight_cpu, bias_cpu, device: torch.device):
|
||||
if device.type != "cuda":
|
||||
out = F.linear(x.to("cpu"), weight_cpu, bias_cpu)
|
||||
ctx.save_for_backward(x.to("cpu"), weight_cpu, bias_cpu)
|
||||
ctx.device = torch.device("cpu")
|
||||
return out.to(x.device)
|
||||
|
||||
state = _get_device_state(device)
|
||||
ts = state["transfer_stream"]
|
||||
w_bufs, b_bufs = state["w_buffers"], state["b_buffers"]
|
||||
ev_tx_f = state["transfer_forward_finished_event"]
|
||||
ev_cu_s = state["compute_forward_start_event"]
|
||||
idx = state["forward_clk"]
|
||||
|
||||
with torch.cuda.stream(ts):
|
||||
ts.wait_event(ev_cu_s)
|
||||
w_bufs[idx] = weight_cpu.to(device, non_blocking=True)
|
||||
b_bufs[idx] = (
|
||||
bias_cpu.to(device, non_blocking=True) if bias_cpu is not None else None
|
||||
)
|
||||
state["forward_clk"] ^= 1
|
||||
ev_tx_f.record()
|
||||
|
||||
torch.cuda.current_stream().wait_event(ev_tx_f)
|
||||
ev_cu_s.record()
|
||||
out = F.linear(x, w_bufs[idx], b_bufs[idx])
|
||||
|
||||
ctx.save_for_backward(x, weight_cpu, bias_cpu)
|
||||
ctx.device = device
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_out):
|
||||
x, weight_cpu, bias_cpu = ctx.saved_tensors
|
||||
device = ctx.device
|
||||
|
||||
if device.type != "cuda":
|
||||
go_cpu = grad_out.to("cpu")
|
||||
x_cpu = x.to("cpu")
|
||||
grad_input = go_cpu @ weight_cpu
|
||||
grad_weight = go_cpu.flatten(0, -2).T @ x_cpu.flatten(0, -2)
|
||||
grad_bias = (
|
||||
go_cpu.sum(dim=tuple(range(go_cpu.ndim - 1)))
|
||||
if bias_cpu is not None
|
||||
else None
|
||||
)
|
||||
return grad_input.to(grad_out.device), grad_weight, grad_bias, None
|
||||
|
||||
state = _get_device_state(device)
|
||||
transfer_stream = state["transfer_stream"]
|
||||
transfer_grad_stream = state["transfer_grad_stream"]
|
||||
|
||||
w_bwd_buffers = state["w_bwd_buffers"]
|
||||
w_grad_buffers = state["w_grad_buffers"]
|
||||
b_grad_buffers = state["b_grad_buffers"]
|
||||
|
||||
ev_tx_b = state["transfer_backward_finished_event"]
|
||||
ev_tx_w_bwd_done = state["transfer_weight_backward_finished_event"]
|
||||
ev_cu_b_start = state["compute_backward_start_event"]
|
||||
ev_cu_b_finish = state["compute_backward_finished_event"]
|
||||
|
||||
idx = state["backward_clk"]
|
||||
|
||||
# Stage weights onto device (transfer stream), ping-pong to avoid races
|
||||
with torch.cuda.stream(transfer_stream):
|
||||
transfer_stream.wait_event(ev_cu_b_start)
|
||||
w_bwd_buffers[idx] = weight_cpu.to(device, non_blocking=True)
|
||||
state["backward_clk"] ^= 1
|
||||
ev_tx_b.record()
|
||||
|
||||
# Compute stream waits for weights to arrive, then start compute
|
||||
torch.cuda.current_stream().wait_event(ev_tx_b)
|
||||
ev_cu_b_start.record()
|
||||
|
||||
# 1) Compute grad_input using the freshly transferred weights
|
||||
grad_input = grad_out @ w_bwd_buffers[idx]
|
||||
|
||||
# 2) Ensure previous grad-to-CPU transfer that used this slot finished
|
||||
torch.cuda.current_stream().wait_event(ev_tx_w_bwd_done)
|
||||
|
||||
# 3) Compute weight/bias grads on GPU into staging buffers
|
||||
w_grad_buffers[idx] = grad_out.flatten(0, -2).T @ x.flatten(0, -2)
|
||||
if bias_cpu is not None:
|
||||
reduce_dims = tuple(range(grad_out.ndim - 1))
|
||||
b_grad_buffers[idx] = grad_out.sum(dim=reduce_dims)
|
||||
|
||||
# Mark end of GPU compute
|
||||
ev_cu_b_finish.record()
|
||||
|
||||
# 4) Launch non-blocking H2D->CPU transfers on a separate grad stream (full-duplex)
|
||||
with torch.cuda.stream(transfer_grad_stream):
|
||||
transfer_grad_stream.wait_event(ev_cu_b_finish)
|
||||
grad_weight = w_grad_buffers[idx].to("cpu", non_blocking=True)
|
||||
grad_bias = (
|
||||
b_grad_buffers[idx].to("cpu", non_blocking=True)
|
||||
if bias_cpu is not None
|
||||
else None
|
||||
)
|
||||
# signal that this slot's CPU transfer is complete (safe for next reuse)
|
||||
state["transfer_weight_backward_finished_event"].record()
|
||||
|
||||
return grad_input, grad_weight, grad_bias, None
|
||||
|
||||
|
||||
class _BouncingConv2dFn(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx,
|
||||
x,
|
||||
weight_cpu,
|
||||
bias_cpu,
|
||||
device: torch.device,
|
||||
stride: Tuple[int, int],
|
||||
padding: Tuple[int, int],
|
||||
dilation: Tuple[int, int],
|
||||
groups: int,
|
||||
):
|
||||
if device.type != "cuda":
|
||||
out = F.conv2d(
|
||||
x.to("cpu"), weight_cpu, bias_cpu, stride, padding, dilation, groups
|
||||
)
|
||||
ctx.save_for_backward(x.to("cpu"), weight_cpu, bias_cpu)
|
||||
ctx.meta = ("cpu", stride, padding, dilation, groups)
|
||||
return out.to(x.device)
|
||||
|
||||
state = _get_device_state(device)
|
||||
ts = state["transfer_stream"]
|
||||
w_bufs, b_bufs = state["w_buffers"], state["b_buffers"]
|
||||
ev_tx_f = state["transfer_forward_finished_event"]
|
||||
ev_cu_s = state["compute_forward_start_event"]
|
||||
idx = state["forward_clk"]
|
||||
|
||||
with torch.cuda.stream(ts):
|
||||
ts.wait_event(ev_cu_s)
|
||||
w_bufs[idx] = weight_cpu.to(device, non_blocking=True)
|
||||
b_bufs[idx] = (
|
||||
bias_cpu.to(device, non_blocking=True) if bias_cpu is not None else None
|
||||
)
|
||||
state["forward_clk"] ^= 1
|
||||
ev_tx_f.record()
|
||||
|
||||
torch.cuda.current_stream().wait_event(ev_tx_f)
|
||||
ev_cu_s.record()
|
||||
out = F.conv2d(x, w_bufs[idx], b_bufs[idx], stride, padding, dilation, groups)
|
||||
|
||||
ctx.save_for_backward(x, weight_cpu, bias_cpu)
|
||||
ctx.meta = (device, stride, padding, dilation, groups)
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_out):
|
||||
x, weight_cpu, bias_cpu = ctx.saved_tensors
|
||||
meta = ctx.meta
|
||||
device, stride, padding, dilation, groups = meta
|
||||
|
||||
if (
|
||||
isinstance(device, torch.device) and device.type != "cuda"
|
||||
) or device == "cpu":
|
||||
# CPU grads
|
||||
go = grad_out.to("cpu")
|
||||
x_cpu = x.to("cpu")
|
||||
w_cpu = weight_cpu
|
||||
from torch.nn.grad import conv2d_input, conv2d_weight # type: ignore
|
||||
|
||||
grad_input = conv2d_input(
|
||||
x_cpu.shape,
|
||||
w_cpu,
|
||||
go,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
)
|
||||
grad_weight = conv2d_weight(
|
||||
x_cpu,
|
||||
w_cpu.shape,
|
||||
go,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
)
|
||||
grad_bias = go.sum(dim=(0, 2, 3)) if bias_cpu is not None else None
|
||||
return (
|
||||
grad_input.to(grad_out.device),
|
||||
grad_weight,
|
||||
grad_bias,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
# CUDA path (full-duplex)
|
||||
state = _get_device_state(device)
|
||||
transfer_stream = state["transfer_stream"]
|
||||
transfer_grad_stream = state["transfer_grad_stream"]
|
||||
|
||||
# device-side buffers
|
||||
w_bwd_buffers = state["w_bwd_buffers"]
|
||||
w_grad_buffers = state["w_grad_buffers"]
|
||||
b_grad_buffers = state["b_grad_buffers"]
|
||||
|
||||
ev_tx_b = state["transfer_backward_finished_event"]
|
||||
ev_tx_w_bwd_done = state["transfer_weight_backward_finished_event"]
|
||||
ev_cu_b_start = state["compute_backward_start_event"]
|
||||
ev_cu_b_finish = state["compute_backward_finished_event"]
|
||||
|
||||
idx = state["backward_clk"]
|
||||
|
||||
# Stage weights for input-grad compute
|
||||
with torch.cuda.stream(transfer_stream):
|
||||
transfer_stream.wait_event(ev_cu_b_start)
|
||||
w_bwd_buffers[idx] = weight_cpu.to(device, non_blocking=True)
|
||||
state["backward_clk"] ^= 1
|
||||
ev_tx_b.record()
|
||||
|
||||
torch.cuda.current_stream().wait_event(ev_tx_b)
|
||||
ev_cu_b_start.record()
|
||||
|
||||
# grad wrt input on GPU with streamed weights
|
||||
from torch.nn.grad import conv2d_input, conv2d_weight # type: ignore
|
||||
|
||||
grad_input = conv2d_input(
|
||||
x.shape,
|
||||
w_bwd_buffers[idx],
|
||||
grad_out,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
)
|
||||
|
||||
# Ensure previous grad transfer that used this slot is done
|
||||
torch.cuda.current_stream().wait_event(ev_tx_w_bwd_done)
|
||||
|
||||
# Compute heavy grads on GPU into staging buffers
|
||||
w_grad_buffers[idx] = conv2d_weight(
|
||||
x,
|
||||
weight_cpu.shape,
|
||||
grad_out,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
)
|
||||
if bias_cpu is not None:
|
||||
b_grad_buffers[idx] = grad_out.sum(dim=(0, 2, 3))
|
||||
|
||||
# Mark end of GPU math
|
||||
ev_cu_b_finish.record()
|
||||
|
||||
# Launch CPU copies on the dedicated grad stream (overlaps with next H2D)
|
||||
with torch.cuda.stream(transfer_grad_stream):
|
||||
transfer_grad_stream.wait_event(ev_cu_b_finish)
|
||||
grad_weight = w_grad_buffers[idx].to("cpu", non_blocking=True)
|
||||
grad_bias = (
|
||||
b_grad_buffers[idx].to("cpu", non_blocking=True)
|
||||
if bias_cpu is not None
|
||||
else None
|
||||
)
|
||||
state["transfer_weight_backward_finished_event"].record()
|
||||
|
||||
return grad_input, grad_weight, grad_bias, None, None, None, None, None
|
||||
|
||||
|
||||
class BaseLayerMemoryManager:
|
||||
def __init__(
|
||||
self,
|
||||
module: nn.Module,
|
||||
manager: "MemoryManager",
|
||||
):
|
||||
self.module: nn.Module = module
|
||||
self.manager: "MemoryManager" = manager
|
||||
|
||||
@classmethod
|
||||
def attach(cls, module: nn.Module, manager: "MemoryManager"):
|
||||
if hasattr(module, "_layer_memory_manager"):
|
||||
return
|
||||
module._layer_memory_manager = cls(module, manager)
|
||||
|
||||
# mark parameters as memory managed
|
||||
for param in module.parameters(recurse=False):
|
||||
param._is_memory_managed = True
|
||||
|
||||
|
||||
class LinearLayerMemoryManager(BaseLayerMemoryManager):
|
||||
def __init__(
|
||||
self,
|
||||
module: nn.Module,
|
||||
manager: "MemoryManager",
|
||||
):
|
||||
super().__init__(module, manager)
|
||||
|
||||
# 1) Move params to CPU + pin memory for fast H2D
|
||||
_move_params_to_cpu_and_pin(self.module)
|
||||
|
||||
# 2) Hijack forward
|
||||
self._original_forward = getattr(self.module, "forward")
|
||||
|
||||
def _mm_forward(x, *args, **kwargs):
|
||||
# ensure we only use expected signature (Linear: x)
|
||||
if args or kwargs:
|
||||
# fall back to original if a custom signature is used
|
||||
return self._original_forward(x, *args, **kwargs)
|
||||
|
||||
weight_cpu = self.module.weight
|
||||
bias_cpu = getattr(self.module, "bias", None)
|
||||
device = self.manager.process_device
|
||||
|
||||
# NOTE: do NOT move params to device here; autograd fn streams & bounces them
|
||||
return _BouncingLinearFn.apply(x, weight_cpu, bias_cpu, device)
|
||||
|
||||
self.module.forward = _mm_forward
|
||||
|
||||
|
||||
class ConvLayerMemoryManager(BaseLayerMemoryManager):
|
||||
def __init__(
|
||||
self,
|
||||
module: nn.Module,
|
||||
manager: "MemoryManager",
|
||||
):
|
||||
super().__init__(module, manager)
|
||||
|
||||
# 1) Move params to CPU + pin memory for fast H2D
|
||||
_move_params_to_cpu_and_pin(self.module)
|
||||
|
||||
# Cache static conv attributes from the module
|
||||
stride = (
|
||||
self.module.stride
|
||||
if isinstance(self.module.stride, tuple)
|
||||
else (self.module.stride, self.module.stride)
|
||||
)
|
||||
padding = (
|
||||
self.module.padding
|
||||
if isinstance(self.module.padding, tuple)
|
||||
else (self.module.padding, self.module.padding)
|
||||
)
|
||||
dilation = (
|
||||
self.module.dilation
|
||||
if isinstance(self.module.dilation, tuple)
|
||||
else (self.module.dilation, self.module.dilation)
|
||||
)
|
||||
groups = self.module.groups
|
||||
|
||||
# 2) Hijack forward
|
||||
self._original_forward = getattr(self.module, "forward")
|
||||
|
||||
def _mm_forward(x, *args, **kwargs):
|
||||
# Support the typical Conv2d(x) call; if user passes uncommon extras, fallback.
|
||||
if args or kwargs:
|
||||
return self._original_forward(x, *args, **kwargs)
|
||||
|
||||
weight_cpu = self.module.weight
|
||||
bias_cpu = getattr(self.module, "bias", None)
|
||||
device = self.manager.process_device
|
||||
|
||||
return _BouncingConv2dFn.apply(
|
||||
x, weight_cpu, bias_cpu, device, stride, padding, dilation, groups
|
||||
)
|
||||
|
||||
self.module.forward = _mm_forward
|
||||
@@ -41,7 +41,6 @@ from torchvision.transforms import functional as TF
|
||||
from toolkit.accelerator import get_accelerator, unwrap_model
|
||||
from typing import TYPE_CHECKING
|
||||
from toolkit.print import print_acc
|
||||
from toolkit.memory_management import MemoryManager
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from toolkit.lora_special import LoRASpecialNetwork
|
||||
@@ -186,8 +185,6 @@ class BaseModel:
|
||||
self.has_multiple_control_images = False
|
||||
# do not resize control images
|
||||
self.use_raw_control_images = False
|
||||
|
||||
self.memory_manager = MemoryManager(self)
|
||||
|
||||
# properties for old arch for backwards compatibility
|
||||
@property
|
||||
|
||||
@@ -70,7 +70,6 @@ from typing import TYPE_CHECKING
|
||||
from toolkit.print import print_acc
|
||||
from diffusers import FluxFillPipeline
|
||||
from transformers import AutoModel, AutoTokenizer, Gemma2Model, Qwen2Model, LlamaModel
|
||||
from toolkit.memory_management import MemoryManager
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from toolkit.lora_special import LoRASpecialNetwork
|
||||
@@ -225,8 +224,6 @@ class StableDiffusion:
|
||||
# do not resize control images
|
||||
self.use_raw_control_images = False
|
||||
|
||||
self.memory_manager = MemoryManager(self)
|
||||
|
||||
# properties for old arch for backwards compatibility
|
||||
@property
|
||||
def is_xl(self):
|
||||
|
||||
@@ -301,14 +301,14 @@ def quantize_model(
|
||||
f" - quantizing {len(all_blocks)} transformer blocks"
|
||||
)
|
||||
for block in tqdm(all_blocks):
|
||||
block.to(base_model.device_torch, dtype=base_model.torch_dtype)
|
||||
block.to(base_model.device_torch, dtype=base_model.torch_dtype, non_blocking=True)
|
||||
quantize(block, weights=quantization_type)
|
||||
freeze(block)
|
||||
block.to("cpu")
|
||||
block.to("cpu", non_blocking=True)
|
||||
|
||||
# todo, on extras find a universal way to quantize them on device and move them back to their original
|
||||
# device without having to move the transformer blocks to the device first
|
||||
base_model.print_and_status_update(" - quantizing extras")
|
||||
model_to_quantize.to(base_model.device_torch, dtype=base_model.torch_dtype)
|
||||
# model_to_quantize.to(base_model.device_torch, dtype=base_model.torch_dtype)
|
||||
quantize(model_to_quantize, weights=quantization_type)
|
||||
freeze(model_to_quantize)
|
||||
|
||||
Reference in New Issue
Block a user