Row split + all_reduce for MLP (not faster, disabled)

This commit is contained in:
turboderp
2024-08-20 00:44:19 +02:00
parent 373bcc187e
commit f17feb8345
4 changed files with 344 additions and 26 deletions

View File

@@ -134,4 +134,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
m.def("tp_broadcast", &tp_broadcast, "tp_broadcast");
m.def("tp_gather", &tp_gather, "tp_gather");
m.def("tp_cross_device_barrier", &tp_cross_device_barrier, "tp_cross_device_barrier");
m.def("tp_all_reduce", &tp_all_reduce, "tp_all_reduce");
}

View File

@@ -67,6 +67,13 @@ ExtTPContext::ExtTPContext
cudaHostAlloc((void**)&tp_data, sizeof(ExtTPData), cudaHostAllocMapped);
init_tp_data(tp_data);
// comms.resize(all_devices.size());
// ncclCommInitAll(&comms[0], all_devices.size(), &all_devices[0]);
// comms_index.resize(streams.size());
// for (int i = 0; i < all_devices.size(); ++i)
// comms_index[all_devices[i]] = i;
}
ExtTPContext::~ExtTPContext()
@@ -75,6 +82,9 @@ ExtTPContext::~ExtTPContext()
delete thread_pool;
#endif
// for (int i = 0; i < comms.size(); ++i)
// ncclCommDestroy(comms[i]);
cudaFreeHost(tp_data);
}
@@ -353,4 +363,133 @@ void tp_cross_device_barrier
cuda_check(cudaStreamWaitEvent(ctx->streams[dev_i], ctx->sync_events[dev_j], 0));
}
}
}
}
//void tp_all_reduce_nccl
//(
// uintptr_t tp_context,
// const std::vector<torch::Tensor> &tensors
//)
//{
// ExtTPContext* ctx = reinterpret_cast<ExtTPContext*> (tp_context);
//
// ncclGroupStart();
//
// for (int i = 0; i < tensors.size(); ++i)
// {
// int dev = tensors[i].device().index();
// int comms_i = ctx->comms_index[dev];
//
// ncclAllReduce
// (
// tensors[i].data_ptr(),
// tensors[i].data_ptr(),
// tensors[i].numel(),
// ncclFloat16,
// ncclSum,
// ctx->comms[comms_i],
// ctx->streams[dev]
// );
// }
//
// ncclGroupEnd();
//}
//void tp_all_reduce
//(
// uintptr_t tp_context,
// const std::vector<torch::Tensor> &tensors
//)
void tp_all_reduce
(
uintptr_t tp_context,
int buffer,
const std::vector<torch::Tensor> &tensors,
const std::vector<torch::Tensor> &residuals
)
{
ExtTPContext* ctx = reinterpret_cast<ExtTPContext*> (tp_context);
size_t size = tensors[0].numel() * tensors[0].element_size();
size_t num = tensors.size();
// Reduction via host buffer
for (int i = 0; i < num; ++i)
{
int dev = tensors[i].device().index();
auto torch_stream = at::cuda::getStreamFromExternal(ctx->streams[dev], dev);
cudaSetDevice(dev);
at::cuda::setCurrentCUDAStream(torch_stream);
if (i > 0)
{
int prev_dev = tensors[i - 1].device().index();
// Copy host buffer to current residual
cuda_check(cudaStreamWaitEvent
(
ctx->streams[dev],
ctx->sync_events[prev_dev],
0
));
cuda_check(cudaMemcpyAsync
(
residuals[i].data_ptr(),
ctx->pinned_temp[buffer],
size,
cudaMemcpyHostToDevice,
ctx->streams[dev]
));
}
// Add current tensor to current residual
residuals[i].add_(tensors[i]);
// Copy current residual to host buffer
cuda_check(cudaMemcpyAsync
(
ctx->pinned_temp[buffer],
residuals[i].data_ptr(),
size,
cudaMemcpyDeviceToHost,
ctx->streams[dev]
));
cuda_check(cudaEventRecord
(
ctx->sync_events[dev],
ctx->streams[dev]
));
}
// Broadcast result
int last_dev = tensors[num - 1].device().index();
for (int i = 0; i < num - 1; ++i)
{
int dev = tensors[i].device().index();
cudaSetDevice(dev);
cuda_check(cudaStreamWaitEvent
(
ctx->streams[dev],
ctx->sync_events[last_dev],
0
));
cuda_check(cudaMemcpyAsync
(
residuals[i].data_ptr(),
ctx->pinned_temp[buffer],
size,
cudaMemcpyHostToDevice,
ctx->streams[dev]
));
}
}

