Add ability to set the quantization type for text encoders and transformer in the ui

This commit is contained in:
Jaret Burkett
2025-07-27 18:00:53 -06:00
parent b717586ee2
commit ed8d14225f
5 changed files with 68 additions and 32 deletions

View File

@@ -1,6 +1,6 @@
'use client';
import { useMemo } from 'react';
import { modelArchs, ModelArch, groupedModelOptions } from './options';
import { modelArchs, ModelArch, groupedModelOptions, quantizationOptions, defaultQtype } from './options';
import { defaultDatasetConfig } from './jobConfig';
import { JobConfig } from '@/types';
import { objectCopy } from '@/utils/basic';
@@ -40,10 +40,16 @@ export default function SimpleJob({
const isVideoModel = !!(modelArch?.group === 'video');
let topBarClass = 'grid grid-cols-1 md:grid-cols-2 lg:grid-cols-3 xl:grid-cols-5 gap-6';
if (modelArch?.disableSections?.includes('model.quantize')) {
topBarClass = 'grid grid-cols-1 md:grid-cols-2 lg:grid-cols-4 xl:grid-cols-4 gap-6';
}
return (
<>
<form onSubmit={handleSubmit} className="space-y-8">
<div className="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-4 gap-6">
<div className={topBarClass}>
<Card title="Job Settings">
<TextInput
label="Training Name"
@@ -89,8 +95,8 @@ export default function SimpleJob({
// update the defaults when a model is selected
const newArch = modelArchs.find(model => model.name === value);
// update vram setting
if (!(newArch?.additionalSections?.includes('model.low_vram'))) {
// update vram setting
if (!newArch?.additionalSections?.includes('model.low_vram')) {
setJobConfig(false, 'config.process[0].model.low_vram');
}
@@ -150,32 +156,48 @@ export default function SimpleJob({
placeholder=""
required
/>
{modelArch?.disableSections?.includes('model.quantize') ? null : (
<FormGroup label="Quantize">
<div className="grid grid-cols-2 gap-2">
<Checkbox
label="Transformer"
checked={jobConfig.config.process[0].model.quantize}
onChange={value => setJobConfig(value, 'config.process[0].model.quantize')}
/>
<Checkbox
label="Text Encoder"
checked={jobConfig.config.process[0].model.quantize_te}
onChange={value => setJobConfig(value, 'config.process[0].model.quantize_te')}
/>
</div>
</FormGroup>
)}
{modelArch?.additionalSections?.includes('model.low_vram') && (
<FormGroup label="Options">
<Checkbox
label="Low VRAM"
checked={jobConfig.config.process[0].model.low_vram}
onChange={value => setJobConfig(value, 'config.process[0].model.low_vram')}
/>
<Checkbox
label="Low VRAM"
checked={jobConfig.config.process[0].model.low_vram}
onChange={value => setJobConfig(value, 'config.process[0].model.low_vram')}
/>
</FormGroup>
)}
</Card>
{modelArch?.disableSections?.includes('model.quantize') ? null : (
<Card title="Quantization">
<SelectInput
label="Transformer"
value={jobConfig.config.process[0].model.quantize ? jobConfig.config.process[0].model.qtype : ''}
onChange={value => {
if (value === '') {
setJobConfig(false, 'config.process[0].model.quantize');
value = defaultQtype;
} else {
setJobConfig(true, 'config.process[0].model.quantize');
}
setJobConfig(value, 'config.process[0].model.qtype');
}}
options={quantizationOptions}
/>
<SelectInput
label="Text Encoder"
value={jobConfig.config.process[0].model.quantize_te ? jobConfig.config.process[0].model.qtype_te : ''}
onChange={value => {
if (value === '') {
setJobConfig(false, 'config.process[0].model.quantize_te');
value = defaultQtype;
} else {
setJobConfig(true, 'config.process[0].model.quantize_te');
}
setJobConfig(value, 'config.process[0].model.qtype_te');
}}
options={quantizationOptions}
/>
</Card>
)}
<Card title="Target Configuration">
<SelectInput
label="Target Type"

View File

@@ -80,7 +80,9 @@ export const defaultJobConfig: JobConfig = {
model: {
name_or_path: 'ostris/Flex.1-alpha',
quantize: true,
qtype: 'qfloat8',
quantize_te: true,
qtype_te: 'qfloat8',
arch: 'flex1',
low_vram: false,
model_kwargs: {},

View File

@@ -1,4 +1,4 @@
import { GroupedSelectOption } from "@/types";
import { GroupedSelectOption, SelectOption } from '@/types';
type Control = 'depth' | 'line' | 'pose' | 'inpaint';
@@ -19,9 +19,6 @@ export interface ModelArch {
const defaultNameOrPath = '';
export const modelArchs: ModelArch[] = [
{
name: 'flux',
@@ -262,10 +259,9 @@ export const modelArchs: ModelArch[] = [
},
].sort((a, b) => {
// Sort by label, case-insensitive
return a.label.localeCompare(b.label, undefined, { sensitivity: 'base' })
return a.label.localeCompare(b.label, undefined, { sensitivity: 'base' });
}) as any;
export const groupedModelOptions: GroupedSelectOption[] = modelArchs.reduce((acc, arch) => {
const group = acc.find(g => g.label === arch.group);
if (group) {
@@ -278,3 +274,17 @@ export const groupedModelOptions: GroupedSelectOption[] = modelArchs.reduce((acc
}
return acc;
}, [] as GroupedSelectOption[]);
export const quantizationOptions: SelectOption[] = [
{ value: '', label: '- NONE -' },
{ value: 'qfloat8', label: 'float8 (default)' },
{ value: 'uint8', label: '8 bit' },
{ value: 'uint7', label: '7 bit' },
{ value: 'uint6', label: '6 bit' },
{ value: 'uint5', label: '5 bit' },
{ value: 'uint4', label: '4 bit' },
{ value: 'uint3', label: '3 bit' },
{ value: 'uint2', label: '2 bit' },
];
export const defaultQtype = 'qfloat8';

View File

@@ -127,6 +127,8 @@ export interface ModelConfig {
name_or_path: string;
quantize: boolean;
quantize_te: boolean;
qtype: string;
qtype_te: string;
quantize_kwargs?: QuantizeKwargsConfig;
arch: string;
low_vram: boolean;

View File

@@ -1 +1 @@
VERSION = "0.3.11"
VERSION = "0.3.12"