Added support for sdxl and sd1.5 to the ui.

This commit is contained in:
Jaret Burkett
2025-06-10 10:03:54 -06:00
parent d5c547da43
commit f8fb3b9c45
6 changed files with 131 additions and 53 deletions

View File

@@ -414,3 +414,12 @@ To learn more about LoKr, read more about it at [KohakuBlueleaf/LyCORIS](https:/
Everything else should work the same including layer targeting. Everything else should work the same including layer targeting.
## Updates
### June 10, 2024
- Decided to keep track up updates in the readme
- Added support for SDXL in the UI
- Added support for SD 1.5 in the UI
- Fixed UI Wan 2.1 14b name bug
- Added support for for conv training in the UI for models that support it

View File

@@ -33,7 +33,6 @@ export default function SimpleJob({
gpuList, gpuList,
datasetOptions, datasetOptions,
}: Props) { }: Props) {
const modelArch = useMemo(() => { const modelArch = useMemo(() => {
return modelArchs.find(a => a.name === jobConfig.config.process[0].model.arch) as ModelArch; return modelArchs.find(a => a.name === jobConfig.config.process[0].model.arch) as ModelArch;
}, [jobConfig.config.process[0].model.arch]); }, [jobConfig.config.process[0].model.arch]);
@@ -104,8 +103,7 @@ export default function SimpleJob({
const newDataset = objectCopy(dataset); const newDataset = objectCopy(dataset);
newDataset.controls = controls; newDataset.controls = controls;
return newDataset; return newDataset;
} });
);
setJobConfig(datasets, 'config.process[0].datasets'); setJobConfig(datasets, 'config.process[0].datasets');
}} }}
options={ options={
@@ -131,20 +129,22 @@ export default function SimpleJob({
placeholder="" placeholder=""
required required
/> />
<FormGroup label="Quantize"> {modelArch?.disableSections?.includes('model.quantize') ? null : (
<div className="grid grid-cols-2 gap-2"> <FormGroup label="Quantize">
<Checkbox <div className="grid grid-cols-2 gap-2">
label="Transformer" <Checkbox
checked={jobConfig.config.process[0].model.quantize} label="Transformer"
onChange={value => setJobConfig(value, 'config.process[0].model.quantize')} checked={jobConfig.config.process[0].model.quantize}
/> onChange={value => setJobConfig(value, 'config.process[0].model.quantize')}
<Checkbox />
label="Text Encoder" <Checkbox
checked={jobConfig.config.process[0].model.quantize_te} label="Text Encoder"
onChange={value => setJobConfig(value, 'config.process[0].model.quantize_te')} checked={jobConfig.config.process[0].model.quantize_te}
/> onChange={value => setJobConfig(value, 'config.process[0].model.quantize_te')}
</div> />
</FormGroup> </div>
</FormGroup>
)}
</Card> </Card>
<Card title="Target Configuration"> <Card title="Target Configuration">
<SelectInput <SelectInput
@@ -171,19 +171,37 @@ export default function SimpleJob({
/> />
)} )}
{jobConfig.config.process[0].network?.type == 'lora' && ( {jobConfig.config.process[0].network?.type == 'lora' && (
<NumberInput <>
label="Linear Rank" <NumberInput
value={jobConfig.config.process[0].network.linear} label="Linear Rank"
onChange={value => { value={jobConfig.config.process[0].network.linear}
console.log('onChange', value); onChange={value => {
setJobConfig(value, 'config.process[0].network.linear'); console.log('onChange', value);
setJobConfig(value, 'config.process[0].network.linear_alpha'); setJobConfig(value, 'config.process[0].network.linear');
}} setJobConfig(value, 'config.process[0].network.linear_alpha');
placeholder="eg. 16" }}
min={0} placeholder="eg. 16"
max={1024} min={0}
required max={1024}
/> required
/>
{
modelArch?.disableSections?.includes('network.conv') ? null : (
<NumberInput
label="Conv Rank"
value={jobConfig.config.process[0].network.conv}
onChange={value => {
console.log('onChange', value);
setJobConfig(value, 'config.process[0].network.conv');
setJobConfig(value, 'config.process[0].network.conv_alpha');
}}
placeholder="eg. 16"
min={0}
max={1024}
/>
)
}
</>
)} )}
</Card> </Card>
<Card title="Save Configuration"> <Card title="Save Configuration">
@@ -276,16 +294,19 @@ export default function SimpleJob({
/> />
</div> </div>
<div> <div>
<SelectInput {modelArch?.disableSections?.includes('train.timestep_type') ? null : (
label="Timestep Type" <SelectInput
value={jobConfig.config.process[0].train.timestep_type} label="Timestep Type"
onChange={value => setJobConfig(value, 'config.process[0].train.timestep_type')} value={jobConfig.config.process[0].train.timestep_type}
options={[ disabled={modelArch?.disableSections?.includes('train.timestep_type') || false}
{ value: 'sigmoid', label: 'Sigmoid' }, onChange={value => setJobConfig(value, 'config.process[0].train.timestep_type')}
{ value: 'linear', label: 'Linear' }, options={[
{ value: 'shift', label: 'Shift' }, { value: 'sigmoid', label: 'Sigmoid' },
]} { value: 'linear', label: 'Linear' },
/> { value: 'shift', label: 'Shift' },
]}
/>
)}
<SelectInput <SelectInput
label="Timestep Bias" label="Timestep Bias"
className="pt-2" className="pt-2"
@@ -466,10 +487,7 @@ export default function SimpleJob({
// automaticallt add the controls for a new dataset // automaticallt add the controls for a new dataset
const controls = modelArch?.controls ?? []; const controls = modelArch?.controls ?? [];
newDataset.controls = controls; newDataset.controls = controls;
setJobConfig( setJobConfig([...jobConfig.config.process[0].datasets, newDataset], 'config.process[0].datasets');
[...jobConfig.config.process[0].datasets, newDataset],
'config.process[0].datasets',
)
}} }}
className="w-full px-4 py-2 bg-gray-700 hover:bg-gray-600 rounded-lg transition-colors" className="w-full px-4 py-2 bg-gray-700 hover:bg-gray-600 rounded-lg transition-colors"
> >

View File

@@ -30,6 +30,8 @@ export const defaultJobConfig: JobConfig = {
type: 'lora', type: 'lora',
linear: 32, linear: 32,
linear_alpha: 32, linear_alpha: 32,
conv: 16,
conv_alpha: 16,
lokr_full_rank: true, lokr_full_rank: true,
lokr_factor: -1, lokr_factor: -1,
network_kwargs: { network_kwargs: {

View File

@@ -1,4 +1,3 @@
type Control = 'depth' | 'line' | 'pose' | 'inpaint'; type Control = 'depth' | 'line' | 'pose' | 'inpaint';
export interface ModelArch { export interface ModelArch {
@@ -6,11 +5,14 @@ export interface ModelArch {
label: string; label: string;
controls?: Control[]; controls?: Control[];
isVideoModel?: boolean; isVideoModel?: boolean;
defaults?: { [key: string]: [any, any] }; defaults?: { [key: string]: any };
disableSections?: DisableableSections[];
} }
const defaultNameOrPath = ''; const defaultNameOrPath = '';
type DisableableSections = 'model.quantize' | 'train.timestep_type' | 'network.conv';
export const modelArchs: ModelArch[] = [ export const modelArchs: ModelArch[] = [
{ {
name: 'flux', name: 'flux',
@@ -23,6 +25,7 @@ export const modelArchs: ModelArch[] = [
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'], 'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'], 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
}, },
disableSections: ['network.conv'],
}, },
{ {
name: 'flex1', name: 'flex1',
@@ -36,6 +39,7 @@ export const modelArchs: ModelArch[] = [
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'], 'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'], 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
}, },
disableSections: ['network.conv'],
}, },
{ {
name: 'flex2', name: 'flex2',
@@ -62,6 +66,7 @@ export const modelArchs: ModelArch[] = [
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'], 'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'], 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
}, },
disableSections: ['network.conv'],
}, },
{ {
name: 'chroma', name: 'chroma',
@@ -74,6 +79,7 @@ export const modelArchs: ModelArch[] = [
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'], 'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'], 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
}, },
disableSections: ['network.conv'],
}, },
{ {
name: 'wan21:1b', name: 'wan21:1b',
@@ -89,6 +95,7 @@ export const modelArchs: ModelArch[] = [
'config.process[0].sample.num_frames': [40, 1], 'config.process[0].sample.num_frames': [40, 1],
'config.process[0].sample.fps': [15, 1], 'config.process[0].sample.fps': [15, 1],
}, },
disableSections: ['network.conv'],
}, },
{ {
name: 'wan21:14b', name: 'wan21:14b',
@@ -104,6 +111,7 @@ export const modelArchs: ModelArch[] = [
'config.process[0].sample.num_frames': [40, 1], 'config.process[0].sample.num_frames': [40, 1],
'config.process[0].sample.fps': [15, 1], 'config.process[0].sample.fps': [15, 1],
}, },
disableSections: ['network.conv'],
}, },
{ {
name: 'lumina2', name: 'lumina2',
@@ -116,6 +124,7 @@ export const modelArchs: ModelArch[] = [
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'], 'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'], 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
}, },
disableSections: ['network.conv'],
}, },
{ {
name: 'hidream', name: 'hidream',
@@ -131,5 +140,37 @@ export const modelArchs: ModelArch[] = [
'config.process[0].train.timestep_type': ['shift', 'sigmoid'], 'config.process[0].train.timestep_type': ['shift', 'sigmoid'],
'config.process[0].network.network_kwargs.ignore_if_contains': [['ff_i.experts', 'ff_i.gate'], []], 'config.process[0].network.network_kwargs.ignore_if_contains': [['ff_i.experts', 'ff_i.gate'], []],
}, },
disableSections: ['network.conv'],
}, },
]; {
name: 'sdxl',
label: 'SDXL',
defaults: {
// default updates when [selected, unselected] in the UI
'config.process[0].model.name_or_path': ['stabilityai/stable-diffusion-xl-base-1.0', defaultNameOrPath],
'config.process[0].model.quantize': [false, false],
'config.process[0].model.quantize_te': [false, false],
'config.process[0].sample.sampler': ['ddpm', 'flowmatch'],
'config.process[0].train.noise_scheduler': ['ddpm', 'flowmatch'],
'config.process[0].sample.guidance_scale': [6, 4],
},
disableSections: ['model.quantize', 'train.timestep_type'],
},
{
name: 'sd15',
label: 'SD 1.5',
defaults: {
// default updates when [selected, unselected] in the UI
'config.process[0].model.name_or_path': ['stable-diffusion-v1-5/stable-diffusion-v1-5', defaultNameOrPath],
'config.process[0].sample.sampler': ['ddpm', 'flowmatch'],
'config.process[0].train.noise_scheduler': ['ddpm', 'flowmatch'],
'config.process[0].sample.width': [512, 1024],
'config.process[0].sample.height': [512, 1024],
'config.process[0].sample.guidance_scale': [6, 4],
},
disableSections: ['model.quantize', 'train.timestep_type'],
},
].sort((a, b) => {
// Sort by label, case-insensitive
return a.label.localeCompare(b.label, undefined, { sensitivity: 'base' })
}) as any;

