Added support for sdxl and sd1.5 to the ui.

This commit is contained in:
Jaret Burkett
2025-06-10 10:03:54 -06:00
parent d5c547da43
commit f8fb3b9c45
6 changed files with 131 additions and 53 deletions

View File

@@ -414,3 +414,12 @@ To learn more about LoKr, read more about it at [KohakuBlueleaf/LyCORIS](https:/
Everything else should work the same including layer targeting.
## Updates
### June 10, 2024
- Decided to keep track up updates in the readme
- Added support for SDXL in the UI
- Added support for SD 1.5 in the UI
- Fixed UI Wan 2.1 14b name bug
- Added support for for conv training in the UI for models that support it

View File

@@ -33,7 +33,6 @@ export default function SimpleJob({
gpuList,
datasetOptions,
}: Props) {
const modelArch = useMemo(() => {
return modelArchs.find(a => a.name === jobConfig.config.process[0].model.arch) as ModelArch;
}, [jobConfig.config.process[0].model.arch]);
@@ -104,8 +103,7 @@ export default function SimpleJob({
const newDataset = objectCopy(dataset);
newDataset.controls = controls;
return newDataset;
}
);
});
setJobConfig(datasets, 'config.process[0].datasets');
}}
options={
@@ -131,20 +129,22 @@ export default function SimpleJob({
placeholder=""
required
/>
<FormGroup label="Quantize">
<div className="grid grid-cols-2 gap-2">
<Checkbox
label="Transformer"
checked={jobConfig.config.process[0].model.quantize}
onChange={value => setJobConfig(value, 'config.process[0].model.quantize')}
/>
<Checkbox
label="Text Encoder"
checked={jobConfig.config.process[0].model.quantize_te}
onChange={value => setJobConfig(value, 'config.process[0].model.quantize_te')}
/>
</div>
</FormGroup>
{modelArch?.disableSections?.includes('model.quantize') ? null : (
<FormGroup label="Quantize">
<div className="grid grid-cols-2 gap-2">
<Checkbox
label="Transformer"
checked={jobConfig.config.process[0].model.quantize}
onChange={value => setJobConfig(value, 'config.process[0].model.quantize')}
/>
<Checkbox
label="Text Encoder"
checked={jobConfig.config.process[0].model.quantize_te}
onChange={value => setJobConfig(value, 'config.process[0].model.quantize_te')}
/>
</div>
</FormGroup>
)}
</Card>
<Card title="Target Configuration">
<SelectInput
@@ -171,19 +171,37 @@ export default function SimpleJob({
/>
)}
{jobConfig.config.process[0].network?.type == 'lora' && (
<NumberInput
label="Linear Rank"
value={jobConfig.config.process[0].network.linear}
onChange={value => {
console.log('onChange', value);
setJobConfig(value, 'config.process[0].network.linear');
setJobConfig(value, 'config.process[0].network.linear_alpha');
}}
placeholder="eg. 16"
min={0}
max={1024}
required
/>
<>
<NumberInput
label="Linear Rank"
value={jobConfig.config.process[0].network.linear}
onChange={value => {
console.log('onChange', value);
setJobConfig(value, 'config.process[0].network.linear');
setJobConfig(value, 'config.process[0].network.linear_alpha');
}}
placeholder="eg. 16"
min={0}
max={1024}
required
/>
{
modelArch?.disableSections?.includes('network.conv') ? null : (
<NumberInput
label="Conv Rank"
value={jobConfig.config.process[0].network.conv}
onChange={value => {
console.log('onChange', value);
setJobConfig(value, 'config.process[0].network.conv');
setJobConfig(value, 'config.process[0].network.conv_alpha');
}}
placeholder="eg. 16"
min={0}
max={1024}
/>
)
}
</>
)}
</Card>
<Card title="Save Configuration">
@@ -276,16 +294,19 @@ export default function SimpleJob({
/>
</div>
<div>
<SelectInput
label="Timestep Type"
value={jobConfig.config.process[0].train.timestep_type}
onChange={value => setJobConfig(value, 'config.process[0].train.timestep_type')}
options={[
{ value: 'sigmoid', label: 'Sigmoid' },
{ value: 'linear', label: 'Linear' },
{ value: 'shift', label: 'Shift' },
]}
/>
{modelArch?.disableSections?.includes('train.timestep_type') ? null : (
<SelectInput
label="Timestep Type"
value={jobConfig.config.process[0].train.timestep_type}
disabled={modelArch?.disableSections?.includes('train.timestep_type') || false}
onChange={value => setJobConfig(value, 'config.process[0].train.timestep_type')}
options={[
{ value: 'sigmoid', label: 'Sigmoid' },
{ value: 'linear', label: 'Linear' },
{ value: 'shift', label: 'Shift' },
]}
/>
)}
<SelectInput
label="Timestep Bias"
className="pt-2"
@@ -466,10 +487,7 @@ export default function SimpleJob({
// automaticallt add the controls for a new dataset
const controls = modelArch?.controls ?? [];
newDataset.controls = controls;
setJobConfig(
[...jobConfig.config.process[0].datasets, newDataset],
'config.process[0].datasets',
)
setJobConfig([...jobConfig.config.process[0].datasets, newDataset], 'config.process[0].datasets');
}}
className="w-full px-4 py-2 bg-gray-700 hover:bg-gray-600 rounded-lg transition-colors"
>

View File

@@ -30,6 +30,8 @@ export const defaultJobConfig: JobConfig = {
type: 'lora',
linear: 32,
linear_alpha: 32,
conv: 16,
conv_alpha: 16,
lokr_full_rank: true,
lokr_factor: -1,
network_kwargs: {

View File

@@ -1,4 +1,3 @@
type Control = 'depth' | 'line' | 'pose' | 'inpaint';
export interface ModelArch {
@@ -6,11 +5,14 @@ export interface ModelArch {
label: string;
controls?: Control[];
isVideoModel?: boolean;
defaults?: { [key: string]: [any, any] };
defaults?: { [key: string]: any };
disableSections?: DisableableSections[];
}
const defaultNameOrPath = '';
type DisableableSections = 'model.quantize' | 'train.timestep_type' | 'network.conv';
export const modelArchs: ModelArch[] = [
{
name: 'flux',
@@ -23,6 +25,7 @@ export const modelArchs: ModelArch[] = [
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
},
disableSections: ['network.conv'],
},
{
name: 'flex1',
@@ -36,6 +39,7 @@ export const modelArchs: ModelArch[] = [
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
},
disableSections: ['network.conv'],
},
{
name: 'flex2',
@@ -62,6 +66,7 @@ export const modelArchs: ModelArch[] = [
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
},
disableSections: ['network.conv'],
},
{
name: 'chroma',
@@ -74,6 +79,7 @@ export const modelArchs: ModelArch[] = [
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
},
disableSections: ['network.conv'],
},
{
name: 'wan21:1b',
@@ -89,6 +95,7 @@ export const modelArchs: ModelArch[] = [
'config.process[0].sample.num_frames': [40, 1],
'config.process[0].sample.fps': [15, 1],
},
disableSections: ['network.conv'],
},
{
name: 'wan21:14b',
@@ -104,6 +111,7 @@ export const modelArchs: ModelArch[] = [
'config.process[0].sample.num_frames': [40, 1],
'config.process[0].sample.fps': [15, 1],
},
disableSections: ['network.conv'],
},
{
name: 'lumina2',
@@ -116,6 +124,7 @@ export const modelArchs: ModelArch[] = [
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
},
disableSections: ['network.conv'],
},
{
name: 'hidream',
@@ -131,5 +140,37 @@ export const modelArchs: ModelArch[] = [
'config.process[0].train.timestep_type': ['shift', 'sigmoid'],
'config.process[0].network.network_kwargs.ignore_if_contains': [['ff_i.experts', 'ff_i.gate'], []],
},
disableSections: ['network.conv'],
},
];
{
name: 'sdxl',
label: 'SDXL',
defaults: {
// default updates when [selected, unselected] in the UI
'config.process[0].model.name_or_path': ['stabilityai/stable-diffusion-xl-base-1.0', defaultNameOrPath],
'config.process[0].model.quantize': [false, false],
'config.process[0].model.quantize_te': [false, false],
'config.process[0].sample.sampler': ['ddpm', 'flowmatch'],
'config.process[0].train.noise_scheduler': ['ddpm', 'flowmatch'],
'config.process[0].sample.guidance_scale': [6, 4],
},
disableSections: ['model.quantize', 'train.timestep_type'],
},
{
name: 'sd15',
label: 'SD 1.5',
defaults: {
// default updates when [selected, unselected] in the UI
'config.process[0].model.name_or_path': ['stable-diffusion-v1-5/stable-diffusion-v1-5', defaultNameOrPath],
'config.process[0].sample.sampler': ['ddpm', 'flowmatch'],
'config.process[0].train.noise_scheduler': ['ddpm', 'flowmatch'],
'config.process[0].sample.width': [512, 1024],
'config.process[0].sample.height': [512, 1024],
'config.process[0].sample.guidance_scale': [6, 4],
},
disableSections: ['model.quantize', 'train.timestep_type'],
},
].sort((a, b) => {
// Sort by label, case-insensitive
return a.label.localeCompare(b.label, undefined, { sensitivity: 'base' })
}) as any;

View File

@@ -2,8 +2,8 @@
import React, { forwardRef } from 'react';
import classNames from 'classnames';
import dynamic from "next/dynamic";
const Select = dynamic(() => import("react-select"), { ssr: false });
import dynamic from 'next/dynamic';
const Select = dynamic(() => import('react-select'), { ssr: false });
const labelClasses = 'block text-xs mb-1 mt-2 text-gray-300';
const inputClasses =
@@ -42,7 +42,7 @@ export const TextInput = forwardRef<HTMLInputElement, TextInputProps>(
/>
</div>
);
}
},
);
// 👇 Helpful for debugging
@@ -114,6 +114,7 @@ export const NumberInput = (props: NumberInputProps) => {
export interface SelectInputProps extends InputProps {
value: string;
disabled?: boolean;
onChange: (value: string) => void;
options: { value: string; label: string }[];
}
@@ -122,11 +123,16 @@ export const SelectInput = (props: SelectInputProps) => {
const { label, value, onChange, options } = props;
const selectedOption = options.find(option => option.value === value);
return (
<div className={classNames(props.className)}>
<div
className={classNames(props.className, {
'opacity-30 cursor-not-allowed': props.disabled,
})}
>
{label && <label className={labelClasses}>{label}</label>}
<Select
value={selectedOption}
<Select
value={selectedOption}
options={options}
isDisabled={props.disabled}
className="aitk-react-select-container"
classNamePrefix="aitk-react-select"
onChange={selected => {

View File

@@ -53,6 +53,8 @@ export interface NetworkConfig {
type: string;
linear: number;
linear_alpha: number;
conv: number;
conv_alpha: number;
lokr_full_rank: boolean;
lokr_factor: number;
network_kwargs: {