mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 11:11:37 +00:00
Merge pull request #264 from ostris/cogview4
Added basics for CogView4. Broken as hell though. Dont use.
This commit is contained in:
4
.gitmodules
vendored
4
.gitmodules
vendored
@@ -1,12 +1,16 @@
|
|||||||
[submodule "repositories/sd-scripts"]
|
[submodule "repositories/sd-scripts"]
|
||||||
path = repositories/sd-scripts
|
path = repositories/sd-scripts
|
||||||
url = https://github.com/kohya-ss/sd-scripts.git
|
url = https://github.com/kohya-ss/sd-scripts.git
|
||||||
|
commit = b78c0e2a69e52ce6c79abc6c8c82d1a9cabcf05c
|
||||||
[submodule "repositories/leco"]
|
[submodule "repositories/leco"]
|
||||||
path = repositories/leco
|
path = repositories/leco
|
||||||
url = https://github.com/p1atdev/LECO
|
url = https://github.com/p1atdev/LECO
|
||||||
|
commit = 9294adf40218e917df4516737afb13f069a6789d
|
||||||
[submodule "repositories/batch_annotator"]
|
[submodule "repositories/batch_annotator"]
|
||||||
path = repositories/batch_annotator
|
path = repositories/batch_annotator
|
||||||
url = https://github.com/ostris/batch-annotator
|
url = https://github.com/ostris/batch-annotator
|
||||||
|
commit = 420e142f6ad3cc14b3ea0500affc2c6c7e7544bf
|
||||||
[submodule "repositories/ipadapter"]
|
[submodule "repositories/ipadapter"]
|
||||||
path = repositories/ipadapter
|
path = repositories/ipadapter
|
||||||
url = https://github.com/tencent-ailab/IP-Adapter.git
|
url = https://github.com/tencent-ailab/IP-Adapter.git
|
||||||
|
commit = 5a18b1f3660acaf8bee8250692d6fb3548a19b14
|
||||||
|
|||||||
@@ -380,9 +380,19 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
elif self.sd.prediction_type == 'v_prediction':
|
elif self.sd.prediction_type == 'v_prediction':
|
||||||
# v-parameterization training
|
# v-parameterization training
|
||||||
target = self.sd.noise_scheduler.get_velocity(batch.tensor, noise, timesteps)
|
target = self.sd.noise_scheduler.get_velocity(batch.tensor, noise, timesteps)
|
||||||
|
|
||||||
|
elif hasattr(self.sd, 'get_loss_target'):
|
||||||
|
target = self.sd.get_loss_target(
|
||||||
|
noise=noise,
|
||||||
|
batch=batch,
|
||||||
|
timesteps=timesteps,
|
||||||
|
).detach()
|
||||||
|
|
||||||
elif self.sd.is_flow_matching:
|
elif self.sd.is_flow_matching:
|
||||||
|
# forward ODE
|
||||||
target = (noise - batch.latents).detach()
|
target = (noise - batch.latents).detach()
|
||||||
|
# reverse ODE
|
||||||
|
# target = (batch.latents - noise).detach()
|
||||||
else:
|
else:
|
||||||
target = noise
|
target = noise
|
||||||
|
|
||||||
|
|||||||
@@ -68,6 +68,8 @@ import transformers
|
|||||||
import diffusers
|
import diffusers
|
||||||
import hashlib
|
import hashlib
|
||||||
|
|
||||||
|
from toolkit.util.get_model import get_model_class
|
||||||
|
|
||||||
def flush():
|
def flush():
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
gc.collect()
|
gc.collect()
|
||||||
@@ -666,7 +668,6 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
# # prepare all the models stuff for accelerator (hopefully we dont miss any)
|
# # prepare all the models stuff for accelerator (hopefully we dont miss any)
|
||||||
self.sd.vae = self.accelerator.prepare(self.sd.vae)
|
self.sd.vae = self.accelerator.prepare(self.sd.vae)
|
||||||
if self.sd.unet is not None:
|
if self.sd.unet is not None:
|
||||||
self.sd.unet_unwrapped = self.sd.unet
|
|
||||||
self.sd.unet = self.accelerator.prepare(self.sd.unet)
|
self.sd.unet = self.accelerator.prepare(self.sd.unet)
|
||||||
# todo always tdo it?
|
# todo always tdo it?
|
||||||
self.modules_being_trained.append(self.sd.unet)
|
self.modules_being_trained.append(self.sd.unet)
|
||||||
@@ -1103,11 +1104,19 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
if timestep_type is None:
|
if timestep_type is None:
|
||||||
timestep_type = self.train_config.timestep_type
|
timestep_type = self.train_config.timestep_type
|
||||||
|
|
||||||
|
patch_size = 1
|
||||||
|
if self.sd.is_flux:
|
||||||
|
# flux is a patch size of 1, but latents are divided by 2, so we need to double it
|
||||||
|
patch_size = 2
|
||||||
|
elif hasattr(self.sd.unet.config, 'patch_size'):
|
||||||
|
patch_size = self.sd.unet.config.patch_size
|
||||||
|
|
||||||
self.sd.noise_scheduler.set_train_timesteps(
|
self.sd.noise_scheduler.set_train_timesteps(
|
||||||
num_train_timesteps,
|
num_train_timesteps,
|
||||||
device=self.device_torch,
|
device=self.device_torch,
|
||||||
timestep_type=timestep_type,
|
timestep_type=timestep_type,
|
||||||
latents=latents
|
latents=latents,
|
||||||
|
patch_size=patch_size,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.sd.noise_scheduler.set_timesteps(
|
self.sd.noise_scheduler.set_timesteps(
|
||||||
@@ -1401,21 +1410,26 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
model_config_to_load.name_or_path = latest_save_path
|
model_config_to_load.name_or_path = latest_save_path
|
||||||
self.load_training_state_from_metadata(latest_save_path)
|
self.load_training_state_from_metadata(latest_save_path)
|
||||||
|
|
||||||
# get the noise scheduler
|
ModelClass = get_model_class(self.model_config)
|
||||||
arch = 'sd'
|
# if the model class has get_train_scheduler static method
|
||||||
if self.model_config.is_pixart:
|
if hasattr(ModelClass, 'get_train_scheduler'):
|
||||||
arch = 'pixart'
|
sampler = ModelClass.get_train_scheduler()
|
||||||
if self.model_config.is_flux:
|
else:
|
||||||
arch = 'flux'
|
# get the noise scheduler
|
||||||
if self.model_config.is_lumina2:
|
arch = 'sd'
|
||||||
arch = 'lumina2'
|
if self.model_config.is_pixart:
|
||||||
sampler = get_sampler(
|
arch = 'pixart'
|
||||||
self.train_config.noise_scheduler,
|
if self.model_config.is_flux:
|
||||||
{
|
arch = 'flux'
|
||||||
"prediction_type": "v_prediction" if self.model_config.is_v_pred else "epsilon",
|
if self.model_config.is_lumina2:
|
||||||
},
|
arch = 'lumina2'
|
||||||
arch=arch,
|
sampler = get_sampler(
|
||||||
)
|
self.train_config.noise_scheduler,
|
||||||
|
{
|
||||||
|
"prediction_type": "v_prediction" if self.model_config.is_v_pred else "epsilon",
|
||||||
|
},
|
||||||
|
arch=arch,
|
||||||
|
)
|
||||||
|
|
||||||
if self.train_config.train_refiner and self.model_config.refiner_name_or_path is not None and self.network_config is None:
|
if self.train_config.train_refiner and self.model_config.refiner_name_or_path is not None and self.network_config is None:
|
||||||
previous_refiner_save = self.get_latest_save_path(self.job.name + '_refiner')
|
previous_refiner_save = self.get_latest_save_path(self.job.name + '_refiner')
|
||||||
@@ -1423,7 +1437,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
model_config_to_load.refiner_name_or_path = previous_refiner_save
|
model_config_to_load.refiner_name_or_path = previous_refiner_save
|
||||||
self.load_training_state_from_metadata(previous_refiner_save)
|
self.load_training_state_from_metadata(previous_refiner_save)
|
||||||
|
|
||||||
self.sd = StableDiffusion(
|
self.sd = ModelClass(
|
||||||
device=self.device,
|
device=self.device,
|
||||||
model_config=model_config_to_load,
|
model_config=model_config_to_load,
|
||||||
dtype=self.train_config.dtype,
|
dtype=self.train_config.dtype,
|
||||||
@@ -1559,6 +1573,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
# if is_lycoris:
|
# if is_lycoris:
|
||||||
# preset = PRESET['full']
|
# preset = PRESET['full']
|
||||||
# NetworkClass.apply_preset(preset)
|
# NetworkClass.apply_preset(preset)
|
||||||
|
|
||||||
|
if hasattr(self.sd, 'target_lora_modules'):
|
||||||
|
network_kwargs['target_lin_modules'] = self.sd.target_lora_modules
|
||||||
|
|
||||||
self.network = NetworkClass(
|
self.network = NetworkClass(
|
||||||
text_encoder=text_encoder,
|
text_encoder=text_encoder,
|
||||||
@@ -1587,6 +1604,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
network_config=self.network_config,
|
network_config=self.network_config,
|
||||||
network_type=self.network_config.type,
|
network_type=self.network_config.type,
|
||||||
transformer_only=self.network_config.transformer_only,
|
transformer_only=self.network_config.transformer_only,
|
||||||
|
is_transformer=self.sd.is_transformer,
|
||||||
**network_kwargs
|
**network_kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
torch==2.5.1
|
torch==2.5.1
|
||||||
torchvision==0.20.1
|
torchvision==0.20.1
|
||||||
safetensors
|
safetensors
|
||||||
git+https://github.com/huggingface/diffusers@28f48f4051e80082cbe97f2d62b365dbb01040ec
|
git+https://github.com/huggingface/diffusers@24c062aaa19f5626d03d058daf8afffa2dfd49f7
|
||||||
transformers
|
transformers==4.49.0
|
||||||
lycoris-lora==1.8.3
|
lycoris-lora==1.8.3
|
||||||
flatten_json
|
flatten_json
|
||||||
pyyaml
|
pyyaml
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ def paramiter_count(model):
|
|||||||
return int(paramiter_count)
|
return int(paramiter_count)
|
||||||
|
|
||||||
|
|
||||||
def calculate_metrics(vae, images, max_imgs=-1):
|
def calculate_metrics(vae, images, max_imgs=-1, save_output=False):
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
vae = vae.to(device)
|
vae = vae.to(device)
|
||||||
lpips_model = lpips.LPIPS(net='alex').to(device)
|
lpips_model = lpips.LPIPS(net='alex').to(device)
|
||||||
@@ -44,6 +44,9 @@ def calculate_metrics(vae, images, max_imgs=-1):
|
|||||||
# ])
|
# ])
|
||||||
# needs values between -1 and 1
|
# needs values between -1 and 1
|
||||||
to_tensor = ToTensor()
|
to_tensor = ToTensor()
|
||||||
|
|
||||||
|
# remove _reconstructed.png files
|
||||||
|
images = [img for img in images if not img.endswith("_reconstructed.png")]
|
||||||
|
|
||||||
if max_imgs > 0 and len(images) > max_imgs:
|
if max_imgs > 0 and len(images) > max_imgs:
|
||||||
images = images[:max_imgs]
|
images = images[:max_imgs]
|
||||||
@@ -82,6 +85,15 @@ def calculate_metrics(vae, images, max_imgs=-1):
|
|||||||
avg_rfid = 0
|
avg_rfid = 0
|
||||||
avg_psnr = sum(psnr_scores) / len(psnr_scores)
|
avg_psnr = sum(psnr_scores) / len(psnr_scores)
|
||||||
avg_lpips = sum(lpips_scores) / len(lpips_scores)
|
avg_lpips = sum(lpips_scores) / len(lpips_scores)
|
||||||
|
|
||||||
|
if save_output:
|
||||||
|
filename_no_ext = os.path.splitext(os.path.basename(img_path))[0]
|
||||||
|
folder = os.path.dirname(img_path)
|
||||||
|
save_path = os.path.join(folder, filename_no_ext + "_reconstructed.png")
|
||||||
|
reconstructed = (reconstructed + 1) / 2
|
||||||
|
reconstructed = reconstructed.clamp(0, 1)
|
||||||
|
reconstructed = transforms.ToPILImage()(reconstructed[0].cpu())
|
||||||
|
reconstructed.save(save_path)
|
||||||
|
|
||||||
return avg_rfid, avg_psnr, avg_lpips
|
return avg_rfid, avg_psnr, avg_lpips
|
||||||
|
|
||||||
@@ -91,18 +103,23 @@ def main():
|
|||||||
parser.add_argument("--vae_path", type=str, required=True, help="Path to the VAE model")
|
parser.add_argument("--vae_path", type=str, required=True, help="Path to the VAE model")
|
||||||
parser.add_argument("--image_folder", type=str, required=True, help="Path to the folder containing images")
|
parser.add_argument("--image_folder", type=str, required=True, help="Path to the folder containing images")
|
||||||
parser.add_argument("--max_imgs", type=int, default=-1, help="Max num of images. Default is -1 for all images.")
|
parser.add_argument("--max_imgs", type=int, default=-1, help="Max num of images. Default is -1 for all images.")
|
||||||
|
# boolean store true
|
||||||
|
parser.add_argument("--save_output", action="store_true", help="Save the output images")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if os.path.isfile(args.vae_path):
|
if os.path.isfile(args.vae_path):
|
||||||
vae = AutoencoderKL.from_single_file(args.vae_path)
|
vae = AutoencoderKL.from_single_file(args.vae_path)
|
||||||
else:
|
else:
|
||||||
vae = AutoencoderKL.from_pretrained(args.vae_path)
|
try:
|
||||||
|
vae = AutoencoderKL.from_pretrained(args.vae_path)
|
||||||
|
except:
|
||||||
|
vae = AutoencoderKL.from_pretrained(args.vae_path, subfolder="vae")
|
||||||
vae.eval()
|
vae.eval()
|
||||||
vae = vae.to(device)
|
vae = vae.to(device)
|
||||||
print(f"Model has {paramiter_count(vae)} parameters")
|
print(f"Model has {paramiter_count(vae)} parameters")
|
||||||
images = load_images(args.image_folder)
|
images = load_images(args.image_folder)
|
||||||
|
|
||||||
avg_rfid, avg_psnr, avg_lpips = calculate_metrics(vae, images, args.max_imgs)
|
avg_rfid, avg_psnr, avg_lpips = calculate_metrics(vae, images, args.max_imgs, args.save_output)
|
||||||
|
|
||||||
# print(f"Average rFID: {avg_rfid}")
|
# print(f"Average rFID: {avg_rfid}")
|
||||||
print(f"Average PSNR: {avg_psnr}")
|
print(f"Average PSNR: {avg_psnr}")
|
||||||
|
|||||||
@@ -432,6 +432,9 @@ class TrainConfig:
|
|||||||
self.force_consistent_noise = kwargs.get('force_consistent_noise', False)
|
self.force_consistent_noise = kwargs.get('force_consistent_noise', False)
|
||||||
|
|
||||||
|
|
||||||
|
ModelArch = Literal['sd1', 'sd2', 'sd3', 'sdxl', 'pixart', 'pixart_sigma', 'auraflow', 'flux', 'flex2', 'lumina2', 'vega', 'ssd', 'wan21']
|
||||||
|
|
||||||
|
|
||||||
class ModelConfig:
|
class ModelConfig:
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
self.name_or_path: str = kwargs.get('name_or_path', None)
|
self.name_or_path: str = kwargs.get('name_or_path', None)
|
||||||
@@ -509,6 +512,36 @@ class ModelConfig:
|
|||||||
self.split_model_other_module_param_count_scale = kwargs.get("split_model_other_module_param_count_scale", 0.3)
|
self.split_model_other_module_param_count_scale = kwargs.get("split_model_other_module_param_count_scale", 0.3)
|
||||||
|
|
||||||
self.te_name_or_path = kwargs.get("te_name_or_path", None)
|
self.te_name_or_path = kwargs.get("te_name_or_path", None)
|
||||||
|
|
||||||
|
self.arch: ModelArch = kwargs.get("arch", None)
|
||||||
|
|
||||||
|
# handle migrating to new model arch
|
||||||
|
if self.arch is None:
|
||||||
|
if kwargs.get('is_v2', False):
|
||||||
|
self.arch = 'sd2'
|
||||||
|
elif kwargs.get('is_v3', False):
|
||||||
|
self.arch = 'sd3'
|
||||||
|
elif kwargs.get('is_xl', False):
|
||||||
|
self.arch = 'sdxl'
|
||||||
|
elif kwargs.get('is_pixart', False):
|
||||||
|
self.arch = 'pixart'
|
||||||
|
elif kwargs.get('is_pixart_sigma', False):
|
||||||
|
self.arch = 'pixart_sigma'
|
||||||
|
elif kwargs.get('is_auraflow', False):
|
||||||
|
self.arch = 'auraflow'
|
||||||
|
elif kwargs.get('is_flux', False):
|
||||||
|
self.arch = 'flux'
|
||||||
|
elif kwargs.get('is_flex2', False):
|
||||||
|
self.arch = 'flex2'
|
||||||
|
elif kwargs.get('is_lumina2', False):
|
||||||
|
self.arch = 'lumina2'
|
||||||
|
elif kwargs.get('is_vega', False):
|
||||||
|
self.arch = 'vega'
|
||||||
|
elif kwargs.get('is_ssd', False):
|
||||||
|
self.arch = 'ssd'
|
||||||
|
else:
|
||||||
|
self.arch = 'sd1'
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class EMAConfig:
|
class EMAConfig:
|
||||||
|
|||||||
@@ -178,6 +178,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
|||||||
transformer_only: bool = False,
|
transformer_only: bool = False,
|
||||||
peft_format: bool = False,
|
peft_format: bool = False,
|
||||||
is_assistant_adapter: bool = False,
|
is_assistant_adapter: bool = False,
|
||||||
|
is_transformer: bool = False,
|
||||||
**kwargs
|
**kwargs
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
@@ -237,9 +238,11 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
|||||||
self.network_config: NetworkConfig = kwargs.get("network_config", None)
|
self.network_config: NetworkConfig = kwargs.get("network_config", None)
|
||||||
|
|
||||||
self.peft_format = peft_format
|
self.peft_format = peft_format
|
||||||
|
self.is_transformer = is_transformer
|
||||||
|
|
||||||
|
|
||||||
# always do peft for flux only for now
|
# always do peft for flux only for now
|
||||||
if self.is_flux or self.is_v3 or self.is_lumina2:
|
if self.is_flux or self.is_v3 or self.is_lumina2 or is_transformer:
|
||||||
# don't do peft format for lokr
|
# don't do peft format for lokr
|
||||||
if self.network_type.lower() != "lokr":
|
if self.network_type.lower() != "lokr":
|
||||||
self.peft_format = True
|
self.peft_format = True
|
||||||
@@ -282,7 +285,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
|||||||
unet_prefix = self.LORA_PREFIX_UNET
|
unet_prefix = self.LORA_PREFIX_UNET
|
||||||
if self.peft_format:
|
if self.peft_format:
|
||||||
unet_prefix = self.PEFT_PREFIX_UNET
|
unet_prefix = self.PEFT_PREFIX_UNET
|
||||||
if is_pixart or is_v3 or is_auraflow or is_flux or is_lumina2:
|
if is_pixart or is_v3 or is_auraflow or is_flux or is_lumina2 or self.is_transformer:
|
||||||
unet_prefix = f"lora_transformer"
|
unet_prefix = f"lora_transformer"
|
||||||
if self.peft_format:
|
if self.peft_format:
|
||||||
unet_prefix = "transformer"
|
unet_prefix = "transformer"
|
||||||
@@ -341,6 +344,11 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
|||||||
if self.transformer_only and self.is_v3 and is_unet:
|
if self.transformer_only and self.is_v3 and is_unet:
|
||||||
if "transformer_blocks" not in lora_name:
|
if "transformer_blocks" not in lora_name:
|
||||||
skip = True
|
skip = True
|
||||||
|
|
||||||
|
# handle custom models
|
||||||
|
if self.transformer_only and is_unet and hasattr(root_module, 'transformer_blocks'):
|
||||||
|
if "transformer_blocks" not in lora_name:
|
||||||
|
skip = True
|
||||||
|
|
||||||
if (is_linear or is_conv2d) and not skip:
|
if (is_linear or is_conv2d) and not skip:
|
||||||
|
|
||||||
|
|||||||
1426
toolkit/models/base_model.py
Normal file
1426
toolkit/models/base_model.py
Normal file
File diff suppressed because it is too large
Load Diff
466
toolkit/models/cogview4.py
Normal file
466
toolkit/models/cogview4.py
Normal file
@@ -0,0 +1,466 @@
|
|||||||
|
# DONT USE THIS!. IT DOES NOT WORK YET!
|
||||||
|
# Will revisit this when they release more info on how it was trained.
|
||||||
|
|
||||||
|
import weakref
|
||||||
|
from diffusers import CogView4Pipeline
|
||||||
|
import torch
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
from toolkit.basic import flush
|
||||||
|
from toolkit.config_modules import GenerateImageConfig, ModelConfig
|
||||||
|
from toolkit.dequantize import patch_dequantization_on_save
|
||||||
|
from toolkit.models.base_model import BaseModel
|
||||||
|
from toolkit.prompt_utils import PromptEmbeds
|
||||||
|
|
||||||
|
import os
|
||||||
|
import copy
|
||||||
|
from toolkit.config_modules import ModelConfig, GenerateImageConfig, ModelArch
|
||||||
|
import torch
|
||||||
|
import diffusers
|
||||||
|
from diffusers import AutoencoderKL, CogView4Transformer2DModel, CogView4Pipeline
|
||||||
|
from optimum.quanto import freeze, qfloat8, QTensor, qint4
|
||||||
|
from toolkit.util.quantize import quantize
|
||||||
|
from transformers import GlmModel, AutoTokenizer
|
||||||
|
from diffusers import FlowMatchEulerDiscreteScheduler
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
from toolkit.accelerator import unwrap_model
|
||||||
|
from toolkit.samplers.custom_flowmatch_sampler import CustomFlowMatchEulerDiscreteScheduler
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from toolkit.lora_special import LoRASpecialNetwork
|
||||||
|
|
||||||
|
# remove this after a bug is fixed in diffusers code. This is a workaround.
|
||||||
|
|
||||||
|
|
||||||
|
class FakeModel:
|
||||||
|
def __init__(self, model):
|
||||||
|
self.model_ref = weakref.ref(model)
|
||||||
|
pass
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self):
|
||||||
|
return self.model_ref().device
|
||||||
|
|
||||||
|
|
||||||
|
scheduler_config = {
|
||||||
|
"base_image_seq_len": 256,
|
||||||
|
"base_shift": 0.25,
|
||||||
|
"invert_sigmas": False,
|
||||||
|
"max_image_seq_len": 4096,
|
||||||
|
"max_shift": 0.75,
|
||||||
|
"num_train_timesteps": 1000,
|
||||||
|
"shift": 1.0,
|
||||||
|
"shift_terminal": None,
|
||||||
|
"time_shift_type": "linear",
|
||||||
|
"use_beta_sigmas": False,
|
||||||
|
"use_dynamic_shifting": True,
|
||||||
|
"use_exponential_sigmas": False,
|
||||||
|
"use_karras_sigmas": False
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class CogView4(BaseModel):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
device,
|
||||||
|
model_config: ModelConfig,
|
||||||
|
dtype='bf16',
|
||||||
|
custom_pipeline=None,
|
||||||
|
noise_scheduler=None,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
super().__init__(device, model_config, dtype,
|
||||||
|
custom_pipeline, noise_scheduler, **kwargs)
|
||||||
|
self.is_flow_matching = True
|
||||||
|
self.is_transformer = True
|
||||||
|
self.target_lora_modules = ['CogView4Transformer2DModel']
|
||||||
|
|
||||||
|
# cache for holding noise
|
||||||
|
self.effective_noise = None
|
||||||
|
|
||||||
|
# static method to get the scheduler
|
||||||
|
@staticmethod
|
||||||
|
def get_train_scheduler():
|
||||||
|
scheduler = CustomFlowMatchEulerDiscreteScheduler(**scheduler_config)
|
||||||
|
return scheduler
|
||||||
|
|
||||||
|
def load_model(self):
|
||||||
|
dtype = self.torch_dtype
|
||||||
|
base_model_path = "THUDM/CogView4-6B"
|
||||||
|
model_path = self.model_config.name_or_path
|
||||||
|
|
||||||
|
self.print_and_status_update("Loading CogView4 model")
|
||||||
|
# base_model_path = "black-forest-labs/FLUX.1-schnell"
|
||||||
|
base_model_path = self.model_config.name_or_path_original
|
||||||
|
subfolder = 'transformer'
|
||||||
|
transformer_path = model_path
|
||||||
|
if os.path.exists(transformer_path):
|
||||||
|
subfolder = None
|
||||||
|
transformer_path = os.path.join(transformer_path, 'transformer')
|
||||||
|
# check if the path is a full checkpoint.
|
||||||
|
te_folder_path = os.path.join(model_path, 'text_encoder')
|
||||||
|
# if we have the te, this folder is a full checkpoint, use it as the base
|
||||||
|
if os.path.exists(te_folder_path):
|
||||||
|
base_model_path = model_path
|
||||||
|
|
||||||
|
self.print_and_status_update("Loading GlmModel")
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
base_model_path, subfolder="tokenizer", torch_dtype=dtype)
|
||||||
|
text_encoder = GlmModel.from_pretrained(
|
||||||
|
base_model_path, subfolder="text_encoder", torch_dtype=dtype)
|
||||||
|
|
||||||
|
text_encoder.to(self.device_torch, dtype=dtype)
|
||||||
|
flush()
|
||||||
|
|
||||||
|
if self.model_config.quantize_te:
|
||||||
|
self.print_and_status_update("Quantizing GlmModel")
|
||||||
|
quantize(text_encoder, weights=qfloat8)
|
||||||
|
freeze(text_encoder)
|
||||||
|
flush()
|
||||||
|
|
||||||
|
# hack to fix diffusers bug workaround
|
||||||
|
text_encoder.model = FakeModel(text_encoder)
|
||||||
|
|
||||||
|
self.print_and_status_update("Loading transformer")
|
||||||
|
transformer = CogView4Transformer2DModel.from_pretrained(
|
||||||
|
transformer_path,
|
||||||
|
subfolder=subfolder,
|
||||||
|
torch_dtype=dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.model_config.split_model_over_gpus:
|
||||||
|
raise ValueError(
|
||||||
|
"Splitting model over gpus is not supported for CogViewModels models")
|
||||||
|
|
||||||
|
transformer.to(self.quantize_device, dtype=dtype)
|
||||||
|
flush()
|
||||||
|
|
||||||
|
if self.model_config.assistant_lora_path is not None or self.model_config.inference_lora_path is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"Assistant LoRA is not supported for CogViewModels models currently")
|
||||||
|
|
||||||
|
if self.model_config.lora_path is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"Loading LoRA is not supported for CogViewModels models currently")
|
||||||
|
|
||||||
|
flush()
|
||||||
|
|
||||||
|
if self.model_config.quantize:
|
||||||
|
quantization_args = self.model_config.quantize_kwargs
|
||||||
|
if 'exclude' not in quantization_args:
|
||||||
|
quantization_args['exclude'] = []
|
||||||
|
if 'include' not in quantization_args:
|
||||||
|
quantization_args['include'] = []
|
||||||
|
|
||||||
|
# Be more specific with the include pattern to exactly match transformer blocks
|
||||||
|
quantization_args['include'] += ["transformer_blocks.*"]
|
||||||
|
|
||||||
|
# Exclude all LayerNorm layers within transformer blocks
|
||||||
|
quantization_args['exclude'] += [
|
||||||
|
"transformer_blocks.*.norm1",
|
||||||
|
"transformer_blocks.*.norm2",
|
||||||
|
"transformer_blocks.*.norm2_context",
|
||||||
|
"transformer_blocks.*.attn1.norm_q",
|
||||||
|
"transformer_blocks.*.attn1.norm_k"
|
||||||
|
]
|
||||||
|
|
||||||
|
# patch the state dict method
|
||||||
|
patch_dequantization_on_save(transformer)
|
||||||
|
quantization_type = qfloat8
|
||||||
|
self.print_and_status_update("Quantizing transformer")
|
||||||
|
quantize(transformer, weights=quantization_type, **quantization_args)
|
||||||
|
freeze(transformer)
|
||||||
|
transformer.to(self.device_torch)
|
||||||
|
else:
|
||||||
|
transformer.to(self.device_torch, dtype=dtype)
|
||||||
|
|
||||||
|
flush()
|
||||||
|
|
||||||
|
scheduler = CogView4.get_train_scheduler()
|
||||||
|
self.print_and_status_update("Loading VAE")
|
||||||
|
vae = AutoencoderKL.from_pretrained(
|
||||||
|
base_model_path, subfolder="vae", torch_dtype=dtype)
|
||||||
|
flush()
|
||||||
|
|
||||||
|
self.print_and_status_update("Making pipe")
|
||||||
|
pipe: CogView4Pipeline = CogView4Pipeline(
|
||||||
|
scheduler=scheduler,
|
||||||
|
text_encoder=None,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
vae=vae,
|
||||||
|
transformer=None,
|
||||||
|
)
|
||||||
|
pipe.text_encoder = text_encoder
|
||||||
|
pipe.transformer = transformer
|
||||||
|
|
||||||
|
self.print_and_status_update("Preparing Model")
|
||||||
|
|
||||||
|
text_encoder = pipe.text_encoder
|
||||||
|
tokenizer = pipe.tokenizer
|
||||||
|
|
||||||
|
pipe.transformer = pipe.transformer.to(self.device_torch)
|
||||||
|
|
||||||
|
flush()
|
||||||
|
text_encoder.to(self.device_torch)
|
||||||
|
text_encoder.requires_grad_(False)
|
||||||
|
text_encoder.eval()
|
||||||
|
pipe.transformer = pipe.transformer.to(self.device_torch)
|
||||||
|
flush()
|
||||||
|
self.pipeline = pipe
|
||||||
|
self.model = transformer
|
||||||
|
self.vae = vae
|
||||||
|
self.text_encoder = text_encoder
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
|
||||||
|
def get_generation_pipeline(self):
|
||||||
|
scheduler = CogView4.get_train_scheduler()
|
||||||
|
pipeline = CogView4Pipeline(
|
||||||
|
vae=self.vae,
|
||||||
|
transformer=self.unet,
|
||||||
|
text_encoder=self.text_encoder,
|
||||||
|
tokenizer=self.tokenizer,
|
||||||
|
scheduler=scheduler,
|
||||||
|
)
|
||||||
|
return pipeline
|
||||||
|
|
||||||
|
def generate_single_image(
|
||||||
|
self,
|
||||||
|
pipeline: CogView4Pipeline,
|
||||||
|
gen_config: GenerateImageConfig,
|
||||||
|
conditional_embeds: PromptEmbeds,
|
||||||
|
unconditional_embeds: PromptEmbeds,
|
||||||
|
generator: torch.Generator,
|
||||||
|
extra: dict,
|
||||||
|
):
|
||||||
|
img = pipeline(
|
||||||
|
prompt_embeds=conditional_embeds.text_embeds.to(
|
||||||
|
self.device_torch, dtype=self.torch_dtype),
|
||||||
|
negative_prompt_embeds=unconditional_embeds.text_embeds.to(
|
||||||
|
self.device_torch, dtype=self.torch_dtype),
|
||||||
|
height=gen_config.height,
|
||||||
|
width=gen_config.width,
|
||||||
|
num_inference_steps=gen_config.num_inference_steps,
|
||||||
|
guidance_scale=gen_config.guidance_scale,
|
||||||
|
latents=gen_config.latents,
|
||||||
|
generator=generator,
|
||||||
|
**extra
|
||||||
|
).images[0]
|
||||||
|
return img
|
||||||
|
|
||||||
|
def get_noise_prediction(
|
||||||
|
self,
|
||||||
|
latent_model_input: torch.Tensor,
|
||||||
|
timestep: torch.Tensor, # 0 to 1000 scale
|
||||||
|
text_embeddings: PromptEmbeds,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
# target_size = (height, width)
|
||||||
|
target_size = latent_model_input.shape[-2:]
|
||||||
|
# multiply by 8
|
||||||
|
target_size = (target_size[0] * 8, target_size[1] * 8)
|
||||||
|
crops_coords_top_left = torch.tensor(
|
||||||
|
[(0, 0)], dtype=self.torch_dtype, device=self.device_torch)
|
||||||
|
|
||||||
|
original_size = torch.tensor(
|
||||||
|
[target_size], dtype=self.torch_dtype, device=self.device_torch)
|
||||||
|
target_size = original_size.clone()
|
||||||
|
noise_pred_cond = self.model(
|
||||||
|
hidden_states=latent_model_input,
|
||||||
|
encoder_hidden_states=text_embeddings.text_embeds,
|
||||||
|
timestep=timestep,
|
||||||
|
original_size=original_size,
|
||||||
|
target_size=target_size,
|
||||||
|
crop_coords=crops_coords_top_left,
|
||||||
|
return_dict=False,
|
||||||
|
)[0]
|
||||||
|
return noise_pred_cond
|
||||||
|
|
||||||
|
def get_prompt_embeds(self, prompt: str) -> PromptEmbeds:
|
||||||
|
prompt_embeds, _ = self.pipeline.encode_prompt(
|
||||||
|
prompt,
|
||||||
|
do_classifier_free_guidance=False,
|
||||||
|
device=self.device_torch,
|
||||||
|
dtype=self.torch_dtype,
|
||||||
|
)
|
||||||
|
return PromptEmbeds(prompt_embeds)
|
||||||
|
|
||||||
|
def get_model_has_grad(self):
|
||||||
|
return self.model.proj_out.weight.requires_grad
|
||||||
|
|
||||||
|
def get_te_has_grad(self):
|
||||||
|
return self.text_encoder.layers[0].mlp.down_proj.weight.requires_grad
|
||||||
|
|
||||||
|
def save_model(self, output_path, meta, save_dtype):
|
||||||
|
# only save the unet
|
||||||
|
transformer: CogView4Transformer2DModel = unwrap_model(self.model)
|
||||||
|
transformer.save_pretrained(
|
||||||
|
save_directory=os.path.join(output_path, 'transformer'),
|
||||||
|
safe_serialization=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
meta_path = os.path.join(output_path, 'aitk_meta.yaml')
|
||||||
|
with open(meta_path, 'w') as f:
|
||||||
|
yaml.dump(meta, f)
|
||||||
|
|
||||||
|
def get_loss_target(self, *args, **kwargs):
|
||||||
|
noise = kwargs.get('noise')
|
||||||
|
effective_noise = self.effective_noise
|
||||||
|
batch = kwargs.get('batch')
|
||||||
|
if batch is None:
|
||||||
|
raise ValueError("Batch is not provided")
|
||||||
|
if noise is None:
|
||||||
|
raise ValueError("Noise is not provided")
|
||||||
|
# return batch.latents
|
||||||
|
# return (batch.latents - noise).detach()
|
||||||
|
return (noise - batch.latents).detach()
|
||||||
|
# return (batch.latents).detach()
|
||||||
|
# return (effective_noise - batch.latents).detach()
|
||||||
|
|
||||||
|
def _get_low_res_latents(self, latents):
|
||||||
|
# todo prevent needing to do this and grab the tensor another way.
|
||||||
|
with torch.no_grad():
|
||||||
|
# Decode latents to image space
|
||||||
|
images = self.decode_latents(
|
||||||
|
latents, device=latents.device, dtype=latents.dtype)
|
||||||
|
|
||||||
|
# Downsample by a factor of 2 using bilinear interpolation
|
||||||
|
B, C, H, W = images.shape
|
||||||
|
low_res_images = torch.nn.functional.interpolate(
|
||||||
|
images,
|
||||||
|
size=(H // 2, W // 2),
|
||||||
|
mode="bilinear",
|
||||||
|
align_corners=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# Upsample back to original resolution to match expected VAE input dimensions
|
||||||
|
upsampled_low_res_images = torch.nn.functional.interpolate(
|
||||||
|
low_res_images,
|
||||||
|
size=(H, W),
|
||||||
|
mode="bilinear",
|
||||||
|
align_corners=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# Encode the low-resolution images back to latent space
|
||||||
|
low_res_latents = self.encode_images(
|
||||||
|
upsampled_low_res_images, device=latents.device, dtype=latents.dtype)
|
||||||
|
return low_res_latents
|
||||||
|
|
||||||
|
# def add_noise(
|
||||||
|
# self,
|
||||||
|
# original_samples: torch.FloatTensor,
|
||||||
|
# noise: torch.FloatTensor,
|
||||||
|
# timesteps: torch.IntTensor,
|
||||||
|
# **kwargs,
|
||||||
|
# ) -> torch.FloatTensor:
|
||||||
|
# relay_start_point = 500
|
||||||
|
|
||||||
|
# # Store original samples for loss calculation
|
||||||
|
# self.original_samples = original_samples
|
||||||
|
|
||||||
|
# # Prepare chunks for batch processing
|
||||||
|
# original_samples_chunks = torch.chunk(
|
||||||
|
# original_samples, original_samples.shape[0], dim=0)
|
||||||
|
# noise_chunks = torch.chunk(noise, noise.shape[0], dim=0)
|
||||||
|
# timesteps_chunks = torch.chunk(timesteps, timesteps.shape[0], dim=0)
|
||||||
|
|
||||||
|
# # Get the low res latents only if needed
|
||||||
|
# low_res_latents_chunks = None
|
||||||
|
|
||||||
|
# # Handle case where timesteps is a single value for all samples
|
||||||
|
# if len(timesteps_chunks) == 1 and len(timesteps_chunks) != len(original_samples_chunks):
|
||||||
|
# timesteps_chunks = [timesteps_chunks[0]] * len(original_samples_chunks)
|
||||||
|
|
||||||
|
# noisy_latents_chunks = []
|
||||||
|
# effective_noise_chunks = [] # Store the effective noise for each sample
|
||||||
|
|
||||||
|
# for idx in range(original_samples.shape[0]):
|
||||||
|
# t = timesteps_chunks[idx]
|
||||||
|
# t_01 = (t / 1000).to(original_samples_chunks[idx].device)
|
||||||
|
|
||||||
|
# # Flowmatching interpolation between original and noise
|
||||||
|
# if t > relay_start_point:
|
||||||
|
# # Standard flowmatching - direct linear interpolation
|
||||||
|
# noisy_latents = (1 - t_01) * original_samples_chunks[idx] + t_01 * noise_chunks[idx]
|
||||||
|
# effective_noise_chunks.append(noise_chunks[idx]) # Effective noise is just the noise
|
||||||
|
# else:
|
||||||
|
# # Relay flowmatching case - only compute low_res_latents if needed
|
||||||
|
# if low_res_latents_chunks is None:
|
||||||
|
# low_res_latents = self._get_low_res_latents(original_samples)
|
||||||
|
# low_res_latents_chunks = torch.chunk(low_res_latents, low_res_latents.shape[0], dim=0)
|
||||||
|
|
||||||
|
# # Calculate the relay ratio (0 to 1)
|
||||||
|
# t_ratio = t.float() / relay_start_point
|
||||||
|
# t_ratio = torch.clamp(t_ratio, 0.0, 1.0)
|
||||||
|
|
||||||
|
# # First blend between original and low-res based on t_ratio
|
||||||
|
# z0_t = (1 - t_ratio) * original_samples_chunks[idx] + t_ratio * low_res_latents_chunks[idx]
|
||||||
|
|
||||||
|
# added_lor_res_noise = z0_t - original_samples_chunks[idx]
|
||||||
|
|
||||||
|
# # Then apply flowmatching interpolation between this blended state and noise
|
||||||
|
# noisy_latents = (1 - t_01) * z0_t + t_01 * noise_chunks[idx]
|
||||||
|
|
||||||
|
# # For prediction target, we need to store the effective "source"
|
||||||
|
# effective_noise_chunks.append(noise_chunks[idx] + added_lor_res_noise)
|
||||||
|
|
||||||
|
# noisy_latents_chunks.append(noisy_latents)
|
||||||
|
|
||||||
|
# noisy_latents = torch.cat(noisy_latents_chunks, dim=0)
|
||||||
|
# self.effective_noise = torch.cat(effective_noise_chunks, dim=0) # Store for loss calculation
|
||||||
|
|
||||||
|
# return noisy_latents
|
||||||
|
|
||||||
|
# def add_noise(
|
||||||
|
# self,
|
||||||
|
# original_samples: torch.FloatTensor,
|
||||||
|
# noise: torch.FloatTensor,
|
||||||
|
# timesteps: torch.IntTensor,
|
||||||
|
# **kwargs,
|
||||||
|
# ) -> torch.FloatTensor:
|
||||||
|
# relay_start_point = 500
|
||||||
|
|
||||||
|
# # Store original samples for loss calculation
|
||||||
|
# self.original_samples = original_samples
|
||||||
|
|
||||||
|
# # Prepare chunks for batch processing
|
||||||
|
# original_samples_chunks = torch.chunk(
|
||||||
|
# original_samples, original_samples.shape[0], dim=0)
|
||||||
|
# noise_chunks = torch.chunk(noise, noise.shape[0], dim=0)
|
||||||
|
# timesteps_chunks = torch.chunk(timesteps, timesteps.shape[0], dim=0)
|
||||||
|
|
||||||
|
# # Get the low res latents only if needed
|
||||||
|
# low_res_latents = self._get_low_res_latents(original_samples)
|
||||||
|
# low_res_latents_chunks = torch.chunk(low_res_latents, low_res_latents.shape[0], dim=0)
|
||||||
|
|
||||||
|
# # Handle case where timesteps is a single value for all samples
|
||||||
|
# if len(timesteps_chunks) == 1 and len(timesteps_chunks) != len(original_samples_chunks):
|
||||||
|
# timesteps_chunks = [timesteps_chunks[0]] * len(original_samples_chunks)
|
||||||
|
|
||||||
|
# noisy_latents_chunks = []
|
||||||
|
# effective_noise_chunks = [] # Store the effective noise for each sample
|
||||||
|
|
||||||
|
# for idx in range(original_samples.shape[0]):
|
||||||
|
# t = timesteps_chunks[idx]
|
||||||
|
# t_01 = (t / 1000).to(original_samples_chunks[idx].device)
|
||||||
|
|
||||||
|
# lrln = low_res_latents_chunks[idx] - original_samples_chunks[idx]
|
||||||
|
# # lrln = lrln * (1 - t_01)
|
||||||
|
|
||||||
|
# # make the noise an interpolation between noise and low_res_latents with
|
||||||
|
# # being noise at t_01=1 and low_res_latents at t_01=0
|
||||||
|
# new_noise = t_01 * noise_chunks[idx] + (1 - t_01) * lrln
|
||||||
|
# # new_noise = noise_chunks[idx] + lrln
|
||||||
|
# # new_noise = noise_chunks[idx] + lrln
|
||||||
|
|
||||||
|
# # Then apply flowmatching interpolation between this blended state and noise
|
||||||
|
# noisy_latents = (1 - t_01) * original_samples + t_01 * new_noise
|
||||||
|
|
||||||
|
# # For prediction target, we need to store the effective "source"
|
||||||
|
# effective_noise_chunks.append(new_noise)
|
||||||
|
|
||||||
|
# noisy_latents_chunks.append(noisy_latents)
|
||||||
|
|
||||||
|
# noisy_latents = torch.cat(noisy_latents_chunks, dim=0)
|
||||||
|
# self.effective_noise = torch.cat(effective_noise_chunks, dim=0) # Store for loss calculation
|
||||||
|
|
||||||
|
# return noisy_latents
|
||||||
82
toolkit/models/wan21.py
Normal file
82
toolkit/models/wan21.py
Normal file
@@ -0,0 +1,82 @@
|
|||||||
|
# WIP, coming soon ish
|
||||||
|
import torch
|
||||||
|
from toolkit.config_modules import GenerateImageConfig, ModelConfig
|
||||||
|
from toolkit.models.base_model import BaseModel
|
||||||
|
from toolkit.prompt_utils import PromptEmbeds
|
||||||
|
from toolkit.paths import REPOS_ROOT
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
|
||||||
|
import gc
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import sys
|
||||||
|
import types
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.cuda.amp as amp
|
||||||
|
import torch.distributed as dist
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
|
class Wan21(BaseModel):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
device,
|
||||||
|
model_config: ModelConfig,
|
||||||
|
dtype='bf16',
|
||||||
|
custom_pipeline=None,
|
||||||
|
noise_scheduler=None,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
super().__init__(device, model_config, dtype,
|
||||||
|
custom_pipeline, noise_scheduler, **kwargs)
|
||||||
|
self.is_flow_matching = True
|
||||||
|
raise NotImplementedError("Wan21 is not implemented yet")
|
||||||
|
# these must be implemented in child classes
|
||||||
|
|
||||||
|
def load_model(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_generation_pipeline(self):
|
||||||
|
# override this in child classes
|
||||||
|
raise NotImplementedError(
|
||||||
|
"get_generation_pipeline must be implemented in child classes")
|
||||||
|
|
||||||
|
def generate_single_image(
|
||||||
|
self,
|
||||||
|
pipeline,
|
||||||
|
gen_config: GenerateImageConfig,
|
||||||
|
conditional_embeds: PromptEmbeds,
|
||||||
|
unconditional_embeds: PromptEmbeds,
|
||||||
|
generator: torch.Generator,
|
||||||
|
extra: dict,
|
||||||
|
):
|
||||||
|
# override this in child classes
|
||||||
|
raise NotImplementedError(
|
||||||
|
"generate_single_image must be implemented in child classes")
|
||||||
|
|
||||||
|
def get_noise_prediction(
|
||||||
|
latent_model_input: torch.Tensor,
|
||||||
|
timestep: torch.Tensor, # 0 to 1000 scale
|
||||||
|
text_embeddings: PromptEmbeds,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
raise NotImplementedError(
|
||||||
|
"get_noise_prediction must be implemented in child classes")
|
||||||
|
|
||||||
|
def get_prompt_embeds(self, prompt: str) -> PromptEmbeds:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"get_prompt_embeds must be implemented in child classes")
|
||||||
|
|
||||||
|
def get_model_has_grad(self):
|
||||||
|
raise NotImplementedError(
|
||||||
|
"get_model_has_grad must be implemented in child classes")
|
||||||
|
|
||||||
|
def get_te_has_grad(self):
|
||||||
|
raise NotImplementedError(
|
||||||
|
"get_te_has_grad must be implemented in child classes")
|
||||||
@@ -44,7 +44,8 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
|
|||||||
hbsmntw_weighing = y_shifted * (num_timesteps / y_shifted.sum())
|
hbsmntw_weighing = y_shifted * (num_timesteps / y_shifted.sum())
|
||||||
|
|
||||||
# flatten second half to max
|
# flatten second half to max
|
||||||
hbsmntw_weighing[num_timesteps // 2:] = hbsmntw_weighing[num_timesteps // 2:].max()
|
hbsmntw_weighing[num_timesteps //
|
||||||
|
2:] = hbsmntw_weighing[num_timesteps // 2:].max()
|
||||||
|
|
||||||
# Create linear timesteps from 1000 to 0
|
# Create linear timesteps from 1000 to 0
|
||||||
timesteps = torch.linspace(1000, 0, num_timesteps, device='cpu')
|
timesteps = torch.linspace(1000, 0, num_timesteps, device='cpu')
|
||||||
@@ -56,7 +57,8 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
|
|||||||
|
|
||||||
def get_weights_for_timesteps(self, timesteps: torch.Tensor, v2=False) -> torch.Tensor:
|
def get_weights_for_timesteps(self, timesteps: torch.Tensor, v2=False) -> torch.Tensor:
|
||||||
# Get the indices of the timesteps
|
# Get the indices of the timesteps
|
||||||
step_indices = [(self.timesteps == t).nonzero().item() for t in timesteps]
|
step_indices = [(self.timesteps == t).nonzero().item()
|
||||||
|
for t in timesteps]
|
||||||
|
|
||||||
# Get the weights for the timesteps
|
# Get the weights for the timesteps
|
||||||
if v2:
|
if v2:
|
||||||
@@ -70,7 +72,8 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
|
|||||||
sigmas = self.sigmas.to(device=device, dtype=dtype)
|
sigmas = self.sigmas.to(device=device, dtype=dtype)
|
||||||
schedule_timesteps = self.timesteps.to(device)
|
schedule_timesteps = self.timesteps.to(device)
|
||||||
timesteps = timesteps.to(device)
|
timesteps = timesteps.to(device)
|
||||||
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
|
step_indices = [(schedule_timesteps == t).nonzero().item()
|
||||||
|
for t in timesteps]
|
||||||
|
|
||||||
sigma = sigmas[step_indices].flatten()
|
sigma = sigmas[step_indices].flatten()
|
||||||
while len(sigma.shape) < n_dim:
|
while len(sigma.shape) < n_dim:
|
||||||
@@ -84,27 +87,24 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
|
|||||||
noise: torch.Tensor,
|
noise: torch.Tensor,
|
||||||
timesteps: torch.Tensor,
|
timesteps: torch.Tensor,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
## ref https://github.com/huggingface/diffusers/blob/fbe29c62984c33c6cf9cf7ad120a992fe6d20854/examples/dreambooth/train_dreambooth_sd3.py#L1578
|
|
||||||
## Add noise according to flow matching.
|
|
||||||
## zt = (1 - texp) * x + texp * z1
|
|
||||||
|
|
||||||
# sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype)
|
|
||||||
# noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise
|
|
||||||
|
|
||||||
# timestep needs to be in [0, 1], we store them in [0, 1000]
|
|
||||||
# noisy_sample = (1 - timestep) * latent + timestep * noise
|
|
||||||
t_01 = (timesteps / 1000).to(original_samples.device)
|
t_01 = (timesteps / 1000).to(original_samples.device)
|
||||||
noisy_model_input = (1 - t_01) * original_samples + t_01 * noise
|
# forward ODE
|
||||||
|
noisy_model_input = (1.0 - t_01) * original_samples + t_01 * noise
|
||||||
# n_dim = original_samples.ndim
|
# reverse ODE
|
||||||
# sigmas = self.get_sigmas(timesteps, n_dim, original_samples.dtype, original_samples.device)
|
# noisy_model_input = (1 - t_01) * noise + t_01 * original_samples
|
||||||
# noisy_model_input = (1.0 - sigmas) * original_samples + sigmas * noise
|
|
||||||
return noisy_model_input
|
return noisy_model_input
|
||||||
|
|
||||||
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
|
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
|
||||||
return sample
|
return sample
|
||||||
|
|
||||||
def set_train_timesteps(self, num_timesteps, device, timestep_type='linear', latents=None):
|
def set_train_timesteps(
|
||||||
|
self,
|
||||||
|
num_timesteps,
|
||||||
|
device,
|
||||||
|
timestep_type='linear',
|
||||||
|
latents=None,
|
||||||
|
patch_size=1
|
||||||
|
):
|
||||||
self.timestep_type = timestep_type
|
self.timestep_type = timestep_type
|
||||||
if timestep_type == 'linear':
|
if timestep_type == 'linear':
|
||||||
timesteps = torch.linspace(1000, 0, num_timesteps, device=device)
|
timesteps = torch.linspace(1000, 0, num_timesteps, device=device)
|
||||||
@@ -124,42 +124,67 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
|
|||||||
self.timesteps = timesteps.to(device=device)
|
self.timesteps = timesteps.to(device=device)
|
||||||
|
|
||||||
return timesteps
|
return timesteps
|
||||||
elif timestep_type == 'flux_shift' or timestep_type == 'lumina2_shift':
|
elif timestep_type in ['flux_shift', 'lumina2_shift', 'shift']:
|
||||||
# matches inference dynamic shifting
|
# matches inference dynamic shifting
|
||||||
timesteps = np.linspace(
|
timesteps = np.linspace(
|
||||||
self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_timesteps
|
self._sigma_to_t(self.sigma_max), self._sigma_to_t(
|
||||||
|
self.sigma_min), num_timesteps
|
||||||
)
|
)
|
||||||
|
|
||||||
sigmas = timesteps / self.config.num_train_timesteps
|
sigmas = timesteps / self.config.num_train_timesteps
|
||||||
|
|
||||||
if latents is None:
|
|
||||||
raise ValueError('latents is None')
|
|
||||||
|
|
||||||
h = latents.shape[2] // 2 # Divide by ph
|
|
||||||
w = latents.shape[3] // 2 # Divide by pw
|
|
||||||
image_seq_len = h * w
|
|
||||||
|
|
||||||
# todo need to know the mu for the shift
|
if self.config.use_dynamic_shifting:
|
||||||
mu = calculate_shift(
|
if latents is None:
|
||||||
image_seq_len,
|
raise ValueError('latents is None')
|
||||||
self.config.get("base_image_seq_len", 256),
|
|
||||||
self.config.get("max_image_seq_len", 4096),
|
|
||||||
self.config.get("base_shift", 0.5),
|
|
||||||
self.config.get("max_shift", 1.16),
|
|
||||||
)
|
|
||||||
sigmas = self.time_shift(mu, 1.0, sigmas)
|
|
||||||
|
|
||||||
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
|
# for flux we double up the patch size before sending her to simulate the latent reduction
|
||||||
|
h = latents.shape[2]
|
||||||
|
w = latents.shape[3]
|
||||||
|
image_seq_len = h * w // (patch_size**2)
|
||||||
|
|
||||||
|
mu = calculate_shift(
|
||||||
|
image_seq_len,
|
||||||
|
self.config.get("base_image_seq_len", 256),
|
||||||
|
self.config.get("max_image_seq_len", 4096),
|
||||||
|
self.config.get("base_shift", 0.5),
|
||||||
|
self.config.get("max_shift", 1.16),
|
||||||
|
)
|
||||||
|
sigmas = self.time_shift(mu, 1.0, sigmas)
|
||||||
|
else:
|
||||||
|
sigmas = self.shift * sigmas / (1 + (self.shift - 1) * sigmas)
|
||||||
|
|
||||||
|
if self.config.shift_terminal:
|
||||||
|
sigmas = self.stretch_shift_to_terminal(sigmas)
|
||||||
|
|
||||||
|
if self.config.use_karras_sigmas:
|
||||||
|
sigmas = self._convert_to_karras(
|
||||||
|
in_sigmas=sigmas, num_inference_steps=self.config.num_train_timesteps)
|
||||||
|
elif self.config.use_exponential_sigmas:
|
||||||
|
sigmas = self._convert_to_exponential(
|
||||||
|
in_sigmas=sigmas, num_inference_steps=self.config.num_train_timesteps)
|
||||||
|
elif self.config.use_beta_sigmas:
|
||||||
|
sigmas = self._convert_to_beta(
|
||||||
|
in_sigmas=sigmas, num_inference_steps=self.config.num_train_timesteps)
|
||||||
|
|
||||||
|
sigmas = torch.from_numpy(sigmas).to(
|
||||||
|
dtype=torch.float32, device=device)
|
||||||
timesteps = sigmas * self.config.num_train_timesteps
|
timesteps = sigmas * self.config.num_train_timesteps
|
||||||
|
|
||||||
sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
|
if self.config.invert_sigmas:
|
||||||
|
sigmas = 1.0 - sigmas
|
||||||
|
timesteps = sigmas * self.config.num_train_timesteps
|
||||||
|
sigmas = torch.cat(
|
||||||
|
[sigmas, torch.ones(1, device=sigmas.device)])
|
||||||
|
else:
|
||||||
|
sigmas = torch.cat(
|
||||||
|
[sigmas, torch.zeros(1, device=sigmas.device)])
|
||||||
|
|
||||||
self.timesteps = timesteps.to(device=device)
|
self.timesteps = timesteps.to(device=device)
|
||||||
self.sigmas = sigmas
|
self.sigmas = sigmas
|
||||||
|
|
||||||
self.timesteps = timesteps.to(device=device)
|
self.timesteps = timesteps.to(device=device)
|
||||||
return timesteps
|
return timesteps
|
||||||
|
|
||||||
elif timestep_type == 'lognorm_blend':
|
elif timestep_type == 'lognorm_blend':
|
||||||
# disgtribute timestepd to the center/early and blend in linear
|
# disgtribute timestepd to the center/early and blend in linear
|
||||||
alpha = 0.75
|
alpha = 0.75
|
||||||
@@ -173,7 +198,8 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
|
|||||||
t1 = ((1 - t1/t1.max()) * 1000)
|
t1 = ((1 - t1/t1.max()) * 1000)
|
||||||
|
|
||||||
# add half of linear
|
# add half of linear
|
||||||
t2 = torch.linspace(1000, 0, int(num_timesteps * (1 - alpha)), device=device)
|
t2 = torch.linspace(1000, 0, int(
|
||||||
|
num_timesteps * (1 - alpha)), device=device)
|
||||||
timesteps = torch.cat((t1, t2))
|
timesteps = torch.cat((t1, t2))
|
||||||
|
|
||||||
# Sort the timesteps in descending order
|
# Sort the timesteps in descending order
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ from toolkit.ip_adapter import IPAdapter
|
|||||||
from library.model_util import convert_unet_state_dict_to_sd, convert_text_encoder_state_dict_to_sd_v2, \
|
from library.model_util import convert_unet_state_dict_to_sd, convert_text_encoder_state_dict_to_sd_v2, \
|
||||||
convert_vae_state_dict, load_vae
|
convert_vae_state_dict, load_vae
|
||||||
from toolkit import train_tools
|
from toolkit import train_tools
|
||||||
from toolkit.config_modules import ModelConfig, GenerateImageConfig
|
from toolkit.config_modules import ModelConfig, GenerateImageConfig, ModelArch
|
||||||
from toolkit.metadata import get_meta_for_safetensors
|
from toolkit.metadata import get_meta_for_safetensors
|
||||||
from toolkit.models.decorator import Decorator
|
from toolkit.models.decorator import Decorator
|
||||||
from toolkit.paths import REPOS_ROOT, KEYMAPS_ROOT
|
from toolkit.paths import REPOS_ROOT, KEYMAPS_ROOT
|
||||||
@@ -64,7 +64,8 @@ from toolkit.paths import ORIG_CONFIGS_ROOT, DIFFUSERS_CONFIGS_ROOT
|
|||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
from toolkit.models.flux import add_model_gpu_splitter_to_flux, bypass_flux_guidance, restore_flux_guidance
|
from toolkit.models.flux import add_model_gpu_splitter_to_flux, bypass_flux_guidance, restore_flux_guidance
|
||||||
|
|
||||||
from optimum.quanto import freeze, qfloat8, quantize, QTensor, qint4
|
from optimum.quanto import freeze, qfloat8, QTensor, qint4
|
||||||
|
from toolkit.util.quantize import quantize
|
||||||
from toolkit.accelerator import get_accelerator, unwrap_model
|
from toolkit.accelerator import get_accelerator, unwrap_model
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
from toolkit.print import print_acc
|
from toolkit.print import print_acc
|
||||||
@@ -160,7 +161,6 @@ class StableDiffusion:
|
|||||||
self.pipeline: Union[None, 'StableDiffusionPipeline', 'CustomStableDiffusionXLPipeline', 'PixArtAlphaPipeline']
|
self.pipeline: Union[None, 'StableDiffusionPipeline', 'CustomStableDiffusionXLPipeline', 'PixArtAlphaPipeline']
|
||||||
self.vae: Union[None, 'AutoencoderKL']
|
self.vae: Union[None, 'AutoencoderKL']
|
||||||
self.unet: Union[None, 'UNet2DConditionModel']
|
self.unet: Union[None, 'UNet2DConditionModel']
|
||||||
self.unet_unwrapped: Union[None, 'UNet2DConditionModel']
|
|
||||||
self.text_encoder: Union[None, 'CLIPTextModel', List[Union['CLIPTextModel', 'CLIPTextModelWithProjection']]]
|
self.text_encoder: Union[None, 'CLIPTextModel', List[Union['CLIPTextModel', 'CLIPTextModelWithProjection']]]
|
||||||
self.tokenizer: Union[None, 'CLIPTokenizer', List['CLIPTokenizer']]
|
self.tokenizer: Union[None, 'CLIPTokenizer', List['CLIPTokenizer']]
|
||||||
self.noise_scheduler: Union[None, 'DDPMScheduler'] = noise_scheduler
|
self.noise_scheduler: Union[None, 'DDPMScheduler'] = noise_scheduler
|
||||||
@@ -177,16 +177,17 @@ class StableDiffusion:
|
|||||||
self.network = None
|
self.network = None
|
||||||
self.adapter: Union['ControlNetModel', 'T2IAdapter', 'IPAdapter', 'ReferenceAdapter', None] = None
|
self.adapter: Union['ControlNetModel', 'T2IAdapter', 'IPAdapter', 'ReferenceAdapter', None] = None
|
||||||
self.decorator: Union[Decorator, None] = None
|
self.decorator: Union[Decorator, None] = None
|
||||||
self.is_xl = model_config.is_xl
|
self.arch: ModelArch = model_config.arch
|
||||||
self.is_v2 = model_config.is_v2
|
# self.is_xl = model_config.is_xl
|
||||||
self.is_ssd = model_config.is_ssd
|
# self.is_v2 = model_config.is_v2
|
||||||
self.is_v3 = model_config.is_v3
|
# self.is_ssd = model_config.is_ssd
|
||||||
self.is_vega = model_config.is_vega
|
# self.is_v3 = model_config.is_v3
|
||||||
self.is_pixart = model_config.is_pixart
|
# self.is_vega = model_config.is_vega
|
||||||
self.is_auraflow = model_config.is_auraflow
|
# self.is_pixart = model_config.is_pixart
|
||||||
self.is_flux = model_config.is_flux
|
# self.is_auraflow = model_config.is_auraflow
|
||||||
self.is_flex2 = model_config.is_flex2
|
# self.is_flux = model_config.is_flux
|
||||||
self.is_lumina2 = model_config.is_lumina2
|
# self.is_flex2 = model_config.is_flex2
|
||||||
|
# self.is_lumina2 = model_config.is_lumina2
|
||||||
|
|
||||||
self.use_text_encoder_1 = model_config.use_text_encoder_1
|
self.use_text_encoder_1 = model_config.use_text_encoder_1
|
||||||
self.use_text_encoder_2 = model_config.use_text_encoder_2
|
self.use_text_encoder_2 = model_config.use_text_encoder_2
|
||||||
@@ -204,6 +205,53 @@ class StableDiffusion:
|
|||||||
self.invert_assistant_lora = False
|
self.invert_assistant_lora = False
|
||||||
self._after_sample_img_hooks = []
|
self._after_sample_img_hooks = []
|
||||||
self._status_update_hooks = []
|
self._status_update_hooks = []
|
||||||
|
# todo update this based on the model
|
||||||
|
self.is_transformer = False
|
||||||
|
|
||||||
|
# properties for old arch for backwards compatibility
|
||||||
|
@property
|
||||||
|
def is_xl(self):
|
||||||
|
return self.arch == 'sdxl'
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_v2(self):
|
||||||
|
return self.arch == 'sd2'
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_ssd(self):
|
||||||
|
return self.arch == 'ssd'
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_v3(self):
|
||||||
|
return self.arch == 'sd3'
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_vega(self):
|
||||||
|
return self.arch == 'vega'
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_pixart(self):
|
||||||
|
return self.arch == 'pixart'
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_auraflow(self):
|
||||||
|
return self.arch == 'auraflow'
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_flux(self):
|
||||||
|
return self.arch == 'flux'
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_flex2(self):
|
||||||
|
return self.arch == 'flex2'
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_lumina2(self):
|
||||||
|
return self.arch == 'lumina2'
|
||||||
|
|
||||||
|
@property
|
||||||
|
def unet_unwrapped(self):
|
||||||
|
return unwrap_model(self.unet)
|
||||||
|
|
||||||
def load_model(self):
|
def load_model(self):
|
||||||
if self.is_loaded:
|
if self.is_loaded:
|
||||||
@@ -935,7 +983,6 @@ class StableDiffusion:
|
|||||||
if self.is_pixart or self.is_v3 or self.is_auraflow or self.is_flux or self.is_lumina2:
|
if self.is_pixart or self.is_v3 or self.is_auraflow or self.is_flux or self.is_lumina2:
|
||||||
# pixart and sd3 dont use a unet
|
# pixart and sd3 dont use a unet
|
||||||
self.unet = pipe.transformer
|
self.unet = pipe.transformer
|
||||||
self.unet_unwrapped = pipe.transformer
|
|
||||||
else:
|
else:
|
||||||
self.unet: 'UNet2DConditionModel' = pipe.unet
|
self.unet: 'UNet2DConditionModel' = pipe.unet
|
||||||
self.vae: 'AutoencoderKL' = pipe.vae.to(self.vae_device_torch, dtype=self.vae_torch_dtype)
|
self.vae: 'AutoencoderKL' = pipe.vae.to(self.vae_device_torch, dtype=self.vae_torch_dtype)
|
||||||
@@ -1734,7 +1781,8 @@ class StableDiffusion:
|
|||||||
self,
|
self,
|
||||||
original_samples: torch.FloatTensor,
|
original_samples: torch.FloatTensor,
|
||||||
noise: torch.FloatTensor,
|
noise: torch.FloatTensor,
|
||||||
timesteps: torch.IntTensor
|
timesteps: torch.IntTensor,
|
||||||
|
**kwargs,
|
||||||
) -> torch.FloatTensor:
|
) -> torch.FloatTensor:
|
||||||
original_samples_chunks = torch.chunk(original_samples, original_samples.shape[0], dim=0)
|
original_samples_chunks = torch.chunk(original_samples, original_samples.shape[0], dim=0)
|
||||||
noise_chunks = torch.chunk(noise, noise.shape[0], dim=0)
|
noise_chunks = torch.chunk(noise, noise.shape[0], dim=0)
|
||||||
|
|||||||
12
toolkit/util/get_model.py
Normal file
12
toolkit/util/get_model.py
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
from toolkit.stable_diffusion_model import StableDiffusion
|
||||||
|
from toolkit.config_modules import ModelConfig
|
||||||
|
|
||||||
|
def get_model_class(config: ModelConfig):
|
||||||
|
if config.arch == "wan21":
|
||||||
|
from toolkit.models.wan21 import Wan21
|
||||||
|
return Wan21
|
||||||
|
elif config.arch == "cogview4":
|
||||||
|
from toolkit.models.cogview4 import CogView4
|
||||||
|
return CogView4
|
||||||
|
else:
|
||||||
|
return StableDiffusion
|
||||||
55
toolkit/util/quantize.py
Normal file
55
toolkit/util/quantize.py
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
from fnmatch import fnmatch
|
||||||
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from optimum.quanto.quantize import _quantize_submodule
|
||||||
|
from optimum.quanto.tensor import Optimizer, qtype
|
||||||
|
|
||||||
|
# the quantize function in quanto had a bug where it was using exclude instead of include
|
||||||
|
|
||||||
|
|
||||||
|
def quantize(
|
||||||
|
model: torch.nn.Module,
|
||||||
|
weights: Optional[Union[str, qtype]] = None,
|
||||||
|
activations: Optional[Union[str, qtype]] = None,
|
||||||
|
optimizer: Optional[Optimizer] = None,
|
||||||
|
include: Optional[Union[str, List[str]]] = None,
|
||||||
|
exclude: Optional[Union[str, List[str]]] = None,
|
||||||
|
):
|
||||||
|
"""Quantize the specified model submodules
|
||||||
|
|
||||||
|
Recursively quantize the submodules of the specified parent model.
|
||||||
|
|
||||||
|
Only modules that have quantized counterparts will be quantized.
|
||||||
|
|
||||||
|
If include patterns are specified, the submodule name must match one of them.
|
||||||
|
|
||||||
|
If exclude patterns are specified, the submodule must not match one of them.
|
||||||
|
|
||||||
|
Include or exclude patterns are Unix shell-style wildcards which are NOT regular expressions. See
|
||||||
|
https://docs.python.org/3/library/fnmatch.html for more details.
|
||||||
|
|
||||||
|
Note: quantization happens in-place and modifies the original model and its descendants.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (`torch.nn.Module`): the model whose submodules will be quantized.
|
||||||
|
weights (`Optional[Union[str, qtype]]`): the qtype for weights quantization.
|
||||||
|
activations (`Optional[Union[str, qtype]]`): the qtype for activations quantization.
|
||||||
|
include (`Optional[Union[str, List[str]]]`):
|
||||||
|
Patterns constituting the allowlist. If provided, module names must match at
|
||||||
|
least one pattern from the allowlist.
|
||||||
|
exclude (`Optional[Union[str, List[str]]]`):
|
||||||
|
Patterns constituting the denylist. If provided, module names must not match
|
||||||
|
any patterns from the denylist.
|
||||||
|
"""
|
||||||
|
if include is not None:
|
||||||
|
include = [include] if isinstance(include, str) else include
|
||||||
|
if exclude is not None:
|
||||||
|
exclude = [exclude] if isinstance(exclude, str) else exclude
|
||||||
|
for name, m in model.named_modules():
|
||||||
|
if include is not None and not any(fnmatch(name, pattern) for pattern in include):
|
||||||
|
continue
|
||||||
|
if exclude is not None and any(fnmatch(name, pattern) for pattern in exclude):
|
||||||
|
continue
|
||||||
|
_quantize_submodule(model, name, m, weights=weights,
|
||||||
|
activations=activations, optimizer=optimizer)
|
||||||
Reference in New Issue
Block a user