From 77b10d884d1c2ee0de79335ba817df8c40e21884 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Tue, 12 Aug 2025 08:21:36 -0600 Subject: [PATCH] Add support for training with an accuracy recovery adapter with qwen image --- .../diffusion_models/qwen_image/qwen_image.py | 23 +- jobs/process/BaseSDTrainProcess.py | 8 + toolkit/config_modules.py | 10 + toolkit/models/base_model.py | 4 +- toolkit/util/quantize.py | 221 +++++++++++++++++- ui/src/app/jobs/new/SimpleJob.tsx | 55 ++++- ui/src/app/jobs/new/options.ts | 5 + version.py | 2 +- 8 files changed, 292 insertions(+), 36 deletions(-) diff --git a/extensions_built_in/diffusion_models/qwen_image/qwen_image.py b/extensions_built_in/diffusion_models/qwen_image/qwen_image.py index b257930a..1dd3da9d 100644 --- a/extensions_built_in/diffusion_models/qwen_image/qwen_image.py +++ b/extensions_built_in/diffusion_models/qwen_image/qwen_image.py @@ -10,10 +10,9 @@ from toolkit.models.base_model import BaseModel from toolkit.basic import flush from toolkit.prompt_utils import PromptEmbeds from toolkit.samplers.custom_flowmatch_sampler import CustomFlowMatchEulerDiscreteScheduler -from toolkit.dequantize import patch_dequantization_on_save from toolkit.accelerator import get_accelerator, unwrap_model from optimum.quanto import freeze, QTensor -from toolkit.util.quantize import quantize, get_qtype +from toolkit.util.quantize import quantize, get_qtype, quantize_model import torch.nn.functional as F from diffusers import QwenImagePipeline, QwenImageTransformer2DModel, AutoencoderKLQwenImage @@ -99,23 +98,9 @@ class QwenImageModel(BaseModel): ) if self.model_config.quantize: - # patch the state dict method - patch_dequantization_on_save(transformer) - # move and quantize only certain pieces at a time. - quantization_type = get_qtype(self.model_config.qtype) - all_blocks = list(transformer.transformer_blocks) - self.print_and_status_update(" - quantizing transformer blocks") - for block in tqdm(all_blocks): - block.to(self.device_torch, dtype=dtype) - quantize(block, weights=quantization_type) - freeze(block) - block.to('cpu') - # flush() - - self.print_and_status_update(" - quantizing extras") - transformer.to(self.device_torch, dtype=dtype) - quantize(transformer, weights=quantization_type) - freeze(transformer) + self.print_and_status_update("Quantizing Transformer") + quantize_model(self, transformer) + flush() if self.model_config.low_vram: self.print_and_status_update("Moving transformer to CPU") diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 156ed766..e5dd8c8c 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -1533,6 +1533,14 @@ class BaseSDTrainProcess(BaseTrainProcess): # run base sd process run self.sd.load_model() + # compile the model if needed + if self.model_config.compile: + try: + torch.compile(self.sd.unet, dynamic=True, fullgraph=True, mode='max-autotune') + except Exception as e: + print_acc(f"Failed to compile model: {e}") + print_acc("Continuing without compilation") + self.sd.add_after_sample_image_hook(self.sample_step_hook) dtype = get_torch_dtype(self.train_config.dtype) diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index e5a5633d..8d9e734b 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -601,6 +601,16 @@ class ModelConfig: # 20 different model variants self.extras_name_or_path = kwargs.get("extras_name_or_path", self.name_or_path) + # path to an accuracy recovery adapter, either local or remote + self.accuracy_recovery_adapter = kwargs.get("accuracy_recovery_adapter", None) + + # parse ARA from qtype + if self.qtype is not None and "|" in self.qtype: + self.qtype, self.accuracy_recovery_adapter = self.qtype.split('|') + + # compile the model with torch compile + self.compile = kwargs.get("compile", False) + # kwargs to pass to the model self.model_kwargs = kwargs.get("model_kwargs", {}) diff --git a/toolkit/models/base_model.py b/toolkit/models/base_model.py index fc0cb764..412e04d4 100644 --- a/toolkit/models/base_model.py +++ b/toolkit/models/base_model.py @@ -168,8 +168,10 @@ class BaseModel: self._after_sample_img_hooks = [] self._status_update_hooks = [] self.is_transformer = False - + self.sample_prompts_cache = None + + self.accuracy_recovery_adapter: Union[None, 'LoRASpecialNetwork'] = None # properties for old arch for backwards compatibility @property diff --git a/toolkit/util/quantize.py b/toolkit/util/quantize.py index b7bded55..641c8ae9 100644 --- a/toolkit/util/quantize.py +++ b/toolkit/util/quantize.py @@ -1,19 +1,36 @@ from fnmatch import fnmatch -from typing import Any, Dict, List, Optional, Union +from typing import List, Optional, Union, TYPE_CHECKING import torch -from dataclasses import dataclass from optimum.quanto.quantize import _quantize_submodule from optimum.quanto.tensor import Optimizer, qtype, qtypes from torchao.quantization.quant_api import ( quantize_ as torchao_quantize_, Float8WeightOnlyConfig, - UIntXWeightOnlyConfig + UIntXWeightOnlyConfig, ) +from optimum.quanto import freeze +from tqdm import tqdm +from safetensors.torch import load_file +from huggingface_hub import hf_hub_download + +from toolkit.print import print_acc +import os + +if TYPE_CHECKING: + from toolkit.models.base_model import BaseModel # the quantize function in quanto had a bug where it was using exclude instead of include -Q_MODULES = ['QLinear', 'QConv2d', 'QEmbedding', 'QBatchNorm2d', 'QLayerNorm', 'QConvTranspose2d', 'QEmbeddingBag'] +Q_MODULES = [ + "QLinear", + "QConv2d", + "QEmbedding", + "QBatchNorm2d", + "QLayerNorm", + "QConvTranspose2d", + "QEmbeddingBag", +] torchao_qtypes = { # "int4": Int4WeightOnlyConfig(), @@ -27,11 +44,13 @@ torchao_qtypes = { "float8": Float8WeightOnlyConfig(), } + class aotype: def __init__(self, name: str): self.name = name self.config = torchao_qtypes[name] + def get_qtype(qtype: Union[str, qtype]) -> qtype: if qtype in torchao_qtypes: return aotype(qtype) @@ -40,6 +59,7 @@ def get_qtype(qtype: Union[str, qtype]) -> qtype: else: return qtype + def quantize( model: torch.nn.Module, weights: Optional[Union[str, qtype, aotype]] = None, @@ -79,7 +99,9 @@ def quantize( if exclude is not None: exclude = [exclude] if isinstance(exclude, str) else exclude for name, m in model.named_modules(): - if include is not None and not any(fnmatch(name, pattern) for pattern in include): + if include is not None and not any( + fnmatch(name, pattern) for pattern in include + ): continue if exclude is not None and any(fnmatch(name, pattern) for pattern in exclude): continue @@ -91,8 +113,191 @@ def quantize( if isinstance(weights, aotype): torchao_quantize_(m, weights.config) else: - _quantize_submodule(model, name, m, weights=weights, - activations=activations, optimizer=optimizer) + _quantize_submodule( + model, + name, + m, + weights=weights, + activations=activations, + optimizer=optimizer, + ) except Exception as e: print(f"Failed to quantize {name}: {e}") - # raise e \ No newline at end of file + # raise e + + +def quantize_model( + base_model: "BaseModel", + model_to_quantize: torch.nn.Module, +): + from toolkit.dequantize import patch_dequantization_on_save + + if not hasattr(base_model, "get_transformer_block_names"): + raise ValueError( + "The model to quantize must have a method `get_transformer_block_names`." + ) + + # patch the state dict method + patch_dequantization_on_save(model_to_quantize) + + if base_model.model_config.accuracy_recovery_adapter is not None: + from toolkit.config_modules import NetworkConfig + from toolkit.lora_special import LoRASpecialNetwork + + # we need to load and quantize with an accuracy recovery adapter + # todo handle hf repos + load_lora_path = base_model.model_config.accuracy_recovery_adapter + + if not os.path.exists(load_lora_path): + # not local file, grab from the hub + + path_split = load_lora_path.split("/") + if len(path_split) > 3: + raise ValueError( + "The accuracy recovery adapter path must be a local path or for a hf repo, 'username/repo_name/filename.safetensors'." + ) + repo_id = f"{path_split[0]}/{path_split[1]}" + print_acc(f"Grabbing lora from the hub: {load_lora_path}") + new_lora_path = hf_hub_download( + repo_id, + filename=path_split[-1], + ) + # replace the path + load_lora_path = new_lora_path + + # build the lora config based on the lora weights + lora_state_dict = load_file(load_lora_path) + + if hasattr(base_model, "convert_lora_weights_before_load"): + lora_state_dict = base_model.convert_lora_weights_before_load(lora_state_dict) + + network_config = { + "type": "lora", + "network_kwargs": {"only_if_contains": []}, + "transformer_only": False, + } + first_key = list(lora_state_dict.keys())[0] + first_weight = lora_state_dict[first_key] + # if it starts with lycoris and includes lokr + if first_key.startswith("lycoris") and any( + "lokr" in key for key in lora_state_dict.keys() + ): + network_config["type"] = "lokr" + + network_kwargs = {} + + # find firse loraA weight + if network_config["type"] == "lora": + linear_dim = None + for key, value in lora_state_dict.items(): + if "lora_A" in key: + linear_dim = int(value.shape[0]) + break + linear_alpha = linear_dim + network_config["linear"] = linear_dim + network_config["linear_alpha"] = linear_alpha + + # we build the keys to match every key + only_if_contains = [] + for key in lora_state_dict.keys(): + contains_key = key.split(".lora_")[0] + if contains_key not in only_if_contains: + only_if_contains.append(contains_key) + + network_kwargs["only_if_contains"] = only_if_contains + elif network_config["type"] == "lokr": + # find the factor + largest_factor = 0 + for key, value in lora_state_dict.items(): + if "lokr_w1" in key: + factor = int(value.shape[0]) + if factor > largest_factor: + largest_factor = factor + network_config["lokr_full_rank"] = True + network_config["lokr_factor"] = largest_factor + + only_if_contains = [] + for key in lora_state_dict.keys(): + if "lokr_w1" in key: + contains_key = key.split(".lokr_w1")[0] + contains_key = contains_key.replace("lycoris_", "") + if contains_key not in only_if_contains: + only_if_contains.append(contains_key) + network_kwargs["only_if_contains"] = only_if_contains + + if hasattr(base_model, 'target_lora_modules'): + network_kwargs['target_lin_modules'] = base_model.target_lora_modules + + # todo auto grab these + # get dim and scale + network_config = NetworkConfig(**network_config) + + network = LoRASpecialNetwork( + text_encoder=None, + unet=model_to_quantize, + lora_dim=network_config.linear, + multiplier=1.0, + alpha=network_config.linear_alpha, + # conv_lora_dim=self.network_config.conv, + # conv_alpha=self.network_config.conv_alpha, + train_unet=True, + train_text_encoder=False, + network_config=network_config, + network_type=network_config.type, + transformer_only=network_config.transformer_only, + is_transformer=base_model.is_transformer, + base_model=base_model, + **network_kwargs + ) + network.apply_to( + None, model_to_quantize, apply_text_encoder=False, apply_unet=True + ) + network.force_to(base_model.device_torch, dtype=base_model.torch_dtype) + network._update_torch_multiplier() + network.load_weights(lora_state_dict) + network.eval() + network.is_active = True + network.can_merge_in = False + base_model.accuracy_recovery_adapter = network + + # quantize it + quantization_type = get_qtype(base_model.model_config.qtype) + for lora_module in tqdm(network.unet_loras, desc="Attaching quantization"): + # the lora has already hijacked the original module + orig_module = lora_module.org_module[0] + orig_module.to(base_model.torch_dtype) + # make the params not require gradients + for param in orig_module.parameters(): + param.requires_grad = False + quantize(orig_module, weights=quantization_type) + freeze(orig_module) + if base_model.model_config.low_vram: + # move it back to cpu + orig_module.to("cpu") + + else: + # quantize model the original way without an accuracy recovery adapter + # move and quantize only certain pieces at a time. + quantization_type = get_qtype(base_model.model_config.qtype) + # all_blocks = list(model_to_quantize.transformer_blocks) + all_blocks: List[torch.nn.Module] = [] + transformer_block_names = base_model.get_transformer_block_names() + for name in transformer_block_names: + block = getattr(model_to_quantize, name, None) + if block is not None: + all_blocks.append(block) + base_model.print_and_status_update( + f" - quantizing {len(all_blocks)} transformer blocks" + ) + for block in tqdm(all_blocks): + block.to(base_model.device_torch, dtype=base_model.torch_dtype) + quantize(block, weights=quantization_type) + freeze(block) + block.to("cpu") + + # todo, on extras find a universal way to quantize them on device and move them back to their original + # device without having to move the transformer blocks to the device first + base_model.print_and_status_update(" - quantizing extras") + model_to_quantize.to(base_model.device_torch, dtype=base_model.torch_dtype) + quantize(model_to_quantize, weights=quantization_type) + freeze(model_to_quantize) diff --git a/ui/src/app/jobs/new/SimpleJob.tsx b/ui/src/app/jobs/new/SimpleJob.tsx index 7ba08e68..b2e31294 100644 --- a/ui/src/app/jobs/new/SimpleJob.tsx +++ b/ui/src/app/jobs/new/SimpleJob.tsx @@ -2,7 +2,7 @@ import { useMemo } from 'react'; import { modelArchs, ModelArch, groupedModelOptions, quantizationOptions, defaultQtype } from './options'; import { defaultDatasetConfig } from './jobConfig'; -import { JobConfig } from '@/types'; +import { GroupedSelectOption, JobConfig, SelectOption } from '@/types'; import { objectCopy } from '@/utils/basic'; import { TextInput, SelectInput, Checkbox, FormGroup, NumberInput } from '@/components/formInputs'; import Card from '@/components/Card'; @@ -46,6 +46,47 @@ export default function SimpleJob({ topBarClass = 'grid grid-cols-1 md:grid-cols-2 lg:grid-cols-4 xl:grid-cols-4 gap-6'; } + const transformerQuantizationOptions: GroupedSelectOption[] | SelectOption[] = useMemo(() => { + const hasARA = modelArch?.accuracyRecoveryAdapters && Object.keys(modelArch.accuracyRecoveryAdapters).length > 0; + if (!hasARA) { + return quantizationOptions; + } + let newQuantizationOptions = [ + { + label: 'Standard', + options: [quantizationOptions[0], quantizationOptions[1]], + }, + ]; + + // add ARAs if they exist for the model + let ARAs: SelectOption[] = []; + if (modelArch.accuracyRecoveryAdapters) { + for (const [label, value] of Object.entries(modelArch.accuracyRecoveryAdapters)) { + ARAs.push({ value, label }); + } + } + if (ARAs.length > 0) { + newQuantizationOptions.push({ + label: 'Accuracy Recovery Adapters', + options: ARAs, + }); + } + + let additionalQuantizationOptions: SelectOption[] = []; + // add the quantization options if they are not already included + for (let i = 2; i < quantizationOptions.length; i++) { + const option = quantizationOptions[i]; + additionalQuantizationOptions.push(option); + } + if (additionalQuantizationOptions.length > 0) { + newQuantizationOptions.push({ + label: 'Additional Quantization Options', + options: additionalQuantizationOptions, + }); + } + return newQuantizationOptions; + }, [modelArch]); + return ( <>
@@ -180,7 +221,7 @@ export default function SimpleJob({ } setJobConfig(value, 'config.process[0].model.qtype'); }} - options={quantizationOptions} + options={transformerQuantizationOptions} /> { - setJobConfig(value, 'config.process[0].train.unload_text_encoder') + onChange={value => { + setJobConfig(value, 'config.process[0].train.unload_text_encoder'); if (value) { setJobConfig(false, 'config.process[0].train.cache_text_embeddings'); } @@ -416,10 +457,10 @@ export default function SimpleJob({ label="Cache Text Embeddings" checked={jobConfig.config.process[0].train.cache_text_embeddings || false} docKey={'train.cache_text_embeddings'} - onChange={(value) => { - setJobConfig(value, 'config.process[0].train.cache_text_embeddings') + onChange={value => { + setJobConfig(value, 'config.process[0].train.cache_text_embeddings'); if (value) { - setJobConfig(false, 'config.process[0].train.unload_text_encoder') + setJobConfig(false, 'config.process[0].train.unload_text_encoder'); } }} /> diff --git a/ui/src/app/jobs/new/options.ts b/ui/src/app/jobs/new/options.ts index b2af8b86..c08a649a 100644 --- a/ui/src/app/jobs/new/options.ts +++ b/ui/src/app/jobs/new/options.ts @@ -15,6 +15,7 @@ export interface ModelArch { defaults?: { [key: string]: any }; disableSections?: DisableableSections[]; additionalSections?: AdditionalSections[]; + accuracyRecoveryAdapters?: { [key: string]: string }; } const defaultNameOrPath = ''; @@ -230,9 +231,13 @@ export const modelArchs: ModelArch[] = [ 'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'], 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'], 'config.process[0].train.timestep_type': ['weighted', 'sigmoid'], + 'config.process[0].model.qtype': ['qfloat8', 'qfloat8'], }, disableSections: ['network.conv'], additionalSections: ['model.low_vram'], + accuracyRecoveryAdapters: { + '3 bit with ARA': 'uint3|ostris/accuracy_recovery_adapters/qwen_image_torchao_uint3.safetensors', + }, }, { name: 'hidream', diff --git a/version.py b/version.py index 0e6b1dd4..eab0b98c 100644 --- a/version.py +++ b/version.py @@ -1 +1 @@ -VERSION = "0.4.0" \ No newline at end of file +VERSION = "0.4.1" \ No newline at end of file