mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Added LoKr to the ui
This commit is contained in:
@@ -137,7 +137,7 @@ class NetworkConfig:
|
||||
self.transformer_only = kwargs.get('transformer_only', True)
|
||||
|
||||
self.lokr_full_rank = kwargs.get('lokr_full_rank', False)
|
||||
if self.lokr_full_rank:
|
||||
if self.lokr_full_rank and self.type.lower() == 'lokr':
|
||||
self.linear = 9999999999
|
||||
self.linear_alpha = 9999999999
|
||||
self.conv = 9999999999
|
||||
|
||||
@@ -29,6 +29,8 @@ export const defaultJobConfig: JobConfig = {
|
||||
type: 'lora',
|
||||
linear: 16,
|
||||
linear_alpha: 16,
|
||||
lokr_full_rank: true,
|
||||
lokr_factor: -1
|
||||
},
|
||||
save: {
|
||||
dtype: 'bf16',
|
||||
|
||||
@@ -227,8 +227,31 @@ export default function TrainingForm() {
|
||||
</div>
|
||||
</FormGroup>
|
||||
</Card>
|
||||
{jobConfig.config.process[0].network?.type && (
|
||||
<Card title="LoRA Configuration">
|
||||
<Card title="Target Configuration">
|
||||
<SelectInput
|
||||
label="Target Type"
|
||||
value={jobConfig.config.process[0].network?.type ?? 'lora'}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].network.type')}
|
||||
options={[
|
||||
{ value: 'lora', label: 'LoRA' },
|
||||
{ value: 'lokr', label: 'LoKr' },
|
||||
]}
|
||||
/>
|
||||
{jobConfig.config.process[0].network?.type == 'lokr' && (
|
||||
<SelectInput
|
||||
label="LoKr Factor"
|
||||
value={ `${jobConfig.config.process[0].network?.lokr_factor ?? -1}`}
|
||||
onChange={value => setJobConfig(parseInt(value), 'config.process[0].network.lokr_factor')}
|
||||
options={[
|
||||
{ value: '-1', label: 'Auto' },
|
||||
{ value: '4', label: '4' },
|
||||
{ value: '8', label: '8' },
|
||||
{ value: '16', label: '16' },
|
||||
{ value: '32', label: '32' },
|
||||
]}
|
||||
/>
|
||||
)}
|
||||
{jobConfig.config.process[0].network?.type == 'lora' && (
|
||||
<NumberInput
|
||||
label="Linear Rank"
|
||||
value={jobConfig.config.process[0].network.linear}
|
||||
@@ -242,8 +265,8 @@ export default function TrainingForm() {
|
||||
max={1024}
|
||||
required
|
||||
/>
|
||||
</Card>
|
||||
)}
|
||||
)}
|
||||
</Card>
|
||||
<Card title="Save Configuration">
|
||||
<SelectInput
|
||||
label="Data Type"
|
||||
@@ -397,7 +420,9 @@ export default function TrainingForm() {
|
||||
label="DFE Loss Multiplier"
|
||||
className="pt-2"
|
||||
value={jobConfig.config.process[0].train.diff_output_preservation_multiplier as number}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].train.diff_output_preservation_multiplier')}
|
||||
onChange={value =>
|
||||
setJobConfig(value, 'config.process[0].train.diff_output_preservation_multiplier')
|
||||
}
|
||||
placeholder="eg. 1.0"
|
||||
min={0}
|
||||
/>
|
||||
|
||||
@@ -50,9 +50,11 @@ export interface GPUApiResponse {
|
||||
*/
|
||||
|
||||
export interface NetworkConfig {
|
||||
type: 'lora';
|
||||
type: string;
|
||||
linear: number;
|
||||
linear_alpha: number;
|
||||
lokr_full_rank: boolean;
|
||||
lokr_factor: number;
|
||||
}
|
||||
|
||||
export interface SaveConfig {
|
||||
|
||||
Reference in New Issue
Block a user