mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-05-01 03:31:35 +00:00
Added support for sdxl and sd1.5 to the ui.
This commit is contained in:
@@ -414,3 +414,12 @@ To learn more about LoKr, read more about it at [KohakuBlueleaf/LyCORIS](https:/
|
|||||||
|
|
||||||
Everything else should work the same including layer targeting.
|
Everything else should work the same including layer targeting.
|
||||||
|
|
||||||
|
|
||||||
|
## Updates
|
||||||
|
|
||||||
|
### June 10, 2024
|
||||||
|
- Decided to keep track up updates in the readme
|
||||||
|
- Added support for SDXL in the UI
|
||||||
|
- Added support for SD 1.5 in the UI
|
||||||
|
- Fixed UI Wan 2.1 14b name bug
|
||||||
|
- Added support for for conv training in the UI for models that support it
|
||||||
@@ -33,7 +33,6 @@ export default function SimpleJob({
|
|||||||
gpuList,
|
gpuList,
|
||||||
datasetOptions,
|
datasetOptions,
|
||||||
}: Props) {
|
}: Props) {
|
||||||
|
|
||||||
const modelArch = useMemo(() => {
|
const modelArch = useMemo(() => {
|
||||||
return modelArchs.find(a => a.name === jobConfig.config.process[0].model.arch) as ModelArch;
|
return modelArchs.find(a => a.name === jobConfig.config.process[0].model.arch) as ModelArch;
|
||||||
}, [jobConfig.config.process[0].model.arch]);
|
}, [jobConfig.config.process[0].model.arch]);
|
||||||
@@ -104,8 +103,7 @@ export default function SimpleJob({
|
|||||||
const newDataset = objectCopy(dataset);
|
const newDataset = objectCopy(dataset);
|
||||||
newDataset.controls = controls;
|
newDataset.controls = controls;
|
||||||
return newDataset;
|
return newDataset;
|
||||||
}
|
});
|
||||||
);
|
|
||||||
setJobConfig(datasets, 'config.process[0].datasets');
|
setJobConfig(datasets, 'config.process[0].datasets');
|
||||||
}}
|
}}
|
||||||
options={
|
options={
|
||||||
@@ -131,20 +129,22 @@ export default function SimpleJob({
|
|||||||
placeholder=""
|
placeholder=""
|
||||||
required
|
required
|
||||||
/>
|
/>
|
||||||
<FormGroup label="Quantize">
|
{modelArch?.disableSections?.includes('model.quantize') ? null : (
|
||||||
<div className="grid grid-cols-2 gap-2">
|
<FormGroup label="Quantize">
|
||||||
<Checkbox
|
<div className="grid grid-cols-2 gap-2">
|
||||||
label="Transformer"
|
<Checkbox
|
||||||
checked={jobConfig.config.process[0].model.quantize}
|
label="Transformer"
|
||||||
onChange={value => setJobConfig(value, 'config.process[0].model.quantize')}
|
checked={jobConfig.config.process[0].model.quantize}
|
||||||
/>
|
onChange={value => setJobConfig(value, 'config.process[0].model.quantize')}
|
||||||
<Checkbox
|
/>
|
||||||
label="Text Encoder"
|
<Checkbox
|
||||||
checked={jobConfig.config.process[0].model.quantize_te}
|
label="Text Encoder"
|
||||||
onChange={value => setJobConfig(value, 'config.process[0].model.quantize_te')}
|
checked={jobConfig.config.process[0].model.quantize_te}
|
||||||
/>
|
onChange={value => setJobConfig(value, 'config.process[0].model.quantize_te')}
|
||||||
</div>
|
/>
|
||||||
</FormGroup>
|
</div>
|
||||||
|
</FormGroup>
|
||||||
|
)}
|
||||||
</Card>
|
</Card>
|
||||||
<Card title="Target Configuration">
|
<Card title="Target Configuration">
|
||||||
<SelectInput
|
<SelectInput
|
||||||
@@ -171,19 +171,37 @@ export default function SimpleJob({
|
|||||||
/>
|
/>
|
||||||
)}
|
)}
|
||||||
{jobConfig.config.process[0].network?.type == 'lora' && (
|
{jobConfig.config.process[0].network?.type == 'lora' && (
|
||||||
<NumberInput
|
<>
|
||||||
label="Linear Rank"
|
<NumberInput
|
||||||
value={jobConfig.config.process[0].network.linear}
|
label="Linear Rank"
|
||||||
onChange={value => {
|
value={jobConfig.config.process[0].network.linear}
|
||||||
console.log('onChange', value);
|
onChange={value => {
|
||||||
setJobConfig(value, 'config.process[0].network.linear');
|
console.log('onChange', value);
|
||||||
setJobConfig(value, 'config.process[0].network.linear_alpha');
|
setJobConfig(value, 'config.process[0].network.linear');
|
||||||
}}
|
setJobConfig(value, 'config.process[0].network.linear_alpha');
|
||||||
placeholder="eg. 16"
|
}}
|
||||||
min={0}
|
placeholder="eg. 16"
|
||||||
max={1024}
|
min={0}
|
||||||
required
|
max={1024}
|
||||||
/>
|
required
|
||||||
|
/>
|
||||||
|
{
|
||||||
|
modelArch?.disableSections?.includes('network.conv') ? null : (
|
||||||
|
<NumberInput
|
||||||
|
label="Conv Rank"
|
||||||
|
value={jobConfig.config.process[0].network.conv}
|
||||||
|
onChange={value => {
|
||||||
|
console.log('onChange', value);
|
||||||
|
setJobConfig(value, 'config.process[0].network.conv');
|
||||||
|
setJobConfig(value, 'config.process[0].network.conv_alpha');
|
||||||
|
}}
|
||||||
|
placeholder="eg. 16"
|
||||||
|
min={0}
|
||||||
|
max={1024}
|
||||||
|
/>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
</>
|
||||||
)}
|
)}
|
||||||
</Card>
|
</Card>
|
||||||
<Card title="Save Configuration">
|
<Card title="Save Configuration">
|
||||||
@@ -276,16 +294,19 @@ export default function SimpleJob({
|
|||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
<div>
|
<div>
|
||||||
<SelectInput
|
{modelArch?.disableSections?.includes('train.timestep_type') ? null : (
|
||||||
label="Timestep Type"
|
<SelectInput
|
||||||
value={jobConfig.config.process[0].train.timestep_type}
|
label="Timestep Type"
|
||||||
onChange={value => setJobConfig(value, 'config.process[0].train.timestep_type')}
|
value={jobConfig.config.process[0].train.timestep_type}
|
||||||
options={[
|
disabled={modelArch?.disableSections?.includes('train.timestep_type') || false}
|
||||||
{ value: 'sigmoid', label: 'Sigmoid' },
|
onChange={value => setJobConfig(value, 'config.process[0].train.timestep_type')}
|
||||||
{ value: 'linear', label: 'Linear' },
|
options={[
|
||||||
{ value: 'shift', label: 'Shift' },
|
{ value: 'sigmoid', label: 'Sigmoid' },
|
||||||
]}
|
{ value: 'linear', label: 'Linear' },
|
||||||
/>
|
{ value: 'shift', label: 'Shift' },
|
||||||
|
]}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
<SelectInput
|
<SelectInput
|
||||||
label="Timestep Bias"
|
label="Timestep Bias"
|
||||||
className="pt-2"
|
className="pt-2"
|
||||||
@@ -466,10 +487,7 @@ export default function SimpleJob({
|
|||||||
// automaticallt add the controls for a new dataset
|
// automaticallt add the controls for a new dataset
|
||||||
const controls = modelArch?.controls ?? [];
|
const controls = modelArch?.controls ?? [];
|
||||||
newDataset.controls = controls;
|
newDataset.controls = controls;
|
||||||
setJobConfig(
|
setJobConfig([...jobConfig.config.process[0].datasets, newDataset], 'config.process[0].datasets');
|
||||||
[...jobConfig.config.process[0].datasets, newDataset],
|
|
||||||
'config.process[0].datasets',
|
|
||||||
)
|
|
||||||
}}
|
}}
|
||||||
className="w-full px-4 py-2 bg-gray-700 hover:bg-gray-600 rounded-lg transition-colors"
|
className="w-full px-4 py-2 bg-gray-700 hover:bg-gray-600 rounded-lg transition-colors"
|
||||||
>
|
>
|
||||||
|
|||||||
@@ -30,6 +30,8 @@ export const defaultJobConfig: JobConfig = {
|
|||||||
type: 'lora',
|
type: 'lora',
|
||||||
linear: 32,
|
linear: 32,
|
||||||
linear_alpha: 32,
|
linear_alpha: 32,
|
||||||
|
conv: 16,
|
||||||
|
conv_alpha: 16,
|
||||||
lokr_full_rank: true,
|
lokr_full_rank: true,
|
||||||
lokr_factor: -1,
|
lokr_factor: -1,
|
||||||
network_kwargs: {
|
network_kwargs: {
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
|
|
||||||
type Control = 'depth' | 'line' | 'pose' | 'inpaint';
|
type Control = 'depth' | 'line' | 'pose' | 'inpaint';
|
||||||
|
|
||||||
export interface ModelArch {
|
export interface ModelArch {
|
||||||
@@ -6,11 +5,14 @@ export interface ModelArch {
|
|||||||
label: string;
|
label: string;
|
||||||
controls?: Control[];
|
controls?: Control[];
|
||||||
isVideoModel?: boolean;
|
isVideoModel?: boolean;
|
||||||
defaults?: { [key: string]: [any, any] };
|
defaults?: { [key: string]: any };
|
||||||
|
disableSections?: DisableableSections[];
|
||||||
}
|
}
|
||||||
|
|
||||||
const defaultNameOrPath = '';
|
const defaultNameOrPath = '';
|
||||||
|
|
||||||
|
type DisableableSections = 'model.quantize' | 'train.timestep_type' | 'network.conv';
|
||||||
|
|
||||||
export const modelArchs: ModelArch[] = [
|
export const modelArchs: ModelArch[] = [
|
||||||
{
|
{
|
||||||
name: 'flux',
|
name: 'flux',
|
||||||
@@ -23,6 +25,7 @@ export const modelArchs: ModelArch[] = [
|
|||||||
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
|
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
|
||||||
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
|
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
|
||||||
},
|
},
|
||||||
|
disableSections: ['network.conv'],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: 'flex1',
|
name: 'flex1',
|
||||||
@@ -36,6 +39,7 @@ export const modelArchs: ModelArch[] = [
|
|||||||
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
|
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
|
||||||
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
|
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
|
||||||
},
|
},
|
||||||
|
disableSections: ['network.conv'],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: 'flex2',
|
name: 'flex2',
|
||||||
@@ -62,6 +66,7 @@ export const modelArchs: ModelArch[] = [
|
|||||||
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
|
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
|
||||||
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
|
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
|
||||||
},
|
},
|
||||||
|
disableSections: ['network.conv'],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: 'chroma',
|
name: 'chroma',
|
||||||
@@ -74,6 +79,7 @@ export const modelArchs: ModelArch[] = [
|
|||||||
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
|
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
|
||||||
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
|
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
|
||||||
},
|
},
|
||||||
|
disableSections: ['network.conv'],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: 'wan21:1b',
|
name: 'wan21:1b',
|
||||||
@@ -89,6 +95,7 @@ export const modelArchs: ModelArch[] = [
|
|||||||
'config.process[0].sample.num_frames': [40, 1],
|
'config.process[0].sample.num_frames': [40, 1],
|
||||||
'config.process[0].sample.fps': [15, 1],
|
'config.process[0].sample.fps': [15, 1],
|
||||||
},
|
},
|
||||||
|
disableSections: ['network.conv'],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: 'wan21:14b',
|
name: 'wan21:14b',
|
||||||
@@ -104,6 +111,7 @@ export const modelArchs: ModelArch[] = [
|
|||||||
'config.process[0].sample.num_frames': [40, 1],
|
'config.process[0].sample.num_frames': [40, 1],
|
||||||
'config.process[0].sample.fps': [15, 1],
|
'config.process[0].sample.fps': [15, 1],
|
||||||
},
|
},
|
||||||
|
disableSections: ['network.conv'],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: 'lumina2',
|
name: 'lumina2',
|
||||||
@@ -116,6 +124,7 @@ export const modelArchs: ModelArch[] = [
|
|||||||
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
|
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
|
||||||
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
|
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
|
||||||
},
|
},
|
||||||
|
disableSections: ['network.conv'],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: 'hidream',
|
name: 'hidream',
|
||||||
@@ -131,5 +140,37 @@ export const modelArchs: ModelArch[] = [
|
|||||||
'config.process[0].train.timestep_type': ['shift', 'sigmoid'],
|
'config.process[0].train.timestep_type': ['shift', 'sigmoid'],
|
||||||
'config.process[0].network.network_kwargs.ignore_if_contains': [['ff_i.experts', 'ff_i.gate'], []],
|
'config.process[0].network.network_kwargs.ignore_if_contains': [['ff_i.experts', 'ff_i.gate'], []],
|
||||||
},
|
},
|
||||||
|
disableSections: ['network.conv'],
|
||||||
},
|
},
|
||||||
];
|
{
|
||||||
|
name: 'sdxl',
|
||||||
|
label: 'SDXL',
|
||||||
|
defaults: {
|
||||||
|
// default updates when [selected, unselected] in the UI
|
||||||
|
'config.process[0].model.name_or_path': ['stabilityai/stable-diffusion-xl-base-1.0', defaultNameOrPath],
|
||||||
|
'config.process[0].model.quantize': [false, false],
|
||||||
|
'config.process[0].model.quantize_te': [false, false],
|
||||||
|
'config.process[0].sample.sampler': ['ddpm', 'flowmatch'],
|
||||||
|
'config.process[0].train.noise_scheduler': ['ddpm', 'flowmatch'],
|
||||||
|
'config.process[0].sample.guidance_scale': [6, 4],
|
||||||
|
},
|
||||||
|
disableSections: ['model.quantize', 'train.timestep_type'],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: 'sd15',
|
||||||
|
label: 'SD 1.5',
|
||||||
|
defaults: {
|
||||||
|
// default updates when [selected, unselected] in the UI
|
||||||
|
'config.process[0].model.name_or_path': ['stable-diffusion-v1-5/stable-diffusion-v1-5', defaultNameOrPath],
|
||||||
|
'config.process[0].sample.sampler': ['ddpm', 'flowmatch'],
|
||||||
|
'config.process[0].train.noise_scheduler': ['ddpm', 'flowmatch'],
|
||||||
|
'config.process[0].sample.width': [512, 1024],
|
||||||
|
'config.process[0].sample.height': [512, 1024],
|
||||||
|
'config.process[0].sample.guidance_scale': [6, 4],
|
||||||
|
},
|
||||||
|
disableSections: ['model.quantize', 'train.timestep_type'],
|
||||||
|
},
|
||||||
|
].sort((a, b) => {
|
||||||
|
// Sort by label, case-insensitive
|
||||||
|
return a.label.localeCompare(b.label, undefined, { sensitivity: 'base' })
|
||||||
|
}) as any;
|
||||||
|
|||||||
@@ -2,8 +2,8 @@
|
|||||||
|
|
||||||
import React, { forwardRef } from 'react';
|
import React, { forwardRef } from 'react';
|
||||||
import classNames from 'classnames';
|
import classNames from 'classnames';
|
||||||
import dynamic from "next/dynamic";
|
import dynamic from 'next/dynamic';
|
||||||
const Select = dynamic(() => import("react-select"), { ssr: false });
|
const Select = dynamic(() => import('react-select'), { ssr: false });
|
||||||
|
|
||||||
const labelClasses = 'block text-xs mb-1 mt-2 text-gray-300';
|
const labelClasses = 'block text-xs mb-1 mt-2 text-gray-300';
|
||||||
const inputClasses =
|
const inputClasses =
|
||||||
@@ -42,7 +42,7 @@ export const TextInput = forwardRef<HTMLInputElement, TextInputProps>(
|
|||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
);
|
);
|
||||||
}
|
},
|
||||||
);
|
);
|
||||||
|
|
||||||
// 👇 Helpful for debugging
|
// 👇 Helpful for debugging
|
||||||
@@ -114,6 +114,7 @@ export const NumberInput = (props: NumberInputProps) => {
|
|||||||
|
|
||||||
export interface SelectInputProps extends InputProps {
|
export interface SelectInputProps extends InputProps {
|
||||||
value: string;
|
value: string;
|
||||||
|
disabled?: boolean;
|
||||||
onChange: (value: string) => void;
|
onChange: (value: string) => void;
|
||||||
options: { value: string; label: string }[];
|
options: { value: string; label: string }[];
|
||||||
}
|
}
|
||||||
@@ -122,11 +123,16 @@ export const SelectInput = (props: SelectInputProps) => {
|
|||||||
const { label, value, onChange, options } = props;
|
const { label, value, onChange, options } = props;
|
||||||
const selectedOption = options.find(option => option.value === value);
|
const selectedOption = options.find(option => option.value === value);
|
||||||
return (
|
return (
|
||||||
<div className={classNames(props.className)}>
|
<div
|
||||||
|
className={classNames(props.className, {
|
||||||
|
'opacity-30 cursor-not-allowed': props.disabled,
|
||||||
|
})}
|
||||||
|
>
|
||||||
{label && <label className={labelClasses}>{label}</label>}
|
{label && <label className={labelClasses}>{label}</label>}
|
||||||
<Select
|
<Select
|
||||||
value={selectedOption}
|
value={selectedOption}
|
||||||
options={options}
|
options={options}
|
||||||
|
isDisabled={props.disabled}
|
||||||
className="aitk-react-select-container"
|
className="aitk-react-select-container"
|
||||||
classNamePrefix="aitk-react-select"
|
classNamePrefix="aitk-react-select"
|
||||||
onChange={selected => {
|
onChange={selected => {
|
||||||
|
|||||||
@@ -53,6 +53,8 @@ export interface NetworkConfig {
|
|||||||
type: string;
|
type: string;
|
||||||
linear: number;
|
linear: number;
|
||||||
linear_alpha: number;
|
linear_alpha: number;
|
||||||
|
conv: number;
|
||||||
|
conv_alpha: number;
|
||||||
lokr_full_rank: boolean;
|
lokr_full_rank: boolean;
|
||||||
lokr_factor: number;
|
lokr_factor: number;
|
||||||
network_kwargs: {
|
network_kwargs: {
|
||||||
|
|||||||
Reference in New Issue
Block a user