Add support for training with an accuracy recovery adapter with qwen image

This commit is contained in:
Jaret Burkett
2025-08-12 08:21:36 -06:00
parent 4ad18f3d00
commit 77b10d884d
8 changed files with 292 additions and 36 deletions

View File

@@ -2,7 +2,7 @@
import { useMemo } from 'react';
import { modelArchs, ModelArch, groupedModelOptions, quantizationOptions, defaultQtype } from './options';
import { defaultDatasetConfig } from './jobConfig';
import { JobConfig } from '@/types';
import { GroupedSelectOption, JobConfig, SelectOption } from '@/types';
import { objectCopy } from '@/utils/basic';
import { TextInput, SelectInput, Checkbox, FormGroup, NumberInput } from '@/components/formInputs';
import Card from '@/components/Card';
@@ -46,6 +46,47 @@ export default function SimpleJob({
topBarClass = 'grid grid-cols-1 md:grid-cols-2 lg:grid-cols-4 xl:grid-cols-4 gap-6';
}
const transformerQuantizationOptions: GroupedSelectOption[] | SelectOption[] = useMemo(() => {
const hasARA = modelArch?.accuracyRecoveryAdapters && Object.keys(modelArch.accuracyRecoveryAdapters).length > 0;
if (!hasARA) {
return quantizationOptions;
}
let newQuantizationOptions = [
{
label: 'Standard',
options: [quantizationOptions[0], quantizationOptions[1]],
},
];
// add ARAs if they exist for the model
let ARAs: SelectOption[] = [];
if (modelArch.accuracyRecoveryAdapters) {
for (const [label, value] of Object.entries(modelArch.accuracyRecoveryAdapters)) {
ARAs.push({ value, label });
}
}
if (ARAs.length > 0) {
newQuantizationOptions.push({
label: 'Accuracy Recovery Adapters',
options: ARAs,
});
}
let additionalQuantizationOptions: SelectOption[] = [];
// add the quantization options if they are not already included
for (let i = 2; i < quantizationOptions.length; i++) {
const option = quantizationOptions[i];
additionalQuantizationOptions.push(option);
}
if (additionalQuantizationOptions.length > 0) {
newQuantizationOptions.push({
label: 'Additional Quantization Options',
options: additionalQuantizationOptions,
});
}
return newQuantizationOptions;
}, [modelArch]);
return (
<>
<form onSubmit={handleSubmit} className="space-y-8">
@@ -180,7 +221,7 @@ export default function SimpleJob({
}
setJobConfig(value, 'config.process[0].model.qtype');
}}
options={quantizationOptions}
options={transformerQuantizationOptions}
/>
<SelectInput
label="Text Encoder"
@@ -405,8 +446,8 @@ export default function SimpleJob({
label="Unload TE"
checked={jobConfig.config.process[0].train.unload_text_encoder || false}
docKey={'train.unload_text_encoder'}
onChange={(value) => {
setJobConfig(value, 'config.process[0].train.unload_text_encoder')
onChange={value => {
setJobConfig(value, 'config.process[0].train.unload_text_encoder');
if (value) {
setJobConfig(false, 'config.process[0].train.cache_text_embeddings');
}
@@ -416,10 +457,10 @@ export default function SimpleJob({
label="Cache Text Embeddings"
checked={jobConfig.config.process[0].train.cache_text_embeddings || false}
docKey={'train.cache_text_embeddings'}
onChange={(value) => {
setJobConfig(value, 'config.process[0].train.cache_text_embeddings')
onChange={value => {
setJobConfig(value, 'config.process[0].train.cache_text_embeddings');
if (value) {
setJobConfig(false, 'config.process[0].train.unload_text_encoder')
setJobConfig(false, 'config.process[0].train.unload_text_encoder');
}
}}
/>

View File

@@ -15,6 +15,7 @@ export interface ModelArch {
defaults?: { [key: string]: any };
disableSections?: DisableableSections[];
additionalSections?: AdditionalSections[];
accuracyRecoveryAdapters?: { [key: string]: string };
}
const defaultNameOrPath = '';
@@ -230,9 +231,13 @@ export const modelArchs: ModelArch[] = [
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
'config.process[0].train.timestep_type': ['weighted', 'sigmoid'],
'config.process[0].model.qtype': ['qfloat8', 'qfloat8'],
},
disableSections: ['network.conv'],
additionalSections: ['model.low_vram'],
accuracyRecoveryAdapters: {
'3 bit with ARA': 'uint3|ostris/accuracy_recovery_adapters/qwen_image_torchao_uint3.safetensors',
},
},
{
name: 'hidream',