mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-29 02:31:17 +00:00
Add support for training with an accuracy recovery adapter with qwen image
This commit is contained in:
@@ -2,7 +2,7 @@
|
||||
import { useMemo } from 'react';
|
||||
import { modelArchs, ModelArch, groupedModelOptions, quantizationOptions, defaultQtype } from './options';
|
||||
import { defaultDatasetConfig } from './jobConfig';
|
||||
import { JobConfig } from '@/types';
|
||||
import { GroupedSelectOption, JobConfig, SelectOption } from '@/types';
|
||||
import { objectCopy } from '@/utils/basic';
|
||||
import { TextInput, SelectInput, Checkbox, FormGroup, NumberInput } from '@/components/formInputs';
|
||||
import Card from '@/components/Card';
|
||||
@@ -46,6 +46,47 @@ export default function SimpleJob({
|
||||
topBarClass = 'grid grid-cols-1 md:grid-cols-2 lg:grid-cols-4 xl:grid-cols-4 gap-6';
|
||||
}
|
||||
|
||||
const transformerQuantizationOptions: GroupedSelectOption[] | SelectOption[] = useMemo(() => {
|
||||
const hasARA = modelArch?.accuracyRecoveryAdapters && Object.keys(modelArch.accuracyRecoveryAdapters).length > 0;
|
||||
if (!hasARA) {
|
||||
return quantizationOptions;
|
||||
}
|
||||
let newQuantizationOptions = [
|
||||
{
|
||||
label: 'Standard',
|
||||
options: [quantizationOptions[0], quantizationOptions[1]],
|
||||
},
|
||||
];
|
||||
|
||||
// add ARAs if they exist for the model
|
||||
let ARAs: SelectOption[] = [];
|
||||
if (modelArch.accuracyRecoveryAdapters) {
|
||||
for (const [label, value] of Object.entries(modelArch.accuracyRecoveryAdapters)) {
|
||||
ARAs.push({ value, label });
|
||||
}
|
||||
}
|
||||
if (ARAs.length > 0) {
|
||||
newQuantizationOptions.push({
|
||||
label: 'Accuracy Recovery Adapters',
|
||||
options: ARAs,
|
||||
});
|
||||
}
|
||||
|
||||
let additionalQuantizationOptions: SelectOption[] = [];
|
||||
// add the quantization options if they are not already included
|
||||
for (let i = 2; i < quantizationOptions.length; i++) {
|
||||
const option = quantizationOptions[i];
|
||||
additionalQuantizationOptions.push(option);
|
||||
}
|
||||
if (additionalQuantizationOptions.length > 0) {
|
||||
newQuantizationOptions.push({
|
||||
label: 'Additional Quantization Options',
|
||||
options: additionalQuantizationOptions,
|
||||
});
|
||||
}
|
||||
return newQuantizationOptions;
|
||||
}, [modelArch]);
|
||||
|
||||
return (
|
||||
<>
|
||||
<form onSubmit={handleSubmit} className="space-y-8">
|
||||
@@ -180,7 +221,7 @@ export default function SimpleJob({
|
||||
}
|
||||
setJobConfig(value, 'config.process[0].model.qtype');
|
||||
}}
|
||||
options={quantizationOptions}
|
||||
options={transformerQuantizationOptions}
|
||||
/>
|
||||
<SelectInput
|
||||
label="Text Encoder"
|
||||
@@ -405,8 +446,8 @@ export default function SimpleJob({
|
||||
label="Unload TE"
|
||||
checked={jobConfig.config.process[0].train.unload_text_encoder || false}
|
||||
docKey={'train.unload_text_encoder'}
|
||||
onChange={(value) => {
|
||||
setJobConfig(value, 'config.process[0].train.unload_text_encoder')
|
||||
onChange={value => {
|
||||
setJobConfig(value, 'config.process[0].train.unload_text_encoder');
|
||||
if (value) {
|
||||
setJobConfig(false, 'config.process[0].train.cache_text_embeddings');
|
||||
}
|
||||
@@ -416,10 +457,10 @@ export default function SimpleJob({
|
||||
label="Cache Text Embeddings"
|
||||
checked={jobConfig.config.process[0].train.cache_text_embeddings || false}
|
||||
docKey={'train.cache_text_embeddings'}
|
||||
onChange={(value) => {
|
||||
setJobConfig(value, 'config.process[0].train.cache_text_embeddings')
|
||||
onChange={value => {
|
||||
setJobConfig(value, 'config.process[0].train.cache_text_embeddings');
|
||||
if (value) {
|
||||
setJobConfig(false, 'config.process[0].train.unload_text_encoder')
|
||||
setJobConfig(false, 'config.process[0].train.unload_text_encoder');
|
||||
}
|
||||
}}
|
||||
/>
|
||||
|
||||
@@ -15,6 +15,7 @@ export interface ModelArch {
|
||||
defaults?: { [key: string]: any };
|
||||
disableSections?: DisableableSections[];
|
||||
additionalSections?: AdditionalSections[];
|
||||
accuracyRecoveryAdapters?: { [key: string]: string };
|
||||
}
|
||||
|
||||
const defaultNameOrPath = '';
|
||||
@@ -230,9 +231,13 @@ export const modelArchs: ModelArch[] = [
|
||||
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
|
||||
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
|
||||
'config.process[0].train.timestep_type': ['weighted', 'sigmoid'],
|
||||
'config.process[0].model.qtype': ['qfloat8', 'qfloat8'],
|
||||
},
|
||||
disableSections: ['network.conv'],
|
||||
additionalSections: ['model.low_vram'],
|
||||
accuracyRecoveryAdapters: {
|
||||
'3 bit with ARA': 'uint3|ostris/accuracy_recovery_adapters/qwen_image_torchao_uint3.safetensors',
|
||||
},
|
||||
},
|
||||
{
|
||||
name: 'hidream',
|
||||
|
||||
Reference in New Issue
Block a user