mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Add ability to set the quantization type for text encoders and transformer in the ui
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
'use client';
|
||||
import { useMemo } from 'react';
|
||||
import { modelArchs, ModelArch, groupedModelOptions } from './options';
|
||||
import { modelArchs, ModelArch, groupedModelOptions, quantizationOptions, defaultQtype } from './options';
|
||||
import { defaultDatasetConfig } from './jobConfig';
|
||||
import { JobConfig } from '@/types';
|
||||
import { objectCopy } from '@/utils/basic';
|
||||
@@ -40,10 +40,16 @@ export default function SimpleJob({
|
||||
|
||||
const isVideoModel = !!(modelArch?.group === 'video');
|
||||
|
||||
let topBarClass = 'grid grid-cols-1 md:grid-cols-2 lg:grid-cols-3 xl:grid-cols-5 gap-6';
|
||||
|
||||
if (modelArch?.disableSections?.includes('model.quantize')) {
|
||||
topBarClass = 'grid grid-cols-1 md:grid-cols-2 lg:grid-cols-4 xl:grid-cols-4 gap-6';
|
||||
}
|
||||
|
||||
return (
|
||||
<>
|
||||
<form onSubmit={handleSubmit} className="space-y-8">
|
||||
<div className="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-4 gap-6">
|
||||
<div className={topBarClass}>
|
||||
<Card title="Job Settings">
|
||||
<TextInput
|
||||
label="Training Name"
|
||||
@@ -89,8 +95,8 @@ export default function SimpleJob({
|
||||
// update the defaults when a model is selected
|
||||
const newArch = modelArchs.find(model => model.name === value);
|
||||
|
||||
// update vram setting
|
||||
if (!(newArch?.additionalSections?.includes('model.low_vram'))) {
|
||||
// update vram setting
|
||||
if (!newArch?.additionalSections?.includes('model.low_vram')) {
|
||||
setJobConfig(false, 'config.process[0].model.low_vram');
|
||||
}
|
||||
|
||||
@@ -150,32 +156,48 @@ export default function SimpleJob({
|
||||
placeholder=""
|
||||
required
|
||||
/>
|
||||
{modelArch?.disableSections?.includes('model.quantize') ? null : (
|
||||
<FormGroup label="Quantize">
|
||||
<div className="grid grid-cols-2 gap-2">
|
||||
<Checkbox
|
||||
label="Transformer"
|
||||
checked={jobConfig.config.process[0].model.quantize}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].model.quantize')}
|
||||
/>
|
||||
<Checkbox
|
||||
label="Text Encoder"
|
||||
checked={jobConfig.config.process[0].model.quantize_te}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].model.quantize_te')}
|
||||
/>
|
||||
</div>
|
||||
</FormGroup>
|
||||
)}
|
||||
{modelArch?.additionalSections?.includes('model.low_vram') && (
|
||||
<FormGroup label="Options">
|
||||
<Checkbox
|
||||
label="Low VRAM"
|
||||
checked={jobConfig.config.process[0].model.low_vram}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].model.low_vram')}
|
||||
/>
|
||||
<Checkbox
|
||||
label="Low VRAM"
|
||||
checked={jobConfig.config.process[0].model.low_vram}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].model.low_vram')}
|
||||
/>
|
||||
</FormGroup>
|
||||
)}
|
||||
</Card>
|
||||
{modelArch?.disableSections?.includes('model.quantize') ? null : (
|
||||
<Card title="Quantization">
|
||||
<SelectInput
|
||||
label="Transformer"
|
||||
value={jobConfig.config.process[0].model.quantize ? jobConfig.config.process[0].model.qtype : ''}
|
||||
onChange={value => {
|
||||
if (value === '') {
|
||||
setJobConfig(false, 'config.process[0].model.quantize');
|
||||
value = defaultQtype;
|
||||
} else {
|
||||
setJobConfig(true, 'config.process[0].model.quantize');
|
||||
}
|
||||
setJobConfig(value, 'config.process[0].model.qtype');
|
||||
}}
|
||||
options={quantizationOptions}
|
||||
/>
|
||||
<SelectInput
|
||||
label="Text Encoder"
|
||||
value={jobConfig.config.process[0].model.quantize_te ? jobConfig.config.process[0].model.qtype_te : ''}
|
||||
onChange={value => {
|
||||
if (value === '') {
|
||||
setJobConfig(false, 'config.process[0].model.quantize_te');
|
||||
value = defaultQtype;
|
||||
} else {
|
||||
setJobConfig(true, 'config.process[0].model.quantize_te');
|
||||
}
|
||||
setJobConfig(value, 'config.process[0].model.qtype_te');
|
||||
}}
|
||||
options={quantizationOptions}
|
||||
/>
|
||||
</Card>
|
||||
)}
|
||||
<Card title="Target Configuration">
|
||||
<SelectInput
|
||||
label="Target Type"
|
||||
|
||||
@@ -80,7 +80,9 @@ export const defaultJobConfig: JobConfig = {
|
||||
model: {
|
||||
name_or_path: 'ostris/Flex.1-alpha',
|
||||
quantize: true,
|
||||
qtype: 'qfloat8',
|
||||
quantize_te: true,
|
||||
qtype_te: 'qfloat8',
|
||||
arch: 'flex1',
|
||||
low_vram: false,
|
||||
model_kwargs: {},
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { GroupedSelectOption } from "@/types";
|
||||
import { GroupedSelectOption, SelectOption } from '@/types';
|
||||
|
||||
type Control = 'depth' | 'line' | 'pose' | 'inpaint';
|
||||
|
||||
@@ -19,9 +19,6 @@ export interface ModelArch {
|
||||
|
||||
const defaultNameOrPath = '';
|
||||
|
||||
|
||||
|
||||
|
||||
export const modelArchs: ModelArch[] = [
|
||||
{
|
||||
name: 'flux',
|
||||
@@ -262,10 +259,9 @@ export const modelArchs: ModelArch[] = [
|
||||
},
|
||||
].sort((a, b) => {
|
||||
// Sort by label, case-insensitive
|
||||
return a.label.localeCompare(b.label, undefined, { sensitivity: 'base' })
|
||||
return a.label.localeCompare(b.label, undefined, { sensitivity: 'base' });
|
||||
}) as any;
|
||||
|
||||
|
||||
export const groupedModelOptions: GroupedSelectOption[] = modelArchs.reduce((acc, arch) => {
|
||||
const group = acc.find(g => g.label === arch.group);
|
||||
if (group) {
|
||||
@@ -278,3 +274,17 @@ export const groupedModelOptions: GroupedSelectOption[] = modelArchs.reduce((acc
|
||||
}
|
||||
return acc;
|
||||
}, [] as GroupedSelectOption[]);
|
||||
|
||||
export const quantizationOptions: SelectOption[] = [
|
||||
{ value: '', label: '- NONE -' },
|
||||
{ value: 'qfloat8', label: 'float8 (default)' },
|
||||
{ value: 'uint8', label: '8 bit' },
|
||||
{ value: 'uint7', label: '7 bit' },
|
||||
{ value: 'uint6', label: '6 bit' },
|
||||
{ value: 'uint5', label: '5 bit' },
|
||||
{ value: 'uint4', label: '4 bit' },
|
||||
{ value: 'uint3', label: '3 bit' },
|
||||
{ value: 'uint2', label: '2 bit' },
|
||||
];
|
||||
|
||||
export const defaultQtype = 'qfloat8';
|
||||
|
||||
@@ -127,6 +127,8 @@ export interface ModelConfig {
|
||||
name_or_path: string;
|
||||
quantize: boolean;
|
||||
quantize_te: boolean;
|
||||
qtype: string;
|
||||
qtype_te: string;
|
||||
quantize_kwargs?: QuantizeKwargsConfig;
|
||||
arch: string;
|
||||
low_vram: boolean;
|
||||
|
||||
Reference in New Issue
Block a user