diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 2c118f1c..07b3e2b7 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -137,7 +137,7 @@ class NetworkConfig: self.transformer_only = kwargs.get('transformer_only', True) self.lokr_full_rank = kwargs.get('lokr_full_rank', False) - if self.lokr_full_rank: + if self.lokr_full_rank and self.type.lower() == 'lokr': self.linear = 9999999999 self.linear_alpha = 9999999999 self.conv = 9999999999 diff --git a/ui/src/app/jobs/new/jobConfig.ts b/ui/src/app/jobs/new/jobConfig.ts index d87d12c4..8098171d 100644 --- a/ui/src/app/jobs/new/jobConfig.ts +++ b/ui/src/app/jobs/new/jobConfig.ts @@ -29,6 +29,8 @@ export const defaultJobConfig: JobConfig = { type: 'lora', linear: 16, linear_alpha: 16, + lokr_full_rank: true, + lokr_factor: -1 }, save: { dtype: 'bf16', diff --git a/ui/src/app/jobs/new/page.tsx b/ui/src/app/jobs/new/page.tsx index e3aec77f..4652cf20 100644 --- a/ui/src/app/jobs/new/page.tsx +++ b/ui/src/app/jobs/new/page.tsx @@ -227,8 +227,31 @@ export default function TrainingForm() { - {jobConfig.config.process[0].network?.type && ( - + + setJobConfig(value, 'config.process[0].network.type')} + options={[ + { value: 'lora', label: 'LoRA' }, + { value: 'lokr', label: 'LoKr' }, + ]} + /> + {jobConfig.config.process[0].network?.type == 'lokr' && ( + setJobConfig(parseInt(value), 'config.process[0].network.lokr_factor')} + options={[ + { value: '-1', label: 'Auto' }, + { value: '4', label: '4' }, + { value: '8', label: '8' }, + { value: '16', label: '16' }, + { value: '32', label: '32' }, + ]} + /> + )} + {jobConfig.config.process[0].network?.type == 'lora' && ( - - )} + )} + setJobConfig(value, 'config.process[0].train.diff_output_preservation_multiplier')} + onChange={value => + setJobConfig(value, 'config.process[0].train.diff_output_preservation_multiplier') + } placeholder="eg. 1.0" min={0} /> diff --git a/ui/src/types.ts b/ui/src/types.ts index 16ecb220..0aa7b67b 100644 --- a/ui/src/types.ts +++ b/ui/src/types.ts @@ -50,9 +50,11 @@ export interface GPUApiResponse { */ export interface NetworkConfig { - type: 'lora'; + type: string; linear: number; linear_alpha: number; + lokr_full_rank: boolean; + lokr_factor: number; } export interface SaveConfig {