Change auto_memory to be layer_offloading and allow you to set the amount to unload

This commit is contained in:
Jaret Burkett
2025-10-10 13:12:32 -06:00
parent 2c2fbf16ea
commit 1bc6dee127
11 changed files with 279 additions and 45 deletions

View File

@@ -125,8 +125,12 @@ class QwenImageModel(BaseModel):
quantize_model(self, transformer) quantize_model(self, transformer)
flush() flush()
if self.model_config.auto_memory: if self.model_config.layer_offloading and self.model_config.layer_offloading_transformer_percent > 0:
MemoryManager.attach(transformer, self.device_torch) MemoryManager.attach(
transformer,
self.device_torch,
offload_percent=self.model_config.layer_offloading_transformer_percent
)
if self.model_config.low_vram: if self.model_config.low_vram:
self.print_and_status_update("Moving transformer to CPU") self.print_and_status_update("Moving transformer to CPU")
@@ -147,8 +151,12 @@ class QwenImageModel(BaseModel):
if not self._qwen_image_keep_visual: if not self._qwen_image_keep_visual:
text_encoder.model.visual = None text_encoder.model.visual = None
if self.model_config.auto_memory: if self.model_config.layer_offloading and self.model_config.layer_offloading_text_encoder_percent > 0:
MemoryManager.attach(text_encoder, self.device_torch) MemoryManager.attach(
text_encoder,
self.device_torch,
offload_percent=self.model_config.layer_offloading_text_encoder_percent
)
text_encoder.to(self.device_torch, dtype=dtype) text_encoder.to(self.device_torch, dtype=dtype)
flush() flush()

View File

@@ -1759,7 +1759,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
) )
# we cannot merge in if quantized # we cannot merge in if quantized
if self.model_config.quantize or self.model_config.auto_memory: if self.model_config.quantize or self.model_config.layer_offloading:
# todo find a way around this # todo find a way around this
self.network.can_merge_in = False self.network.can_merge_in = False

View File

@@ -626,12 +626,18 @@ class ModelConfig:
# auto memory management, only for some models # auto memory management, only for some models
self.auto_memory = kwargs.get("auto_memory", False) self.auto_memory = kwargs.get("auto_memory", False)
if self.auto_memory and self.qtype == "qfloat8": # auto memory is deprecated, use layer offloading instead
print(f"Auto memory is not compatible with qfloat8, switching to float8 for model") if self.auto_memory:
print("auto_memory is deprecated, use layer_offloading instead")
self.layer_offloading = kwargs.get("layer_offloading", self.auto_memory )
if self.layer_offloading and self.qtype == "qfloat8":
self.qtype = "float8" self.qtype = "float8"
if self.auto_memory and not self.qtype_te == "qfloat8": if self.layer_offloading and not self.qtype_te == "qfloat8":
print(f"Auto memory is not compatible with qfloat8, switching to float8 for te")
self.qtype_te = "float8" self.qtype_te = "float8"
# 0 is off and 1.0 is 100% of the layers
self.layer_offloading_transformer_percent = kwargs.get("layer_offloading_transformer_percent", 1.0)
self.layer_offloading_text_encoder_percent = kwargs.get("layer_offloading_text_encoder_percent", 1.0)
# can be used to load the extras like text encoder or vae from here # can be used to load the extras like text encoder or vae from here
# only setup for some models but will prevent having to download the te for # only setup for some models but will prevent having to download the te for

View File

