Memory optimizations. Default to using cudamalloc when torch 2.0 for mem allocation

This commit is contained in:
Jaret Burkett
2023-09-12 04:30:23 -06:00
parent e8583860ad
commit d74dd636ee
5 changed files with 104 additions and 5 deletions

View File

@@ -813,6 +813,10 @@ class BaseSDTrainProcess(BaseTrainProcess):
if isinstance(batch, DataLoaderBatchDTO):
batch.cleanup()
# flush every 10 steps
if self.step_num % 10 == 0:
flush()
self.progress_bar.close()
self.sample(self.step_num + 1)
print("")

2
run.py
View File

@@ -3,6 +3,8 @@ import sys
from typing import Union, OrderedDict
sys.path.insert(0, os.getcwd())
# must come before ANY torch or fastai imports
import toolkit.cuda_malloc
import argparse
from toolkit.job import get_job

93
toolkit/cuda_malloc.py Normal file
View File

@@ -0,0 +1,93 @@
# ref comfy ui
import os
import importlib.util
# Can't use pytorch to get the GPU names because the cuda malloc has to be set before the first import.
def get_gpu_names():
if os.name == 'nt':
import ctypes
# Define necessary C structures and types
class DISPLAY_DEVICEA(ctypes.Structure):
_fields_ = [
('cb', ctypes.c_ulong),
('DeviceName', ctypes.c_char * 32),
('DeviceString', ctypes.c_char * 128),
('StateFlags', ctypes.c_ulong),
('DeviceID', ctypes.c_char * 128),
('DeviceKey', ctypes.c_char * 128)
]
# Load user32.dll
user32 = ctypes.windll.user32
# Call EnumDisplayDevicesA
def enum_display_devices():
device_info = DISPLAY_DEVICEA()
device_info.cb = ctypes.sizeof(device_info)
device_index = 0
gpu_names = set()
while user32.EnumDisplayDevicesA(None, device_index, ctypes.byref(device_info), 0):
device_index += 1
gpu_names.add(device_info.DeviceString.decode('utf-8'))
return gpu_names
return enum_display_devices()
else:
return set()
blacklist = {"GeForce GTX TITAN X", "GeForce GTX 980", "GeForce GTX 970", "GeForce GTX 960", "GeForce GTX 950",
"GeForce 945M",
"GeForce 940M", "GeForce 930M", "GeForce 920M", "GeForce 910M", "GeForce GTX 750", "GeForce GTX 745",
"Quadro K620",
"Quadro K1200", "Quadro K2200", "Quadro M500", "Quadro M520", "Quadro M600", "Quadro M620", "Quadro M1000",
"Quadro M1200", "Quadro M2000", "Quadro M2200", "Quadro M3000", "Quadro M4000", "Quadro M5000",
"Quadro M5500", "Quadro M6000",
"GeForce MX110", "GeForce MX130", "GeForce 830M", "GeForce 840M", "GeForce GTX 850M", "GeForce GTX 860M",
"GeForce GTX 1650", "GeForce GTX 1630"
}
def cuda_malloc_supported():
try:
names = get_gpu_names()
except:
names = set()
for x in names:
if "NVIDIA" in x:
for b in blacklist:
if b in x:
return False
return True
cuda_malloc = False
if not cuda_malloc:
try:
version = ""
torch_spec = importlib.util.find_spec("torch")
for folder in torch_spec.submodule_search_locations:
ver_file = os.path.join(folder, "version.py")
if os.path.isfile(ver_file):
spec = importlib.util.spec_from_file_location("torch_version_import", ver_file)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
version = module.__version__
if int(version[0]) >= 2: # enable by default for torch version 2.0 and up
cuda_malloc = cuda_malloc_supported()
except:
pass
if cuda_malloc:
env_var = os.environ.get('PYTORCH_CUDA_ALLOC_CONF', None)
if env_var is None:
env_var = "backend:cudaMallocAsync"
else:
env_var += ",backend:cudaMallocAsync"
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = env_var
print("CUDA Malloc Async Enabled")

View File

@@ -484,14 +484,14 @@ def get_dataloader_from_datasets(
drop_last=False,
shuffle=True,
collate_fn=dto_collation, # Use the custom collate function
num_workers=1
num_workers=0
)
else:
data_loader = DataLoader(
concatenated_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=1,
num_workers=0,
collate_fn=dto_collation
)
return data_loader

View File

@@ -53,7 +53,7 @@ class FileItemDTO(LatentCachingFileItemDTOMixin, CaptionProcessingDTOMixin, Imag
self.tensor: Union[torch.Tensor, None] = None
def cleanup(self):
self.tensor = None
del self.tensor
self.cleanup_latent()
@@ -90,7 +90,7 @@ class DataLoaderBatchDTO:
) for x in self.file_items]
def cleanup(self):
self.tensor = None
del self.latents
del self.tensor
for file_item in self.file_items:
file_item.cleanup()
del self.tensor