mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Change auto_memory to be layer_offloading and allow you to set the amount to unload
This commit is contained in:
@@ -125,8 +125,12 @@ class QwenImageModel(BaseModel):
|
|||||||
quantize_model(self, transformer)
|
quantize_model(self, transformer)
|
||||||
flush()
|
flush()
|
||||||
|
|
||||||
if self.model_config.auto_memory:
|
if self.model_config.layer_offloading and self.model_config.layer_offloading_transformer_percent > 0:
|
||||||
MemoryManager.attach(transformer, self.device_torch)
|
MemoryManager.attach(
|
||||||
|
transformer,
|
||||||
|
self.device_torch,
|
||||||
|
offload_percent=self.model_config.layer_offloading_transformer_percent
|
||||||
|
)
|
||||||
|
|
||||||
if self.model_config.low_vram:
|
if self.model_config.low_vram:
|
||||||
self.print_and_status_update("Moving transformer to CPU")
|
self.print_and_status_update("Moving transformer to CPU")
|
||||||
@@ -147,8 +151,12 @@ class QwenImageModel(BaseModel):
|
|||||||
if not self._qwen_image_keep_visual:
|
if not self._qwen_image_keep_visual:
|
||||||
text_encoder.model.visual = None
|
text_encoder.model.visual = None
|
||||||
|
|
||||||
if self.model_config.auto_memory:
|
if self.model_config.layer_offloading and self.model_config.layer_offloading_text_encoder_percent > 0:
|
||||||
MemoryManager.attach(text_encoder, self.device_torch)
|
MemoryManager.attach(
|
||||||
|
text_encoder,
|
||||||
|
self.device_torch,
|
||||||
|
offload_percent=self.model_config.layer_offloading_text_encoder_percent
|
||||||
|
)
|
||||||
|
|
||||||
text_encoder.to(self.device_torch, dtype=dtype)
|
text_encoder.to(self.device_torch, dtype=dtype)
|
||||||
flush()
|
flush()
|
||||||
|
|||||||
@@ -1759,7 +1759,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# we cannot merge in if quantized
|
# we cannot merge in if quantized
|
||||||
if self.model_config.quantize or self.model_config.auto_memory:
|
if self.model_config.quantize or self.model_config.layer_offloading:
|
||||||
# todo find a way around this
|
# todo find a way around this
|
||||||
self.network.can_merge_in = False
|
self.network.can_merge_in = False
|
||||||
|
|
||||||
|
|||||||
@@ -626,12 +626,18 @@ class ModelConfig:
|
|||||||
|
|
||||||
# auto memory management, only for some models
|
# auto memory management, only for some models
|
||||||
self.auto_memory = kwargs.get("auto_memory", False)
|
self.auto_memory = kwargs.get("auto_memory", False)
|
||||||
if self.auto_memory and self.qtype == "qfloat8":
|
# auto memory is deprecated, use layer offloading instead
|
||||||
print(f"Auto memory is not compatible with qfloat8, switching to float8 for model")
|
if self.auto_memory:
|
||||||
|
print("auto_memory is deprecated, use layer_offloading instead")
|
||||||
|
self.layer_offloading = kwargs.get("layer_offloading", self.auto_memory )
|
||||||
|
if self.layer_offloading and self.qtype == "qfloat8":
|
||||||
self.qtype = "float8"
|
self.qtype = "float8"
|
||||||
if self.auto_memory and not self.qtype_te == "qfloat8":
|
if self.layer_offloading and not self.qtype_te == "qfloat8":
|
||||||
print(f"Auto memory is not compatible with qfloat8, switching to float8 for te")
|
|
||||||
self.qtype_te = "float8"
|
self.qtype_te = "float8"
|
||||||
|
|
||||||
|
# 0 is off and 1.0 is 100% of the layers
|
||||||
|
self.layer_offloading_transformer_percent = kwargs.get("layer_offloading_transformer_percent", 1.0)
|
||||||
|
self.layer_offloading_text_encoder_percent = kwargs.get("layer_offloading_text_encoder_percent", 1.0)
|
||||||
|
|
||||||
# can be used to load the extras like text encoder or vae from here
|
# can be used to load the extras like text encoder or vae from here
|
||||||
# only setup for some models but will prevent having to download the te for
|
# only setup for some models but will prevent having to download the te for
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
import torch
|
import torch
|
||||||
from .manager_modules import LinearLayerMemoryManager, ConvLayerMemoryManager
|
from .manager_modules import LinearLayerMemoryManager, ConvLayerMemoryManager
|
||||||
|
import random
|
||||||
|
|
||||||
LINEAR_MODULES = [
|
LINEAR_MODULES = [
|
||||||
"Linear",
|
"Linear",
|
||||||
@@ -60,7 +61,9 @@ class MemoryManager:
|
|||||||
return self.module
|
return self.module
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def attach(cls, module: torch.nn.Module, device: torch.device):
|
def attach(
|
||||||
|
cls, module: torch.nn.Module, device: torch.device, offload_percent: float = 1.0
|
||||||
|
):
|
||||||
if hasattr(module, "_memory_manager"):
|
if hasattr(module, "_memory_manager"):
|
||||||
# already attached
|
# already attached
|
||||||
return
|
return
|
||||||
@@ -71,17 +74,44 @@ class MemoryManager:
|
|||||||
module._mm_to = module.to
|
module._mm_to = module.to
|
||||||
module.to = module._memory_manager.memory_managed_to
|
module.to = module._memory_manager.memory_managed_to
|
||||||
|
|
||||||
|
modules_processed = []
|
||||||
# attach to all modules
|
# attach to all modules
|
||||||
for name, sub_module in module.named_modules():
|
for name, sub_module in module.named_modules():
|
||||||
for child_name, child_module in sub_module.named_modules():
|
for child_name, child_module in sub_module.named_modules():
|
||||||
if child_module.__class__.__name__ in LINEAR_MODULES:
|
if (
|
||||||
# linear
|
child_module.__class__.__name__ in LINEAR_MODULES
|
||||||
LinearLayerMemoryManager.attach(
|
and child_module not in modules_processed
|
||||||
child_module, module._memory_manager
|
):
|
||||||
)
|
skip = False
|
||||||
elif child_module.__class__.__name__ in CONV_MODULES:
|
if offload_percent < 1.0:
|
||||||
# conv
|
# randomly skip some modules
|
||||||
ConvLayerMemoryManager.attach(child_module, module._memory_manager)
|
if random.random() > offload_percent:
|
||||||
|
skip = True
|
||||||
|
if skip:
|
||||||
|
module._memory_manager.unmanaged_modules.append(child_module)
|
||||||
|
else:
|
||||||
|
# linear
|
||||||
|
LinearLayerMemoryManager.attach(
|
||||||
|
child_module, module._memory_manager
|
||||||
|
)
|
||||||
|
modules_processed.append(child_module)
|
||||||
|
elif (
|
||||||
|
child_module.__class__.__name__ in CONV_MODULES
|
||||||
|
and child_module not in modules_processed
|
||||||
|
):
|
||||||
|
skip = False
|
||||||
|
if offload_percent < 1.0:
|
||||||
|
# randomly skip some modules
|
||||||
|
if random.random() > offload_percent:
|
||||||
|
skip = True
|
||||||
|
if skip:
|
||||||
|
module._memory_manager.unmanaged_modules.append(child_module)
|
||||||
|
else:
|
||||||
|
# conv
|
||||||
|
ConvLayerMemoryManager.attach(
|
||||||
|
child_module, module._memory_manager
|
||||||
|
)
|
||||||
|
modules_processed.append(child_module)
|
||||||
elif child_module.__class__.__name__ in UNMANAGED_MODULES or any(
|
elif child_module.__class__.__name__ in UNMANAGED_MODULES or any(
|
||||||
inc in child_module.__class__.__name__
|
inc in child_module.__class__.__name__
|
||||||
for inc in UNMANAGED_MODULES_INCLUDES
|
for inc in UNMANAGED_MODULES_INCLUDES
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ import {
|
|||||||
import { defaultDatasetConfig } from './jobConfig';
|
import { defaultDatasetConfig } from './jobConfig';
|
||||||
import { GroupedSelectOption, JobConfig, SelectOption } from '@/types';
|
import { GroupedSelectOption, JobConfig, SelectOption } from '@/types';
|
||||||
import { objectCopy } from '@/utils/basic';
|
import { objectCopy } from '@/utils/basic';
|
||||||
import { TextInput, SelectInput, Checkbox, FormGroup, NumberInput } from '@/components/formInputs';
|
import { TextInput, SelectInput, Checkbox, FormGroup, NumberInput, SliderInput } from '@/components/formInputs';
|
||||||
import Card from '@/components/Card';
|
import Card from '@/components/Card';
|
||||||
import { X } from 'lucide-react';
|
import { X } from 'lucide-react';
|
||||||
import AddSingleImageModal, { openAddImageModal } from '@/components/AddSingleImageModal';
|
import AddSingleImageModal, { openAddImageModal } from '@/components/AddSingleImageModal';
|
||||||
@@ -214,17 +214,47 @@ export default function SimpleJob({
|
|||||||
/>
|
/>
|
||||||
</FormGroup>
|
</FormGroup>
|
||||||
)}
|
)}
|
||||||
{modelArch?.additionalSections?.includes('model.auto_memory') && (
|
{modelArch?.additionalSections?.includes('model.layer_offloading') && (
|
||||||
<Checkbox
|
<>
|
||||||
label={
|
<Checkbox
|
||||||
<>
|
label={
|
||||||
Auto Memory <IoFlaskSharp className="inline text-yellow-500" name="Experimental" />{' '}
|
<>
|
||||||
</>
|
Layer Offloading <IoFlaskSharp className="inline text-yellow-500" name="Experimental" />{' '}
|
||||||
}
|
</>
|
||||||
checked={jobConfig.config.process[0].model.auto_memory || false}
|
}
|
||||||
onChange={value => setJobConfig(value, 'config.process[0].model.auto_memory')}
|
checked={jobConfig.config.process[0].model.layer_offloading || false}
|
||||||
docKey="model.auto_memory"
|
onChange={value => setJobConfig(value, 'config.process[0].model.layer_offloading')}
|
||||||
/>
|
docKey="model.layer_offloading"
|
||||||
|
/>
|
||||||
|
{jobConfig.config.process[0].model.layer_offloading && (
|
||||||
|
<div className="pt-2">
|
||||||
|
<SliderInput
|
||||||
|
label="Transformer Offload %"
|
||||||
|
value={Math.round(
|
||||||
|
(jobConfig.config.process[0].model.layer_offloading_transformer_percent ?? 1) * 100,
|
||||||
|
)}
|
||||||
|
onChange={value =>
|
||||||
|
setJobConfig(value * 0.01, 'config.process[0].model.layer_offloading_transformer_percent')
|
||||||
|
}
|
||||||
|
min={0}
|
||||||
|
max={100}
|
||||||
|
step={1}
|
||||||
|
/>
|
||||||
|
<SliderInput
|
||||||
|
label="Text Encoder Offload %"
|
||||||
|
value={Math.round(
|
||||||
|
(jobConfig.config.process[0].model.layer_offloading_text_encoder_percent ?? 1) * 100,
|
||||||
|
)}
|
||||||
|
onChange={value =>
|
||||||
|
setJobConfig(value * 0.01, 'config.process[0].model.layer_offloading_text_encoder_percent')
|
||||||
|
}
|
||||||
|
min={0}
|
||||||
|
max={100}
|
||||||
|
step={1}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</>
|
||||||
)}
|
)}
|
||||||
</Card>
|
</Card>
|
||||||
{disableSections.includes('model.quantize') ? null : (
|
{disableSections.includes('model.quantize') ? null : (
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ export const defaultSliderConfig: SliderConfig = {
|
|||||||
positive_prompt: 'person who is happy',
|
positive_prompt: 'person who is happy',
|
||||||
negative_prompt: 'person who is sad',
|
negative_prompt: 'person who is sad',
|
||||||
target_class: 'person',
|
target_class: 'person',
|
||||||
anchor_class: "",
|
anchor_class: '',
|
||||||
};
|
};
|
||||||
|
|
||||||
export const defaultJobConfig: JobConfig = {
|
export const defaultJobConfig: JobConfig = {
|
||||||
@@ -181,5 +181,11 @@ export const migrateJobConfig = (jobConfig: JobConfig): JobConfig => {
|
|||||||
if (jobConfig?.config?.process && jobConfig.config.process[0]?.type === 'ui_trainer') {
|
if (jobConfig?.config?.process && jobConfig.config.process[0]?.type === 'ui_trainer') {
|
||||||
jobConfig.config.process[0].type = 'diffusion_trainer';
|
jobConfig.config.process[0].type = 'diffusion_trainer';
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if ('auto_memory' in jobConfig.config.process[0].model) {
|
||||||
|
jobConfig.config.process[0].model.layer_offloading = (jobConfig.config.process[0].model.auto_memory ||
|
||||||
|
false) as boolean;
|
||||||
|
delete jobConfig.config.process[0].model.auto_memory;
|
||||||
|
}
|
||||||
return jobConfig;
|
return jobConfig;
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ type AdditionalSections =
|
|||||||
| 'sample.multi_ctrl_imgs'
|
| 'sample.multi_ctrl_imgs'
|
||||||
| 'datasets.num_frames'
|
| 'datasets.num_frames'
|
||||||
| 'model.multistage'
|
| 'model.multistage'
|
||||||
| 'model.auto_memory'
|
| 'model.layer_offloading'
|
||||||
| 'model.low_vram';
|
| 'model.low_vram';
|
||||||
type ModelGroup = 'image' | 'instruction' | 'video';
|
type ModelGroup = 'image' | 'instruction' | 'video';
|
||||||
|
|
||||||
@@ -313,7 +313,7 @@ export const modelArchs: ModelArch[] = [
|
|||||||
'config.process[0].model.qtype': ['qfloat8', 'qfloat8'],
|
'config.process[0].model.qtype': ['qfloat8', 'qfloat8'],
|
||||||
},
|
},
|
||||||
disableSections: ['network.conv'],
|
disableSections: ['network.conv'],
|
||||||
additionalSections: ['model.low_vram', 'model.auto_memory'],
|
additionalSections: ['model.low_vram', 'model.layer_offloading'],
|
||||||
accuracyRecoveryAdapters: {
|
accuracyRecoveryAdapters: {
|
||||||
'3 bit with ARA': 'uint3|ostris/accuracy_recovery_adapters/qwen_image_torchao_uint3.safetensors',
|
'3 bit with ARA': 'uint3|ostris/accuracy_recovery_adapters/qwen_image_torchao_uint3.safetensors',
|
||||||
},
|
},
|
||||||
@@ -334,7 +334,7 @@ export const modelArchs: ModelArch[] = [
|
|||||||
'config.process[0].model.qtype': ['qfloat8', 'qfloat8'],
|
'config.process[0].model.qtype': ['qfloat8', 'qfloat8'],
|
||||||
},
|
},
|
||||||
disableSections: ['network.conv'],
|
disableSections: ['network.conv'],
|
||||||
additionalSections: ['datasets.control_path', 'sample.ctrl_img', 'model.low_vram', 'model.auto_memory'],
|
additionalSections: ['datasets.control_path', 'sample.ctrl_img', 'model.low_vram', 'model.layer_offloading'],
|
||||||
accuracyRecoveryAdapters: {
|
accuracyRecoveryAdapters: {
|
||||||
'3 bit with ARA': 'uint3|ostris/accuracy_recovery_adapters/qwen_image_edit_torchao_uint3.safetensors',
|
'3 bit with ARA': 'uint3|ostris/accuracy_recovery_adapters/qwen_image_edit_torchao_uint3.safetensors',
|
||||||
},
|
},
|
||||||
@@ -360,7 +360,7 @@ export const modelArchs: ModelArch[] = [
|
|||||||
'datasets.multi_control_paths',
|
'datasets.multi_control_paths',
|
||||||
'sample.multi_ctrl_imgs',
|
'sample.multi_ctrl_imgs',
|
||||||
'model.low_vram',
|
'model.low_vram',
|
||||||
'model.auto_memory',
|
'model.layer_offloading',
|
||||||
],
|
],
|
||||||
accuracyRecoveryAdapters: {
|
accuracyRecoveryAdapters: {
|
||||||
'3 bit with ARA': 'uint3|ostris/accuracy_recovery_adapters/qwen_image_edit_2509_torchao_uint3.safetensors',
|
'3 bit with ARA': 'uint3|ostris/accuracy_recovery_adapters/qwen_image_edit_2509_torchao_uint3.safetensors',
|
||||||
|
|||||||
@@ -21,17 +21,21 @@ export const handleModelArchChange = (
|
|||||||
setJobConfig(false, 'config.process[0].model.low_vram');
|
setJobConfig(false, 'config.process[0].model.low_vram');
|
||||||
}
|
}
|
||||||
|
|
||||||
// handle auto memory setting
|
// handle layer offloading setting
|
||||||
if (!newArch?.additionalSections?.includes('model.auto_memory')) {
|
if (!newArch?.additionalSections?.includes('model.layer_offloading')) {
|
||||||
if ('auto_memory' in jobConfig.config.process[0].model) {
|
if ('layer_offloading' in jobConfig.config.process[0].model) {
|
||||||
const newModel = objectCopy(jobConfig.config.process[0].model);
|
const newModel = objectCopy(jobConfig.config.process[0].model);
|
||||||
delete newModel.auto_memory;
|
delete newModel.layer_offloading;
|
||||||
|
delete newModel.layer_offloading_text_encoder_percent;
|
||||||
|
delete newModel.layer_offloading_transformer_percent;
|
||||||
setJobConfig(newModel, 'config.process[0].model');
|
setJobConfig(newModel, 'config.process[0].model');
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// set to false if not set
|
// set to false if not set
|
||||||
if (!('auto_memory' in jobConfig.config.process[0].model)) {
|
if (!('layer_offloading' in jobConfig.config.process[0].model)) {
|
||||||
setJobConfig(false, 'config.process[0].model.auto_memory');
|
setJobConfig(false, 'config.process[0].model.layer_offloading');
|
||||||
|
setJobConfig(1.0, 'config.process[0].model.layer_offloading_text_encoder_percent');
|
||||||
|
setJobConfig(1.0, 'config.process[0].model.layer_offloading_transformer_percent');
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -296,3 +296,147 @@ export const FormGroup: React.FC<FormGroupProps> = props => {
|
|||||||
</div>
|
</div>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export interface SliderInputProps extends InputProps {
|
||||||
|
value: number;
|
||||||
|
onChange: (value: number) => void;
|
||||||
|
min: number;
|
||||||
|
max: number;
|
||||||
|
step?: number;
|
||||||
|
disabled?: boolean;
|
||||||
|
showValue?: boolean;
|
||||||
|
}
|
||||||
|
|
||||||
|
export const SliderInput: React.FC<SliderInputProps> = props => {
|
||||||
|
const { label, value, onChange, min, max, step = 1, disabled, className, docKey = null, showValue = true } = props;
|
||||||
|
let { doc } = props;
|
||||||
|
if (!doc && docKey) {
|
||||||
|
doc = getDoc(docKey);
|
||||||
|
}
|
||||||
|
|
||||||
|
const trackRef = React.useRef<HTMLDivElement | null>(null);
|
||||||
|
const [dragging, setDragging] = React.useState(false);
|
||||||
|
|
||||||
|
const clamp = (v: number) => (v < min ? min : v > max ? max : v);
|
||||||
|
const snapToStep = (v: number) => {
|
||||||
|
if (!Number.isFinite(v)) return min;
|
||||||
|
const steps = Math.round((v - min) / step);
|
||||||
|
const snapped = min + steps * step;
|
||||||
|
return clamp(Number(snapped.toFixed(6)));
|
||||||
|
};
|
||||||
|
|
||||||
|
const percent = React.useMemo(() => {
|
||||||
|
if (max === min) return 0;
|
||||||
|
const p = ((value - min) / (max - min)) * 100;
|
||||||
|
return p < 0 ? 0 : p > 100 ? 100 : p;
|
||||||
|
}, [value, min, max]);
|
||||||
|
|
||||||
|
const calcFromClientX = React.useCallback(
|
||||||
|
(clientX: number) => {
|
||||||
|
const el = trackRef.current;
|
||||||
|
if (!el || !Number.isFinite(clientX)) return;
|
||||||
|
const rect = el.getBoundingClientRect();
|
||||||
|
const width = rect.right - rect.left;
|
||||||
|
if (!(width > 0)) return;
|
||||||
|
|
||||||
|
// Clamp ratio to [0, 1] so it can never flip ends.
|
||||||
|
const ratioRaw = (clientX - rect.left) / width;
|
||||||
|
const ratio = ratioRaw <= 0 ? 0 : ratioRaw >= 1 ? 1 : ratioRaw;
|
||||||
|
|
||||||
|
const raw = min + ratio * (max - min);
|
||||||
|
onChange(snapToStep(raw));
|
||||||
|
},
|
||||||
|
[min, max, step, onChange],
|
||||||
|
);
|
||||||
|
|
||||||
|
// Mouse/touch pointer drag
|
||||||
|
const onPointerDown = (e: React.PointerEvent) => {
|
||||||
|
if (disabled) return;
|
||||||
|
e.preventDefault();
|
||||||
|
|
||||||
|
// Capture the pointer so moves outside the element are still tracked correctly
|
||||||
|
try {
|
||||||
|
(e.currentTarget as HTMLElement).setPointerCapture?.(e.pointerId);
|
||||||
|
} catch {}
|
||||||
|
|
||||||
|
setDragging(true);
|
||||||
|
calcFromClientX(e.clientX);
|
||||||
|
|
||||||
|
const handleMove = (ev: PointerEvent) => {
|
||||||
|
ev.preventDefault();
|
||||||
|
calcFromClientX(ev.clientX);
|
||||||
|
};
|
||||||
|
const handleUp = (ev: PointerEvent) => {
|
||||||
|
setDragging(false);
|
||||||
|
// release capture if we got it
|
||||||
|
try {
|
||||||
|
(e.currentTarget as HTMLElement).releasePointerCapture?.(e.pointerId);
|
||||||
|
} catch {}
|
||||||
|
window.removeEventListener('pointermove', handleMove);
|
||||||
|
window.removeEventListener('pointerup', handleUp);
|
||||||
|
};
|
||||||
|
|
||||||
|
window.addEventListener('pointermove', handleMove);
|
||||||
|
window.addEventListener('pointerup', handleUp);
|
||||||
|
};
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className={classNames(className, disabled ? 'opacity-30 cursor-not-allowed' : '')}>
|
||||||
|
{label && (
|
||||||
|
<label className={labelClasses}>
|
||||||
|
{label}{' '}
|
||||||
|
{doc && (
|
||||||
|
<div className="inline-block ml-1 text-xs text-gray-500 cursor-pointer" onClick={() => openDoc(doc)}>
|
||||||
|
<CircleHelp className="inline-block w-4 h-4 cursor-pointer" />
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</label>
|
||||||
|
)}
|
||||||
|
|
||||||
|
<div className="flex items-center gap-3">
|
||||||
|
<div className="flex-1">
|
||||||
|
<div
|
||||||
|
ref={trackRef}
|
||||||
|
onPointerDown={onPointerDown}
|
||||||
|
className={classNames(
|
||||||
|
'relative w-full h-6 select-none outline-none',
|
||||||
|
disabled ? 'pointer-events-none' : 'cursor-pointer',
|
||||||
|
)}
|
||||||
|
>
|
||||||
|
{/* Thicker track */}
|
||||||
|
<div className="pointer-events-none absolute left-0 right-0 top-1/2 -translate-y-1/2 h-3 rounded-sm bg-gray-800 border border-gray-700" />
|
||||||
|
|
||||||
|
{/* Fill */}
|
||||||
|
<div
|
||||||
|
className="pointer-events-none absolute left-0 top-1/2 -translate-y-1/2 h-3 rounded-sm bg-blue-600"
|
||||||
|
style={{ width: `${percent}%` }}
|
||||||
|
/>
|
||||||
|
|
||||||
|
{/* Thumb */}
|
||||||
|
<div
|
||||||
|
onPointerDown={onPointerDown}
|
||||||
|
className={classNames(
|
||||||
|
'absolute top-1/2 -translate-y-1/2 -ml-2',
|
||||||
|
'h-4 w-4 rounded-full bg-white shadow border border-gray-300 cursor-pointer',
|
||||||
|
'after:content-[""] after:absolute after:inset-[-6px] after:rounded-full after:bg-transparent', // expands hit area
|
||||||
|
dragging ? 'ring-2 ring-blue-600' : '',
|
||||||
|
)}
|
||||||
|
style={{ left: `calc(${percent}% )` }}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div className="flex justify-between text-xs text-gray-500 mt-0.5 select-none">
|
||||||
|
<span>{min}</span>
|
||||||
|
<span>{max}</span>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{showValue && (
|
||||||
|
<div className="min-w-[3.5rem] text-right text-sm px-3 py-1 bg-gray-800 border border-gray-700 rounded-sm">
|
||||||
|
{Number.isFinite(value) ? value : ''}
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|||||||
@@ -185,10 +185,10 @@ const docs: { [key: string]: ConfigDoc } = {
|
|||||||
</>
|
</>
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
'model.auto_memory': {
|
'model.layer_offloading': {
|
||||||
title: (
|
title: (
|
||||||
<>
|
<>
|
||||||
Auto Memory{' '}
|
Layer Offloading{' '}
|
||||||
<span className="text-yellow-500">
|
<span className="text-yellow-500">
|
||||||
( <IoFlaskSharp className="inline text-yellow-500" name="Experimental" /> Experimental)
|
( <IoFlaskSharp className="inline text-yellow-500" name="Experimental" /> Experimental)
|
||||||
</span>
|
</span>
|
||||||
@@ -204,10 +204,14 @@ const docs: { [key: string]: ConfigDoc } = {
|
|||||||
one update to the next. It will also only work with certain models.
|
one update to the next. It will also only work with certain models.
|
||||||
<br />
|
<br />
|
||||||
<br />
|
<br />
|
||||||
Auto Memory uses the CPU RAM instead of the GPU ram to hold most of the model weights. This allows training a
|
Layer Offloading uses the CPU RAM instead of the GPU ram to hold most of the model weights. This allows training a
|
||||||
much larger model on a smaller GPU, assuming you have enough CPU RAM. This is slower than training on pure GPU
|
much larger model on a smaller GPU, assuming you have enough CPU RAM. This is slower than training on pure GPU
|
||||||
RAM, but CPU RAM is cheaper and upgradeable. You will still need GPU RAM to hold the optimizer states and LoRA weights,
|
RAM, but CPU RAM is cheaper and upgradeable. You will still need GPU RAM to hold the optimizer states and LoRA weights,
|
||||||
so a larger card is usually still needed.
|
so a larger card is usually still needed.
|
||||||
|
<br />
|
||||||
|
<br />
|
||||||
|
You can also select the percentage of the layers to offload. It is generally best to offload as few as possible (close to 0%)
|
||||||
|
for best performance, but you can offload more if you need the memory.
|
||||||
</>
|
</>
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -153,7 +153,9 @@ export interface ModelConfig {
|
|||||||
arch: string;
|
arch: string;
|
||||||
low_vram: boolean;
|
low_vram: boolean;
|
||||||
model_kwargs: { [key: string]: any };
|
model_kwargs: { [key: string]: any };
|
||||||
auto_memory?: boolean;
|
layer_offloading?: boolean;
|
||||||
|
layer_offloading_transformer_percent?: number;
|
||||||
|
layer_offloading_text_encoder_percent?: number;
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface SampleItem {
|
export interface SampleItem {
|
||||||
|
|||||||
Reference in New Issue
Block a user