View File

@@ -2,8 +2,8 @@
import React, { forwardRef } from 'react'; import React, { forwardRef } from 'react';
import classNames from 'classnames'; import classNames from 'classnames';
import dynamic from "next/dynamic"; import dynamic from 'next/dynamic';
const Select = dynamic(() => import("react-select"), { ssr: false }); const Select = dynamic(() => import('react-select'), { ssr: false });
const labelClasses = 'block text-xs mb-1 mt-2 text-gray-300'; const labelClasses = 'block text-xs mb-1 mt-2 text-gray-300';
const inputClasses = const inputClasses =
@@ -42,7 +42,7 @@ export const TextInput = forwardRef<HTMLInputElement, TextInputProps>(
/> />
</div> </div>
); );
} },
); );
// 👇 Helpful for debugging // 👇 Helpful for debugging
@@ -114,6 +114,7 @@ export const NumberInput = (props: NumberInputProps) => {
export interface SelectInputProps extends InputProps { export interface SelectInputProps extends InputProps {
value: string; value: string;
disabled?: boolean;
onChange: (value: string) => void; onChange: (value: string) => void;
options: { value: string; label: string }[]; options: { value: string; label: string }[];
} }
@@ -122,11 +123,16 @@ export const SelectInput = (props: SelectInputProps) => {
const { label, value, onChange, options } = props; const { label, value, onChange, options } = props;
const selectedOption = options.find(option => option.value === value); const selectedOption = options.find(option => option.value === value);
return ( return (
<div className={classNames(props.className)}> <div
className={classNames(props.className, {
'opacity-30 cursor-not-allowed': props.disabled,
})}
>
{label && <label className={labelClasses}>{label}</label>} {label && <label className={labelClasses}>{label}</label>}
<Select <Select
value={selectedOption} value={selectedOption}
options={options} options={options}
isDisabled={props.disabled}
className="aitk-react-select-container" className="aitk-react-select-container"
classNamePrefix="aitk-react-select" classNamePrefix="aitk-react-select"
onChange={selected => { onChange={selected => {

View File

@@ -53,6 +53,8 @@ export interface NetworkConfig {
type: string; type: string;
linear: number; linear: number;
linear_alpha: number; linear_alpha: number;
conv: number;
conv_alpha: number;
lokr_full_rank: boolean; lokr_full_rank: boolean;
lokr_factor: number; lokr_factor: number;
network_kwargs: { network_kwargs: {