From 1bc6dee127b1c7abdd8b7f9a11bb89f730dfc834 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Fri, 10 Oct 2025 13:12:32 -0600 Subject: [PATCH] Change auto_memory to be layer_offloading and allow you to set the amount to unload --- .../diffusion_models/qwen_image/qwen_image.py | 16 +- jobs/process/BaseSDTrainProcess.py | 2 +- toolkit/config_modules.py | 14 +- toolkit/memory_management/manager.py | 48 ++++-- ui/src/app/jobs/new/SimpleJob.tsx | 54 +++++-- ui/src/app/jobs/new/jobConfig.ts | 8 +- ui/src/app/jobs/new/options.ts | 8 +- ui/src/app/jobs/new/utils.ts | 16 +- ui/src/components/formInputs.tsx | 144 ++++++++++++++++++ ui/src/docs.tsx | 10 +- ui/src/types.ts | 4 +- 11 files changed, 279 insertions(+), 45 deletions(-) diff --git a/extensions_built_in/diffusion_models/qwen_image/qwen_image.py b/extensions_built_in/diffusion_models/qwen_image/qwen_image.py index 1828f835..b9c78374 100644 --- a/extensions_built_in/diffusion_models/qwen_image/qwen_image.py +++ b/extensions_built_in/diffusion_models/qwen_image/qwen_image.py @@ -125,8 +125,12 @@ class QwenImageModel(BaseModel): quantize_model(self, transformer) flush() - if self.model_config.auto_memory: - MemoryManager.attach(transformer, self.device_torch) + if self.model_config.layer_offloading and self.model_config.layer_offloading_transformer_percent > 0: + MemoryManager.attach( + transformer, + self.device_torch, + offload_percent=self.model_config.layer_offloading_transformer_percent + ) if self.model_config.low_vram: self.print_and_status_update("Moving transformer to CPU") @@ -147,8 +151,12 @@ class QwenImageModel(BaseModel): if not self._qwen_image_keep_visual: text_encoder.model.visual = None - if self.model_config.auto_memory: - MemoryManager.attach(text_encoder, self.device_torch) + if self.model_config.layer_offloading and self.model_config.layer_offloading_text_encoder_percent > 0: + MemoryManager.attach( + text_encoder, + self.device_torch, + offload_percent=self.model_config.layer_offloading_text_encoder_percent + ) text_encoder.to(self.device_torch, dtype=dtype) flush() diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 6a696d6c..d376b299 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -1759,7 +1759,7 @@ class BaseSDTrainProcess(BaseTrainProcess): ) # we cannot merge in if quantized - if self.model_config.quantize or self.model_config.auto_memory: + if self.model_config.quantize or self.model_config.layer_offloading: # todo find a way around this self.network.can_merge_in = False diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index cb3c6c80..1fa7c688 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -626,12 +626,18 @@ class ModelConfig: # auto memory management, only for some models self.auto_memory = kwargs.get("auto_memory", False) - if self.auto_memory and self.qtype == "qfloat8": - print(f"Auto memory is not compatible with qfloat8, switching to float8 for model") + # auto memory is deprecated, use layer offloading instead + if self.auto_memory: + print("auto_memory is deprecated, use layer_offloading instead") + self.layer_offloading = kwargs.get("layer_offloading", self.auto_memory ) + if self.layer_offloading and self.qtype == "qfloat8": self.qtype = "float8" - if self.auto_memory and not self.qtype_te == "qfloat8": - print(f"Auto memory is not compatible with qfloat8, switching to float8 for te") + if self.layer_offloading and not self.qtype_te == "qfloat8": self.qtype_te = "float8" + + # 0 is off and 1.0 is 100% of the layers + self.layer_offloading_transformer_percent = kwargs.get("layer_offloading_transformer_percent", 1.0) + self.layer_offloading_text_encoder_percent = kwargs.get("layer_offloading_text_encoder_percent", 1.0) # can be used to load the extras like text encoder or vae from here # only setup for some models but will prevent having to download the te for diff --git a/toolkit/memory_management/manager.py b/toolkit/memory_management/manager.py index fa7a2d07..dbb0ab8d 100644 --- a/toolkit/memory_management/manager.py +++ b/toolkit/memory_management/manager.py @@ -1,5 +1,6 @@ import torch from .manager_modules import LinearLayerMemoryManager, ConvLayerMemoryManager +import random LINEAR_MODULES = [ "Linear", @@ -60,7 +61,9 @@ class MemoryManager: return self.module @classmethod - def attach(cls, module: torch.nn.Module, device: torch.device): + def attach( + cls, module: torch.nn.Module, device: torch.device, offload_percent: float = 1.0 + ): if hasattr(module, "_memory_manager"): # already attached return @@ -71,17 +74,44 @@ class MemoryManager: module._mm_to = module.to module.to = module._memory_manager.memory_managed_to + modules_processed = [] # attach to all modules for name, sub_module in module.named_modules(): for child_name, child_module in sub_module.named_modules(): - if child_module.__class__.__name__ in LINEAR_MODULES: - # linear - LinearLayerMemoryManager.attach( - child_module, module._memory_manager - ) - elif child_module.__class__.__name__ in CONV_MODULES: - # conv - ConvLayerMemoryManager.attach(child_module, module._memory_manager) + if ( + child_module.__class__.__name__ in LINEAR_MODULES + and child_module not in modules_processed + ): + skip = False + if offload_percent < 1.0: + # randomly skip some modules + if random.random() > offload_percent: + skip = True + if skip: + module._memory_manager.unmanaged_modules.append(child_module) + else: + # linear + LinearLayerMemoryManager.attach( + child_module, module._memory_manager + ) + modules_processed.append(child_module) + elif ( + child_module.__class__.__name__ in CONV_MODULES + and child_module not in modules_processed + ): + skip = False + if offload_percent < 1.0: + # randomly skip some modules + if random.random() > offload_percent: + skip = True + if skip: + module._memory_manager.unmanaged_modules.append(child_module) + else: + # conv + ConvLayerMemoryManager.attach( + child_module, module._memory_manager + ) + modules_processed.append(child_module) elif child_module.__class__.__name__ in UNMANAGED_MODULES or any( inc in child_module.__class__.__name__ for inc in UNMANAGED_MODULES_INCLUDES diff --git a/ui/src/app/jobs/new/SimpleJob.tsx b/ui/src/app/jobs/new/SimpleJob.tsx index 6e0367a7..964a7309 100644 --- a/ui/src/app/jobs/new/SimpleJob.tsx +++ b/ui/src/app/jobs/new/SimpleJob.tsx @@ -11,7 +11,7 @@ import { import { defaultDatasetConfig } from './jobConfig'; import { GroupedSelectOption, JobConfig, SelectOption } from '@/types'; import { objectCopy } from '@/utils/basic'; -import { TextInput, SelectInput, Checkbox, FormGroup, NumberInput } from '@/components/formInputs'; +import { TextInput, SelectInput, Checkbox, FormGroup, NumberInput, SliderInput } from '@/components/formInputs'; import Card from '@/components/Card'; import { X } from 'lucide-react'; import AddSingleImageModal, { openAddImageModal } from '@/components/AddSingleImageModal'; @@ -214,17 +214,47 @@ export default function SimpleJob({ /> )} - {modelArch?.additionalSections?.includes('model.auto_memory') && ( - - Auto Memory {' '} - - } - checked={jobConfig.config.process[0].model.auto_memory || false} - onChange={value => setJobConfig(value, 'config.process[0].model.auto_memory')} - docKey="model.auto_memory" - /> + {modelArch?.additionalSections?.includes('model.layer_offloading') && ( + <> + + Layer Offloading {' '} + + } + checked={jobConfig.config.process[0].model.layer_offloading || false} + onChange={value => setJobConfig(value, 'config.process[0].model.layer_offloading')} + docKey="model.layer_offloading" + /> + {jobConfig.config.process[0].model.layer_offloading && ( +
+ + setJobConfig(value * 0.01, 'config.process[0].model.layer_offloading_transformer_percent') + } + min={0} + max={100} + step={1} + /> + + setJobConfig(value * 0.01, 'config.process[0].model.layer_offloading_text_encoder_percent') + } + min={0} + max={100} + step={1} + /> +
+ )} + )} {disableSections.includes('model.quantize') ? null : ( diff --git a/ui/src/app/jobs/new/jobConfig.ts b/ui/src/app/jobs/new/jobConfig.ts index 54f34ca9..6f3a891a 100644 --- a/ui/src/app/jobs/new/jobConfig.ts +++ b/ui/src/app/jobs/new/jobConfig.ts @@ -25,7 +25,7 @@ export const defaultSliderConfig: SliderConfig = { positive_prompt: 'person who is happy', negative_prompt: 'person who is sad', target_class: 'person', - anchor_class: "", + anchor_class: '', }; export const defaultJobConfig: JobConfig = { @@ -181,5 +181,11 @@ export const migrateJobConfig = (jobConfig: JobConfig): JobConfig => { if (jobConfig?.config?.process && jobConfig.config.process[0]?.type === 'ui_trainer') { jobConfig.config.process[0].type = 'diffusion_trainer'; } + + if ('auto_memory' in jobConfig.config.process[0].model) { + jobConfig.config.process[0].model.layer_offloading = (jobConfig.config.process[0].model.auto_memory || + false) as boolean; + delete jobConfig.config.process[0].model.auto_memory; + } return jobConfig; }; diff --git a/ui/src/app/jobs/new/options.ts b/ui/src/app/jobs/new/options.ts index f9d04786..37585a24 100644 --- a/ui/src/app/jobs/new/options.ts +++ b/ui/src/app/jobs/new/options.ts @@ -20,7 +20,7 @@ type AdditionalSections = | 'sample.multi_ctrl_imgs' | 'datasets.num_frames' | 'model.multistage' - | 'model.auto_memory' + | 'model.layer_offloading' | 'model.low_vram'; type ModelGroup = 'image' | 'instruction' | 'video'; @@ -313,7 +313,7 @@ export const modelArchs: ModelArch[] = [ 'config.process[0].model.qtype': ['qfloat8', 'qfloat8'], }, disableSections: ['network.conv'], - additionalSections: ['model.low_vram', 'model.auto_memory'], + additionalSections: ['model.low_vram', 'model.layer_offloading'], accuracyRecoveryAdapters: { '3 bit with ARA': 'uint3|ostris/accuracy_recovery_adapters/qwen_image_torchao_uint3.safetensors', }, @@ -334,7 +334,7 @@ export const modelArchs: ModelArch[] = [ 'config.process[0].model.qtype': ['qfloat8', 'qfloat8'], }, disableSections: ['network.conv'], - additionalSections: ['datasets.control_path', 'sample.ctrl_img', 'model.low_vram', 'model.auto_memory'], + additionalSections: ['datasets.control_path', 'sample.ctrl_img', 'model.low_vram', 'model.layer_offloading'], accuracyRecoveryAdapters: { '3 bit with ARA': 'uint3|ostris/accuracy_recovery_adapters/qwen_image_edit_torchao_uint3.safetensors', }, @@ -360,7 +360,7 @@ export const modelArchs: ModelArch[] = [ 'datasets.multi_control_paths', 'sample.multi_ctrl_imgs', 'model.low_vram', - 'model.auto_memory', + 'model.layer_offloading', ], accuracyRecoveryAdapters: { '3 bit with ARA': 'uint3|ostris/accuracy_recovery_adapters/qwen_image_edit_2509_torchao_uint3.safetensors', diff --git a/ui/src/app/jobs/new/utils.ts b/ui/src/app/jobs/new/utils.ts index 912ea9f2..e216333a 100644 --- a/ui/src/app/jobs/new/utils.ts +++ b/ui/src/app/jobs/new/utils.ts @@ -21,17 +21,21 @@ export const handleModelArchChange = ( setJobConfig(false, 'config.process[0].model.low_vram'); } - // handle auto memory setting - if (!newArch?.additionalSections?.includes('model.auto_memory')) { - if ('auto_memory' in jobConfig.config.process[0].model) { + // handle layer offloading setting + if (!newArch?.additionalSections?.includes('model.layer_offloading')) { + if ('layer_offloading' in jobConfig.config.process[0].model) { const newModel = objectCopy(jobConfig.config.process[0].model); - delete newModel.auto_memory; + delete newModel.layer_offloading; + delete newModel.layer_offloading_text_encoder_percent; + delete newModel.layer_offloading_transformer_percent; setJobConfig(newModel, 'config.process[0].model'); } } else { // set to false if not set - if (!('auto_memory' in jobConfig.config.process[0].model)) { - setJobConfig(false, 'config.process[0].model.auto_memory'); + if (!('layer_offloading' in jobConfig.config.process[0].model)) { + setJobConfig(false, 'config.process[0].model.layer_offloading'); + setJobConfig(1.0, 'config.process[0].model.layer_offloading_text_encoder_percent'); + setJobConfig(1.0, 'config.process[0].model.layer_offloading_transformer_percent'); } } diff --git a/ui/src/components/formInputs.tsx b/ui/src/components/formInputs.tsx index 90f543db..46ec7918 100644 --- a/ui/src/components/formInputs.tsx +++ b/ui/src/components/formInputs.tsx @@ -296,3 +296,147 @@ export const FormGroup: React.FC = props => { ); }; + +export interface SliderInputProps extends InputProps { + value: number; + onChange: (value: number) => void; + min: number; + max: number; + step?: number; + disabled?: boolean; + showValue?: boolean; +} + +export const SliderInput: React.FC = props => { + const { label, value, onChange, min, max, step = 1, disabled, className, docKey = null, showValue = true } = props; + let { doc } = props; + if (!doc && docKey) { + doc = getDoc(docKey); + } + + const trackRef = React.useRef(null); + const [dragging, setDragging] = React.useState(false); + + const clamp = (v: number) => (v < min ? min : v > max ? max : v); + const snapToStep = (v: number) => { + if (!Number.isFinite(v)) return min; + const steps = Math.round((v - min) / step); + const snapped = min + steps * step; + return clamp(Number(snapped.toFixed(6))); + }; + + const percent = React.useMemo(() => { + if (max === min) return 0; + const p = ((value - min) / (max - min)) * 100; + return p < 0 ? 0 : p > 100 ? 100 : p; + }, [value, min, max]); + + const calcFromClientX = React.useCallback( + (clientX: number) => { + const el = trackRef.current; + if (!el || !Number.isFinite(clientX)) return; + const rect = el.getBoundingClientRect(); + const width = rect.right - rect.left; + if (!(width > 0)) return; + + // Clamp ratio to [0, 1] so it can never flip ends. + const ratioRaw = (clientX - rect.left) / width; + const ratio = ratioRaw <= 0 ? 0 : ratioRaw >= 1 ? 1 : ratioRaw; + + const raw = min + ratio * (max - min); + onChange(snapToStep(raw)); + }, + [min, max, step, onChange], + ); + + // Mouse/touch pointer drag + const onPointerDown = (e: React.PointerEvent) => { + if (disabled) return; + e.preventDefault(); + + // Capture the pointer so moves outside the element are still tracked correctly + try { + (e.currentTarget as HTMLElement).setPointerCapture?.(e.pointerId); + } catch {} + + setDragging(true); + calcFromClientX(e.clientX); + + const handleMove = (ev: PointerEvent) => { + ev.preventDefault(); + calcFromClientX(ev.clientX); + }; + const handleUp = (ev: PointerEvent) => { + setDragging(false); + // release capture if we got it + try { + (e.currentTarget as HTMLElement).releasePointerCapture?.(e.pointerId); + } catch {} + window.removeEventListener('pointermove', handleMove); + window.removeEventListener('pointerup', handleUp); + }; + + window.addEventListener('pointermove', handleMove); + window.addEventListener('pointerup', handleUp); + }; + + return ( +
+ {label && ( + + )} + +
+
+
+ {/* Thicker track */} +
+ + {/* Fill */} +
+ + {/* Thumb */} +
+
+ +
+ {min} + {max} +
+
+ + {showValue && ( +
+ {Number.isFinite(value) ? value : ''} +
+ )} +
+
+ ); +}; diff --git a/ui/src/docs.tsx b/ui/src/docs.tsx index 1c4bfc97..f929b2f3 100644 --- a/ui/src/docs.tsx +++ b/ui/src/docs.tsx @@ -185,10 +185,10 @@ const docs: { [key: string]: ConfigDoc } = { ), }, - 'model.auto_memory': { + 'model.layer_offloading': { title: ( <> - Auto Memory{' '} + Layer Offloading{' '} ( Experimental) @@ -204,10 +204,14 @@ const docs: { [key: string]: ConfigDoc } = { one update to the next. It will also only work with certain models.

- Auto Memory uses the CPU RAM instead of the GPU ram to hold most of the model weights. This allows training a + Layer Offloading uses the CPU RAM instead of the GPU ram to hold most of the model weights. This allows training a much larger model on a smaller GPU, assuming you have enough CPU RAM. This is slower than training on pure GPU RAM, but CPU RAM is cheaper and upgradeable. You will still need GPU RAM to hold the optimizer states and LoRA weights, so a larger card is usually still needed. +
+
+ You can also select the percentage of the layers to offload. It is generally best to offload as few as possible (close to 0%) + for best performance, but you can offload more if you need the memory. ), }, diff --git a/ui/src/types.ts b/ui/src/types.ts index 3fa58153..d86220a6 100644 --- a/ui/src/types.ts +++ b/ui/src/types.ts @@ -153,7 +153,9 @@ export interface ModelConfig { arch: string; low_vram: boolean; model_kwargs: { [key: string]: any }; - auto_memory?: boolean; + layer_offloading?: boolean; + layer_offloading_transformer_percent?: number; + layer_offloading_text_encoder_percent?: number; } export interface SampleItem {