Change auto_memory to be layer_offloading and allow you to set the amount to unload

This commit is contained in:
Jaret Burkett
2025-10-10 13:12:32 -06:00
parent 2c2fbf16ea
commit 1bc6dee127
11 changed files with 279 additions and 45 deletions

View File

@@ -125,8 +125,12 @@ class QwenImageModel(BaseModel):
quantize_model(self, transformer)
flush()
if self.model_config.auto_memory:
MemoryManager.attach(transformer, self.device_torch)
if self.model_config.layer_offloading and self.model_config.layer_offloading_transformer_percent > 0:
MemoryManager.attach(
transformer,
self.device_torch,
offload_percent=self.model_config.layer_offloading_transformer_percent
)
if self.model_config.low_vram:
self.print_and_status_update("Moving transformer to CPU")
@@ -147,8 +151,12 @@ class QwenImageModel(BaseModel):
if not self._qwen_image_keep_visual:
text_encoder.model.visual = None
if self.model_config.auto_memory:
MemoryManager.attach(text_encoder, self.device_torch)
if self.model_config.layer_offloading and self.model_config.layer_offloading_text_encoder_percent > 0:
MemoryManager.attach(
text_encoder,
self.device_torch,
offload_percent=self.model_config.layer_offloading_text_encoder_percent
)
text_encoder.to(self.device_torch, dtype=dtype)
flush()

View File

@@ -1759,7 +1759,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
)
# we cannot merge in if quantized
if self.model_config.quantize or self.model_config.auto_memory:
if self.model_config.quantize or self.model_config.layer_offloading:
# todo find a way around this
self.network.can_merge_in = False

View File

@@ -626,12 +626,18 @@ class ModelConfig:
# auto memory management, only for some models
self.auto_memory = kwargs.get("auto_memory", False)
if self.auto_memory and self.qtype == "qfloat8":
print(f"Auto memory is not compatible with qfloat8, switching to float8 for model")
# auto memory is deprecated, use layer offloading instead
if self.auto_memory:
print("auto_memory is deprecated, use layer_offloading instead")
self.layer_offloading = kwargs.get("layer_offloading", self.auto_memory )
if self.layer_offloading and self.qtype == "qfloat8":
self.qtype = "float8"
if self.auto_memory and not self.qtype_te == "qfloat8":
print(f"Auto memory is not compatible with qfloat8, switching to float8 for te")
if self.layer_offloading and not self.qtype_te == "qfloat8":
self.qtype_te = "float8"
# 0 is off and 1.0 is 100% of the layers
self.layer_offloading_transformer_percent = kwargs.get("layer_offloading_transformer_percent", 1.0)
self.layer_offloading_text_encoder_percent = kwargs.get("layer_offloading_text_encoder_percent", 1.0)
# can be used to load the extras like text encoder or vae from here
# only setup for some models but will prevent having to download the te for

View File

