mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-03-14 06:57:35 +00:00
Add support for training with an accuracy recovery adapter with qwen image
This commit is contained in:
@@ -10,10 +10,9 @@ from toolkit.models.base_model import BaseModel
|
||||
from toolkit.basic import flush
|
||||
from toolkit.prompt_utils import PromptEmbeds
|
||||
from toolkit.samplers.custom_flowmatch_sampler import CustomFlowMatchEulerDiscreteScheduler
|
||||
from toolkit.dequantize import patch_dequantization_on_save
|
||||
from toolkit.accelerator import get_accelerator, unwrap_model
|
||||
from optimum.quanto import freeze, QTensor
|
||||
from toolkit.util.quantize import quantize, get_qtype
|
||||
from toolkit.util.quantize import quantize, get_qtype, quantize_model
|
||||
import torch.nn.functional as F
|
||||
|
||||
from diffusers import QwenImagePipeline, QwenImageTransformer2DModel, AutoencoderKLQwenImage
|
||||
@@ -99,23 +98,9 @@ class QwenImageModel(BaseModel):
|
||||
)
|
||||
|
||||
if self.model_config.quantize:
|
||||
# patch the state dict method
|
||||
patch_dequantization_on_save(transformer)
|
||||
# move and quantize only certain pieces at a time.
|
||||
quantization_type = get_qtype(self.model_config.qtype)
|
||||
all_blocks = list(transformer.transformer_blocks)
|
||||
self.print_and_status_update(" - quantizing transformer blocks")
|
||||
for block in tqdm(all_blocks):
|
||||
block.to(self.device_torch, dtype=dtype)
|
||||
quantize(block, weights=quantization_type)
|
||||
freeze(block)
|
||||
block.to('cpu')
|
||||
# flush()
|
||||
|
||||
self.print_and_status_update(" - quantizing extras")
|
||||
transformer.to(self.device_torch, dtype=dtype)
|
||||
quantize(transformer, weights=quantization_type)
|
||||
freeze(transformer)
|
||||
self.print_and_status_update("Quantizing Transformer")
|
||||
quantize_model(self, transformer)
|
||||
flush()
|
||||
|
||||
if self.model_config.low_vram:
|
||||
self.print_and_status_update("Moving transformer to CPU")
|
||||
|
||||
@@ -1533,6 +1533,14 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
# run base sd process run
|
||||
self.sd.load_model()
|
||||
|
||||
# compile the model if needed
|
||||
if self.model_config.compile:
|
||||
try:
|
||||
torch.compile(self.sd.unet, dynamic=True, fullgraph=True, mode='max-autotune')
|
||||
except Exception as e:
|
||||
print_acc(f"Failed to compile model: {e}")
|
||||
print_acc("Continuing without compilation")
|
||||
|
||||
self.sd.add_after_sample_image_hook(self.sample_step_hook)
|
||||
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
|
||||
@@ -601,6 +601,16 @@ class ModelConfig:
|
||||
# 20 different model variants
|
||||
self.extras_name_or_path = kwargs.get("extras_name_or_path", self.name_or_path)
|
||||
|
||||
# path to an accuracy recovery adapter, either local or remote
|
||||
self.accuracy_recovery_adapter = kwargs.get("accuracy_recovery_adapter", None)
|
||||
|
||||
# parse ARA from qtype
|
||||
if self.qtype is not None and "|" in self.qtype:
|
||||
self.qtype, self.accuracy_recovery_adapter = self.qtype.split('|')
|
||||
|
||||
# compile the model with torch compile
|
||||
self.compile = kwargs.get("compile", False)
|
||||
|
||||
# kwargs to pass to the model
|
||||
self.model_kwargs = kwargs.get("model_kwargs", {})
|
||||
|
||||
|
||||
@@ -168,8 +168,10 @@ class BaseModel:
|
||||
self._after_sample_img_hooks = []
|
||||
self._status_update_hooks = []
|
||||
self.is_transformer = False
|
||||
|
||||
|
||||
self.sample_prompts_cache = None
|
||||
|
||||
self.accuracy_recovery_adapter: Union[None, 'LoRASpecialNetwork'] = None
|
||||
|
||||
# properties for old arch for backwards compatibility
|
||||
@property
|
||||
|
||||
@@ -1,19 +1,36 @@
|
||||
from fnmatch import fnmatch
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import List, Optional, Union, TYPE_CHECKING
|
||||
import torch
|
||||
from dataclasses import dataclass
|
||||
|
||||
from optimum.quanto.quantize import _quantize_submodule
|
||||
from optimum.quanto.tensor import Optimizer, qtype, qtypes
|
||||
from torchao.quantization.quant_api import (
|
||||
quantize_ as torchao_quantize_,
|
||||
Float8WeightOnlyConfig,
|
||||
UIntXWeightOnlyConfig
|
||||
UIntXWeightOnlyConfig,
|
||||
)
|
||||
from optimum.quanto import freeze
|
||||
from tqdm import tqdm
|
||||
from safetensors.torch import load_file
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
from toolkit.print import print_acc
|
||||
import os
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from toolkit.models.base_model import BaseModel
|
||||
|
||||
# the quantize function in quanto had a bug where it was using exclude instead of include
|
||||
|
||||
Q_MODULES = ['QLinear', 'QConv2d', 'QEmbedding', 'QBatchNorm2d', 'QLayerNorm', 'QConvTranspose2d', 'QEmbeddingBag']
|
||||
Q_MODULES = [
|
||||
"QLinear",
|
||||
"QConv2d",
|
||||
"QEmbedding",
|
||||
"QBatchNorm2d",
|
||||
"QLayerNorm",
|
||||
"QConvTranspose2d",
|
||||
"QEmbeddingBag",
|
||||
]
|
||||
|
||||
torchao_qtypes = {
|
||||
# "int4": Int4WeightOnlyConfig(),
|
||||
@@ -27,11 +44,13 @@ torchao_qtypes = {
|
||||
"float8": Float8WeightOnlyConfig(),
|
||||
}
|
||||
|
||||
|
||||
class aotype:
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
self.config = torchao_qtypes[name]
|
||||
|
||||
|
||||
def get_qtype(qtype: Union[str, qtype]) -> qtype:
|
||||
if qtype in torchao_qtypes:
|
||||
return aotype(qtype)
|
||||
@@ -40,6 +59,7 @@ def get_qtype(qtype: Union[str, qtype]) -> qtype:
|
||||
else:
|
||||
return qtype
|
||||
|
||||
|
||||
def quantize(
|
||||
model: torch.nn.Module,
|
||||
weights: Optional[Union[str, qtype, aotype]] = None,
|
||||
@@ -79,7 +99,9 @@ def quantize(
|
||||
if exclude is not None:
|
||||
exclude = [exclude] if isinstance(exclude, str) else exclude
|
||||
for name, m in model.named_modules():
|
||||
if include is not None and not any(fnmatch(name, pattern) for pattern in include):
|
||||
if include is not None and not any(
|
||||
fnmatch(name, pattern) for pattern in include
|
||||
):
|
||||
continue
|
||||
if exclude is not None and any(fnmatch(name, pattern) for pattern in exclude):
|
||||
continue
|
||||
@@ -91,8 +113,191 @@ def quantize(
|
||||
if isinstance(weights, aotype):
|
||||
torchao_quantize_(m, weights.config)
|
||||
else:
|
||||
_quantize_submodule(model, name, m, weights=weights,
|
||||
activations=activations, optimizer=optimizer)
|
||||
_quantize_submodule(
|
||||
model,
|
||||
name,
|
||||
m,
|
||||
weights=weights,
|
||||
activations=activations,
|
||||
optimizer=optimizer,
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Failed to quantize {name}: {e}")
|
||||
# raise e
|
||||
# raise e
|
||||
|
||||
|
||||
def quantize_model(
|
||||
base_model: "BaseModel",
|
||||
model_to_quantize: torch.nn.Module,
|
||||
):
|
||||
from toolkit.dequantize import patch_dequantization_on_save
|
||||
|
||||
if not hasattr(base_model, "get_transformer_block_names"):
|
||||
raise ValueError(
|
||||
"The model to quantize must have a method `get_transformer_block_names`."
|
||||
)
|
||||
|
||||
# patch the state dict method
|
||||
patch_dequantization_on_save(model_to_quantize)
|
||||
|
||||
if base_model.model_config.accuracy_recovery_adapter is not None:
|
||||
from toolkit.config_modules import NetworkConfig
|
||||
from toolkit.lora_special import LoRASpecialNetwork
|
||||
|
||||
# we need to load and quantize with an accuracy recovery adapter
|
||||
# todo handle hf repos
|
||||
load_lora_path = base_model.model_config.accuracy_recovery_adapter
|
||||
|
||||
if not os.path.exists(load_lora_path):
|
||||
# not local file, grab from the hub
|
||||
|
||||
path_split = load_lora_path.split("/")
|
||||
if len(path_split) > 3:
|
||||
raise ValueError(
|
||||
"The accuracy recovery adapter path must be a local path or for a hf repo, 'username/repo_name/filename.safetensors'."
|
||||
)
|
||||
repo_id = f"{path_split[0]}/{path_split[1]}"
|
||||
print_acc(f"Grabbing lora from the hub: {load_lora_path}")
|
||||
new_lora_path = hf_hub_download(
|
||||
repo_id,
|
||||
filename=path_split[-1],
|
||||
)
|
||||
# replace the path
|
||||
load_lora_path = new_lora_path
|
||||
|
||||
# build the lora config based on the lora weights
|
||||
lora_state_dict = load_file(load_lora_path)
|
||||
|
||||
if hasattr(base_model, "convert_lora_weights_before_load"):
|
||||
lora_state_dict = base_model.convert_lora_weights_before_load(lora_state_dict)
|
||||
|
||||
network_config = {
|
||||
"type": "lora",
|
||||
"network_kwargs": {"only_if_contains": []},
|
||||
"transformer_only": False,
|
||||
}
|
||||
first_key = list(lora_state_dict.keys())[0]
|
||||
first_weight = lora_state_dict[first_key]
|
||||
# if it starts with lycoris and includes lokr
|
||||
if first_key.startswith("lycoris") and any(
|
||||
"lokr" in key for key in lora_state_dict.keys()
|
||||
):
|
||||
network_config["type"] = "lokr"
|
||||
|
||||
network_kwargs = {}
|
||||
|
||||
# find firse loraA weight
|
||||
if network_config["type"] == "lora":
|
||||
linear_dim = None
|
||||
for key, value in lora_state_dict.items():
|
||||
if "lora_A" in key:
|
||||
linear_dim = int(value.shape[0])
|
||||
break
|
||||
linear_alpha = linear_dim
|
||||
network_config["linear"] = linear_dim
|
||||
network_config["linear_alpha"] = linear_alpha
|
||||
|
||||
# we build the keys to match every key
|
||||
only_if_contains = []
|
||||
for key in lora_state_dict.keys():
|
||||
contains_key = key.split(".lora_")[0]
|
||||
if contains_key not in only_if_contains:
|
||||
only_if_contains.append(contains_key)
|
||||
|
||||
network_kwargs["only_if_contains"] = only_if_contains
|
||||
elif network_config["type"] == "lokr":
|
||||
# find the factor
|
||||
largest_factor = 0
|
||||
for key, value in lora_state_dict.items():
|
||||
if "lokr_w1" in key:
|
||||
factor = int(value.shape[0])
|
||||
if factor > largest_factor:
|
||||
largest_factor = factor
|
||||
network_config["lokr_full_rank"] = True
|
||||
network_config["lokr_factor"] = largest_factor
|
||||
|
||||
only_if_contains = []
|
||||
for key in lora_state_dict.keys():
|
||||
if "lokr_w1" in key:
|
||||
contains_key = key.split(".lokr_w1")[0]
|
||||
contains_key = contains_key.replace("lycoris_", "")
|
||||
if contains_key not in only_if_contains:
|
||||
only_if_contains.append(contains_key)
|
||||
network_kwargs["only_if_contains"] = only_if_contains
|
||||
|
||||
if hasattr(base_model, 'target_lora_modules'):
|
||||
network_kwargs['target_lin_modules'] = base_model.target_lora_modules
|
||||
|
||||
# todo auto grab these
|
||||
# get dim and scale
|
||||
network_config = NetworkConfig(**network_config)
|
||||
|
||||
network = LoRASpecialNetwork(
|
||||
text_encoder=None,
|
||||
unet=model_to_quantize,
|
||||
lora_dim=network_config.linear,
|
||||
multiplier=1.0,
|
||||
alpha=network_config.linear_alpha,
|
||||
# conv_lora_dim=self.network_config.conv,
|
||||
# conv_alpha=self.network_config.conv_alpha,
|
||||
train_unet=True,
|
||||
train_text_encoder=False,
|
||||
network_config=network_config,
|
||||
network_type=network_config.type,
|
||||
transformer_only=network_config.transformer_only,
|
||||
is_transformer=base_model.is_transformer,
|
||||
base_model=base_model,
|
||||
**network_kwargs
|
||||
)
|
||||
network.apply_to(
|
||||
None, model_to_quantize, apply_text_encoder=False, apply_unet=True
|
||||
)
|
||||
network.force_to(base_model.device_torch, dtype=base_model.torch_dtype)
|
||||
network._update_torch_multiplier()
|
||||
network.load_weights(lora_state_dict)
|
||||
network.eval()
|
||||
network.is_active = True
|
||||
network.can_merge_in = False
|
||||
base_model.accuracy_recovery_adapter = network
|
||||
|
||||
# quantize it
|
||||
quantization_type = get_qtype(base_model.model_config.qtype)
|
||||
for lora_module in tqdm(network.unet_loras, desc="Attaching quantization"):
|
||||
# the lora has already hijacked the original module
|
||||
orig_module = lora_module.org_module[0]
|
||||
orig_module.to(base_model.torch_dtype)
|
||||
# make the params not require gradients
|
||||
for param in orig_module.parameters():
|
||||
param.requires_grad = False
|
||||
quantize(orig_module, weights=quantization_type)
|
||||
freeze(orig_module)
|
||||
if base_model.model_config.low_vram:
|
||||
# move it back to cpu
|
||||
orig_module.to("cpu")
|
||||
|
||||
else:
|
||||
# quantize model the original way without an accuracy recovery adapter
|
||||
# move and quantize only certain pieces at a time.
|
||||
quantization_type = get_qtype(base_model.model_config.qtype)
|
||||
# all_blocks = list(model_to_quantize.transformer_blocks)
|
||||
all_blocks: List[torch.nn.Module] = []
|
||||
transformer_block_names = base_model.get_transformer_block_names()
|
||||
for name in transformer_block_names:
|
||||
block = getattr(model_to_quantize, name, None)
|
||||
if block is not None:
|
||||
all_blocks.append(block)
|
||||
base_model.print_and_status_update(
|
||||
f" - quantizing {len(all_blocks)} transformer blocks"
|
||||
)
|
||||
for block in tqdm(all_blocks):
|
||||
block.to(base_model.device_torch, dtype=base_model.torch_dtype)
|
||||
quantize(block, weights=quantization_type)
|
||||
freeze(block)
|
||||
block.to("cpu")
|
||||
|
||||
# todo, on extras find a universal way to quantize them on device and move them back to their original
|
||||
# device without having to move the transformer blocks to the device first
|
||||
base_model.print_and_status_update(" - quantizing extras")
|
||||
model_to_quantize.to(base_model.device_torch, dtype=base_model.torch_dtype)
|
||||
quantize(model_to_quantize, weights=quantization_type)
|
||||
freeze(model_to_quantize)
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
import { useMemo } from 'react';
|
||||
import { modelArchs, ModelArch, groupedModelOptions, quantizationOptions, defaultQtype } from './options';
|
||||
import { defaultDatasetConfig } from './jobConfig';
|
||||
import { JobConfig } from '@/types';
|
||||
import { GroupedSelectOption, JobConfig, SelectOption } from '@/types';
|
||||
import { objectCopy } from '@/utils/basic';
|
||||
import { TextInput, SelectInput, Checkbox, FormGroup, NumberInput } from '@/components/formInputs';
|
||||
import Card from '@/components/Card';
|
||||
@@ -46,6 +46,47 @@ export default function SimpleJob({
|
||||
topBarClass = 'grid grid-cols-1 md:grid-cols-2 lg:grid-cols-4 xl:grid-cols-4 gap-6';
|
||||
}
|
||||
|
||||
const transformerQuantizationOptions: GroupedSelectOption[] | SelectOption[] = useMemo(() => {
|
||||
const hasARA = modelArch?.accuracyRecoveryAdapters && Object.keys(modelArch.accuracyRecoveryAdapters).length > 0;
|
||||
if (!hasARA) {
|
||||
return quantizationOptions;
|
||||
}
|
||||
let newQuantizationOptions = [
|
||||
{
|
||||
label: 'Standard',
|
||||
options: [quantizationOptions[0], quantizationOptions[1]],
|
||||
},
|
||||
];
|
||||
|
||||
// add ARAs if they exist for the model
|
||||
let ARAs: SelectOption[] = [];
|
||||
if (modelArch.accuracyRecoveryAdapters) {
|
||||
for (const [label, value] of Object.entries(modelArch.accuracyRecoveryAdapters)) {
|
||||
ARAs.push({ value, label });
|
||||
}
|
||||
}
|
||||
if (ARAs.length > 0) {
|
||||
newQuantizationOptions.push({
|
||||
label: 'Accuracy Recovery Adapters',
|
||||
options: ARAs,
|
||||
});
|
||||
}
|
||||
|
||||
let additionalQuantizationOptions: SelectOption[] = [];
|
||||
// add the quantization options if they are not already included
|
||||
for (let i = 2; i < quantizationOptions.length; i++) {
|
||||
const option = quantizationOptions[i];
|
||||
additionalQuantizationOptions.push(option);
|
||||
}
|
||||
if (additionalQuantizationOptions.length > 0) {
|
||||
newQuantizationOptions.push({
|
||||
label: 'Additional Quantization Options',
|
||||
options: additionalQuantizationOptions,
|
||||
});
|
||||
}
|
||||
return newQuantizationOptions;
|
||||
}, [modelArch]);
|
||||
|
||||
return (
|
||||
<>
|
||||
<form onSubmit={handleSubmit} className="space-y-8">
|
||||
@@ -180,7 +221,7 @@ export default function SimpleJob({
|
||||
}
|
||||
setJobConfig(value, 'config.process[0].model.qtype');
|
||||
}}
|
||||
options={quantizationOptions}
|
||||
options={transformerQuantizationOptions}
|
||||
/>
|
||||
<SelectInput
|
||||
label="Text Encoder"
|
||||
@@ -405,8 +446,8 @@ export default function SimpleJob({
|
||||
label="Unload TE"
|
||||
checked={jobConfig.config.process[0].train.unload_text_encoder || false}
|
||||
docKey={'train.unload_text_encoder'}
|
||||
onChange={(value) => {
|
||||
setJobConfig(value, 'config.process[0].train.unload_text_encoder')
|
||||
onChange={value => {
|
||||
setJobConfig(value, 'config.process[0].train.unload_text_encoder');
|
||||
if (value) {
|
||||
setJobConfig(false, 'config.process[0].train.cache_text_embeddings');
|
||||
}
|
||||
@@ -416,10 +457,10 @@ export default function SimpleJob({
|
||||
label="Cache Text Embeddings"
|
||||
checked={jobConfig.config.process[0].train.cache_text_embeddings || false}
|
||||
docKey={'train.cache_text_embeddings'}
|
||||
onChange={(value) => {
|
||||
setJobConfig(value, 'config.process[0].train.cache_text_embeddings')
|
||||
onChange={value => {
|
||||
setJobConfig(value, 'config.process[0].train.cache_text_embeddings');
|
||||
if (value) {
|
||||
setJobConfig(false, 'config.process[0].train.unload_text_encoder')
|
||||
setJobConfig(false, 'config.process[0].train.unload_text_encoder');
|
||||
}
|
||||
}}
|
||||
/>
|
||||
|
||||
@@ -15,6 +15,7 @@ export interface ModelArch {
|
||||
defaults?: { [key: string]: any };
|
||||
disableSections?: DisableableSections[];
|
||||
additionalSections?: AdditionalSections[];
|
||||
accuracyRecoveryAdapters?: { [key: string]: string };
|
||||
}
|
||||
|
||||
const defaultNameOrPath = '';
|
||||
@@ -230,9 +231,13 @@ export const modelArchs: ModelArch[] = [
|
||||
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
|
||||
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
|
||||
'config.process[0].train.timestep_type': ['weighted', 'sigmoid'],
|
||||
'config.process[0].model.qtype': ['qfloat8', 'qfloat8'],
|
||||
},
|
||||
disableSections: ['network.conv'],
|
||||
additionalSections: ['model.low_vram'],
|
||||
accuracyRecoveryAdapters: {
|
||||
'3 bit with ARA': 'uint3|ostris/accuracy_recovery_adapters/qwen_image_torchao_uint3.safetensors',
|
||||
},
|
||||
},
|
||||
{
|
||||
name: 'hidream',
|
||||
|
||||
@@ -1 +1 @@
|
||||
VERSION = "0.4.0"
|
||||
VERSION = "0.4.1"
|
||||
Reference in New Issue
Block a user