diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py
index 2c118f1c..07b3e2b7 100644
--- a/toolkit/config_modules.py
+++ b/toolkit/config_modules.py
@@ -137,7 +137,7 @@ class NetworkConfig:
self.transformer_only = kwargs.get('transformer_only', True)
self.lokr_full_rank = kwargs.get('lokr_full_rank', False)
- if self.lokr_full_rank:
+ if self.lokr_full_rank and self.type.lower() == 'lokr':
self.linear = 9999999999
self.linear_alpha = 9999999999
self.conv = 9999999999
diff --git a/ui/src/app/jobs/new/jobConfig.ts b/ui/src/app/jobs/new/jobConfig.ts
index d87d12c4..8098171d 100644
--- a/ui/src/app/jobs/new/jobConfig.ts
+++ b/ui/src/app/jobs/new/jobConfig.ts
@@ -29,6 +29,8 @@ export const defaultJobConfig: JobConfig = {
type: 'lora',
linear: 16,
linear_alpha: 16,
+ lokr_full_rank: true,
+ lokr_factor: -1
},
save: {
dtype: 'bf16',
diff --git a/ui/src/app/jobs/new/page.tsx b/ui/src/app/jobs/new/page.tsx
index e3aec77f..4652cf20 100644
--- a/ui/src/app/jobs/new/page.tsx
+++ b/ui/src/app/jobs/new/page.tsx
@@ -227,8 +227,31 @@ export default function TrainingForm() {
- {jobConfig.config.process[0].network?.type && (
-
+
+ setJobConfig(value, 'config.process[0].network.type')}
+ options={[
+ { value: 'lora', label: 'LoRA' },
+ { value: 'lokr', label: 'LoKr' },
+ ]}
+ />
+ {jobConfig.config.process[0].network?.type == 'lokr' && (
+ setJobConfig(parseInt(value), 'config.process[0].network.lokr_factor')}
+ options={[
+ { value: '-1', label: 'Auto' },
+ { value: '4', label: '4' },
+ { value: '8', label: '8' },
+ { value: '16', label: '16' },
+ { value: '32', label: '32' },
+ ]}
+ />
+ )}
+ {jobConfig.config.process[0].network?.type == 'lora' && (
-
- )}
+ )}
+
setJobConfig(value, 'config.process[0].train.diff_output_preservation_multiplier')}
+ onChange={value =>
+ setJobConfig(value, 'config.process[0].train.diff_output_preservation_multiplier')
+ }
placeholder="eg. 1.0"
min={0}
/>
diff --git a/ui/src/types.ts b/ui/src/types.ts
index 16ecb220..0aa7b67b 100644
--- a/ui/src/types.ts
+++ b/ui/src/types.ts
@@ -50,9 +50,11 @@ export interface GPUApiResponse {
*/
export interface NetworkConfig {
- type: 'lora';
+ type: string;
linear: number;
linear_alpha: number;
+ lokr_full_rank: boolean;
+ lokr_factor: number;
}
export interface SaveConfig {