Merge remote-tracking branch 'origin/main' into ik/nccl3_async

This commit is contained in:
Iwan Kawrakow
2025-12-25 07:57:23 +00:00
3 changed files with 106 additions and 54 deletions

View File

@@ -1479,9 +1479,10 @@ static void ggml_cuda_op_mul_mat_cublas(
GGML_UNUSED(src1_padded_row_size);
}
static void ggml_cuda_set_peer_access(int main_device) {
static bool ggml_cuda_set_peer_access(int main_device) {
ggml_cuda_set_device(main_device);
bool all_enabled = true;
for (int id_other = 0; id_other < ggml_backend_cuda_get_device_count(); ++id_other) {
if (main_device == id_other) {
continue;
@@ -1500,8 +1501,11 @@ static void ggml_cuda_set_peer_access(int main_device) {
// reset the error
(void)cudaGetLastError();
}
} else {
all_enabled = false;
}
}
return all_enabled;
}
static cudaError_t ggml_cuda_Memcpy2DPeerAsync(
@@ -4453,7 +4457,7 @@ GGML_CALL ggml_backend_t ggml_backend_cuda_init(int device, [[maybe_unused]] con
#if !defined(GGML_CUDA_NO_PEER_COPY)
if (enable_p2p) {
ggml_cuda_set_peer_access(device);
ctx->p2p_enabled = ggml_cuda_set_peer_access(device);
}
#endif

View File

@@ -843,6 +843,7 @@ struct ggml_backend_cuda_context {
int device;
std::string name;
cudaEvent_t copy_event = nullptr;
bool p2p_enabled = false;
cudaStream_t streams[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS] = { { nullptr } };
cublasHandle_t cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr};

View File

@@ -45,6 +45,23 @@ static __global__ void k_reduce_add(copy_task task) {
}
}
template <typename T, int block_size, int nptr>
static __global__ void k_reduce_add_T(copy_task task) {
int i = blockIdx.x*block_size + threadIdx.x;
if (i >= task.nelem) return;
auto dst = (T *)task.ptrs[0];
#pragma unroll
for (int j = 1; j < nptr; ++j) {
auto src = (T *)task.ptrs[j];
dst[i] += src[i];
}
#pragma unroll
for (int j = 1; j < nptr; ++j) {
auto src = (T *)task.ptrs[j];
src[i] = dst[i];
}
}
void ggml_cuda_op_reduce([[maybe_unused]] ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
auto op = (ggml_op)dst->op_params[0];
@@ -146,53 +163,6 @@ void ggml_cuda_op_reduce([[maybe_unused]] ggml_backend_cuda_context & ctx, ggml_
#endif
GGML_ASSERT(dst->data == dst->src[ctx.device]->data);
auto nbytes = ggml_nbytes(dst);
if (nhave == 2 && (nhave == nreduce || dst->ne[1] <= 8)) {
int idx[2];
int ii = 0;
for (int i = 0; i < nreduce; ++i) {
if (dst->src[i]) {
idx[ii++] = i;
}
}
// With P2P access enabled, we can access peer memory so as if it was local.
// Hence, we can launch two reduce kernels, one on each device, each kernel
// processing half of the data. This very simply approach almost matches NCCL
// performance (I see ~1% lower PP and TG performance on my 2x3090 system).
for (int i = 0; i < nhave; ++i) {
GGML_ASSERT(dst->src[idx[i]]->type == dst->type);
GGML_ASSERT(ggml_are_same_shape(dst, dst->src[idx[i]]));
ggml_cuda_set_device(idx[i]);
if (!info.all_ctx[idx[i]]->copy_event) {
CUDA_CHECK(cudaEventCreateWithFlags(&info.all_ctx[idx[i]]->copy_event, cudaEventDisableTiming));
}
CUDA_CHECK(cudaEventRecord(info.all_ctx[idx[i]]->copy_event, info.all_ctx[idx[i]]->stream()));
}
auto nelem = ggml_nelements(dst);
auto nelem_half = (nelem + 1)/2;
for (int i = 0; i < nhave; ++i) {
ggml_cuda_set_device(idx[i]);
CUDA_CHECK(cudaStreamWaitEvent(info.all_ctx[idx[i]]->stream(), info.all_ctx[idx[(i+1)%2]]->copy_event, 0));
auto this_nelem = std::min(nelem_half, nelem - nelem_half);
int nblock = (this_nelem + CUDA_REDUCE_BLOCK_SIZE - 1)/CUDA_REDUCE_BLOCK_SIZE;
if (dst->type == GGML_TYPE_F16) {
auto src_ptr = (half *)dst->src[idx[i]]->data + i*nelem_half;
auto dst_ptr = (half *)dst->src[idx[(i+1)%2]]->data + i*nelem_half;
k_add_sym<half, CUDA_REDUCE_BLOCK_SIZE><<<nblock, CUDA_REDUCE_BLOCK_SIZE, 0, info.all_ctx[idx[i]]->stream()>>>(this_nelem, src_ptr, dst_ptr);
} else {
auto src_ptr = (float *)dst->src[idx[i]]->data + i*nelem_half;
auto dst_ptr = (float *)dst->src[idx[(i+1)%2]]->data + i*nelem_half;
k_add_sym<float, CUDA_REDUCE_BLOCK_SIZE><<<nblock, CUDA_REDUCE_BLOCK_SIZE, 0, info.all_ctx[idx[i]]->stream()>>>(this_nelem, src_ptr, dst_ptr);
}
}
for (int i = 0; i < nhave; ++i) {
ggml_cuda_set_device(idx[i]);
CUDA_CHECK(cudaEventRecord(info.all_ctx[idx[i]]->copy_event, info.all_ctx[idx[i]]->stream()));
ggml_cuda_set_device(idx[(i+1)%2]);
CUDA_CHECK(cudaStreamWaitEvent(info.all_ctx[idx[(i+1)%2]]->stream(), info.all_ctx[idx[i]]->copy_event));
}
ggml_cuda_set_device(ctx.device);
return;
}
int idx[GGML_CUDA_MAX_DEVICES];
{
int ii = 0;
@@ -206,7 +176,7 @@ void ggml_cuda_op_reduce([[maybe_unused]] ggml_backend_cuda_context & ctx, ggml_
GGML_ASSERT(ii == nhave);
GGML_ASSERT(have_this_device);
}
if (nhave == 4 && dst->ne[1] <= 8) {
if (nhave == 4 && dst->ne[1] <= 8 && ctx.p2p_enabled) {
for (int ii = 0; ii < nhave; ++ii) {
int i = idx[ii];
GGML_ASSERT(dst->src[i]->type == dst->type);
@@ -230,9 +200,9 @@ void ggml_cuda_op_reduce([[maybe_unused]] ggml_backend_cuda_context & ctx, ggml_
task.ptrs[1] = (char *)dst->src[j]->data;
CUDA_CHECK(cudaStreamWaitEvent(info.all_ctx[i]->stream(), info.all_ctx[j]->copy_event));
if (dst->type == GGML_TYPE_F16) {
k_reduce_add<half, CUDA_REDUCE_BLOCK_SIZE><<<nblocks, CUDA_REDUCE_BLOCK_SIZE, 0, info.all_ctx[i]->stream()>>>(task);
k_reduce_add_T<half, CUDA_REDUCE_BLOCK_SIZE, 2><<<nblocks, CUDA_REDUCE_BLOCK_SIZE, 0, info.all_ctx[i]->stream()>>>(task);
} else {
k_reduce_add<float, CUDA_REDUCE_BLOCK_SIZE><<<nblocks, CUDA_REDUCE_BLOCK_SIZE, 0, info.all_ctx[i]->stream()>>>(task);
k_reduce_add_T<float, CUDA_REDUCE_BLOCK_SIZE, 2><<<nblocks, CUDA_REDUCE_BLOCK_SIZE, 0, info.all_ctx[i]->stream()>>>(task);
}
}
for (int ii = 0; ii < nhave/2; ++ii) {
@@ -252,9 +222,9 @@ void ggml_cuda_op_reduce([[maybe_unused]] ggml_backend_cuda_context & ctx, ggml_
task.ptrs[1] = (char *)dst->src[j]->data;
CUDA_CHECK(cudaStreamWaitEvent(info.all_ctx[i]->stream(), info.all_ctx[j]->copy_event));
if (dst->type == GGML_TYPE_F16) {
k_reduce_add<half, CUDA_REDUCE_BLOCK_SIZE><<<nblocks, CUDA_REDUCE_BLOCK_SIZE, 0, info.all_ctx[i]->stream()>>>(task);
k_reduce_add_T<half, CUDA_REDUCE_BLOCK_SIZE, 2><<<nblocks, CUDA_REDUCE_BLOCK_SIZE, 0, info.all_ctx[i]->stream()>>>(task);
} else {
k_reduce_add<float, CUDA_REDUCE_BLOCK_SIZE><<<nblocks, CUDA_REDUCE_BLOCK_SIZE, 0, info.all_ctx[i]->stream()>>>(task);
k_reduce_add_T<float, CUDA_REDUCE_BLOCK_SIZE, 2><<<nblocks, CUDA_REDUCE_BLOCK_SIZE, 0, info.all_ctx[i]->stream()>>>(task);
}
}
for (int ii = 0; ii < nhave/2; ++ii) {
@@ -271,6 +241,83 @@ void ggml_cuda_op_reduce([[maybe_unused]] ggml_backend_cuda_context & ctx, ggml_
ggml_cuda_set_device(ctx.device);
return;
}
if (dst->ne[1] <= 8 && ctx.p2p_enabled) {
for (int ii = 0; ii < nhave; ++ii) {
int i = idx[ii];
GGML_ASSERT(dst->src[i]->type == dst->type);
GGML_ASSERT(ggml_are_same_shape(dst, dst->src[i]));
ggml_cuda_set_device(i);
if (!info.all_ctx[i]->copy_event) {
CUDA_CHECK(cudaEventCreateWithFlags(&info.all_ctx[i]->copy_event, cudaEventDisableTiming));
}
CUDA_CHECK(cudaEventRecord(info.all_ctx[i]->copy_event, info.all_ctx[i]->stream()));
}
//printf("Recorded events\n");
auto nelem = ggml_nelements(dst);
auto nelem_per_device = (nelem + nhave - 1)/nhave;
auto elem_size = ggml_element_size(dst);
for (int ii = 0; ii < nhave; ++ii) {
int i = idx[ii];
int this_nelem = std::min(nelem_per_device, nelem - ii*nelem_per_device);
copy_task task;
task.nptr = nhave;
task.nelem = this_nelem;
task.ptrs[0] = (char *)dst->src[i]->data + ii*nelem_per_device*elem_size;
int k = 1;
for (int jj = 0; jj < nhave; ++jj) {
if (jj == ii) continue;
int j = idx[jj];
CUDA_CHECK(cudaStreamWaitEvent(info.all_ctx[i]->stream(), info.all_ctx[j]->copy_event));
task.ptrs[k++] = (char *)dst->src[j]->data + ii*nelem_per_device*elem_size;
}
int nblock = (this_nelem + CUDA_REDUCE_BLOCK_SIZE - 1)/CUDA_REDUCE_BLOCK_SIZE;
if (dst->type == GGML_TYPE_F16) {
switch (nhave) {
case 2:
k_reduce_add_T<half, CUDA_REDUCE_BLOCK_SIZE, 2><<<nblock, CUDA_REDUCE_BLOCK_SIZE, 0, info.all_ctx[i]->stream()>>>(task);
break;
case 3:
k_reduce_add_T<half, CUDA_REDUCE_BLOCK_SIZE, 3><<<nblock, CUDA_REDUCE_BLOCK_SIZE, 0, info.all_ctx[i]->stream()>>>(task);
break;
case 4:
k_reduce_add_T<half, CUDA_REDUCE_BLOCK_SIZE, 4><<<nblock, CUDA_REDUCE_BLOCK_SIZE, 0, info.all_ctx[i]->stream()>>>(task);
break;
default:
k_reduce_add<half, CUDA_REDUCE_BLOCK_SIZE><<<nblock, CUDA_REDUCE_BLOCK_SIZE, 0, info.all_ctx[i]->stream()>>>(task);
}
} else {
switch (nhave) {
case 2:
k_reduce_add_T<float, CUDA_REDUCE_BLOCK_SIZE, 2><<<nblock, CUDA_REDUCE_BLOCK_SIZE, 0, info.all_ctx[i]->stream()>>>(task);
break;
case 3:
k_reduce_add_T<float, CUDA_REDUCE_BLOCK_SIZE, 3><<<nblock, CUDA_REDUCE_BLOCK_SIZE, 0, info.all_ctx[i]->stream()>>>(task);
break;
case 4:
k_reduce_add_T<float, CUDA_REDUCE_BLOCK_SIZE, 4><<<nblock, CUDA_REDUCE_BLOCK_SIZE, 0, info.all_ctx[i]->stream()>>>(task);
break;
default:
k_reduce_add<float, CUDA_REDUCE_BLOCK_SIZE><<<nblock, CUDA_REDUCE_BLOCK_SIZE, 0, info.all_ctx[i]->stream()>>>(task);
}
}
}
//printf("Submitted kernels\n");
for (int ii = 0; ii < nhave; ++ii) {
int i = idx[ii];
CUDA_CHECK(cudaEventRecord(info.all_ctx[i]->copy_event, info.all_ctx[i]->stream()));
}
//printf("Recorded events again\n");
for (int ii = 0; ii < nhave; ++ii) {
int i = idx[ii];
for (int jj = 0; jj < nhave; ++jj) {
if (jj == ii) continue;
int j = idx[jj];
CUDA_CHECK(cudaStreamWaitEvent(info.all_ctx[i]->stream(), info.all_ctx[j]->copy_event));
}
}
//printf("All good so far\n");
return;
}
auto required_size = nbytes*(nhave-1);
if (required_size > ctx.copy_size) {
if (ctx.copy_buffer) {