mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Memory optimizations. Default to using cudamalloc when torch 2.0 for mem allocation
This commit is contained in:
@@ -813,6 +813,10 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
if isinstance(batch, DataLoaderBatchDTO):
|
||||
batch.cleanup()
|
||||
|
||||
# flush every 10 steps
|
||||
if self.step_num % 10 == 0:
|
||||
flush()
|
||||
|
||||
self.progress_bar.close()
|
||||
self.sample(self.step_num + 1)
|
||||
print("")
|
||||
|
||||
2
run.py
2
run.py
@@ -3,6 +3,8 @@ import sys
|
||||
from typing import Union, OrderedDict
|
||||
|
||||
sys.path.insert(0, os.getcwd())
|
||||
# must come before ANY torch or fastai imports
|
||||
import toolkit.cuda_malloc
|
||||
import argparse
|
||||
from toolkit.job import get_job
|
||||
|
||||
|
||||
93
toolkit/cuda_malloc.py
Normal file
93
toolkit/cuda_malloc.py
Normal file
@@ -0,0 +1,93 @@
|
||||
# ref comfy ui
|
||||
import os
|
||||
import importlib.util
|
||||
|
||||
|
||||
# Can't use pytorch to get the GPU names because the cuda malloc has to be set before the first import.
|
||||
def get_gpu_names():
|
||||
if os.name == 'nt':
|
||||
import ctypes
|
||||
|
||||
# Define necessary C structures and types
|
||||
class DISPLAY_DEVICEA(ctypes.Structure):
|
||||
_fields_ = [
|
||||
('cb', ctypes.c_ulong),
|
||||
('DeviceName', ctypes.c_char * 32),
|
||||
('DeviceString', ctypes.c_char * 128),
|
||||
('StateFlags', ctypes.c_ulong),
|
||||
('DeviceID', ctypes.c_char * 128),
|
||||
('DeviceKey', ctypes.c_char * 128)
|
||||
]
|
||||
|
||||
# Load user32.dll
|
||||
user32 = ctypes.windll.user32
|
||||
|
||||
# Call EnumDisplayDevicesA
|
||||
def enum_display_devices():
|
||||
device_info = DISPLAY_DEVICEA()
|
||||
device_info.cb = ctypes.sizeof(device_info)
|
||||
device_index = 0
|
||||
gpu_names = set()
|
||||
|
||||
while user32.EnumDisplayDevicesA(None, device_index, ctypes.byref(device_info), 0):
|
||||
device_index += 1
|
||||
gpu_names.add(device_info.DeviceString.decode('utf-8'))
|
||||
return gpu_names
|
||||
|
||||
return enum_display_devices()
|
||||
else:
|
||||
return set()
|
||||
|
||||
|
||||
blacklist = {"GeForce GTX TITAN X", "GeForce GTX 980", "GeForce GTX 970", "GeForce GTX 960", "GeForce GTX 950",
|
||||
"GeForce 945M",
|
||||
"GeForce 940M", "GeForce 930M", "GeForce 920M", "GeForce 910M", "GeForce GTX 750", "GeForce GTX 745",
|
||||
"Quadro K620",
|
||||
"Quadro K1200", "Quadro K2200", "Quadro M500", "Quadro M520", "Quadro M600", "Quadro M620", "Quadro M1000",
|
||||
"Quadro M1200", "Quadro M2000", "Quadro M2200", "Quadro M3000", "Quadro M4000", "Quadro M5000",
|
||||
"Quadro M5500", "Quadro M6000",
|
||||
"GeForce MX110", "GeForce MX130", "GeForce 830M", "GeForce 840M", "GeForce GTX 850M", "GeForce GTX 860M",
|
||||
"GeForce GTX 1650", "GeForce GTX 1630"
|
||||
}
|
||||
|
||||
|
||||
def cuda_malloc_supported():
|
||||
try:
|
||||
names = get_gpu_names()
|
||||
except:
|
||||
names = set()
|
||||
for x in names:
|
||||
if "NVIDIA" in x:
|
||||
for b in blacklist:
|
||||
if b in x:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
cuda_malloc = False
|
||||
|
||||
if not cuda_malloc:
|
||||
try:
|
||||
version = ""
|
||||
torch_spec = importlib.util.find_spec("torch")
|
||||
for folder in torch_spec.submodule_search_locations:
|
||||
ver_file = os.path.join(folder, "version.py")
|
||||
if os.path.isfile(ver_file):
|
||||
spec = importlib.util.spec_from_file_location("torch_version_import", ver_file)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
version = module.__version__
|
||||
if int(version[0]) >= 2: # enable by default for torch version 2.0 and up
|
||||
cuda_malloc = cuda_malloc_supported()
|
||||
except:
|
||||
pass
|
||||
|
||||
if cuda_malloc:
|
||||
env_var = os.environ.get('PYTORCH_CUDA_ALLOC_CONF', None)
|
||||
if env_var is None:
|
||||
env_var = "backend:cudaMallocAsync"
|
||||
else:
|
||||
env_var += ",backend:cudaMallocAsync"
|
||||
|
||||
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = env_var
|
||||
print("CUDA Malloc Async Enabled")
|
||||
@@ -484,14 +484,14 @@ def get_dataloader_from_datasets(
|
||||
drop_last=False,
|
||||
shuffle=True,
|
||||
collate_fn=dto_collation, # Use the custom collate function
|
||||
num_workers=1
|
||||
num_workers=0
|
||||
)
|
||||
else:
|
||||
data_loader = DataLoader(
|
||||
concatenated_dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
num_workers=1,
|
||||
num_workers=0,
|
||||
collate_fn=dto_collation
|
||||
)
|
||||
return data_loader
|
||||
|
||||
@@ -53,7 +53,7 @@ class FileItemDTO(LatentCachingFileItemDTOMixin, CaptionProcessingDTOMixin, Imag
|
||||
self.tensor: Union[torch.Tensor, None] = None
|
||||
|
||||
def cleanup(self):
|
||||
self.tensor = None
|
||||
del self.tensor
|
||||
self.cleanup_latent()
|
||||
|
||||
|
||||
@@ -90,7 +90,7 @@ class DataLoaderBatchDTO:
|
||||
) for x in self.file_items]
|
||||
|
||||
def cleanup(self):
|
||||
self.tensor = None
|
||||
del self.latents
|
||||
del self.tensor
|
||||
for file_item in self.file_items:
|
||||
file_item.cleanup()
|
||||
del self.tensor
|
||||
|
||||
Reference in New Issue
Block a user