@@ -1,5 +1,6 @@
import torch
from .manager_modules import LinearLayerMemoryManager, ConvLayerMemoryManager
import random
LINEAR_MODULES = [
"Linear",
@@ -60,7 +61,9 @@ class MemoryManager:
return self.module
@classmethod
def attach(cls, module: torch.nn.Module, device: torch.device):
def attach(
cls, module: torch.nn.Module, device: torch.device, offload_percent: float = 1.0
):
if hasattr(module, "_memory_manager"):
# already attached
return
@@ -71,17 +74,44 @@ class MemoryManager:
module._mm_to = module.to
module.to = module._memory_manager.memory_managed_to
modules_processed = []
# attach to all modules
for name, sub_module in module.named_modules():
for child_name, child_module in sub_module.named_modules():
if child_module.__class__.__name__ in LINEAR_MODULES:
# linear
LinearLayerMemoryManager.attach(
child_module, module._memory_manager
)
elif child_module.__class__.__name__ in CONV_MODULES:
# conv
ConvLayerMemoryManager.attach(child_module, module._memory_manager)
if (
child_module.__class__.__name__ in LINEAR_MODULES
and child_module not in modules_processed
):
skip = False
if offload_percent < 1.0:
# randomly skip some modules
if random.random() > offload_percent:
skip = True
if skip:
module._memory_manager.unmanaged_modules.append(child_module)
else:
# linear
LinearLayerMemoryManager.attach(
child_module, module._memory_manager
)
modules_processed.append(child_module)
elif (
child_module.__class__.__name__ in CONV_MODULES
and child_module not in modules_processed
):
skip = False
if offload_percent < 1.0:
# randomly skip some modules
if random.random() > offload_percent:
skip = True
if skip:
module._memory_manager.unmanaged_modules.append(child_module)
else:
# conv
ConvLayerMemoryManager.attach(
child_module, module._memory_manager
)
modules_processed.append(child_module)
elif child_module.__class__.__name__ in UNMANAGED_MODULES or any(
inc in child_module.__class__.__name__
for inc in UNMANAGED_MODULES_INCLUDES

View File

@@ -11,7 +11,7 @@ import {
import { defaultDatasetConfig } from './jobConfig';
import { GroupedSelectOption, JobConfig, SelectOption } from '@/types';
import { objectCopy } from '@/utils/basic';
import { TextInput, SelectInput, Checkbox, FormGroup, NumberInput } from '@/components/formInputs';
import { TextInput, SelectInput, Checkbox, FormGroup, NumberInput, SliderInput } from '@/components/formInputs';
import Card from '@/components/Card';
import { X } from 'lucide-react';
import AddSingleImageModal, { openAddImageModal } from '@/components/AddSingleImageModal';
@@ -214,17 +214,47 @@ export default function SimpleJob({
/>
</FormGroup>
)}
{modelArch?.additionalSections?.includes('model.auto_memory') && (
<Checkbox
label={
<>
Auto Memory <IoFlaskSharp className="inline text-yellow-500" name="Experimental" />{' '}
</>
}
checked={jobConfig.config.process[0].model.auto_memory || false}
onChange={value => setJobConfig(value, 'config.process[0].model.auto_memory')}
docKey="model.auto_memory"
/>
{modelArch?.additionalSections?.includes('model.layer_offloading') && (
<>
<Checkbox
label={
<>
Layer Offloading <IoFlaskSharp className="inline text-yellow-500" name="Experimental" />{' '}
</>
}
checked={jobConfig.config.process[0].model.layer_offloading || false}
onChange={value => setJobConfig(value, 'config.process[0].model.layer_offloading')}
docKey="model.layer_offloading"
/>
{jobConfig.config.process[0].model.layer_offloading && (
<div className="pt-2">
<SliderInput
label="Transformer Offload %"
value={Math.round(
(jobConfig.config.process[0].model.layer_offloading_transformer_percent ?? 1) * 100,
)}
onChange={value =>
setJobConfig(value * 0.01, 'config.process[0].model.layer_offloading_transformer_percent')
}
min={0}
max={100}
step={1}
/>
<SliderInput
label="Text Encoder Offload %"
value={Math.round(
(jobConfig.config.process[0].model.layer_offloading_text_encoder_percent ?? 1) * 100,
)}
onChange={value =>
setJobConfig(value * 0.01, 'config.process[0].model.layer_offloading_text_encoder_percent')
}
min={0}
max={100}
step={1}
/>
</div>
)}
</>
)}
</Card>
{disableSections.includes('model.quantize') ? null : (

View File

@@ -25,7 +25,7 @@ export const defaultSliderConfig: SliderConfig = {
positive_prompt: 'person who is happy',
negative_prompt: 'person who is sad',
target_class: 'person',
anchor_class: "",
anchor_class: '',
};
export const defaultJobConfig: JobConfig = {
@@ -181,5 +181,11 @@ export const migrateJobConfig = (jobConfig: JobConfig): JobConfig => {
if (jobConfig?.config?.process && jobConfig.config.process[0]?.type === 'ui_trainer') {
jobConfig.config.process[0].type = 'diffusion_trainer';
}
if ('auto_memory' in jobConfig.config.process[0].model) {
jobConfig.config.process[0].model.layer_offloading = (jobConfig.config.process[0].model.auto_memory ||
false) as boolean;
delete jobConfig.config.process[0].model.auto_memory;
}
return jobConfig;
};

View File

@@ -20,7 +20,7 @@ type AdditionalSections =
| 'sample.multi_ctrl_imgs'
| 'datasets.num_frames'
| 'model.multistage'
| 'model.auto_memory'
| 'model.layer_offloading'
| 'model.low_vram';
type ModelGroup = 'image' | 'instruction' | 'video';
@@ -313,7 +313,7 @@ export const modelArchs: ModelArch[] = [
'config.process[0].model.qtype': ['qfloat8', 'qfloat8'],
},
disableSections: ['network.conv'],
additionalSections: ['model.low_vram', 'model.auto_memory'],
additionalSections: ['model.low_vram', 'model.layer_offloading'],
accuracyRecoveryAdapters: {
'3 bit with ARA': 'uint3|ostris/accuracy_recovery_adapters/qwen_image_torchao_uint3.safetensors',
},
@@ -334,7 +334,7 @@ export const modelArchs: ModelArch[] = [
'config.process[0].model.qtype': ['qfloat8', 'qfloat8'],
},
disableSections: ['network.conv'],
additionalSections: ['datasets.control_path', 'sample.ctrl_img', 'model.low_vram', 'model.auto_memory'],
additionalSections: ['datasets.control_path', 'sample.ctrl_img', 'model.low_vram', 'model.layer_offloading'],
accuracyRecoveryAdapters: {
'3 bit with ARA': 'uint3|ostris/accuracy_recovery_adapters/qwen_image_edit_torchao_uint3.safetensors',
},
@@ -360,7 +360,7 @@ export const modelArchs: ModelArch[] = [
'datasets.multi_control_paths',
'sample.multi_ctrl_imgs',
'model.low_vram',
'model.auto_memory',
'model.layer_offloading',
],
accuracyRecoveryAdapters: {
'3 bit with ARA': 'uint3|ostris/accuracy_recovery_adapters/qwen_image_edit_2509_torchao_uint3.safetensors',

View File

@@ -21,17 +21,21 @@ export const handleModelArchChange = (
setJobConfig(false, 'config.process[0].model.low_vram');
}
// handle auto memory setting
if (!newArch?.additionalSections?.includes('model.auto_memory')) {
if ('auto_memory' in jobConfig.config.process[0].model) {
// handle layer offloading setting
if (!newArch?.additionalSections?.includes('model.layer_offloading')) {
if ('layer_offloading' in jobConfig.config.process[0].model) {
const newModel = objectCopy(jobConfig.config.process[0].model);
delete newModel.auto_memory;
delete newModel.layer_offloading;
delete newModel.layer_offloading_text_encoder_percent;
delete newModel.layer_offloading_transformer_percent;
setJobConfig(newModel, 'config.process[0].model');
}
} else {
// set to false if not set
if (!('auto_memory' in jobConfig.config.process[0].model)) {
setJobConfig(false, 'config.process[0].model.auto_memory');
if (!('layer_offloading' in jobConfig.config.process[0].model)) {
setJobConfig(false, 'config.process[0].model.layer_offloading');
setJobConfig(1.0, 'config.process[0].model.layer_offloading_text_encoder_percent');
setJobConfig(1.0, 'config.process[0].model.layer_offloading_transformer_percent');
}
}

View File

@@ -296,3 +296,147 @@ export const FormGroup: React.FC<FormGroupProps> = props => {
</div>
);
};
export interface SliderInputProps extends InputProps {
value: number;
onChange: (value: number) => void;
min: number;
max: number;
step?: number;
disabled?: boolean;
showValue?: boolean;
}
export const SliderInput: React.FC<SliderInputProps> = props => {
const { label, value, onChange, min, max, step = 1, disabled, className, docKey = null, showValue = true } = props;
let { doc } = props;
if (!doc && docKey) {
doc = getDoc(docKey);
}
const trackRef = React.useRef<HTMLDivElement | null>(null);
const [dragging, setDragging] = React.useState(false);
const clamp = (v: number) => (v < min ? min : v > max ? max : v);
const snapToStep = (v: number) => {
if (!Number.isFinite(v)) return min;
const steps = Math.round((v - min) / step);
const snapped = min + steps * step;
return clamp(Number(snapped.toFixed(6)));
};
const percent = React.useMemo(() => {
if (max === min) return 0;
const p = ((value - min) / (max - min)) * 100;
return p < 0 ? 0 : p > 100 ? 100 : p;
}, [value, min, max]);
const calcFromClientX = React.useCallback(
(clientX: number) => {
const el = trackRef.current;
if (!el || !Number.isFinite(clientX)) return;
const rect = el.getBoundingClientRect();
const width = rect.right - rect.left;
if (!(width > 0)) return;
// Clamp ratio to [0, 1] so it can never flip ends.
const ratioRaw = (clientX - rect.left) / width;
const ratio = ratioRaw <= 0 ? 0 : ratioRaw >= 1 ? 1 : ratioRaw;
const raw = min + ratio * (max - min);
onChange(snapToStep(raw));
},
[min, max, step, onChange],
);
// Mouse/touch pointer drag
const onPointerDown = (e: React.PointerEvent) => {
if (disabled) return;
e.preventDefault();
// Capture the pointer so moves outside the element are still tracked correctly
try {
(e.currentTarget as HTMLElement).setPointerCapture?.(e.pointerId);
} catch {}
setDragging(true);
calcFromClientX(e.clientX);
const handleMove = (ev: PointerEvent) => {
ev.preventDefault();
calcFromClientX(ev.clientX);
};
const handleUp = (ev: PointerEvent) => {
setDragging(false);
// release capture if we got it
try {
(e.currentTarget as HTMLElement).releasePointerCapture?.(e.pointerId);
} catch {}
window.removeEventListener('pointermove', handleMove);
window.removeEventListener('pointerup', handleUp);
};
window.addEventListener('pointermove', handleMove);
window.addEventListener('pointerup', handleUp);
};
return (
<div className={classNames(className, disabled ? 'opacity-30 cursor-not-allowed' : '')}>
{label && (
<label className={labelClasses}>
{label}{' '}
{doc && (
<div className="inline-block ml-1 text-xs text-gray-500 cursor-pointer" onClick={() => openDoc(doc)}>
<CircleHelp className="inline-block w-4 h-4 cursor-pointer" />
</div>
)}
</label>
)}
<div className="flex items-center gap-3">
<div className="flex-1">
<div
ref={trackRef}
onPointerDown={onPointerDown}
className={classNames(
'relative w-full h-6 select-none outline-none',
disabled ? 'pointer-events-none' : 'cursor-pointer',
)}
>
{/* Thicker track */}
<div className="pointer-events-none absolute left-0 right-0 top-1/2 -translate-y-1/2 h-3 rounded-sm bg-gray-800 border border-gray-700" />
{/* Fill */}
<div
className="pointer-events-none absolute left-0 top-1/2 -translate-y-1/2 h-3 rounded-sm bg-blue-600"
style={{ width: `${percent}%` }}
/>
{/* Thumb */}
<div
onPointerDown={onPointerDown}
className={classNames(
'absolute top-1/2 -translate-y-1/2 -ml-2',
'h-4 w-4 rounded-full bg-white shadow border border-gray-300 cursor-pointer',
'after:content-[""] after:absolute after:inset-[-6px] after:rounded-full after:bg-transparent', // expands hit area
dragging ? 'ring-2 ring-blue-600' : '',
)}
style={{ left: `calc(${percent}% )` }}
/>
</div>
<div className="flex justify-between text-xs text-gray-500 mt-0.5 select-none">
<span>{min}</span>
<span>{max}</span>
</div>
</div>
{showValue && (
<div className="min-w-[3.5rem] text-right text-sm px-3 py-1 bg-gray-800 border border-gray-700 rounded-sm">
{Number.isFinite(value) ? value : ''}
</div>
)}
</div>
</div>
);
};

View File

@@ -185,10 +185,10 @@ const docs: { [key: string]: ConfigDoc } = {
</>
),
},
'model.auto_memory': {
'model.layer_offloading': {
title: (
<>
Auto Memory{' '}
Layer Offloading{' '}
<span className="text-yellow-500">
( <IoFlaskSharp className="inline text-yellow-500" name="Experimental" /> Experimental)
</span>
@@ -204,10 +204,14 @@ const docs: { [key: string]: ConfigDoc } = {
one update to the next. It will also only work with certain models.
<br />
<br />
Auto Memory uses the CPU RAM instead of the GPU ram to hold most of the model weights. This allows training a
Layer Offloading uses the CPU RAM instead of the GPU ram to hold most of the model weights. This allows training a
much larger model on a smaller GPU, assuming you have enough CPU RAM. This is slower than training on pure GPU
RAM, but CPU RAM is cheaper and upgradeable. You will still need GPU RAM to hold the optimizer states and LoRA weights,
so a larger card is usually still needed.
<br />
<br />
You can also select the percentage of the layers to offload. It is generally best to offload as few as possible (close to 0%)
for best performance, but you can offload more if you need the memory.
</>
),
},

View File

@@ -153,7 +153,9 @@ export interface ModelConfig {
arch: string;
low_vram: boolean;
model_kwargs: { [key: string]: any };
auto_memory?: boolean;
layer_offloading?: boolean;
layer_offloading_transformer_percent?: number;
layer_offloading_text_encoder_percent?: number;
}
export interface SampleItem {