@@ -1,5 +1,6 @@
import torch import torch
from .manager_modules import LinearLayerMemoryManager, ConvLayerMemoryManager from .manager_modules import LinearLayerMemoryManager, ConvLayerMemoryManager
import random
LINEAR_MODULES = [ LINEAR_MODULES = [
"Linear", "Linear",
@@ -60,7 +61,9 @@ class MemoryManager:
return self.module return self.module
@classmethod @classmethod
def attach(cls, module: torch.nn.Module, device: torch.device): def attach(
cls, module: torch.nn.Module, device: torch.device, offload_percent: float = 1.0
):
if hasattr(module, "_memory_manager"): if hasattr(module, "_memory_manager"):
# already attached # already attached
return return
@@ -71,17 +74,44 @@ class MemoryManager:
module._mm_to = module.to module._mm_to = module.to
module.to = module._memory_manager.memory_managed_to module.to = module._memory_manager.memory_managed_to
modules_processed = []
# attach to all modules # attach to all modules
for name, sub_module in module.named_modules(): for name, sub_module in module.named_modules():
for child_name, child_module in sub_module.named_modules(): for child_name, child_module in sub_module.named_modules():
if child_module.__class__.__name__ in LINEAR_MODULES: if (
# linear child_module.__class__.__name__ in LINEAR_MODULES
LinearLayerMemoryManager.attach( and child_module not in modules_processed
child_module, module._memory_manager ):
) skip = False
elif child_module.__class__.__name__ in CONV_MODULES: if offload_percent < 1.0:
# conv # randomly skip some modules
ConvLayerMemoryManager.attach(child_module, module._memory_manager) if random.random() > offload_percent:
skip = True
if skip:
module._memory_manager.unmanaged_modules.append(child_module)
else:
# linear
LinearLayerMemoryManager.attach(
child_module, module._memory_manager
)
modules_processed.append(child_module)
elif (
child_module.__class__.__name__ in CONV_MODULES
and child_module not in modules_processed
):
skip = False
if offload_percent < 1.0:
# randomly skip some modules
if random.random() > offload_percent:
skip = True
if skip:
module._memory_manager.unmanaged_modules.append(child_module)
else:
# conv
ConvLayerMemoryManager.attach(
child_module, module._memory_manager
)
modules_processed.append(child_module)
elif child_module.__class__.__name__ in UNMANAGED_MODULES or any( elif child_module.__class__.__name__ in UNMANAGED_MODULES or any(
inc in child_module.__class__.__name__ inc in child_module.__class__.__name__
for inc in UNMANAGED_MODULES_INCLUDES for inc in UNMANAGED_MODULES_INCLUDES

View File

@@ -11,7 +11,7 @@ import {
import { defaultDatasetConfig } from './jobConfig'; import { defaultDatasetConfig } from './jobConfig';
import { GroupedSelectOption, JobConfig, SelectOption } from '@/types'; import { GroupedSelectOption, JobConfig, SelectOption } from '@/types';
import { objectCopy } from '@/utils/basic'; import { objectCopy } from '@/utils/basic';
import { TextInput, SelectInput, Checkbox, FormGroup, NumberInput } from '@/components/formInputs'; import { TextInput, SelectInput, Checkbox, FormGroup, NumberInput, SliderInput } from '@/components/formInputs';
import Card from '@/components/Card'; import Card from '@/components/Card';
import { X } from 'lucide-react'; import { X } from 'lucide-react';
import AddSingleImageModal, { openAddImageModal } from '@/components/AddSingleImageModal'; import AddSingleImageModal, { openAddImageModal } from '@/components/AddSingleImageModal';
@@ -214,17 +214,47 @@ export default function SimpleJob({
/> />
</FormGroup> </FormGroup>
)} )}
{modelArch?.additionalSections?.includes('model.auto_memory') && ( {modelArch?.additionalSections?.includes('model.layer_offloading') && (
<Checkbox <>
label={ <Checkbox
<> label={
Auto Memory <IoFlaskSharp className="inline text-yellow-500" name="Experimental" />{' '} <>
</> Layer Offloading <IoFlaskSharp className="inline text-yellow-500" name="Experimental" />{' '}
} </>
checked={jobConfig.config.process[0].model.auto_memory || false} }
onChange={value => setJobConfig(value, 'config.process[0].model.auto_memory')} checked={jobConfig.config.process[0].model.layer_offloading || false}
docKey="model.auto_memory" onChange={value => setJobConfig(value, 'config.process[0].model.layer_offloading')}
/> docKey="model.layer_offloading"
/>
{jobConfig.config.process[0].model.layer_offloading && (
<div className="pt-2">
<SliderInput
label="Transformer Offload %"
value={Math.round(
(jobConfig.config.process[0].model.layer_offloading_transformer_percent ?? 1) * 100,
)}
onChange={value =>
setJobConfig(value * 0.01, 'config.process[0].model.layer_offloading_transformer_percent')
}
min={0}
max={100}
step={1}
/>
<SliderInput
label="Text Encoder Offload %"
value={Math.round(
(jobConfig.config.process[0].model.layer_offloading_text_encoder_percent ?? 1) * 100,
)}
onChange={value =>
setJobConfig(value * 0.01, 'config.process[0].model.layer_offloading_text_encoder_percent')
}
min={0}
max={100}
step={1}
/>
</div>
)}
</>
)} )}
</Card> </Card>
{disableSections.includes('model.quantize') ? null : ( {disableSections.includes('model.quantize') ? null : (

View File

@@ -25,7 +25,7 @@ export const defaultSliderConfig: SliderConfig = {
positive_prompt: 'person who is happy', positive_prompt: 'person who is happy',
negative_prompt: 'person who is sad', negative_prompt: 'person who is sad',
target_class: 'person', target_class: 'person',
anchor_class: "", anchor_class: '',
}; };
export const defaultJobConfig: JobConfig = { export const defaultJobConfig: JobConfig = {
@@ -181,5 +181,11 @@ export const migrateJobConfig = (jobConfig: JobConfig): JobConfig => {
if (jobConfig?.config?.process && jobConfig.config.process[0]?.type === 'ui_trainer') { if (jobConfig?.config?.process && jobConfig.config.process[0]?.type === 'ui_trainer') {
jobConfig.config.process[0].type = 'diffusion_trainer'; jobConfig.config.process[0].type = 'diffusion_trainer';
} }
if ('auto_memory' in jobConfig.config.process[0].model) {
jobConfig.config.process[0].model.layer_offloading = (jobConfig.config.process[0].model.auto_memory ||
false) as boolean;
delete jobConfig.config.process[0].model.auto_memory;
}
return jobConfig; return jobConfig;
}; };

View File

@@ -20,7 +20,7 @@ type AdditionalSections =
| 'sample.multi_ctrl_imgs' | 'sample.multi_ctrl_imgs'
| 'datasets.num_frames' | 'datasets.num_frames'
| 'model.multistage' | 'model.multistage'
| 'model.auto_memory' | 'model.layer_offloading'
| 'model.low_vram'; | 'model.low_vram';
type ModelGroup = 'image' | 'instruction' | 'video'; type ModelGroup = 'image' | 'instruction' | 'video';
@@ -313,7 +313,7 @@ export const modelArchs: ModelArch[] = [
'config.process[0].model.qtype': ['qfloat8', 'qfloat8'], 'config.process[0].model.qtype': ['qfloat8', 'qfloat8'],
}, },
disableSections: ['network.conv'], disableSections: ['network.conv'],
additionalSections: ['model.low_vram', 'model.auto_memory'], additionalSections: ['model.low_vram', 'model.layer_offloading'],
accuracyRecoveryAdapters: { accuracyRecoveryAdapters: {
'3 bit with ARA': 'uint3|ostris/accuracy_recovery_adapters/qwen_image_torchao_uint3.safetensors', '3 bit with ARA': 'uint3|ostris/accuracy_recovery_adapters/qwen_image_torchao_uint3.safetensors',
}, },
@@ -334,7 +334,7 @@ export const modelArchs: ModelArch[] = [
'config.process[0].model.qtype': ['qfloat8', 'qfloat8'], 'config.process[0].model.qtype': ['qfloat8', 'qfloat8'],
}, },
disableSections: ['network.conv'], disableSections: ['network.conv'],
additionalSections: ['datasets.control_path', 'sample.ctrl_img', 'model.low_vram', 'model.auto_memory'], additionalSections: ['datasets.control_path', 'sample.ctrl_img', 'model.low_vram', 'model.layer_offloading'],
accuracyRecoveryAdapters: { accuracyRecoveryAdapters: {
'3 bit with ARA': 'uint3|ostris/accuracy_recovery_adapters/qwen_image_edit_torchao_uint3.safetensors', '3 bit with ARA': 'uint3|ostris/accuracy_recovery_adapters/qwen_image_edit_torchao_uint3.safetensors',
}, },
@@ -360,7 +360,7 @@ export const modelArchs: ModelArch[] = [
'datasets.multi_control_paths', 'datasets.multi_control_paths',
'sample.multi_ctrl_imgs', 'sample.multi_ctrl_imgs',
'model.low_vram', 'model.low_vram',
'model.auto_memory', 'model.layer_offloading',
], ],
accuracyRecoveryAdapters: { accuracyRecoveryAdapters: {
'3 bit with ARA': 'uint3|ostris/accuracy_recovery_adapters/qwen_image_edit_2509_torchao_uint3.safetensors', '3 bit with ARA': 'uint3|ostris/accuracy_recovery_adapters/qwen_image_edit_2509_torchao_uint3.safetensors',

View File

@@ -21,17 +21,21 @@ export const handleModelArchChange = (
setJobConfig(false, 'config.process[0].model.low_vram'); setJobConfig(false, 'config.process[0].model.low_vram');
} }
// handle auto memory setting // handle layer offloading setting
if (!newArch?.additionalSections?.includes('model.auto_memory')) { if (!newArch?.additionalSections?.includes('model.layer_offloading')) {
if ('auto_memory' in jobConfig.config.process[0].model) { if ('layer_offloading' in jobConfig.config.process[0].model) {
const newModel = objectCopy(jobConfig.config.process[0].model); const newModel = objectCopy(jobConfig.config.process[0].model);
delete newModel.auto_memory; delete newModel.layer_offloading;
delete newModel.layer_offloading_text_encoder_percent;
delete newModel.layer_offloading_transformer_percent;
setJobConfig(newModel, 'config.process[0].model'); setJobConfig(newModel, 'config.process[0].model');
} }
} else { } else {
// set to false if not set // set to false if not set
if (!('auto_memory' in jobConfig.config.process[0].model)) { if (!('layer_offloading' in jobConfig.config.process[0].model)) {
setJobConfig(false, 'config.process[0].model.auto_memory'); setJobConfig(false, 'config.process[0].model.layer_offloading');
setJobConfig(1.0, 'config.process[0].model.layer_offloading_text_encoder_percent');
setJobConfig(1.0, 'config.process[0].model.layer_offloading_transformer_percent');
} }
} }

View File

@@ -296,3 +296,147 @@ export const FormGroup: React.FC<FormGroupProps> = props => {
</div> </div>
); );
}; };
export interface SliderInputProps extends InputProps {
value: number;
onChange: (value: number) => void;
min: number;
max: number;
step?: number;
disabled?: boolean;
showValue?: boolean;
}
export const SliderInput: React.FC<SliderInputProps> = props => {
const { label, value, onChange, min, max, step = 1, disabled, className, docKey = null, showValue = true } = props;
let { doc } = props;
if (!doc && docKey) {
doc = getDoc(docKey);
}
const trackRef = React.useRef<HTMLDivElement | null>(null);
const [dragging, setDragging] = React.useState(false);
const clamp = (v: number) => (v < min ? min : v > max ? max : v);
const snapToStep = (v: number) => {
if (!Number.isFinite(v)) return min;
const steps = Math.round((v - min) / step);
const snapped = min + steps * step;
return clamp(Number(snapped.toFixed(6)));
};
const percent = React.useMemo(() => {
if (max === min) return 0;
const p = ((value - min) / (max - min)) * 100;
return p < 0 ? 0 : p > 100 ? 100 : p;
}, [value, min, max]);
const calcFromClientX = React.useCallback(
(clientX: number) => {
const el = trackRef.current;
if (!el || !Number.isFinite(clientX)) return;
const rect = el.getBoundingClientRect();
const width = rect.right - rect.left;
if (!(width > 0)) return;
// Clamp ratio to [0, 1] so it can never flip ends.
const ratioRaw = (clientX - rect.left) / width;
const ratio = ratioRaw <= 0 ? 0 : ratioRaw >= 1 ? 1 : ratioRaw;
const raw = min + ratio * (max - min);
onChange(snapToStep(raw));
},
[min, max, step, onChange],
);
// Mouse/touch pointer drag
const onPointerDown = (e: React.PointerEvent) => {
if (disabled) return;
e.preventDefault();
// Capture the pointer so moves outside the element are still tracked correctly
try {
(e.currentTarget as HTMLElement).setPointerCapture?.(e.pointerId);
} catch {}
setDragging(true);
calcFromClientX(e.clientX);
const handleMove = (ev: PointerEvent) => {
ev.preventDefault();
calcFromClientX(ev.clientX);
};
const handleUp = (ev: PointerEvent) => {
setDragging(false);
// release capture if we got it
try {
(e.currentTarget as HTMLElement).releasePointerCapture?.(e.pointerId);
} catch {}
window.removeEventListener('pointermove', handleMove);
window.removeEventListener('pointerup', handleUp);
};
window.addEventListener('pointermove', handleMove);
window.addEventListener('pointerup', handleUp);
};
return (
<div className={classNames(className, disabled ? 'opacity-30 cursor-not-allowed' : '')}>
{label && (
<label className={labelClasses}>
{label}{' '}
{doc && (
<div className="inline-block ml-1 text-xs text-gray-500 cursor-pointer" onClick={() => openDoc(doc)}>
<CircleHelp className="inline-block w-4 h-4 cursor-pointer" />
</div>
)}
</label>
)}
<div className="flex items-center gap-3">
<div className="flex-1">
<div
ref={trackRef}
onPointerDown={onPointerDown}
className={classNames(
'relative w-full h-6 select-none outline-none',
disabled ? 'pointer-events-none' : 'cursor-pointer',
)}
>
{/* Thicker track */}
<div className="pointer-events-none absolute left-0 right-0 top-1/2 -translate-y-1/2 h-3 rounded-sm bg-gray-800 border border-gray-700" />
{/* Fill */}
<div
className="pointer-events-none absolute left-0 top-1/2 -translate-y-1/2 h-3 rounded-sm bg-blue-600"
style={{ width: `${percent}%` }}
/>
{/* Thumb */}
<div
onPointerDown={onPointerDown}
className={classNames(
'absolute top-1/2 -translate-y-1/2 -ml-2',
'h-4 w-4 rounded-full bg-white shadow border border-gray-300 cursor-pointer',
'after:content-[""] after:absolute after:inset-[-6px] after:rounded-full after:bg-transparent', // expands hit area
dragging ? 'ring-2 ring-blue-600' : '',
)}
style={{ left: `calc(${percent}% )` }}
/>
</div>
<div className="flex justify-between text-xs text-gray-500 mt-0.5 select-none">
<span>{min}</span>
<span>{max}</span>
</div>
</div>
{showValue && (
<div className="min-w-[3.5rem] text-right text-sm px-3 py-1 bg-gray-800 border border-gray-700 rounded-sm">
{Number.isFinite(value) ? value : ''}
</div>
)}
</div>
</div>
);
};

View File

@@ -185,10 +185,10 @@ const docs: { [key: string]: ConfigDoc } = {
</> </>
), ),
}, },
'model.auto_memory': { 'model.layer_offloading': {
title: ( title: (
<> <>
Auto Memory{' '} Layer Offloading{' '}
<span className="text-yellow-500"> <span className="text-yellow-500">
( <IoFlaskSharp className="inline text-yellow-500" name="Experimental" /> Experimental) ( <IoFlaskSharp className="inline text-yellow-500" name="Experimental" /> Experimental)
</span> </span>
@@ -204,10 +204,14 @@ const docs: { [key: string]: ConfigDoc } = {
one update to the next. It will also only work with certain models. one update to the next. It will also only work with certain models.
<br /> <br />
<br /> <br />
Auto Memory uses the CPU RAM instead of the GPU ram to hold most of the model weights. This allows training a Layer Offloading uses the CPU RAM instead of the GPU ram to hold most of the model weights. This allows training a
much larger model on a smaller GPU, assuming you have enough CPU RAM. This is slower than training on pure GPU much larger model on a smaller GPU, assuming you have enough CPU RAM. This is slower than training on pure GPU
RAM, but CPU RAM is cheaper and upgradeable. You will still need GPU RAM to hold the optimizer states and LoRA weights, RAM, but CPU RAM is cheaper and upgradeable. You will still need GPU RAM to hold the optimizer states and LoRA weights,
so a larger card is usually still needed. so a larger card is usually still needed.
<br />
<br />
You can also select the percentage of the layers to offload. It is generally best to offload as few as possible (close to 0%)
for best performance, but you can offload more if you need the memory.
</> </>
), ),
}, },

View File

@@ -153,7 +153,9 @@ export interface ModelConfig {
arch: string; arch: string;
low_vram: boolean; low_vram: boolean;
model_kwargs: { [key: string]: any }; model_kwargs: { [key: string]: any };
auto_memory?: boolean; layer_offloading?: boolean;
layer_offloading_transformer_percent?: number;
layer_offloading_text_encoder_percent?: number;
} }
export interface SampleItem { export interface SampleItem {