mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Added Differential Guidance training target
This commit is contained in:
@@ -708,7 +708,11 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
|
||||
unconditional_target = unconditional_target * alpha
|
||||
target = unconditional_target + guidance_scale * (target - unconditional_target)
|
||||
|
||||
|
||||
if self.train_config.do_differential_guidance:
|
||||
with torch.no_grad():
|
||||
guidance_scale = self.train_config.differential_guidance_scale
|
||||
target = noise_pred + guidance_scale * (target - noise_pred)
|
||||
|
||||
if target is None:
|
||||
target = noise
|
||||
|
||||
@@ -545,7 +545,10 @@ class TrainConfig:
|
||||
self.unconditional_prompt: str = kwargs.get('unconditional_prompt', '')
|
||||
if isinstance(self.guidance_loss_target, tuple):
|
||||
self.guidance_loss_target = list(self.guidance_loss_target)
|
||||
|
||||
|
||||
self.do_differential_guidance = kwargs.get('do_differential_guidance', False)
|
||||
self.differential_guidance_scale = kwargs.get('differential_guidance_scale', 3.0)
|
||||
|
||||
# for multi stage models, how often to switch the boundary
|
||||
self.switch_boundary_every: int = kwargs.get('switch_boundary_every', 1)
|
||||
|
||||
|
||||
@@ -2907,7 +2907,7 @@ class StableDiffusion:
|
||||
try:
|
||||
te_has_grad = encoder.text_model.final_layer_norm.weight.requires_grad
|
||||
except:
|
||||
te_has_grad = encoder.encoder.block[0].layer[0].SelfAttention.q.weight.requires_grad
|
||||
te_has_grad = False
|
||||
self.device_state['text_encoder'].append({
|
||||
'training': encoder.training,
|
||||
'device': encoder.device,
|
||||
|
||||
BIN
ui/public/imgs/diff_guidance.png
Normal file
BIN
ui/public/imgs/diff_guidance.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 45 KiB |
@@ -669,6 +669,45 @@ export default function SimpleJob({
|
||||
</div>
|
||||
</Card>
|
||||
</div>
|
||||
<div>
|
||||
<Card title="Advanced" collapsible>
|
||||
<div className="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-4 gap-6">
|
||||
<div>
|
||||
<Checkbox
|
||||
label="Do Differential Guidance"
|
||||
docKey={'train.do_differential_guidance'}
|
||||
className="pt-1"
|
||||
checked={jobConfig.config.process[0].train.do_differential_guidance || false}
|
||||
onChange={value => {
|
||||
let newValue = value == false ? undefined : value;
|
||||
setJobConfig(newValue, 'config.process[0].train.do_differential_guidance');
|
||||
if (!newValue) {
|
||||
setJobConfig(undefined, 'config.process[0].train.differential_guidance_scale');
|
||||
} else if (
|
||||
jobConfig.config.process[0].train.differential_guidance_scale === undefined ||
|
||||
jobConfig.config.process[0].train.differential_guidance_scale === null
|
||||
) {
|
||||
// set default differential guidance scale to 3.0
|
||||
setJobConfig(3.0, 'config.process[0].train.differential_guidance_scale');
|
||||
}
|
||||
}}
|
||||
/>
|
||||
{jobConfig.config.process[0].train.differential_guidance_scale && (
|
||||
<>
|
||||
<NumberInput
|
||||
label="Differential Guidance Scale"
|
||||
className="pt-2"
|
||||
value={(jobConfig.config.process[0].train.differential_guidance_scale as number) || 3.0}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].train.differential_guidance_scale')}
|
||||
placeholder="eg. 3.0"
|
||||
min={0}
|
||||
/>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</Card>
|
||||
</div>
|
||||
<div>
|
||||
<Card title="Datasets">
|
||||
<>
|
||||
|
||||
@@ -1,13 +1,41 @@
|
||||
import { Disclosure, DisclosureButton, DisclosurePanel } from '@headlessui/react';
|
||||
import { FaChevronDown } from 'react-icons/fa';
|
||||
import classNames from 'classnames';
|
||||
|
||||
interface CardProps {
|
||||
title?: string;
|
||||
children?: React.ReactNode;
|
||||
collapsible?: boolean;
|
||||
defaultOpen?: boolean;
|
||||
}
|
||||
|
||||
const Card: React.FC<CardProps> = ({ title, children }) => {
|
||||
const Card: React.FC<CardProps> = ({ title, children, collapsible, defaultOpen }) => {
|
||||
if (collapsible) {
|
||||
return (
|
||||
<Disclosure as="section" className="space-y-2 px-4 pb-2 pt-2 bg-gray-900 rounded-lg" defaultOpen={defaultOpen}>
|
||||
{({ open }) => (
|
||||
<>
|
||||
<DisclosureButton className="w-full text-left flex items-center justify-between">
|
||||
<div className="flex-1">
|
||||
{title && (
|
||||
<h2 className={classNames('text-lg mb-2 font-semibold uppercase text-gray-500', { 'mb-0': !open })}>
|
||||
{title}
|
||||
</h2>
|
||||
)}
|
||||
</div>
|
||||
<FaChevronDown className={`ml-2 inline-block transition-transform ${open ? 'rotate-180' : ''}`} />
|
||||
</DisclosureButton>
|
||||
<DisclosurePanel>{children ?? null}</DisclosurePanel>
|
||||
{open && <div className="pt-2"></div>}
|
||||
</>
|
||||
)}
|
||||
</Disclosure>
|
||||
);
|
||||
}
|
||||
return (
|
||||
<section className="space-y-2 px-4 pb-4 pt-2 bg-gray-900 rounded-lg">
|
||||
{title && <h2 className="text-lg mb-2 font-semibold uppercase text-gray-500">{title}</h2>}
|
||||
{children ? children : null}
|
||||
{children ?? null}
|
||||
</section>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -258,6 +258,25 @@ const docs: { [key: string]: ConfigDoc } = {
|
||||
</>
|
||||
),
|
||||
},
|
||||
'train.do_differential_guidance': {
|
||||
title: 'Differential Guidance',
|
||||
description: (
|
||||
<>
|
||||
Differential Guidance will amplify the difference of the model prediction and the target during training to make
|
||||
a new target. Differential Guidance Scale will be the multiplier for the difference. This is still experimental,
|
||||
but in my tests, it makes the model train faster, and learns details better in every scenario I have tried with
|
||||
it.
|
||||
<br />
|
||||
<br />
|
||||
The idea is that normal training inches closer to the target but never actually gets there, because it is
|
||||
limited by the learning rate. With differential guidance, we amplify the difference for a new target beyond the
|
||||
actual target, this would make the model learn to hit or overshoot the target instead of falling short.
|
||||
<br />
|
||||
<br />
|
||||
<img src="/imgs/diff_guidance.png" alt="Differential Guidance Diagram" className="max-w-full mx-auto" />
|
||||
</>
|
||||
),
|
||||
},
|
||||
};
|
||||
|
||||
export const getDoc = (key: string | null | undefined): ConfigDoc | null => {
|
||||
|
||||
@@ -139,6 +139,8 @@ export interface TrainConfig {
|
||||
blank_prompt_preservation_multiplier?: number;
|
||||
switch_boundary_every: number;
|
||||
loss_type: 'mse' | 'mae' | 'wavelet' | 'stepped';
|
||||
do_differential_guidance?: boolean;
|
||||
differential_guidance_scale?: number;
|
||||
}
|
||||
|
||||
export interface QuantizeKwargsConfig {
|
||||
|
||||
@@ -1 +1 @@
|
||||
VERSION = "0.7.3"
|
||||
VERSION = "0.7.4"
|
||||
Reference in New Issue
Block a user