More stream gymnastics

This commit is contained in:
turboderp
2024-09-23 17:28:55 +02:00
parent a5132d072e
commit 15e54046ba
7 changed files with 81 additions and 14 deletions

View File

@@ -1,6 +1,7 @@
from __future__ import annotations
import torch
import itertools
from exllamav2.device import get_device_stream
# Emulate pairwise on Python <3.10
@@ -63,6 +64,8 @@ def safe_move_tensor(
# Accept torch.device, string or int
device = torch.device(device)
from_index = tensor.device.index
to_index = device.index
# No move
@@ -71,15 +74,68 @@ def safe_move_tensor(
# Copies to/from system RAM are always fine
if tensor.device.type == "cpu" or device.type == "cpu":
return tensor.to(device, non_blocking = non_blocking)
if tensor.device.type == "cpu":
stream = get_device_stream(to_index)
if stream is not None:
with torch.cuda.stream(stream):
r = tensor.to(device, non_blocking = True)
torch.cuda.synchronize(to_index)
return r
else:
return tensor.to(device, non_blocking = non_blocking)
if device.type == "cpu":
stream = get_device_stream(from_index)
if stream is not None:
with torch.cuda.stream(stream):
r = tensor.to(device, non_blocking = True)
torch.cuda.synchronize(from_index)
return r
else:
return tensor.to(device, non_blocking = non_blocking)
# Source and dest are distinct CUDA devices
# Test tensor.to (once) and if it seems to be working, let Torch decide
if test_gpu_peer_copy(tensor.device, device):
return tensor.to(device, non_blocking = non_blocking)
from_stream = get_device_stream(from_index)
to_stream = get_device_stream(to_index)
if from_stream is not None and to_stream is not None:
with torch.cuda.stream(from_stream):
with torch.cuda.stream(to_stream):
r = tensor.to(device, non_blocking = True)
elif from_stream is not None:
with torch.cuda.stream(from_stream):
r = tensor.to(device, non_blocking = True)
elif to_stream is not None:
with torch.cuda.stream(to_stream):
r = tensor.to(device, non_blocking = True)
else:
r = tensor.to(device, non_blocking = True)
if not non_blocking:
torch.cuda.synchronize(to_index)
return r
# Force move tensor via CPU
return tensor.cpu().to(device)
from_stream = get_device_stream(from_index)
to_stream = get_device_stream(to_index)
if from_stream is not None:
with torch.cuda.stream(from_stream):
tensor_cpu = tensor.to("cpu", non_blocking = True)
torch.cuda.synchronize(from_index)
else:
tensor_cpu = tensor.cpu()
if to_stream is not None:
with torch.cuda.stream(to_stream):
r = tensor_cpu.to(device, non_blocking = True)
torch.cuda.synchronize(to_index)
return r
else:
return tensor_cpu.to(device)

View File

@@ -21,9 +21,15 @@ def set_device_streams():
global global_streams
for(k, v) in global_streams.items():
with torch.cuda.device(torch.device(k)):
torch.cuda.set_device(torch.device(k))
torch.cuda.set_stream(v)
def get_device_stream(index: int):
global global_streams
return global_streams.get(index)
class ExLlamaV2DeviceContext:
model: ExLlamaV2
@@ -56,7 +62,8 @@ class ExLlamaV2DeviceContext:
# Create streams (only one per device)
if device_idx not in global_streams:
global_streams[device_idx] = torch.cuda.Stream(torch.device(device_idx), -100)
s = torch.cuda.Stream(torch.device(device_idx), -100)
global_streams[device_idx] = s
self.stream = global_streams[device_idx]

View File

@@ -192,7 +192,7 @@ QMatrix::QMatrix
gridDim.x = DIVIDE(width, THREADS_X);
gridDim.y = 1;
shuffle_kernel<<<gridDim, blockDim>>>(cuda_q_weight, height, width, rows_8, rows_6, rows_5, rows_4, rows_3, rows_2);
shuffle_kernel<<<gridDim, blockDim, 0, stream>>>(cuda_q_weight, height, width, rows_8, rows_6, rows_5, rows_4, rows_3, rows_2);
}
QMatrix::~QMatrix()

View File

@@ -80,7 +80,9 @@ uintptr_t make_q_matrix
TORCH_CHECK(temp_dq.size(0) >= dq_req, "Insufficient size of temp_dq buffer")
}
const at::cuda::OptionalCUDAGuard device_guard(device_of(q_weight));
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
QMatrix* m = new QMatrix
(
stream,
@@ -151,6 +153,7 @@ uintptr_t make_q_matrix_split
TORCH_CHECK(false, "Tensor split not implemented for GPTQ matrices");
}
const at::cuda::OptionalCUDAGuard device_guard(device_of(q_weight));
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
QMatrix* m = new QMatrix

View File

@@ -35,7 +35,7 @@ void stloader_read
TORCH_CHECK(load_buffer, "Can't allocate buffer for tensor");
cuda_buffer = (uint8_t*) target.data_ptr();
cudaSetDevice(device.value().index());
stream = at::cuda::getCurrentCUDAStream().stream();
stream = at::cuda::getCurrentCUDAStream(device.value().index()).stream();
}
// Synchronization

View File

@@ -6,6 +6,7 @@ import sys
import platform
import threading
from exllamav2.util import get_basic_progress
from exllamav2.compat import safe_move_tensor
extension_name = "exllamav2_ext"
verbose = False # Print wall of text when compiling
@@ -315,8 +316,8 @@ def make_group_map_py(q_groups: torch.Tensor, num_qrows: int) -> torch.Tensor:
return torch.tensor(group_map, dtype = torch.short, device = q_groups.device)
def make_group_map(q_groups: torch.Tensor, num_qrows: int) -> torch.Tensor:
group_map = ext_c.make_group_map(q_groups.cpu(), num_qrows).to(q_groups.device)
return group_map
group_map = ext_c.make_group_map(q_groups.cpu(), num_qrows)
return safe_move_tensor(group_map, q_groups.device)
# Create Q matrix

View File

@@ -560,16 +560,16 @@ class ExLlamaV2Linear(ExLlamaV2Module):
w = {
"q_scale": safe_move_tensor(self.q_tensors["q_scale"][:, a // 8:b // 8], idx).contiguous(),
"q_scale_max": safe_move_tensor(self.q_tensors["q_scale_max"], idx).contiguous(),
"q_group_map": safe_move_tensor(self.q_tensors["q_group_map"], idx).contiguous(),
"q_groups": safe_move_tensor(self.q_tensors["q_groups"], idx).contiguous(),
"q_scale_max": safe_move_tensor(self.q_tensors["q_scale_max"], idx),
"q_group_map": safe_move_tensor(self.q_tensors["q_group_map"], idx),
"q_groups": safe_move_tensor(self.q_tensors["q_groups"], idx),
"q_weight": safe_move_tensor(self.q_tensors["q_weight"][:, a:b], idx).contiguous()
}
if "q_perm" in self.q_tensors:
w.update({
"q_perm": safe_move_tensor(self.q_tensors["q_perm"], idx).contiguous(),
"q_invperm": safe_move_tensor(self.q_tensors["q_invperm"], idx).contiguous(),
"q_perm": safe_move_tensor(self.q_tensors["q_perm"], idx),
"q_invperm": safe_move_tensor(self.q_tensors["q_invperm"], idx),
})
if "bias" in self.q_tensors: