mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 19:21:39 +00:00
Add vram flag to some models in the ui
This commit is contained in:
@@ -86,13 +86,19 @@ export default function SimpleJob({
|
|||||||
if (!currentArch || currentArch.name === value) {
|
if (!currentArch || currentArch.name === value) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
// update the defaults when a model is selected
|
||||||
|
const newArch = modelArchs.find(model => model.name === value);
|
||||||
|
|
||||||
|
// update vram setting
|
||||||
|
if (!(newArch?.additionalSections?.includes('model.low_vram'))) {
|
||||||
|
setJobConfig(false, 'config.process[0].model.low_vram');
|
||||||
|
}
|
||||||
|
|
||||||
// revert defaults from previous model
|
// revert defaults from previous model
|
||||||
for (const key in currentArch.defaults) {
|
for (const key in currentArch.defaults) {
|
||||||
setJobConfig(currentArch.defaults[key][1], key);
|
setJobConfig(currentArch.defaults[key][1], key);
|
||||||
}
|
}
|
||||||
// update the defaults when a model is selected
|
|
||||||
const newArch = modelArchs.find(model => model.name === value);
|
|
||||||
if (newArch?.defaults) {
|
if (newArch?.defaults) {
|
||||||
for (const key in newArch.defaults) {
|
for (const key in newArch.defaults) {
|
||||||
setJobConfig(newArch.defaults[key][0], key);
|
setJobConfig(newArch.defaults[key][0], key);
|
||||||
@@ -160,6 +166,15 @@ export default function SimpleJob({
|
|||||||
</div>
|
</div>
|
||||||
</FormGroup>
|
</FormGroup>
|
||||||
)}
|
)}
|
||||||
|
{modelArch?.additionalSections?.includes('model.low_vram') && (
|
||||||
|
<FormGroup label="Options">
|
||||||
|
<Checkbox
|
||||||
|
label="Low VRAM"
|
||||||
|
checked={jobConfig.config.process[0].model.low_vram}
|
||||||
|
onChange={value => setJobConfig(value, 'config.process[0].model.low_vram')}
|
||||||
|
/>
|
||||||
|
</FormGroup>
|
||||||
|
)}
|
||||||
</Card>
|
</Card>
|
||||||
<Card title="Target Configuration">
|
<Card title="Target Configuration">
|
||||||
<SelectInput
|
<SelectInput
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ import { GroupedSelectOption } from "@/types";
|
|||||||
type Control = 'depth' | 'line' | 'pose' | 'inpaint';
|
type Control = 'depth' | 'line' | 'pose' | 'inpaint';
|
||||||
|
|
||||||
type DisableableSections = 'model.quantize' | 'train.timestep_type' | 'network.conv';
|
type DisableableSections = 'model.quantize' | 'train.timestep_type' | 'network.conv';
|
||||||
type AdditionalSections = 'datasets.control_path' | 'sample.ctrl_img' | 'datasets.num_frames';
|
type AdditionalSections = 'datasets.control_path' | 'sample.ctrl_img' | 'datasets.num_frames' | 'model.low_vram';
|
||||||
type ModelGroup = 'image' | 'video';
|
type ModelGroup = 'image' | 'video';
|
||||||
|
|
||||||
export interface ModelArch {
|
export interface ModelArch {
|
||||||
@@ -126,7 +126,7 @@ export const modelArchs: ModelArch[] = [
|
|||||||
'config.process[0].sample.fps': [15, 1],
|
'config.process[0].sample.fps': [15, 1],
|
||||||
},
|
},
|
||||||
disableSections: ['network.conv'],
|
disableSections: ['network.conv'],
|
||||||
additionalSections: ['datasets.num_frames'],
|
additionalSections: ['datasets.num_frames', 'model.low_vram'],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: 'wan21_i2v:14b480p',
|
name: 'wan21_i2v:14b480p',
|
||||||
@@ -145,7 +145,7 @@ export const modelArchs: ModelArch[] = [
|
|||||||
'config.process[0].train.timestep_type': ['weighted', 'sigmoid'],
|
'config.process[0].train.timestep_type': ['weighted', 'sigmoid'],
|
||||||
},
|
},
|
||||||
disableSections: ['network.conv'],
|
disableSections: ['network.conv'],
|
||||||
additionalSections: ['sample.ctrl_img', 'datasets.num_frames'],
|
additionalSections: ['sample.ctrl_img', 'datasets.num_frames', 'model.low_vram'],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: 'wan21_i2v:14b',
|
name: 'wan21_i2v:14b',
|
||||||
@@ -164,7 +164,7 @@ export const modelArchs: ModelArch[] = [
|
|||||||
'config.process[0].train.timestep_type': ['weighted', 'sigmoid'],
|
'config.process[0].train.timestep_type': ['weighted', 'sigmoid'],
|
||||||
},
|
},
|
||||||
disableSections: ['network.conv'],
|
disableSections: ['network.conv'],
|
||||||
additionalSections: ['sample.ctrl_img', 'datasets.num_frames'],
|
additionalSections: ['sample.ctrl_img', 'datasets.num_frames', 'model.low_vram'],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: 'wan21:14b',
|
name: 'wan21:14b',
|
||||||
@@ -182,7 +182,7 @@ export const modelArchs: ModelArch[] = [
|
|||||||
'config.process[0].sample.fps': [15, 1],
|
'config.process[0].sample.fps': [15, 1],
|
||||||
},
|
},
|
||||||
disableSections: ['network.conv'],
|
disableSections: ['network.conv'],
|
||||||
additionalSections: ['datasets.num_frames'],
|
additionalSections: ['datasets.num_frames', 'model.low_vram'],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: 'lumina2',
|
name: 'lumina2',
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
VERSION = "0.3.9"
|
VERSION = "0.3.10"
|
||||||
Reference in New Issue
Block a user