Added LoKr to the ui

This commit is contained in:
Jaret Burkett
2025-03-02 08:49:01 -07:00
parent b16819f8e7
commit 7ae31c9ae9
4 changed files with 36 additions and 7 deletions

View File

@@ -137,7 +137,7 @@ class NetworkConfig:
self.transformer_only = kwargs.get('transformer_only', True)
self.lokr_full_rank = kwargs.get('lokr_full_rank', False)
if self.lokr_full_rank:
if self.lokr_full_rank and self.type.lower() == 'lokr':
self.linear = 9999999999
self.linear_alpha = 9999999999
self.conv = 9999999999

View File

@@ -29,6 +29,8 @@ export const defaultJobConfig: JobConfig = {
type: 'lora',
linear: 16,
linear_alpha: 16,
lokr_full_rank: true,
lokr_factor: -1
},
save: {
dtype: 'bf16',

View File

@@ -227,8 +227,31 @@ export default function TrainingForm() {
</div>
</FormGroup>
</Card>
{jobConfig.config.process[0].network?.type && (
<Card title="LoRA Configuration">
<Card title="Target Configuration">
<SelectInput
label="Target Type"
value={jobConfig.config.process[0].network?.type ?? 'lora'}
onChange={value => setJobConfig(value, 'config.process[0].network.type')}
options={[
{ value: 'lora', label: 'LoRA' },
{ value: 'lokr', label: 'LoKr' },
]}
/>
{jobConfig.config.process[0].network?.type == 'lokr' && (
<SelectInput
label="LoKr Factor"
value={ `${jobConfig.config.process[0].network?.lokr_factor ?? -1}`}
onChange={value => setJobConfig(parseInt(value), 'config.process[0].network.lokr_factor')}
options={[
{ value: '-1', label: 'Auto' },
{ value: '4', label: '4' },
{ value: '8', label: '8' },
{ value: '16', label: '16' },
{ value: '32', label: '32' },
]}
/>
)}
{jobConfig.config.process[0].network?.type == 'lora' && (
<NumberInput
label="Linear Rank"
value={jobConfig.config.process[0].network.linear}
@@ -242,8 +265,8 @@ export default function TrainingForm() {
max={1024}
required
/>
</Card>
)}
)}
</Card>
<Card title="Save Configuration">
<SelectInput
label="Data Type"
@@ -397,7 +420,9 @@ export default function TrainingForm() {
label="DFE Loss Multiplier"
className="pt-2"
value={jobConfig.config.process[0].train.diff_output_preservation_multiplier as number}
onChange={value => setJobConfig(value, 'config.process[0].train.diff_output_preservation_multiplier')}
onChange={value =>
setJobConfig(value, 'config.process[0].train.diff_output_preservation_multiplier')
}
placeholder="eg. 1.0"
min={0}
/>

View File

@@ -50,9 +50,11 @@ export interface GPUApiResponse {
*/
export interface NetworkConfig {
type: 'lora';
type: string;
linear: number;
linear_alpha: number;
lokr_full_rank: boolean;
lokr_factor: number;
}
export interface SaveConfig {