mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 19:21:39 +00:00
Added Differential Guidance training target
This commit is contained in:
@@ -708,7 +708,11 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
|
|
||||||
unconditional_target = unconditional_target * alpha
|
unconditional_target = unconditional_target * alpha
|
||||||
target = unconditional_target + guidance_scale * (target - unconditional_target)
|
target = unconditional_target + guidance_scale * (target - unconditional_target)
|
||||||
|
|
||||||
|
if self.train_config.do_differential_guidance:
|
||||||
|
with torch.no_grad():
|
||||||
|
guidance_scale = self.train_config.differential_guidance_scale
|
||||||
|
target = noise_pred + guidance_scale * (target - noise_pred)
|
||||||
|
|
||||||
if target is None:
|
if target is None:
|
||||||
target = noise
|
target = noise
|
||||||
|
|||||||
@@ -545,7 +545,10 @@ class TrainConfig:
|
|||||||
self.unconditional_prompt: str = kwargs.get('unconditional_prompt', '')
|
self.unconditional_prompt: str = kwargs.get('unconditional_prompt', '')
|
||||||
if isinstance(self.guidance_loss_target, tuple):
|
if isinstance(self.guidance_loss_target, tuple):
|
||||||
self.guidance_loss_target = list(self.guidance_loss_target)
|
self.guidance_loss_target = list(self.guidance_loss_target)
|
||||||
|
|
||||||
|
self.do_differential_guidance = kwargs.get('do_differential_guidance', False)
|
||||||
|
self.differential_guidance_scale = kwargs.get('differential_guidance_scale', 3.0)
|
||||||
|
|
||||||
# for multi stage models, how often to switch the boundary
|
# for multi stage models, how often to switch the boundary
|
||||||
self.switch_boundary_every: int = kwargs.get('switch_boundary_every', 1)
|
self.switch_boundary_every: int = kwargs.get('switch_boundary_every', 1)
|
||||||
|
|
||||||
|
|||||||
@@ -2907,7 +2907,7 @@ class StableDiffusion:
|
|||||||
try:
|
try:
|
||||||
te_has_grad = encoder.text_model.final_layer_norm.weight.requires_grad
|
te_has_grad = encoder.text_model.final_layer_norm.weight.requires_grad
|
||||||
except:
|
except:
|
||||||
te_has_grad = encoder.encoder.block[0].layer[0].SelfAttention.q.weight.requires_grad
|
te_has_grad = False
|
||||||
self.device_state['text_encoder'].append({
|
self.device_state['text_encoder'].append({
|
||||||
'training': encoder.training,
|
'training': encoder.training,
|
||||||
'device': encoder.device,
|
'device': encoder.device,
|
||||||
|
|||||||
BIN
ui/public/imgs/diff_guidance.png
Normal file
BIN
ui/public/imgs/diff_guidance.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 45 KiB |
@@ -669,6 +669,45 @@ export default function SimpleJob({
|
|||||||
</div>
|
</div>
|
||||||
</Card>
|
</Card>
|
||||||
</div>
|
</div>
|
||||||
|
<div>
|
||||||
|
<Card title="Advanced" collapsible>
|
||||||
|
<div className="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-4 gap-6">
|
||||||
|
<div>
|
||||||
|
<Checkbox
|
||||||
|
label="Do Differential Guidance"
|
||||||
|
docKey={'train.do_differential_guidance'}
|
||||||
|
className="pt-1"
|
||||||
|
checked={jobConfig.config.process[0].train.do_differential_guidance || false}
|
||||||
|
onChange={value => {
|
||||||
|
let newValue = value == false ? undefined : value;
|
||||||
|
setJobConfig(newValue, 'config.process[0].train.do_differential_guidance');
|
||||||
|
if (!newValue) {
|
||||||
|
setJobConfig(undefined, 'config.process[0].train.differential_guidance_scale');
|
||||||
|
} else if (
|
||||||
|
jobConfig.config.process[0].train.differential_guidance_scale === undefined ||
|
||||||
|
jobConfig.config.process[0].train.differential_guidance_scale === null
|
||||||
|
) {
|
||||||
|
// set default differential guidance scale to 3.0
|
||||||
|
setJobConfig(3.0, 'config.process[0].train.differential_guidance_scale');
|
||||||
|
}
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
{jobConfig.config.process[0].train.differential_guidance_scale && (
|
||||||
|
<>
|
||||||
|
<NumberInput
|
||||||
|
label="Differential Guidance Scale"
|
||||||
|
className="pt-2"
|
||||||
|
value={(jobConfig.config.process[0].train.differential_guidance_scale as number) || 3.0}
|
||||||
|
onChange={value => setJobConfig(value, 'config.process[0].train.differential_guidance_scale')}
|
||||||
|
placeholder="eg. 3.0"
|
||||||
|
min={0}
|
||||||
|
/>
|
||||||
|
</>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</Card>
|
||||||
|
</div>
|
||||||
<div>
|
<div>
|
||||||
<Card title="Datasets">
|
<Card title="Datasets">
|
||||||
<>
|
<>
|
||||||
|
|||||||
@@ -1,13 +1,41 @@
|
|||||||
|
import { Disclosure, DisclosureButton, DisclosurePanel } from '@headlessui/react';
|
||||||
|
import { FaChevronDown } from 'react-icons/fa';
|
||||||
|
import classNames from 'classnames';
|
||||||
|
|
||||||
interface CardProps {
|
interface CardProps {
|
||||||
title?: string;
|
title?: string;
|
||||||
children?: React.ReactNode;
|
children?: React.ReactNode;
|
||||||
|
collapsible?: boolean;
|
||||||
|
defaultOpen?: boolean;
|
||||||
}
|
}
|
||||||
|
|
||||||
const Card: React.FC<CardProps> = ({ title, children }) => {
|
const Card: React.FC<CardProps> = ({ title, children, collapsible, defaultOpen }) => {
|
||||||
|
if (collapsible) {
|
||||||
|
return (
|
||||||
|
<Disclosure as="section" className="space-y-2 px-4 pb-2 pt-2 bg-gray-900 rounded-lg" defaultOpen={defaultOpen}>
|
||||||
|
{({ open }) => (
|
||||||
|
<>
|
||||||
|
<DisclosureButton className="w-full text-left flex items-center justify-between">
|
||||||
|
<div className="flex-1">
|
||||||
|
{title && (
|
||||||
|
<h2 className={classNames('text-lg mb-2 font-semibold uppercase text-gray-500', { 'mb-0': !open })}>
|
||||||
|
{title}
|
||||||
|
</h2>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
<FaChevronDown className={`ml-2 inline-block transition-transform ${open ? 'rotate-180' : ''}`} />
|
||||||
|
</DisclosureButton>
|
||||||
|
<DisclosurePanel>{children ?? null}</DisclosurePanel>
|
||||||
|
{open && <div className="pt-2"></div>}
|
||||||
|
</>
|
||||||
|
)}
|
||||||
|
</Disclosure>
|
||||||
|
);
|
||||||
|
}
|
||||||
return (
|
return (
|
||||||
<section className="space-y-2 px-4 pb-4 pt-2 bg-gray-900 rounded-lg">
|
<section className="space-y-2 px-4 pb-4 pt-2 bg-gray-900 rounded-lg">
|
||||||
{title && <h2 className="text-lg mb-2 font-semibold uppercase text-gray-500">{title}</h2>}
|
{title && <h2 className="text-lg mb-2 font-semibold uppercase text-gray-500">{title}</h2>}
|
||||||
{children ? children : null}
|
{children ?? null}
|
||||||
</section>
|
</section>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -258,6 +258,25 @@ const docs: { [key: string]: ConfigDoc } = {
|
|||||||
</>
|
</>
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
|
'train.do_differential_guidance': {
|
||||||
|
title: 'Differential Guidance',
|
||||||
|
description: (
|
||||||
|
<>
|
||||||
|
Differential Guidance will amplify the difference of the model prediction and the target during training to make
|
||||||
|
a new target. Differential Guidance Scale will be the multiplier for the difference. This is still experimental,
|
||||||
|
but in my tests, it makes the model train faster, and learns details better in every scenario I have tried with
|
||||||
|
it.
|
||||||
|
<br />
|
||||||
|
<br />
|
||||||
|
The idea is that normal training inches closer to the target but never actually gets there, because it is
|
||||||
|
limited by the learning rate. With differential guidance, we amplify the difference for a new target beyond the
|
||||||
|
actual target, this would make the model learn to hit or overshoot the target instead of falling short.
|
||||||
|
<br />
|
||||||
|
<br />
|
||||||
|
<img src="/imgs/diff_guidance.png" alt="Differential Guidance Diagram" className="max-w-full mx-auto" />
|
||||||
|
</>
|
||||||
|
),
|
||||||
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
export const getDoc = (key: string | null | undefined): ConfigDoc | null => {
|
export const getDoc = (key: string | null | undefined): ConfigDoc | null => {
|
||||||
|
|||||||
@@ -139,6 +139,8 @@ export interface TrainConfig {
|
|||||||
blank_prompt_preservation_multiplier?: number;
|
blank_prompt_preservation_multiplier?: number;
|
||||||
switch_boundary_every: number;
|
switch_boundary_every: number;
|
||||||
loss_type: 'mse' | 'mae' | 'wavelet' | 'stepped';
|
loss_type: 'mse' | 'mae' | 'wavelet' | 'stepped';
|
||||||
|
do_differential_guidance?: boolean;
|
||||||
|
differential_guidance_scale?: number;
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface QuantizeKwargsConfig {
|
export interface QuantizeKwargsConfig {
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
VERSION = "0.7.3"
|
VERSION = "0.7.4"
|
||||||
Reference in New Issue
Block a user