mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Change auto_memory to be layer_offloading and allow you to set the amount to unload
This commit is contained in:
@@ -125,8 +125,12 @@ class QwenImageModel(BaseModel):
|
||||
quantize_model(self, transformer)
|
||||
flush()
|
||||
|
||||
if self.model_config.auto_memory:
|
||||
MemoryManager.attach(transformer, self.device_torch)
|
||||
if self.model_config.layer_offloading and self.model_config.layer_offloading_transformer_percent > 0:
|
||||
MemoryManager.attach(
|
||||
transformer,
|
||||
self.device_torch,
|
||||
offload_percent=self.model_config.layer_offloading_transformer_percent
|
||||
)
|
||||
|
||||
if self.model_config.low_vram:
|
||||
self.print_and_status_update("Moving transformer to CPU")
|
||||
@@ -147,8 +151,12 @@ class QwenImageModel(BaseModel):
|
||||
if not self._qwen_image_keep_visual:
|
||||
text_encoder.model.visual = None
|
||||
|
||||
if self.model_config.auto_memory:
|
||||
MemoryManager.attach(text_encoder, self.device_torch)
|
||||
if self.model_config.layer_offloading and self.model_config.layer_offloading_text_encoder_percent > 0:
|
||||
MemoryManager.attach(
|
||||
text_encoder,
|
||||
self.device_torch,
|
||||
offload_percent=self.model_config.layer_offloading_text_encoder_percent
|
||||
)
|
||||
|
||||
text_encoder.to(self.device_torch, dtype=dtype)
|
||||
flush()
|
||||
|
||||
@@ -1759,7 +1759,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
)
|
||||
|
||||
# we cannot merge in if quantized
|
||||
if self.model_config.quantize or self.model_config.auto_memory:
|
||||
if self.model_config.quantize or self.model_config.layer_offloading:
|
||||
# todo find a way around this
|
||||
self.network.can_merge_in = False
|
||||
|
||||
|
||||
@@ -626,13 +626,19 @@ class ModelConfig:
|
||||
|
||||
# auto memory management, only for some models
|
||||
self.auto_memory = kwargs.get("auto_memory", False)
|
||||
if self.auto_memory and self.qtype == "qfloat8":
|
||||
print(f"Auto memory is not compatible with qfloat8, switching to float8 for model")
|
||||
# auto memory is deprecated, use layer offloading instead
|
||||
if self.auto_memory:
|
||||
print("auto_memory is deprecated, use layer_offloading instead")
|
||||
self.layer_offloading = kwargs.get("layer_offloading", self.auto_memory )
|
||||
if self.layer_offloading and self.qtype == "qfloat8":
|
||||
self.qtype = "float8"
|
||||
if self.auto_memory and not self.qtype_te == "qfloat8":
|
||||
print(f"Auto memory is not compatible with qfloat8, switching to float8 for te")
|
||||
if self.layer_offloading and not self.qtype_te == "qfloat8":
|
||||
self.qtype_te = "float8"
|
||||
|
||||
# 0 is off and 1.0 is 100% of the layers
|
||||
self.layer_offloading_transformer_percent = kwargs.get("layer_offloading_transformer_percent", 1.0)
|
||||
self.layer_offloading_text_encoder_percent = kwargs.get("layer_offloading_text_encoder_percent", 1.0)
|
||||
|
||||
# can be used to load the extras like text encoder or vae from here
|
||||
# only setup for some models but will prevent having to download the te for
|
||||
# 20 different model variants
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import torch
|
||||
from .manager_modules import LinearLayerMemoryManager, ConvLayerMemoryManager
|
||||
import random
|
||||
|
||||
LINEAR_MODULES = [
|
||||
"Linear",
|
||||
@@ -60,7 +61,9 @@ class MemoryManager:
|
||||
return self.module
|
||||
|
||||
@classmethod
|
||||
def attach(cls, module: torch.nn.Module, device: torch.device):
|
||||
def attach(
|
||||
cls, module: torch.nn.Module, device: torch.device, offload_percent: float = 1.0
|
||||
):
|
||||
if hasattr(module, "_memory_manager"):
|
||||
# already attached
|
||||
return
|
||||
@@ -71,17 +74,44 @@ class MemoryManager:
|
||||
module._mm_to = module.to
|
||||
module.to = module._memory_manager.memory_managed_to
|
||||
|
||||
modules_processed = []
|
||||
# attach to all modules
|
||||
for name, sub_module in module.named_modules():
|
||||
for child_name, child_module in sub_module.named_modules():
|
||||
if child_module.__class__.__name__ in LINEAR_MODULES:
|
||||
# linear
|
||||
LinearLayerMemoryManager.attach(
|
||||
child_module, module._memory_manager
|
||||
)
|
||||
elif child_module.__class__.__name__ in CONV_MODULES:
|
||||
# conv
|
||||
ConvLayerMemoryManager.attach(child_module, module._memory_manager)
|
||||
if (
|
||||
child_module.__class__.__name__ in LINEAR_MODULES
|
||||
and child_module not in modules_processed
|
||||
):
|
||||
skip = False
|
||||
if offload_percent < 1.0:
|
||||
# randomly skip some modules
|
||||
if random.random() > offload_percent:
|
||||
skip = True
|
||||
if skip:
|
||||
module._memory_manager.unmanaged_modules.append(child_module)
|
||||
else:
|
||||
# linear
|
||||
LinearLayerMemoryManager.attach(
|
||||
child_module, module._memory_manager
|
||||
)
|
||||
modules_processed.append(child_module)
|
||||
elif (
|
||||
child_module.__class__.__name__ in CONV_MODULES
|
||||
and child_module not in modules_processed
|
||||
):
|
||||
skip = False
|
||||
if offload_percent < 1.0:
|
||||
# randomly skip some modules
|
||||
if random.random() > offload_percent:
|
||||
skip = True
|
||||
if skip:
|
||||
module._memory_manager.unmanaged_modules.append(child_module)
|
||||
else:
|
||||
# conv
|
||||
ConvLayerMemoryManager.attach(
|
||||
child_module, module._memory_manager
|
||||
)
|
||||
modules_processed.append(child_module)
|
||||
elif child_module.__class__.__name__ in UNMANAGED_MODULES or any(
|
||||
inc in child_module.__class__.__name__
|
||||
for inc in UNMANAGED_MODULES_INCLUDES
|
||||
|
||||
@@ -11,7 +11,7 @@ import {
|
||||
import { defaultDatasetConfig } from './jobConfig';
|
||||
import { GroupedSelectOption, JobConfig, SelectOption } from '@/types';
|
||||
import { objectCopy } from '@/utils/basic';
|
||||
import { TextInput, SelectInput, Checkbox, FormGroup, NumberInput } from '@/components/formInputs';
|
||||
import { TextInput, SelectInput, Checkbox, FormGroup, NumberInput, SliderInput } from '@/components/formInputs';
|
||||
import Card from '@/components/Card';
|
||||
import { X } from 'lucide-react';
|
||||
import AddSingleImageModal, { openAddImageModal } from '@/components/AddSingleImageModal';
|
||||
@@ -214,17 +214,47 @@ export default function SimpleJob({
|
||||
/>
|
||||
</FormGroup>
|
||||
)}
|
||||
{modelArch?.additionalSections?.includes('model.auto_memory') && (
|
||||
<Checkbox
|
||||
label={
|
||||
<>
|
||||
Auto Memory <IoFlaskSharp className="inline text-yellow-500" name="Experimental" />{' '}
|
||||
</>
|
||||
}
|
||||
checked={jobConfig.config.process[0].model.auto_memory || false}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].model.auto_memory')}
|
||||
docKey="model.auto_memory"
|
||||
/>
|
||||
{modelArch?.additionalSections?.includes('model.layer_offloading') && (
|
||||
<>
|
||||
<Checkbox
|
||||
label={
|
||||
<>
|
||||
Layer Offloading <IoFlaskSharp className="inline text-yellow-500" name="Experimental" />{' '}
|
||||
</>
|
||||
}
|
||||
checked={jobConfig.config.process[0].model.layer_offloading || false}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].model.layer_offloading')}
|
||||
docKey="model.layer_offloading"
|
||||
/>
|
||||
{jobConfig.config.process[0].model.layer_offloading && (
|
||||
<div className="pt-2">
|
||||
<SliderInput
|
||||
label="Transformer Offload %"
|
||||
value={Math.round(
|
||||
(jobConfig.config.process[0].model.layer_offloading_transformer_percent ?? 1) * 100,
|
||||
)}
|
||||
onChange={value =>
|
||||
setJobConfig(value * 0.01, 'config.process[0].model.layer_offloading_transformer_percent')
|
||||
}
|
||||
min={0}
|
||||
max={100}
|
||||
step={1}
|
||||
/>
|
||||
<SliderInput
|
||||
label="Text Encoder Offload %"
|
||||
value={Math.round(
|
||||
(jobConfig.config.process[0].model.layer_offloading_text_encoder_percent ?? 1) * 100,
|
||||
)}
|
||||
onChange={value =>
|
||||
setJobConfig(value * 0.01, 'config.process[0].model.layer_offloading_text_encoder_percent')
|
||||
}
|
||||
min={0}
|
||||
max={100}
|
||||
step={1}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
</Card>
|
||||
{disableSections.includes('model.quantize') ? null : (
|
||||
|
||||
@@ -25,7 +25,7 @@ export const defaultSliderConfig: SliderConfig = {
|
||||
positive_prompt: 'person who is happy',
|
||||
negative_prompt: 'person who is sad',
|
||||
target_class: 'person',
|
||||
anchor_class: "",
|
||||
anchor_class: '',
|
||||
};
|
||||
|
||||
export const defaultJobConfig: JobConfig = {
|
||||
@@ -181,5 +181,11 @@ export const migrateJobConfig = (jobConfig: JobConfig): JobConfig => {
|
||||
if (jobConfig?.config?.process && jobConfig.config.process[0]?.type === 'ui_trainer') {
|
||||
jobConfig.config.process[0].type = 'diffusion_trainer';
|
||||
}
|
||||
|
||||
if ('auto_memory' in jobConfig.config.process[0].model) {
|
||||
jobConfig.config.process[0].model.layer_offloading = (jobConfig.config.process[0].model.auto_memory ||
|
||||
false) as boolean;
|
||||
delete jobConfig.config.process[0].model.auto_memory;
|
||||
}
|
||||
return jobConfig;
|
||||
};
|
||||
|
||||
@@ -20,7 +20,7 @@ type AdditionalSections =
|
||||
| 'sample.multi_ctrl_imgs'
|
||||
| 'datasets.num_frames'
|
||||
| 'model.multistage'
|
||||
| 'model.auto_memory'
|
||||
| 'model.layer_offloading'
|
||||
| 'model.low_vram';
|
||||
type ModelGroup = 'image' | 'instruction' | 'video';
|
||||
|
||||
@@ -313,7 +313,7 @@ export const modelArchs: ModelArch[] = [
|
||||
'config.process[0].model.qtype': ['qfloat8', 'qfloat8'],
|
||||
},
|
||||
disableSections: ['network.conv'],
|
||||
additionalSections: ['model.low_vram', 'model.auto_memory'],
|
||||
additionalSections: ['model.low_vram', 'model.layer_offloading'],
|
||||
accuracyRecoveryAdapters: {
|
||||
'3 bit with ARA': 'uint3|ostris/accuracy_recovery_adapters/qwen_image_torchao_uint3.safetensors',
|
||||
},
|
||||
@@ -334,7 +334,7 @@ export const modelArchs: ModelArch[] = [
|
||||
'config.process[0].model.qtype': ['qfloat8', 'qfloat8'],
|
||||
},
|
||||
disableSections: ['network.conv'],
|
||||
additionalSections: ['datasets.control_path', 'sample.ctrl_img', 'model.low_vram', 'model.auto_memory'],
|
||||
additionalSections: ['datasets.control_path', 'sample.ctrl_img', 'model.low_vram', 'model.layer_offloading'],
|
||||
accuracyRecoveryAdapters: {
|
||||
'3 bit with ARA': 'uint3|ostris/accuracy_recovery_adapters/qwen_image_edit_torchao_uint3.safetensors',
|
||||
},
|
||||
@@ -360,7 +360,7 @@ export const modelArchs: ModelArch[] = [
|
||||
'datasets.multi_control_paths',
|
||||
'sample.multi_ctrl_imgs',
|
||||
'model.low_vram',
|
||||
'model.auto_memory',
|
||||
'model.layer_offloading',
|
||||
],
|
||||
accuracyRecoveryAdapters: {
|
||||
'3 bit with ARA': 'uint3|ostris/accuracy_recovery_adapters/qwen_image_edit_2509_torchao_uint3.safetensors',
|
||||
|
||||
@@ -21,17 +21,21 @@ export const handleModelArchChange = (
|
||||
setJobConfig(false, 'config.process[0].model.low_vram');
|
||||
}
|
||||
|
||||
// handle auto memory setting
|
||||
if (!newArch?.additionalSections?.includes('model.auto_memory')) {
|
||||
if ('auto_memory' in jobConfig.config.process[0].model) {
|
||||
// handle layer offloading setting
|
||||
if (!newArch?.additionalSections?.includes('model.layer_offloading')) {
|
||||
if ('layer_offloading' in jobConfig.config.process[0].model) {
|
||||
const newModel = objectCopy(jobConfig.config.process[0].model);
|
||||
delete newModel.auto_memory;
|
||||
delete newModel.layer_offloading;
|
||||
delete newModel.layer_offloading_text_encoder_percent;
|
||||
delete newModel.layer_offloading_transformer_percent;
|
||||
setJobConfig(newModel, 'config.process[0].model');
|
||||
}
|
||||
} else {
|
||||
// set to false if not set
|
||||
if (!('auto_memory' in jobConfig.config.process[0].model)) {
|
||||
setJobConfig(false, 'config.process[0].model.auto_memory');
|
||||
if (!('layer_offloading' in jobConfig.config.process[0].model)) {
|
||||
setJobConfig(false, 'config.process[0].model.layer_offloading');
|
||||
setJobConfig(1.0, 'config.process[0].model.layer_offloading_text_encoder_percent');
|
||||
setJobConfig(1.0, 'config.process[0].model.layer_offloading_transformer_percent');
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -296,3 +296,147 @@ export const FormGroup: React.FC<FormGroupProps> = props => {
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export interface SliderInputProps extends InputProps {
|
||||
value: number;
|
||||
onChange: (value: number) => void;
|
||||
min: number;
|
||||
max: number;
|
||||
step?: number;
|
||||
disabled?: boolean;
|
||||
showValue?: boolean;
|
||||
}
|
||||
|
||||
export const SliderInput: React.FC<SliderInputProps> = props => {
|
||||
const { label, value, onChange, min, max, step = 1, disabled, className, docKey = null, showValue = true } = props;
|
||||
let { doc } = props;
|
||||
if (!doc && docKey) {
|
||||
doc = getDoc(docKey);
|
||||
}
|
||||
|
||||
const trackRef = React.useRef<HTMLDivElement | null>(null);
|
||||
const [dragging, setDragging] = React.useState(false);
|
||||
|
||||
const clamp = (v: number) => (v < min ? min : v > max ? max : v);
|
||||
const snapToStep = (v: number) => {
|
||||
if (!Number.isFinite(v)) return min;
|
||||
const steps = Math.round((v - min) / step);
|
||||
const snapped = min + steps * step;
|
||||
return clamp(Number(snapped.toFixed(6)));
|
||||
};
|
||||
|
||||
const percent = React.useMemo(() => {
|
||||
if (max === min) return 0;
|
||||
const p = ((value - min) / (max - min)) * 100;
|
||||
return p < 0 ? 0 : p > 100 ? 100 : p;
|
||||
}, [value, min, max]);
|
||||
|
||||
const calcFromClientX = React.useCallback(
|
||||
(clientX: number) => {
|
||||
const el = trackRef.current;
|
||||
if (!el || !Number.isFinite(clientX)) return;
|
||||
const rect = el.getBoundingClientRect();
|
||||
const width = rect.right - rect.left;
|
||||
if (!(width > 0)) return;
|
||||
|
||||
// Clamp ratio to [0, 1] so it can never flip ends.
|
||||
const ratioRaw = (clientX - rect.left) / width;
|
||||
const ratio = ratioRaw <= 0 ? 0 : ratioRaw >= 1 ? 1 : ratioRaw;
|
||||
|
||||
const raw = min + ratio * (max - min);
|
||||
onChange(snapToStep(raw));
|
||||
},
|
||||
[min, max, step, onChange],
|
||||
);
|
||||
|
||||
// Mouse/touch pointer drag
|
||||
const onPointerDown = (e: React.PointerEvent) => {
|
||||
if (disabled) return;
|
||||
e.preventDefault();
|
||||
|
||||
// Capture the pointer so moves outside the element are still tracked correctly
|
||||
try {
|
||||
(e.currentTarget as HTMLElement).setPointerCapture?.(e.pointerId);
|
||||
} catch {}
|
||||
|
||||
setDragging(true);
|
||||
calcFromClientX(e.clientX);
|
||||
|
||||
const handleMove = (ev: PointerEvent) => {
|
||||
ev.preventDefault();
|
||||
calcFromClientX(ev.clientX);
|
||||
};
|
||||
const handleUp = (ev: PointerEvent) => {
|
||||
setDragging(false);
|
||||
// release capture if we got it
|
||||
try {
|
||||
(e.currentTarget as HTMLElement).releasePointerCapture?.(e.pointerId);
|
||||
} catch {}
|
||||
window.removeEventListener('pointermove', handleMove);
|
||||
window.removeEventListener('pointerup', handleUp);
|
||||
};
|
||||
|
||||
window.addEventListener('pointermove', handleMove);
|
||||
window.addEventListener('pointerup', handleUp);
|
||||
};
|
||||
|
||||
return (
|
||||
<div className={classNames(className, disabled ? 'opacity-30 cursor-not-allowed' : '')}>
|
||||
{label && (
|
||||
<label className={labelClasses}>
|
||||
{label}{' '}
|
||||
{doc && (
|
||||
<div className="inline-block ml-1 text-xs text-gray-500 cursor-pointer" onClick={() => openDoc(doc)}>
|
||||
<CircleHelp className="inline-block w-4 h-4 cursor-pointer" />
|
||||
</div>
|
||||
)}
|
||||
</label>
|
||||
)}
|
||||
|
||||
<div className="flex items-center gap-3">
|
||||
<div className="flex-1">
|
||||
<div
|
||||
ref={trackRef}
|
||||
onPointerDown={onPointerDown}
|
||||
className={classNames(
|
||||
'relative w-full h-6 select-none outline-none',
|
||||
disabled ? 'pointer-events-none' : 'cursor-pointer',
|
||||
)}
|
||||
>
|
||||
{/* Thicker track */}
|
||||
<div className="pointer-events-none absolute left-0 right-0 top-1/2 -translate-y-1/2 h-3 rounded-sm bg-gray-800 border border-gray-700" />
|
||||
|
||||
{/* Fill */}
|
||||
<div
|
||||
className="pointer-events-none absolute left-0 top-1/2 -translate-y-1/2 h-3 rounded-sm bg-blue-600"
|
||||
style={{ width: `${percent}%` }}
|
||||
/>
|
||||
|
||||
{/* Thumb */}
|
||||
<div
|
||||
onPointerDown={onPointerDown}
|
||||
className={classNames(
|
||||
'absolute top-1/2 -translate-y-1/2 -ml-2',
|
||||
'h-4 w-4 rounded-full bg-white shadow border border-gray-300 cursor-pointer',
|
||||
'after:content-[""] after:absolute after:inset-[-6px] after:rounded-full after:bg-transparent', // expands hit area
|
||||
dragging ? 'ring-2 ring-blue-600' : '',
|
||||
)}
|
||||
style={{ left: `calc(${percent}% )` }}
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div className="flex justify-between text-xs text-gray-500 mt-0.5 select-none">
|
||||
<span>{min}</span>
|
||||
<span>{max}</span>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{showValue && (
|
||||
<div className="min-w-[3.5rem] text-right text-sm px-3 py-1 bg-gray-800 border border-gray-700 rounded-sm">
|
||||
{Number.isFinite(value) ? value : ''}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -185,10 +185,10 @@ const docs: { [key: string]: ConfigDoc } = {
|
||||
</>
|
||||
),
|
||||
},
|
||||
'model.auto_memory': {
|
||||
'model.layer_offloading': {
|
||||
title: (
|
||||
<>
|
||||
Auto Memory{' '}
|
||||
Layer Offloading{' '}
|
||||
<span className="text-yellow-500">
|
||||
( <IoFlaskSharp className="inline text-yellow-500" name="Experimental" /> Experimental)
|
||||
</span>
|
||||
@@ -204,10 +204,14 @@ const docs: { [key: string]: ConfigDoc } = {
|
||||
one update to the next. It will also only work with certain models.
|
||||
<br />
|
||||
<br />
|
||||
Auto Memory uses the CPU RAM instead of the GPU ram to hold most of the model weights. This allows training a
|
||||
Layer Offloading uses the CPU RAM instead of the GPU ram to hold most of the model weights. This allows training a
|
||||
much larger model on a smaller GPU, assuming you have enough CPU RAM. This is slower than training on pure GPU
|
||||
RAM, but CPU RAM is cheaper and upgradeable. You will still need GPU RAM to hold the optimizer states and LoRA weights,
|
||||
so a larger card is usually still needed.
|
||||
<br />
|
||||
<br />
|
||||
You can also select the percentage of the layers to offload. It is generally best to offload as few as possible (close to 0%)
|
||||
for best performance, but you can offload more if you need the memory.
|
||||
</>
|
||||
),
|
||||
},
|
||||
|
||||
@@ -153,7 +153,9 @@ export interface ModelConfig {
|
||||
arch: string;
|
||||
low_vram: boolean;
|
||||
model_kwargs: { [key: string]: any };
|
||||
auto_memory?: boolean;
|
||||
layer_offloading?: boolean;
|
||||
layer_offloading_transformer_percent?: number;
|
||||
layer_offloading_text_encoder_percent?: number;
|
||||
}
|
||||
|
||||
export interface SampleItem {
|
||||
|
||||
Reference in New Issue
Block a user