mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Added support for sdxl and sd1.5 to the ui.
This commit is contained in:
@@ -414,3 +414,12 @@ To learn more about LoKr, read more about it at [KohakuBlueleaf/LyCORIS](https:/
|
||||
|
||||
Everything else should work the same including layer targeting.
|
||||
|
||||
|
||||
## Updates
|
||||
|
||||
### June 10, 2024
|
||||
- Decided to keep track up updates in the readme
|
||||
- Added support for SDXL in the UI
|
||||
- Added support for SD 1.5 in the UI
|
||||
- Fixed UI Wan 2.1 14b name bug
|
||||
- Added support for for conv training in the UI for models that support it
|
||||
@@ -33,7 +33,6 @@ export default function SimpleJob({
|
||||
gpuList,
|
||||
datasetOptions,
|
||||
}: Props) {
|
||||
|
||||
const modelArch = useMemo(() => {
|
||||
return modelArchs.find(a => a.name === jobConfig.config.process[0].model.arch) as ModelArch;
|
||||
}, [jobConfig.config.process[0].model.arch]);
|
||||
@@ -104,8 +103,7 @@ export default function SimpleJob({
|
||||
const newDataset = objectCopy(dataset);
|
||||
newDataset.controls = controls;
|
||||
return newDataset;
|
||||
}
|
||||
);
|
||||
});
|
||||
setJobConfig(datasets, 'config.process[0].datasets');
|
||||
}}
|
||||
options={
|
||||
@@ -131,20 +129,22 @@ export default function SimpleJob({
|
||||
placeholder=""
|
||||
required
|
||||
/>
|
||||
<FormGroup label="Quantize">
|
||||
<div className="grid grid-cols-2 gap-2">
|
||||
<Checkbox
|
||||
label="Transformer"
|
||||
checked={jobConfig.config.process[0].model.quantize}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].model.quantize')}
|
||||
/>
|
||||
<Checkbox
|
||||
label="Text Encoder"
|
||||
checked={jobConfig.config.process[0].model.quantize_te}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].model.quantize_te')}
|
||||
/>
|
||||
</div>
|
||||
</FormGroup>
|
||||
{modelArch?.disableSections?.includes('model.quantize') ? null : (
|
||||
<FormGroup label="Quantize">
|
||||
<div className="grid grid-cols-2 gap-2">
|
||||
<Checkbox
|
||||
label="Transformer"
|
||||
checked={jobConfig.config.process[0].model.quantize}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].model.quantize')}
|
||||
/>
|
||||
<Checkbox
|
||||
label="Text Encoder"
|
||||
checked={jobConfig.config.process[0].model.quantize_te}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].model.quantize_te')}
|
||||
/>
|
||||
</div>
|
||||
</FormGroup>
|
||||
)}
|
||||
</Card>
|
||||
<Card title="Target Configuration">
|
||||
<SelectInput
|
||||
@@ -171,19 +171,37 @@ export default function SimpleJob({
|
||||
/>
|
||||
)}
|
||||
{jobConfig.config.process[0].network?.type == 'lora' && (
|
||||
<NumberInput
|
||||
label="Linear Rank"
|
||||
value={jobConfig.config.process[0].network.linear}
|
||||
onChange={value => {
|
||||
console.log('onChange', value);
|
||||
setJobConfig(value, 'config.process[0].network.linear');
|
||||
setJobConfig(value, 'config.process[0].network.linear_alpha');
|
||||
}}
|
||||
placeholder="eg. 16"
|
||||
min={0}
|
||||
max={1024}
|
||||
required
|
||||
/>
|
||||
<>
|
||||
<NumberInput
|
||||
label="Linear Rank"
|
||||
value={jobConfig.config.process[0].network.linear}
|
||||
onChange={value => {
|
||||
console.log('onChange', value);
|
||||
setJobConfig(value, 'config.process[0].network.linear');
|
||||
setJobConfig(value, 'config.process[0].network.linear_alpha');
|
||||
}}
|
||||
placeholder="eg. 16"
|
||||
min={0}
|
||||
max={1024}
|
||||
required
|
||||
/>
|
||||
{
|
||||
modelArch?.disableSections?.includes('network.conv') ? null : (
|
||||
<NumberInput
|
||||
label="Conv Rank"
|
||||
value={jobConfig.config.process[0].network.conv}
|
||||
onChange={value => {
|
||||
console.log('onChange', value);
|
||||
setJobConfig(value, 'config.process[0].network.conv');
|
||||
setJobConfig(value, 'config.process[0].network.conv_alpha');
|
||||
}}
|
||||
placeholder="eg. 16"
|
||||
min={0}
|
||||
max={1024}
|
||||
/>
|
||||
)
|
||||
}
|
||||
</>
|
||||
)}
|
||||
</Card>
|
||||
<Card title="Save Configuration">
|
||||
@@ -276,16 +294,19 @@ export default function SimpleJob({
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
<SelectInput
|
||||
label="Timestep Type"
|
||||
value={jobConfig.config.process[0].train.timestep_type}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].train.timestep_type')}
|
||||
options={[
|
||||
{ value: 'sigmoid', label: 'Sigmoid' },
|
||||
{ value: 'linear', label: 'Linear' },
|
||||
{ value: 'shift', label: 'Shift' },
|
||||
]}
|
||||
/>
|
||||
{modelArch?.disableSections?.includes('train.timestep_type') ? null : (
|
||||
<SelectInput
|
||||
label="Timestep Type"
|
||||
value={jobConfig.config.process[0].train.timestep_type}
|
||||
disabled={modelArch?.disableSections?.includes('train.timestep_type') || false}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].train.timestep_type')}
|
||||
options={[
|
||||
{ value: 'sigmoid', label: 'Sigmoid' },
|
||||
{ value: 'linear', label: 'Linear' },
|
||||
{ value: 'shift', label: 'Shift' },
|
||||
]}
|
||||
/>
|
||||
)}
|
||||
<SelectInput
|
||||
label="Timestep Bias"
|
||||
className="pt-2"
|
||||
@@ -466,10 +487,7 @@ export default function SimpleJob({
|
||||
// automaticallt add the controls for a new dataset
|
||||
const controls = modelArch?.controls ?? [];
|
||||
newDataset.controls = controls;
|
||||
setJobConfig(
|
||||
[...jobConfig.config.process[0].datasets, newDataset],
|
||||
'config.process[0].datasets',
|
||||
)
|
||||
setJobConfig([...jobConfig.config.process[0].datasets, newDataset], 'config.process[0].datasets');
|
||||
}}
|
||||
className="w-full px-4 py-2 bg-gray-700 hover:bg-gray-600 rounded-lg transition-colors"
|
||||
>
|
||||
|
||||
@@ -30,6 +30,8 @@ export const defaultJobConfig: JobConfig = {
|
||||
type: 'lora',
|
||||
linear: 32,
|
||||
linear_alpha: 32,
|
||||
conv: 16,
|
||||
conv_alpha: 16,
|
||||
lokr_full_rank: true,
|
||||
lokr_factor: -1,
|
||||
network_kwargs: {
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
|
||||
type Control = 'depth' | 'line' | 'pose' | 'inpaint';
|
||||
|
||||
export interface ModelArch {
|
||||
@@ -6,11 +5,14 @@ export interface ModelArch {
|
||||
label: string;
|
||||
controls?: Control[];
|
||||
isVideoModel?: boolean;
|
||||
defaults?: { [key: string]: [any, any] };
|
||||
defaults?: { [key: string]: any };
|
||||
disableSections?: DisableableSections[];
|
||||
}
|
||||
|
||||
const defaultNameOrPath = '';
|
||||
|
||||
type DisableableSections = 'model.quantize' | 'train.timestep_type' | 'network.conv';
|
||||
|
||||
export const modelArchs: ModelArch[] = [
|
||||
{
|
||||
name: 'flux',
|
||||
@@ -23,6 +25,7 @@ export const modelArchs: ModelArch[] = [
|
||||
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
|
||||
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
|
||||
},
|
||||
disableSections: ['network.conv'],
|
||||
},
|
||||
{
|
||||
name: 'flex1',
|
||||
@@ -36,6 +39,7 @@ export const modelArchs: ModelArch[] = [
|
||||
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
|
||||
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
|
||||
},
|
||||
disableSections: ['network.conv'],
|
||||
},
|
||||
{
|
||||
name: 'flex2',
|
||||
@@ -62,6 +66,7 @@ export const modelArchs: ModelArch[] = [
|
||||
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
|
||||
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
|
||||
},
|
||||
disableSections: ['network.conv'],
|
||||
},
|
||||
{
|
||||
name: 'chroma',
|
||||
@@ -74,6 +79,7 @@ export const modelArchs: ModelArch[] = [
|
||||
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
|
||||
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
|
||||
},
|
||||
disableSections: ['network.conv'],
|
||||
},
|
||||
{
|
||||
name: 'wan21:1b',
|
||||
@@ -89,6 +95,7 @@ export const modelArchs: ModelArch[] = [
|
||||
'config.process[0].sample.num_frames': [40, 1],
|
||||
'config.process[0].sample.fps': [15, 1],
|
||||
},
|
||||
disableSections: ['network.conv'],
|
||||
},
|
||||
{
|
||||
name: 'wan21:14b',
|
||||
@@ -104,6 +111,7 @@ export const modelArchs: ModelArch[] = [
|
||||
'config.process[0].sample.num_frames': [40, 1],
|
||||
'config.process[0].sample.fps': [15, 1],
|
||||
},
|
||||
disableSections: ['network.conv'],
|
||||
},
|
||||
{
|
||||
name: 'lumina2',
|
||||
@@ -116,6 +124,7 @@ export const modelArchs: ModelArch[] = [
|
||||
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
|
||||
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
|
||||
},
|
||||
disableSections: ['network.conv'],
|
||||
},
|
||||
{
|
||||
name: 'hidream',
|
||||
@@ -131,5 +140,37 @@ export const modelArchs: ModelArch[] = [
|
||||
'config.process[0].train.timestep_type': ['shift', 'sigmoid'],
|
||||
'config.process[0].network.network_kwargs.ignore_if_contains': [['ff_i.experts', 'ff_i.gate'], []],
|
||||
},
|
||||
disableSections: ['network.conv'],
|
||||
},
|
||||
];
|
||||
{
|
||||
name: 'sdxl',
|
||||
label: 'SDXL',
|
||||
defaults: {
|
||||
// default updates when [selected, unselected] in the UI
|
||||
'config.process[0].model.name_or_path': ['stabilityai/stable-diffusion-xl-base-1.0', defaultNameOrPath],
|
||||
'config.process[0].model.quantize': [false, false],
|
||||
'config.process[0].model.quantize_te': [false, false],
|
||||
'config.process[0].sample.sampler': ['ddpm', 'flowmatch'],
|
||||
'config.process[0].train.noise_scheduler': ['ddpm', 'flowmatch'],
|
||||
'config.process[0].sample.guidance_scale': [6, 4],
|
||||
},
|
||||
disableSections: ['model.quantize', 'train.timestep_type'],
|
||||
},
|
||||
{
|
||||
name: 'sd15',
|
||||
label: 'SD 1.5',
|
||||
defaults: {
|
||||
// default updates when [selected, unselected] in the UI
|
||||
'config.process[0].model.name_or_path': ['stable-diffusion-v1-5/stable-diffusion-v1-5', defaultNameOrPath],
|
||||
'config.process[0].sample.sampler': ['ddpm', 'flowmatch'],
|
||||
'config.process[0].train.noise_scheduler': ['ddpm', 'flowmatch'],
|
||||
'config.process[0].sample.width': [512, 1024],
|
||||
'config.process[0].sample.height': [512, 1024],
|
||||
'config.process[0].sample.guidance_scale': [6, 4],
|
||||
},
|
||||
disableSections: ['model.quantize', 'train.timestep_type'],
|
||||
},
|
||||
].sort((a, b) => {
|
||||
// Sort by label, case-insensitive
|
||||
return a.label.localeCompare(b.label, undefined, { sensitivity: 'base' })
|
||||
}) as any;
|
||||
|
||||
@@ -2,8 +2,8 @@
|
||||
|
||||
import React, { forwardRef } from 'react';
|
||||
import classNames from 'classnames';
|
||||
import dynamic from "next/dynamic";
|
||||
const Select = dynamic(() => import("react-select"), { ssr: false });
|
||||
import dynamic from 'next/dynamic';
|
||||
const Select = dynamic(() => import('react-select'), { ssr: false });
|
||||
|
||||
const labelClasses = 'block text-xs mb-1 mt-2 text-gray-300';
|
||||
const inputClasses =
|
||||
@@ -42,7 +42,7 @@ export const TextInput = forwardRef<HTMLInputElement, TextInputProps>(
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
},
|
||||
);
|
||||
|
||||
// 👇 Helpful for debugging
|
||||
@@ -114,6 +114,7 @@ export const NumberInput = (props: NumberInputProps) => {
|
||||
|
||||
export interface SelectInputProps extends InputProps {
|
||||
value: string;
|
||||
disabled?: boolean;
|
||||
onChange: (value: string) => void;
|
||||
options: { value: string; label: string }[];
|
||||
}
|
||||
@@ -122,11 +123,16 @@ export const SelectInput = (props: SelectInputProps) => {
|
||||
const { label, value, onChange, options } = props;
|
||||
const selectedOption = options.find(option => option.value === value);
|
||||
return (
|
||||
<div className={classNames(props.className)}>
|
||||
<div
|
||||
className={classNames(props.className, {
|
||||
'opacity-30 cursor-not-allowed': props.disabled,
|
||||
})}
|
||||
>
|
||||
{label && <label className={labelClasses}>{label}</label>}
|
||||
<Select
|
||||
value={selectedOption}
|
||||
<Select
|
||||
value={selectedOption}
|
||||
options={options}
|
||||
isDisabled={props.disabled}
|
||||
className="aitk-react-select-container"
|
||||
classNamePrefix="aitk-react-select"
|
||||
onChange={selected => {
|
||||
|
||||
@@ -53,6 +53,8 @@ export interface NetworkConfig {
|
||||
type: string;
|
||||
linear: number;
|
||||
linear_alpha: number;
|
||||
conv: number;
|
||||
conv_alpha: number;
|
||||
lokr_full_rank: boolean;
|
||||
lokr_factor: number;
|
||||
network_kwargs: {
|
||||
|
||||
Reference in New Issue
Block a user