mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-01-26 17:20:01 +00:00
Reduce add improvemens without NCCL (#1088)
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
@@ -45,6 +45,23 @@ static __global__ void k_reduce_add(copy_task task) {
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int block_size, int nptr>
|
||||
static __global__ void k_reduce_add_T(copy_task task) {
|
||||
int i = blockIdx.x*block_size + threadIdx.x;
|
||||
if (i >= task.nelem) return;
|
||||
auto dst = (T *)task.ptrs[0];
|
||||
#pragma unroll
|
||||
for (int j = 1; j < nptr; ++j) {
|
||||
auto src = (T *)task.ptrs[j];
|
||||
dst[i] += src[i];
|
||||
}
|
||||
#pragma unroll
|
||||
for (int j = 1; j < nptr; ++j) {
|
||||
auto src = (T *)task.ptrs[j];
|
||||
src[i] = dst[i];
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_cuda_op_reduce([[maybe_unused]] ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
|
||||
auto op = (ggml_op)dst->op_params[0];
|
||||
@@ -146,53 +163,6 @@ void ggml_cuda_op_reduce([[maybe_unused]] ggml_backend_cuda_context & ctx, ggml_
|
||||
#endif
|
||||
GGML_ASSERT(dst->data == dst->src[ctx.device]->data);
|
||||
auto nbytes = ggml_nbytes(dst);
|
||||
if (nhave == 2 && (nhave == nreduce || dst->ne[1] <= 8)) {
|
||||
int idx[2];
|
||||
int ii = 0;
|
||||
for (int i = 0; i < nreduce; ++i) {
|
||||
if (dst->src[i]) {
|
||||
idx[ii++] = i;
|
||||
}
|
||||
}
|
||||
// With P2P access enabled, we can access peer memory so as if it was local.
|
||||
// Hence, we can launch two reduce kernels, one on each device, each kernel
|
||||
// processing half of the data. This very simply approach almost matches NCCL
|
||||
// performance (I see ~1% lower PP and TG performance on my 2x3090 system).
|
||||
for (int i = 0; i < nhave; ++i) {
|
||||
GGML_ASSERT(dst->src[idx[i]]->type == dst->type);
|
||||
GGML_ASSERT(ggml_are_same_shape(dst, dst->src[idx[i]]));
|
||||
ggml_cuda_set_device(idx[i]);
|
||||
if (!info.all_ctx[idx[i]]->copy_event) {
|
||||
CUDA_CHECK(cudaEventCreateWithFlags(&info.all_ctx[idx[i]]->copy_event, cudaEventDisableTiming));
|
||||
}
|
||||
CUDA_CHECK(cudaEventRecord(info.all_ctx[idx[i]]->copy_event, info.all_ctx[idx[i]]->stream()));
|
||||
}
|
||||
auto nelem = ggml_nelements(dst);
|
||||
auto nelem_half = (nelem + 1)/2;
|
||||
for (int i = 0; i < nhave; ++i) {
|
||||
ggml_cuda_set_device(idx[i]);
|
||||
CUDA_CHECK(cudaStreamWaitEvent(info.all_ctx[idx[i]]->stream(), info.all_ctx[idx[(i+1)%2]]->copy_event, 0));
|
||||
auto this_nelem = std::min(nelem_half, nelem - nelem_half);
|
||||
int nblock = (this_nelem + CUDA_REDUCE_BLOCK_SIZE - 1)/CUDA_REDUCE_BLOCK_SIZE;
|
||||
if (dst->type == GGML_TYPE_F16) {
|
||||
auto src_ptr = (half *)dst->src[idx[i]]->data + i*nelem_half;
|
||||
auto dst_ptr = (half *)dst->src[idx[(i+1)%2]]->data + i*nelem_half;
|
||||
k_add_sym<half, CUDA_REDUCE_BLOCK_SIZE><<<nblock, CUDA_REDUCE_BLOCK_SIZE, 0, info.all_ctx[idx[i]]->stream()>>>(this_nelem, src_ptr, dst_ptr);
|
||||
} else {
|
||||
auto src_ptr = (float *)dst->src[idx[i]]->data + i*nelem_half;
|
||||
auto dst_ptr = (float *)dst->src[idx[(i+1)%2]]->data + i*nelem_half;
|
||||
k_add_sym<float, CUDA_REDUCE_BLOCK_SIZE><<<nblock, CUDA_REDUCE_BLOCK_SIZE, 0, info.all_ctx[idx[i]]->stream()>>>(this_nelem, src_ptr, dst_ptr);
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < nhave; ++i) {
|
||||
ggml_cuda_set_device(idx[i]);
|
||||
CUDA_CHECK(cudaEventRecord(info.all_ctx[idx[i]]->copy_event, info.all_ctx[idx[i]]->stream()));
|
||||
ggml_cuda_set_device(idx[(i+1)%2]);
|
||||
CUDA_CHECK(cudaStreamWaitEvent(info.all_ctx[idx[(i+1)%2]]->stream(), info.all_ctx[idx[i]]->copy_event));
|
||||
}
|
||||
ggml_cuda_set_device(ctx.device);
|
||||
return;
|
||||
}
|
||||
int idx[GGML_CUDA_MAX_DEVICES];
|
||||
{
|
||||
int ii = 0;
|
||||
@@ -230,9 +200,9 @@ void ggml_cuda_op_reduce([[maybe_unused]] ggml_backend_cuda_context & ctx, ggml_
|
||||
task.ptrs[1] = (char *)dst->src[j]->data;
|
||||
CUDA_CHECK(cudaStreamWaitEvent(info.all_ctx[i]->stream(), info.all_ctx[j]->copy_event));
|
||||
if (dst->type == GGML_TYPE_F16) {
|
||||
k_reduce_add<half, CUDA_REDUCE_BLOCK_SIZE><<<nblocks, CUDA_REDUCE_BLOCK_SIZE, 0, info.all_ctx[i]->stream()>>>(task);
|
||||
k_reduce_add_T<half, CUDA_REDUCE_BLOCK_SIZE, 2><<<nblocks, CUDA_REDUCE_BLOCK_SIZE, 0, info.all_ctx[i]->stream()>>>(task);
|
||||
} else {
|
||||
k_reduce_add<float, CUDA_REDUCE_BLOCK_SIZE><<<nblocks, CUDA_REDUCE_BLOCK_SIZE, 0, info.all_ctx[i]->stream()>>>(task);
|
||||
k_reduce_add_T<float, CUDA_REDUCE_BLOCK_SIZE, 2><<<nblocks, CUDA_REDUCE_BLOCK_SIZE, 0, info.all_ctx[i]->stream()>>>(task);
|
||||
}
|
||||
}
|
||||
for (int ii = 0; ii < nhave/2; ++ii) {
|
||||
@@ -252,9 +222,9 @@ void ggml_cuda_op_reduce([[maybe_unused]] ggml_backend_cuda_context & ctx, ggml_
|
||||
task.ptrs[1] = (char *)dst->src[j]->data;
|
||||
CUDA_CHECK(cudaStreamWaitEvent(info.all_ctx[i]->stream(), info.all_ctx[j]->copy_event));
|
||||
if (dst->type == GGML_TYPE_F16) {
|
||||
k_reduce_add<half, CUDA_REDUCE_BLOCK_SIZE><<<nblocks, CUDA_REDUCE_BLOCK_SIZE, 0, info.all_ctx[i]->stream()>>>(task);
|
||||
k_reduce_add_T<half, CUDA_REDUCE_BLOCK_SIZE, 2><<<nblocks, CUDA_REDUCE_BLOCK_SIZE, 0, info.all_ctx[i]->stream()>>>(task);
|
||||
} else {
|
||||
k_reduce_add<float, CUDA_REDUCE_BLOCK_SIZE><<<nblocks, CUDA_REDUCE_BLOCK_SIZE, 0, info.all_ctx[i]->stream()>>>(task);
|
||||
k_reduce_add_T<float, CUDA_REDUCE_BLOCK_SIZE, 2><<<nblocks, CUDA_REDUCE_BLOCK_SIZE, 0, info.all_ctx[i]->stream()>>>(task);
|
||||
}
|
||||
}
|
||||
for (int ii = 0; ii < nhave/2; ++ii) {
|
||||
@@ -271,6 +241,83 @@ void ggml_cuda_op_reduce([[maybe_unused]] ggml_backend_cuda_context & ctx, ggml_
|
||||
ggml_cuda_set_device(ctx.device);
|
||||
return;
|
||||
}
|
||||
if (dst->ne[1] <= 8) {
|
||||
for (int ii = 0; ii < nhave; ++ii) {
|
||||
int i = idx[ii];
|
||||
GGML_ASSERT(dst->src[i]->type == dst->type);
|
||||
GGML_ASSERT(ggml_are_same_shape(dst, dst->src[i]));
|
||||
ggml_cuda_set_device(i);
|
||||
if (!info.all_ctx[i]->copy_event) {
|
||||
CUDA_CHECK(cudaEventCreateWithFlags(&info.all_ctx[i]->copy_event, cudaEventDisableTiming));
|
||||
}
|
||||
CUDA_CHECK(cudaEventRecord(info.all_ctx[i]->copy_event, info.all_ctx[i]->stream()));
|
||||
}
|
||||
//printf("Recorded events\n");
|
||||
auto nelem = ggml_nelements(dst);
|
||||
auto nelem_per_device = (nelem + nhave - 1)/nhave;
|
||||
auto elem_size = ggml_element_size(dst);
|
||||
for (int ii = 0; ii < nhave; ++ii) {
|
||||
int i = idx[ii];
|
||||
int this_nelem = std::min(nelem_per_device, nelem - ii*nelem_per_device);
|
||||
copy_task task;
|
||||
task.nptr = nhave;
|
||||
task.nelem = this_nelem;
|
||||
task.ptrs[0] = (char *)dst->src[i]->data + ii*nelem_per_device*elem_size;
|
||||
int k = 1;
|
||||
for (int jj = 0; jj < nhave; ++jj) {
|
||||
if (jj == ii) continue;
|
||||
int j = idx[jj];
|
||||
CUDA_CHECK(cudaStreamWaitEvent(info.all_ctx[i]->stream(), info.all_ctx[j]->copy_event));
|
||||
task.ptrs[k++] = (char *)dst->src[j]->data + ii*nelem_per_device*elem_size;
|
||||
}
|
||||
int nblock = (this_nelem + CUDA_REDUCE_BLOCK_SIZE - 1)/CUDA_REDUCE_BLOCK_SIZE;
|
||||
if (dst->type == GGML_TYPE_F16) {
|
||||
switch (nhave) {
|
||||
case 2:
|
||||
k_reduce_add_T<half, CUDA_REDUCE_BLOCK_SIZE, 2><<<nblock, CUDA_REDUCE_BLOCK_SIZE, 0, info.all_ctx[i]->stream()>>>(task);
|
||||
break;
|
||||
case 3:
|
||||
k_reduce_add_T<half, CUDA_REDUCE_BLOCK_SIZE, 3><<<nblock, CUDA_REDUCE_BLOCK_SIZE, 0, info.all_ctx[i]->stream()>>>(task);
|
||||
break;
|
||||
case 4:
|
||||
k_reduce_add_T<half, CUDA_REDUCE_BLOCK_SIZE, 4><<<nblock, CUDA_REDUCE_BLOCK_SIZE, 0, info.all_ctx[i]->stream()>>>(task);
|
||||
break;
|
||||
default:
|
||||
k_reduce_add<half, CUDA_REDUCE_BLOCK_SIZE><<<nblock, CUDA_REDUCE_BLOCK_SIZE, 0, info.all_ctx[i]->stream()>>>(task);
|
||||
}
|
||||
} else {
|
||||
switch (nhave) {
|
||||
case 2:
|
||||
k_reduce_add_T<float, CUDA_REDUCE_BLOCK_SIZE, 2><<<nblock, CUDA_REDUCE_BLOCK_SIZE, 0, info.all_ctx[i]->stream()>>>(task);
|
||||
break;
|
||||
case 3:
|
||||
k_reduce_add_T<float, CUDA_REDUCE_BLOCK_SIZE, 3><<<nblock, CUDA_REDUCE_BLOCK_SIZE, 0, info.all_ctx[i]->stream()>>>(task);
|
||||
break;
|
||||
case 4:
|
||||
k_reduce_add_T<float, CUDA_REDUCE_BLOCK_SIZE, 4><<<nblock, CUDA_REDUCE_BLOCK_SIZE, 0, info.all_ctx[i]->stream()>>>(task);
|
||||
break;
|
||||
default:
|
||||
k_reduce_add<float, CUDA_REDUCE_BLOCK_SIZE><<<nblock, CUDA_REDUCE_BLOCK_SIZE, 0, info.all_ctx[i]->stream()>>>(task);
|
||||
}
|
||||
}
|
||||
}
|
||||
//printf("Submitted kernels\n");
|
||||
for (int ii = 0; ii < nhave; ++ii) {
|
||||
int i = idx[ii];
|
||||
CUDA_CHECK(cudaEventRecord(info.all_ctx[i]->copy_event, info.all_ctx[i]->stream()));
|
||||
}
|
||||
//printf("Recorded events again\n");
|
||||
for (int ii = 0; ii < nhave; ++ii) {
|
||||
int i = idx[ii];
|
||||
for (int jj = 0; jj < nhave; ++jj) {
|
||||
if (jj == ii) continue;
|
||||
int j = idx[jj];
|
||||
CUDA_CHECK(cudaStreamWaitEvent(info.all_ctx[i]->stream(), info.all_ctx[j]->copy_event));
|
||||
}
|
||||
}
|
||||
//printf("All good so far\n");
|
||||
return;
|
||||
}
|
||||
auto required_size = nbytes*(nhave-1);
|
||||
if (required_size > ctx.copy_size) {
|
||||
if (ctx.copy_buffer) {
|
||||
|
||||
Reference in New Issue
Block a user