View File

@@ -9,6 +9,7 @@
//#define TP_MULTITHREADED
//#include <nccl.h>
#include "cpp/threadpool.h"
#include "cuda/tp.cuh"
@@ -32,6 +33,8 @@ public:
void* mapped_globals;
std::vector<cudaEvent_t> sync_events;
// std::vector<ncclComm_t> comms;
// std::vector<int> comms_index;
ExtTPContext
(
@@ -104,4 +107,18 @@ void tp_cross_device_barrier
int next_stage = -1
);
//void tp_all_reduce
//(
// uintptr_t tp_context,
// const std::vector<torch::Tensor> &tensors
//);
void tp_all_reduce
(
uintptr_t tp_context,
int buffer,
const std::vector<torch::Tensor> &tensors,
const std::vector<torch::Tensor> &residuals
);
#endif

View File

@@ -21,6 +21,7 @@ class ExLlamaV2Linear(ExLlamaV2Module):
name: str = "Linear"
in_features: int
in_features_tp: list[int] | None
out_features: int
out_features_tp: list[int] | None
has_bias: bool
@@ -46,6 +47,7 @@ class ExLlamaV2Linear(ExLlamaV2Module):
is_tp: bool
broadcast_type: int | None
broadcast_type_out: int | None
is_sub_module: bool
def __init__(self,
@@ -69,6 +71,7 @@ class ExLlamaV2Linear(ExLlamaV2Module):
self.is_tp = False
self.broadcast_type = None
self.broadcast_type_out = None
if pad32:
self.padding = -out_features % 32
@@ -260,9 +263,10 @@ class ExLlamaV2Linear(ExLlamaV2Module):
(self.temp_fwd_size() if self.is_sub_module else 0)
def temp_dq_size(self, out_features = None) -> int:
def temp_dq_size(self, in_features = None, out_features = None) -> int:
dq = self.in_features * (self.out_features if out_features is None else out_features)
dq = (self.in_features if in_features is None else in_features) * \
(self.out_features if out_features is None else out_features)
dq = 2 * min(dq, self.model.config.max_dq_size)
return dq
@@ -301,17 +305,32 @@ class ExLlamaV2Linear(ExLlamaV2Module):
) -> torch.Tensor | dict[str: torch.Tensor]:
if self.is_tp:
return self.forward_tp(
hidden_states,
cache,
attn_params,
past_len,
intermediates,
loras,
force_recons,
force_cuda,
**kwargs
)
if self.out_features_tp:
return self.forward_tp(
hidden_states,
cache,
attn_params,
past_len,
intermediates,
loras,
force_recons,
force_cuda,
**kwargs
)
elif self.in_features_tp:
return self.forward_tp_row(
hidden_states,
cache,
attn_params,
past_len,
intermediates,
loras,
force_recons,
force_cuda,
**kwargs
)
else:
assert False, "Unitialized TP linear layer"
# Linear forward
@@ -375,12 +394,13 @@ class ExLlamaV2Linear(ExLlamaV2Module):
**kwargs
) -> torch.Tensor | dict[str: torch.Tensor]:
split = self.model.tp_context.get_split(self.broadcast_type)
ctx = self.model.tp_context
split = ctx.get_split(self.broadcast_type)
if isinstance(hidden_states, torch.Tensor):
output_shape = hidden_states.shape[:-1] + (self.out_features,)
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
hidden_states = self.model.tp_context.broadcast(hidden_states, self.broadcast_type)
hidden_states = ctx.broadcast(hidden_states, self.broadcast_type)
else:
output_shape = hidden_states[0].shape[:-1] + (self.out_features,)
hidden_states = [hs.view(-1, hs.shape[-1]) for hs in hidden_states]
@@ -398,18 +418,55 @@ class ExLlamaV2Linear(ExLlamaV2Module):
self.q_handle,
outputs,
force_cuda,
self.model.tp_context.ext_tp_context,
ctx.ext_tp_context,
-1
)
if output_split:
return outputs
output = self.model.tp_context.gather(0, outputs, self.broadcast_type)
output = ctx.gather(0, outputs, self.broadcast_type)
hidden_states_out = output.view(output_shape)
return hidden_states_out
def forward_tp_row(
self,
hidden_states: list[torch.Tensor],
cache = None,
attn_params = None,
past_len = None,
intermediates: bool = False,
loras: list[ExLlamaV2Lora] | None = None,
force_recons: bool = False,
force_cuda: bool = False,
output_split: bool = False,
dim: int = 1,
**kwargs
) -> torch.Tensor | dict[str: torch.Tensor]:
ctx = self.model.tp_context
split = ctx.get_split(self.broadcast_type)
assert isinstance(hidden_states, list)
assert output_split
rows = hidden_states[0].shape[0]
dtype = hidden_states[0].dtype
outputs = ctx.get_temp_tensors_bc(rows, dtype, self.broadcast_type_out)
ext_c.gemm_half_q_half_tp(
hidden_states,
self.q_handle,
outputs,
force_cuda,
ctx.ext_tp_context,
-1
)
return outputs
def get_weight_tensor_dq(self) -> torch.Tensor:
if self.linear is not None:
@@ -470,8 +527,6 @@ class ExLlamaV2Linear(ExLlamaV2Module):
assert all(x in self.q_tensors for x in [
"q_scale",
"q_scale_max",
"q_perm",
"q_invperm",
"q_group_map",
"q_groups",
"q_weight"
@@ -481,7 +536,6 @@ class ExLlamaV2Linear(ExLlamaV2Module):
ctx = self.model.tp_context
self.broadcast_type = broadcast_type
split = ctx.get_split(broadcast_type)
maxdev = max(dev for dev, _, _ in split)
if dim:
split = [(d, a * dim, b * dim) for (d, a, b) in split]
@@ -501,13 +555,17 @@ class ExLlamaV2Linear(ExLlamaV2Module):
w = {
"q_scale": safe_move_tensor(self.q_tensors["q_scale"][:, a // 8:b // 8], idx).contiguous(),
"q_scale_max": safe_move_tensor(self.q_tensors["q_scale_max"], idx).contiguous(),
"q_perm": safe_move_tensor(self.q_tensors["q_perm"], idx).contiguous(),
"q_invperm": safe_move_tensor(self.q_tensors["q_invperm"], idx).contiguous(),
"q_group_map": safe_move_tensor(self.q_tensors["q_group_map"], idx).contiguous(),
"q_groups": safe_move_tensor(self.q_tensors["q_groups"], idx).contiguous(),
"q_weight": safe_move_tensor(self.q_tensors["q_weight"][:, a:b], idx).contiguous()
}
if "q_perm" in self.q_tensors:
w.update({
"q_perm": safe_move_tensor(self.q_tensors["q_perm"], idx).contiguous(),
"q_invperm": safe_move_tensor(self.q_tensors["q_invperm"], idx).contiguous(),
})
if "bias" in self.q_tensors:
w["bias"] = safe_move_tensor(self.q_tensors["bias"][a:b], idx).contiguous()
@@ -515,14 +573,14 @@ class ExLlamaV2Linear(ExLlamaV2Module):
device_context = self.model.get_device_context(idx)
device_context.begin_scratch_alloc()
new_temp_dq.append(device_context.get_scratch_slice(self.temp_dq_size(s)))
new_temp_dq.append(device_context.get_scratch_slice(self.temp_dq_size(out_features = s)))
max_dq_rows = cfg.max_dq_size // s
new_q_handle.append(
ext_c.make_q_matrix_split(
w["q_weight"],
w["q_perm"],
w["q_invperm"],
w.get("q_perm", none_tensor),
w.get("q_invperm", none_tensor),
w["q_scale"],
w["q_scale_max"],
w["q_groups"],
@@ -542,3 +600,106 @@ class ExLlamaV2Linear(ExLlamaV2Module):
self.temp_dq = new_temp_dq
self.is_tp = True
def tp_split_row(self, broadcast_type_in: int, broadcast_type_out: int, dim = None):
assert self.q_handle is not None, \
"Can only split quantized tensor."
assert all(x in self.q_tensors for x in [
"q_scale",
"q_scale_max",
"q_group_map",
"q_groups",
"q_weight"
]), "Can only split fully loaded EXL2 tensor."
assert not any(x in self.q_tensors for x in [
"q_perm",
"q_invperm",
]), "Tensor not prepared for row split"
cfg = self.model.config
ctx = self.model.tp_context
self.broadcast_type = broadcast_type_in
self.broadcast_type_out = broadcast_type_out
split = ctx.get_split(broadcast_type_in)
if dim:
split = [(d, a * dim, b * dim) for (d, a, b) in split]
self.in_features_tp = [0] * ctx.num_devices
for dev, a, b in split:
self.in_features_tp[dev] = b - a
new_q_handle = []
new_q_tensors = []
new_temp_dq = []
q_scale = self.q_tensors["q_scale"]
q_scale_max = self.q_tensors["q_scale_max"]
q_group_map = self.q_tensors["q_group_map"]
q_groups = self.q_tensors["q_groups"]
q_weight = self.q_tensors["q_weight"]
assert q_group_map[b * 2 - 1].item() == 1
for idx, a, b in split:
s = b - a
if s == 0: continue
groups = q_groups.shape[0] // 2
group_a = q_group_map[a * 2].item()
group_b = q_group_map[(b - 1) * 2].item() + 1
row_a = q_groups[group_a * 2 + 1].item()
row_b = q_weight.shape[0] if group_b == groups else q_groups[group_b * 2 + 1].item()
tq_scale = q_scale[group_a:group_b, :]
tq_scale_max = q_scale_max[group_a:group_b]
tq_group_map = q_group_map[a * 2:b * 2]
tq_groups = q_groups[group_a * 2:group_b * 2]
tq_weight = q_weight[row_a:row_b, :]
tq_group_map[::2] -= group_a
tq_groups[1::2] -= row_a
w = {
"q_scale": safe_move_tensor(tq_scale, idx).clone().contiguous(),
"q_scale_max": safe_move_tensor(tq_scale_max, idx).clone().contiguous(),
"q_group_map": safe_move_tensor(tq_group_map, idx).clone().contiguous(),
"q_groups": safe_move_tensor(tq_groups, idx).clone().contiguous(),
"q_weight": safe_move_tensor(tq_weight, idx).clone().contiguous()
}
if "bias" in self.q_tensors:
w["bias"] = safe_move_tensor(self.q_tensors["bias"], idx).contiguous()
new_q_tensors.append(w)
device_context = self.model.get_device_context(idx)
device_context.begin_scratch_alloc()
new_temp_dq.append(device_context.get_scratch_slice(self.temp_dq_size(in_features = s)))
max_dq_rows = cfg.max_dq_size
new_q_handle.append(
ext_c.make_q_matrix_split(
w["q_weight"],
none_tensor,
none_tensor,
w["q_scale"],
w["q_scale_max"],
w["q_groups"],
w["q_group_map"],
none_tensor,
none_tensor,
none_tensor,
w.get("bias", none_tensor),
new_temp_dq[-1],
max_dq_rows
)
)
ext_c.free_q_matrix(self.q_handle)
self.q_handle = new_q_handle
self.q_tensors = new_q_tensors
self.temp_dq = new_temp_dq
self.is_tp = True