mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-05-01 03:31:35 +00:00
Added LoKr to the ui
This commit is contained in:
@@ -137,7 +137,7 @@ class NetworkConfig:
|
|||||||
self.transformer_only = kwargs.get('transformer_only', True)
|
self.transformer_only = kwargs.get('transformer_only', True)
|
||||||
|
|
||||||
self.lokr_full_rank = kwargs.get('lokr_full_rank', False)
|
self.lokr_full_rank = kwargs.get('lokr_full_rank', False)
|
||||||
if self.lokr_full_rank:
|
if self.lokr_full_rank and self.type.lower() == 'lokr':
|
||||||
self.linear = 9999999999
|
self.linear = 9999999999
|
||||||
self.linear_alpha = 9999999999
|
self.linear_alpha = 9999999999
|
||||||
self.conv = 9999999999
|
self.conv = 9999999999
|
||||||
|
|||||||
@@ -29,6 +29,8 @@ export const defaultJobConfig: JobConfig = {
|
|||||||
type: 'lora',
|
type: 'lora',
|
||||||
linear: 16,
|
linear: 16,
|
||||||
linear_alpha: 16,
|
linear_alpha: 16,
|
||||||
|
lokr_full_rank: true,
|
||||||
|
lokr_factor: -1
|
||||||
},
|
},
|
||||||
save: {
|
save: {
|
||||||
dtype: 'bf16',
|
dtype: 'bf16',
|
||||||
|
|||||||
@@ -227,8 +227,31 @@ export default function TrainingForm() {
|
|||||||
</div>
|
</div>
|
||||||
</FormGroup>
|
</FormGroup>
|
||||||
</Card>
|
</Card>
|
||||||
{jobConfig.config.process[0].network?.type && (
|
<Card title="Target Configuration">
|
||||||
<Card title="LoRA Configuration">
|
<SelectInput
|
||||||
|
label="Target Type"
|
||||||
|
value={jobConfig.config.process[0].network?.type ?? 'lora'}
|
||||||
|
onChange={value => setJobConfig(value, 'config.process[0].network.type')}
|
||||||
|
options={[
|
||||||
|
{ value: 'lora', label: 'LoRA' },
|
||||||
|
{ value: 'lokr', label: 'LoKr' },
|
||||||
|
]}
|
||||||
|
/>
|
||||||
|
{jobConfig.config.process[0].network?.type == 'lokr' && (
|
||||||
|
<SelectInput
|
||||||
|
label="LoKr Factor"
|
||||||
|
value={ `${jobConfig.config.process[0].network?.lokr_factor ?? -1}`}
|
||||||
|
onChange={value => setJobConfig(parseInt(value), 'config.process[0].network.lokr_factor')}
|
||||||
|
options={[
|
||||||
|
{ value: '-1', label: 'Auto' },
|
||||||
|
{ value: '4', label: '4' },
|
||||||
|
{ value: '8', label: '8' },
|
||||||
|
{ value: '16', label: '16' },
|
||||||
|
{ value: '32', label: '32' },
|
||||||
|
]}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
{jobConfig.config.process[0].network?.type == 'lora' && (
|
||||||
<NumberInput
|
<NumberInput
|
||||||
label="Linear Rank"
|
label="Linear Rank"
|
||||||
value={jobConfig.config.process[0].network.linear}
|
value={jobConfig.config.process[0].network.linear}
|
||||||
@@ -242,8 +265,8 @@ export default function TrainingForm() {
|
|||||||
max={1024}
|
max={1024}
|
||||||
required
|
required
|
||||||
/>
|
/>
|
||||||
</Card>
|
)}
|
||||||
)}
|
</Card>
|
||||||
<Card title="Save Configuration">
|
<Card title="Save Configuration">
|
||||||
<SelectInput
|
<SelectInput
|
||||||
label="Data Type"
|
label="Data Type"
|
||||||
@@ -397,7 +420,9 @@ export default function TrainingForm() {
|
|||||||
label="DFE Loss Multiplier"
|
label="DFE Loss Multiplier"
|
||||||
className="pt-2"
|
className="pt-2"
|
||||||
value={jobConfig.config.process[0].train.diff_output_preservation_multiplier as number}
|
value={jobConfig.config.process[0].train.diff_output_preservation_multiplier as number}
|
||||||
onChange={value => setJobConfig(value, 'config.process[0].train.diff_output_preservation_multiplier')}
|
onChange={value =>
|
||||||
|
setJobConfig(value, 'config.process[0].train.diff_output_preservation_multiplier')
|
||||||
|
}
|
||||||
placeholder="eg. 1.0"
|
placeholder="eg. 1.0"
|
||||||
min={0}
|
min={0}
|
||||||
/>
|
/>
|
||||||
|
|||||||
@@ -50,9 +50,11 @@ export interface GPUApiResponse {
|
|||||||
*/
|
*/
|
||||||
|
|
||||||
export interface NetworkConfig {
|
export interface NetworkConfig {
|
||||||
type: 'lora';
|
type: string;
|
||||||
linear: number;
|
linear: number;
|
||||||
linear_alpha: number;
|
linear_alpha: number;
|
||||||
|
lokr_full_rank: boolean;
|
||||||
|
lokr_factor: number;
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface SaveConfig {
|
export interface SaveConfig {
|
||||||
|
|||||||
Reference in New Issue
Block a user