mirror of
https://github.com/turboderp-org/exllamav2.git
synced 2026-04-20 14:29:28 +00:00
Row split + all_reduce for MLP (not faster, disabled)
This commit is contained in:
@@ -134,4 +134,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
||||
m.def("tp_broadcast", &tp_broadcast, "tp_broadcast");
|
||||
m.def("tp_gather", &tp_gather, "tp_gather");
|
||||
m.def("tp_cross_device_barrier", &tp_cross_device_barrier, "tp_cross_device_barrier");
|
||||
m.def("tp_all_reduce", &tp_all_reduce, "tp_all_reduce");
|
||||
}
|
||||
@@ -67,6 +67,13 @@ ExtTPContext::ExtTPContext
|
||||
|
||||
cudaHostAlloc((void**)&tp_data, sizeof(ExtTPData), cudaHostAllocMapped);
|
||||
init_tp_data(tp_data);
|
||||
|
||||
// comms.resize(all_devices.size());
|
||||
// ncclCommInitAll(&comms[0], all_devices.size(), &all_devices[0]);
|
||||
// comms_index.resize(streams.size());
|
||||
// for (int i = 0; i < all_devices.size(); ++i)
|
||||
// comms_index[all_devices[i]] = i;
|
||||
|
||||
}
|
||||
|
||||
ExtTPContext::~ExtTPContext()
|
||||
@@ -75,6 +82,9 @@ ExtTPContext::~ExtTPContext()
|
||||
delete thread_pool;
|
||||
#endif
|
||||
|
||||
// for (int i = 0; i < comms.size(); ++i)
|
||||
// ncclCommDestroy(comms[i]);
|
||||
|
||||
cudaFreeHost(tp_data);
|
||||
}
|
||||
|
||||
@@ -353,4 +363,133 @@ void tp_cross_device_barrier
|
||||
cuda_check(cudaStreamWaitEvent(ctx->streams[dev_i], ctx->sync_events[dev_j], 0));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//void tp_all_reduce_nccl
|
||||
//(
|
||||
// uintptr_t tp_context,
|
||||
// const std::vector<torch::Tensor> &tensors
|
||||
//)
|
||||
//{
|
||||
// ExtTPContext* ctx = reinterpret_cast<ExtTPContext*> (tp_context);
|
||||
//
|
||||
// ncclGroupStart();
|
||||
//
|
||||
// for (int i = 0; i < tensors.size(); ++i)
|
||||
// {
|
||||
// int dev = tensors[i].device().index();
|
||||
// int comms_i = ctx->comms_index[dev];
|
||||
//
|
||||
// ncclAllReduce
|
||||
// (
|
||||
// tensors[i].data_ptr(),
|
||||
// tensors[i].data_ptr(),
|
||||
// tensors[i].numel(),
|
||||
// ncclFloat16,
|
||||
// ncclSum,
|
||||
// ctx->comms[comms_i],
|
||||
// ctx->streams[dev]
|
||||
// );
|
||||
// }
|
||||
//
|
||||
// ncclGroupEnd();
|
||||
//}
|
||||
|
||||
//void tp_all_reduce
|
||||
//(
|
||||
// uintptr_t tp_context,
|
||||
// const std::vector<torch::Tensor> &tensors
|
||||
//)
|
||||
|
||||
void tp_all_reduce
|
||||
(
|
||||
uintptr_t tp_context,
|
||||
int buffer,
|
||||
const std::vector<torch::Tensor> &tensors,
|
||||
const std::vector<torch::Tensor> &residuals
|
||||
)
|
||||
{
|
||||
ExtTPContext* ctx = reinterpret_cast<ExtTPContext*> (tp_context);
|
||||
|
||||
size_t size = tensors[0].numel() * tensors[0].element_size();
|
||||
size_t num = tensors.size();
|
||||
|
||||
// Reduction via host buffer
|
||||
|
||||
for (int i = 0; i < num; ++i)
|
||||
{
|
||||
int dev = tensors[i].device().index();
|
||||
auto torch_stream = at::cuda::getStreamFromExternal(ctx->streams[dev], dev);
|
||||
cudaSetDevice(dev);
|
||||
at::cuda::setCurrentCUDAStream(torch_stream);
|
||||
|
||||
if (i > 0)
|
||||
{
|
||||
int prev_dev = tensors[i - 1].device().index();
|
||||
|
||||
// Copy host buffer to current residual
|
||||
|
||||
cuda_check(cudaStreamWaitEvent
|
||||
(
|
||||
ctx->streams[dev],
|
||||
ctx->sync_events[prev_dev],
|
||||
0
|
||||
));
|
||||
|
||||
cuda_check(cudaMemcpyAsync
|
||||
(
|
||||
residuals[i].data_ptr(),
|
||||
ctx->pinned_temp[buffer],
|
||||
size,
|
||||
cudaMemcpyHostToDevice,
|
||||
ctx->streams[dev]
|
||||
));
|
||||
}
|
||||
|
||||
// Add current tensor to current residual
|
||||
|
||||
residuals[i].add_(tensors[i]);
|
||||
|
||||
// Copy current residual to host buffer
|
||||
|
||||
cuda_check(cudaMemcpyAsync
|
||||
(
|
||||
ctx->pinned_temp[buffer],
|
||||
residuals[i].data_ptr(),
|
||||
size,
|
||||
cudaMemcpyDeviceToHost,
|
||||
ctx->streams[dev]
|
||||
));
|
||||
|
||||
cuda_check(cudaEventRecord
|
||||
(
|
||||
ctx->sync_events[dev],
|
||||
ctx->streams[dev]
|
||||
));
|
||||
}
|
||||
|
||||
// Broadcast result
|
||||
|
||||
int last_dev = tensors[num - 1].device().index();
|
||||
|
||||
for (int i = 0; i < num - 1; ++i)
|
||||
{
|
||||
int dev = tensors[i].device().index();
|
||||
cudaSetDevice(dev);
|
||||
|
||||
cuda_check(cudaStreamWaitEvent
|
||||
(
|
||||
ctx->streams[dev],
|
||||
ctx->sync_events[last_dev],
|
||||
0
|
||||
));
|
||||
cuda_check(cudaMemcpyAsync
|
||||
(
|
||||
residuals[i].data_ptr(),
|
||||
ctx->pinned_temp[buffer],
|
||||
size,
|
||||
cudaMemcpyHostToDevice,
|
||||
ctx->streams[dev]
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@
|
||||
|
||||
//#define TP_MULTITHREADED
|
||||
|
||||
//#include <nccl.h>
|
||||
#include "cpp/threadpool.h"
|
||||
#include "cuda/tp.cuh"
|
||||
|
||||
@@ -32,6 +33,8 @@ public:
|
||||
void* mapped_globals;
|
||||
|
||||
std::vector<cudaEvent_t> sync_events;
|
||||
// std::vector<ncclComm_t> comms;
|
||||
// std::vector<int> comms_index;
|
||||
|
||||
ExtTPContext
|
||||
(
|
||||
@@ -104,4 +107,18 @@ void tp_cross_device_barrier
|
||||
int next_stage = -1
|
||||
);
|
||||
|
||||
//void tp_all_reduce
|
||||
//(
|
||||
// uintptr_t tp_context,
|
||||
// const std::vector<torch::Tensor> &tensors
|
||||
//);
|
||||
|
||||
void tp_all_reduce
|
||||
(
|
||||
uintptr_t tp_context,
|
||||
int buffer,
|
||||
const std::vector<torch::Tensor> &tensors,
|
||||
const std::vector<torch::Tensor> &residuals
|
||||
);
|
||||
|
||||
#endif
|
||||
@@ -21,6 +21,7 @@ class ExLlamaV2Linear(ExLlamaV2Module):
|
||||
name: str = "Linear"
|
||||
|
||||
in_features: int
|
||||
in_features_tp: list[int] | None
|
||||
out_features: int
|
||||
out_features_tp: list[int] | None
|
||||
has_bias: bool
|
||||
@@ -46,6 +47,7 @@ class ExLlamaV2Linear(ExLlamaV2Module):
|
||||
|
||||
is_tp: bool
|
||||
broadcast_type: int | None
|
||||
broadcast_type_out: int | None
|
||||
is_sub_module: bool
|
||||
|
||||
def __init__(self,
|
||||
@@ -69,6 +71,7 @@ class ExLlamaV2Linear(ExLlamaV2Module):
|
||||
|
||||
self.is_tp = False
|
||||
self.broadcast_type = None
|
||||
self.broadcast_type_out = None
|
||||
|
||||
if pad32:
|
||||
self.padding = -out_features % 32
|
||||
@@ -260,9 +263,10 @@ class ExLlamaV2Linear(ExLlamaV2Module):
|
||||
(self.temp_fwd_size() if self.is_sub_module else 0)
|
||||
|
||||
|
||||
def temp_dq_size(self, out_features = None) -> int:
|
||||
def temp_dq_size(self, in_features = None, out_features = None) -> int:
|
||||
|
||||
dq = self.in_features * (self.out_features if out_features is None else out_features)
|
||||
dq = (self.in_features if in_features is None else in_features) * \
|
||||
(self.out_features if out_features is None else out_features)
|
||||
dq = 2 * min(dq, self.model.config.max_dq_size)
|
||||
return dq
|
||||
|
||||
@@ -301,17 +305,32 @@ class ExLlamaV2Linear(ExLlamaV2Module):
|
||||
) -> torch.Tensor | dict[str: torch.Tensor]:
|
||||
|
||||
if self.is_tp:
|
||||
return self.forward_tp(
|
||||
hidden_states,
|
||||
cache,
|
||||
attn_params,
|
||||
past_len,
|
||||
intermediates,
|
||||
loras,
|
||||
force_recons,
|
||||
force_cuda,
|
||||
**kwargs
|
||||
)
|
||||
if self.out_features_tp:
|
||||
return self.forward_tp(
|
||||
hidden_states,
|
||||
cache,
|
||||
attn_params,
|
||||
past_len,
|
||||
intermediates,
|
||||
loras,
|
||||
force_recons,
|
||||
force_cuda,
|
||||
**kwargs
|
||||
)
|
||||
elif self.in_features_tp:
|
||||
return self.forward_tp_row(
|
||||
hidden_states,
|
||||
cache,
|
||||
attn_params,
|
||||
past_len,
|
||||
intermediates,
|
||||
loras,
|
||||
force_recons,
|
||||
force_cuda,
|
||||
**kwargs
|
||||
)
|
||||
else:
|
||||
assert False, "Unitialized TP linear layer"
|
||||
|
||||
# Linear forward
|
||||
|
||||
@@ -375,12 +394,13 @@ class ExLlamaV2Linear(ExLlamaV2Module):
|
||||
**kwargs
|
||||
) -> torch.Tensor | dict[str: torch.Tensor]:
|
||||
|
||||
split = self.model.tp_context.get_split(self.broadcast_type)
|
||||
ctx = self.model.tp_context
|
||||
split = ctx.get_split(self.broadcast_type)
|
||||
|
||||
if isinstance(hidden_states, torch.Tensor):
|
||||
output_shape = hidden_states.shape[:-1] + (self.out_features,)
|
||||
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
||||
hidden_states = self.model.tp_context.broadcast(hidden_states, self.broadcast_type)
|
||||
hidden_states = ctx.broadcast(hidden_states, self.broadcast_type)
|
||||
else:
|
||||
output_shape = hidden_states[0].shape[:-1] + (self.out_features,)
|
||||
hidden_states = [hs.view(-1, hs.shape[-1]) for hs in hidden_states]
|
||||
@@ -398,18 +418,55 @@ class ExLlamaV2Linear(ExLlamaV2Module):
|
||||
self.q_handle,
|
||||
outputs,
|
||||
force_cuda,
|
||||
self.model.tp_context.ext_tp_context,
|
||||
ctx.ext_tp_context,
|
||||
-1
|
||||
)
|
||||
|
||||
if output_split:
|
||||
return outputs
|
||||
|
||||
output = self.model.tp_context.gather(0, outputs, self.broadcast_type)
|
||||
output = ctx.gather(0, outputs, self.broadcast_type)
|
||||
hidden_states_out = output.view(output_shape)
|
||||
return hidden_states_out
|
||||
|
||||
|
||||
def forward_tp_row(
|
||||
self,
|
||||
hidden_states: list[torch.Tensor],
|
||||
cache = None,
|
||||
attn_params = None,
|
||||
past_len = None,
|
||||
intermediates: bool = False,
|
||||
loras: list[ExLlamaV2Lora] | None = None,
|
||||
force_recons: bool = False,
|
||||
force_cuda: bool = False,
|
||||
output_split: bool = False,
|
||||
dim: int = 1,
|
||||
**kwargs
|
||||
) -> torch.Tensor | dict[str: torch.Tensor]:
|
||||
|
||||
ctx = self.model.tp_context
|
||||
split = ctx.get_split(self.broadcast_type)
|
||||
|
||||
assert isinstance(hidden_states, list)
|
||||
assert output_split
|
||||
|
||||
rows = hidden_states[0].shape[0]
|
||||
dtype = hidden_states[0].dtype
|
||||
outputs = ctx.get_temp_tensors_bc(rows, dtype, self.broadcast_type_out)
|
||||
|
||||
ext_c.gemm_half_q_half_tp(
|
||||
hidden_states,
|
||||
self.q_handle,
|
||||
outputs,
|
||||
force_cuda,
|
||||
ctx.ext_tp_context,
|
||||
-1
|
||||
)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
def get_weight_tensor_dq(self) -> torch.Tensor:
|
||||
|
||||
if self.linear is not None:
|
||||
@@ -470,8 +527,6 @@ class ExLlamaV2Linear(ExLlamaV2Module):
|
||||
assert all(x in self.q_tensors for x in [
|
||||
"q_scale",
|
||||
"q_scale_max",
|
||||
"q_perm",
|
||||
"q_invperm",
|
||||
"q_group_map",
|
||||
"q_groups",
|
||||
"q_weight"
|
||||
@@ -481,7 +536,6 @@ class ExLlamaV2Linear(ExLlamaV2Module):
|
||||
ctx = self.model.tp_context
|
||||
self.broadcast_type = broadcast_type
|
||||
split = ctx.get_split(broadcast_type)
|
||||
maxdev = max(dev for dev, _, _ in split)
|
||||
|
||||
if dim:
|
||||
split = [(d, a * dim, b * dim) for (d, a, b) in split]
|
||||
@@ -501,13 +555,17 @@ class ExLlamaV2Linear(ExLlamaV2Module):
|
||||
w = {
|
||||
"q_scale": safe_move_tensor(self.q_tensors["q_scale"][:, a // 8:b // 8], idx).contiguous(),
|
||||
"q_scale_max": safe_move_tensor(self.q_tensors["q_scale_max"], idx).contiguous(),
|
||||
"q_perm": safe_move_tensor(self.q_tensors["q_perm"], idx).contiguous(),
|
||||
"q_invperm": safe_move_tensor(self.q_tensors["q_invperm"], idx).contiguous(),
|
||||
"q_group_map": safe_move_tensor(self.q_tensors["q_group_map"], idx).contiguous(),
|
||||
"q_groups": safe_move_tensor(self.q_tensors["q_groups"], idx).contiguous(),
|
||||
"q_weight": safe_move_tensor(self.q_tensors["q_weight"][:, a:b], idx).contiguous()
|
||||
}
|
||||
|
||||
if "q_perm" in self.q_tensors:
|
||||
w.update({
|
||||
"q_perm": safe_move_tensor(self.q_tensors["q_perm"], idx).contiguous(),
|
||||
"q_invperm": safe_move_tensor(self.q_tensors["q_invperm"], idx).contiguous(),
|
||||
})
|
||||
|
||||
if "bias" in self.q_tensors:
|
||||
w["bias"] = safe_move_tensor(self.q_tensors["bias"][a:b], idx).contiguous()
|
||||
|
||||
@@ -515,14 +573,14 @@ class ExLlamaV2Linear(ExLlamaV2Module):
|
||||
|
||||
device_context = self.model.get_device_context(idx)
|
||||
device_context.begin_scratch_alloc()
|
||||
new_temp_dq.append(device_context.get_scratch_slice(self.temp_dq_size(s)))
|
||||
new_temp_dq.append(device_context.get_scratch_slice(self.temp_dq_size(out_features = s)))
|
||||
max_dq_rows = cfg.max_dq_size // s
|
||||
|
||||
new_q_handle.append(
|
||||
ext_c.make_q_matrix_split(
|
||||
w["q_weight"],
|
||||
w["q_perm"],
|
||||
w["q_invperm"],
|
||||
w.get("q_perm", none_tensor),
|
||||
w.get("q_invperm", none_tensor),
|
||||
w["q_scale"],
|
||||
w["q_scale_max"],
|
||||
w["q_groups"],
|
||||
@@ -542,3 +600,106 @@ class ExLlamaV2Linear(ExLlamaV2Module):
|
||||
self.temp_dq = new_temp_dq
|
||||
self.is_tp = True
|
||||
|
||||
|
||||
def tp_split_row(self, broadcast_type_in: int, broadcast_type_out: int, dim = None):
|
||||
assert self.q_handle is not None, \
|
||||
"Can only split quantized tensor."
|
||||
assert all(x in self.q_tensors for x in [
|
||||
"q_scale",
|
||||
"q_scale_max",
|
||||
"q_group_map",
|
||||
"q_groups",
|
||||
"q_weight"
|
||||
]), "Can only split fully loaded EXL2 tensor."
|
||||
assert not any(x in self.q_tensors for x in [
|
||||
"q_perm",
|
||||
"q_invperm",
|
||||
]), "Tensor not prepared for row split"
|
||||
|
||||
cfg = self.model.config
|
||||
ctx = self.model.tp_context
|
||||
self.broadcast_type = broadcast_type_in
|
||||
self.broadcast_type_out = broadcast_type_out
|
||||
split = ctx.get_split(broadcast_type_in)
|
||||
|
||||
if dim:
|
||||
split = [(d, a * dim, b * dim) for (d, a, b) in split]
|
||||
|
||||
self.in_features_tp = [0] * ctx.num_devices
|
||||
for dev, a, b in split:
|
||||
self.in_features_tp[dev] = b - a
|
||||
|
||||
new_q_handle = []
|
||||
new_q_tensors = []
|
||||
new_temp_dq = []
|
||||
|
||||
q_scale = self.q_tensors["q_scale"]
|
||||
q_scale_max = self.q_tensors["q_scale_max"]
|
||||
q_group_map = self.q_tensors["q_group_map"]
|
||||
q_groups = self.q_tensors["q_groups"]
|
||||
q_weight = self.q_tensors["q_weight"]
|
||||
|
||||
assert q_group_map[b * 2 - 1].item() == 1
|
||||
|
||||
for idx, a, b in split:
|
||||
s = b - a
|
||||
if s == 0: continue
|
||||
|
||||
groups = q_groups.shape[0] // 2
|
||||
group_a = q_group_map[a * 2].item()
|
||||
group_b = q_group_map[(b - 1) * 2].item() + 1
|
||||
row_a = q_groups[group_a * 2 + 1].item()
|
||||
row_b = q_weight.shape[0] if group_b == groups else q_groups[group_b * 2 + 1].item()
|
||||
|
||||
tq_scale = q_scale[group_a:group_b, :]
|
||||
tq_scale_max = q_scale_max[group_a:group_b]
|
||||
tq_group_map = q_group_map[a * 2:b * 2]
|
||||
tq_groups = q_groups[group_a * 2:group_b * 2]
|
||||
tq_weight = q_weight[row_a:row_b, :]
|
||||
|
||||
tq_group_map[::2] -= group_a
|
||||
tq_groups[1::2] -= row_a
|
||||
|
||||
w = {
|
||||
"q_scale": safe_move_tensor(tq_scale, idx).clone().contiguous(),
|
||||
"q_scale_max": safe_move_tensor(tq_scale_max, idx).clone().contiguous(),
|
||||
"q_group_map": safe_move_tensor(tq_group_map, idx).clone().contiguous(),
|
||||
"q_groups": safe_move_tensor(tq_groups, idx).clone().contiguous(),
|
||||
"q_weight": safe_move_tensor(tq_weight, idx).clone().contiguous()
|
||||
}
|
||||
|
||||
if "bias" in self.q_tensors:
|
||||
w["bias"] = safe_move_tensor(self.q_tensors["bias"], idx).contiguous()
|
||||
|
||||
new_q_tensors.append(w)
|
||||
|
||||
device_context = self.model.get_device_context(idx)
|
||||
device_context.begin_scratch_alloc()
|
||||
new_temp_dq.append(device_context.get_scratch_slice(self.temp_dq_size(in_features = s)))
|
||||
max_dq_rows = cfg.max_dq_size
|
||||
|
||||
new_q_handle.append(
|
||||
ext_c.make_q_matrix_split(
|
||||
w["q_weight"],
|
||||
none_tensor,
|
||||
none_tensor,
|
||||
w["q_scale"],
|
||||
w["q_scale_max"],
|
||||
w["q_groups"],
|
||||
w["q_group_map"],
|
||||
none_tensor,
|
||||
none_tensor,
|
||||
none_tensor,
|
||||
w.get("bias", none_tensor),
|
||||
new_temp_dq[-1],
|
||||
max_dq_rows
|
||||
)
|
||||
)
|
||||
|
||||
ext_c.free_q_matrix(self.q_handle)
|
||||
self.q_handle = new_q_handle
|
||||
self.q_tensors = new_q_tensors
|
||||
self.temp_dq = new_temp_dq
|
||||
self.is_tp = True
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user