diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index 269ed37d..ed0b8b56 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -2321,54 +2321,127 @@ static __global__ void k_quick_add(uint32_t n_per_row, const float * src1, const } } +static __global__ void k_moe_row_count( + const char * __restrict__ ids, + size_t nb0, size_t nb1, + int64_t n_ids, int64_t n_rows, + int32_t n_as, + int32_t * __restrict__ moe_counts, + int32_t * __restrict__ has_invalid_ids) { + const int64_t idx = (int64_t) blockIdx.x * blockDim.x + threadIdx.x; + const int64_t ne = n_ids * n_rows; + + if (idx >= ne) { + return; + } + + const int64_t iid1 = idx / n_ids; + const int64_t id = idx - iid1*n_ids; + const int32_t row_id_i = *(const int32_t *) (ids + iid1*nb1 + id*nb0); + + if ((uint32_t) row_id_i < (uint32_t) n_as) { + atomicAdd(moe_counts + row_id_i, 1); + } else { + atomicExch(has_invalid_ids, 1); + } +} + +static __global__ void k_moe_row_exclusive_scan( + const int32_t * __restrict__ moe_counts, + int32_t * __restrict__ cum_moe_counts, + int32_t n_as) { + if (blockIdx.x != 0 || threadIdx.x != 0) { + return; + } + + int32_t sum = 0; + cum_moe_counts[0] = 0; + for (int i = 0; i < n_as; ++i) { + sum += moe_counts[i]; + cum_moe_counts[i + 1] = sum; + } +} + +static __global__ void k_moe_row_scatter( + const char * __restrict__ ids, + size_t nb0, size_t nb1, + int64_t n_ids, int64_t n_rows, + int32_t n_as, + int32_t * __restrict__ row_offsets, + mmid_row_mapping * __restrict__ row_mapping) { + const int64_t idx = (int64_t) blockIdx.x * blockDim.x + threadIdx.x; + const int64_t ne = n_ids * n_rows; + + if (idx >= ne) { + return; + } + + const int64_t iid1 = idx / n_ids; + const int64_t id = idx - iid1*n_ids; + const int32_t row_id_i = *(const int32_t *) (ids + iid1*nb1 + id*nb0); + + if ((uint32_t) row_id_i < (uint32_t) n_as) { + const int32_t dst_idx = atomicAdd(row_offsets + row_id_i, 1); + row_mapping[dst_idx] = { (int32_t) id, (int32_t) iid1 }; + } +} + static inline bool prepare_row_mappigs(ggml_backend_cuda_context& ctx, int64_t n_as, int64_t n_ids, const ggml_tensor * ids, std::vector& moe_counts, std::vector& cum_moe_counts, ggml_cuda_pool_alloc& dev_row_mapping) { GGML_ASSERT(moe_counts.empty() && cum_moe_counts.empty()); + GGML_ASSERT(ids->type == GGML_TYPE_I32); + GGML_ASSERT(n_as <= (int64_t) std::numeric_limits::max()); + GGML_ASSERT(n_ids <= (int64_t) std::numeric_limits::max()); + GGML_ASSERT(ids->ne[1] <= (int64_t) std::numeric_limits::max()); auto stream = ctx.stream(); - std::vector ids_host(ggml_nbytes(ids)); const char * ids_dev = (const char *) ids->data; - CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids_dev, ggml_nbytes(ids), cudaMemcpyDeviceToHost, stream)); + const int64_t n_rows = ids->ne[1]; + const int64_t n_entries = n_rows*n_ids; + + moe_counts.resize(n_as, 0); + cum_moe_counts.resize(n_as + 1, 0); + + if (n_entries == 0 || n_as == 0) { + return false; + } + + // Build row mapping fully on-device to avoid per-call ids D2H/H2D round-trips. + dev_row_mapping.alloc(n_entries); + + ggml_cuda_pool_alloc dev_moe_counts(ctx.pool(), n_as); + ggml_cuda_pool_alloc dev_cum_moe_counts(ctx.pool(), n_as + 1); + ggml_cuda_pool_alloc dev_row_offsets(ctx.pool(), n_as); + ggml_cuda_pool_alloc dev_has_invalid_ids(ctx.pool(), 1); + + CUDA_CHECK(cudaMemsetAsync(dev_moe_counts.get(), 0, n_as*sizeof(int32_t), stream)); + CUDA_CHECK(cudaMemsetAsync(dev_has_invalid_ids.get(), 0, sizeof(int32_t), stream)); + + constexpr int block_size = 256; + const dim3 grid_dims((n_entries + block_size - 1) / block_size, 1, 1); + k_moe_row_count<<>>( + ids_dev, ids->nb[0], ids->nb[1], n_ids, n_rows, n_as, dev_moe_counts.get(), dev_has_invalid_ids.get()); + CUDA_CHECK(cudaGetLastError()); + + k_moe_row_exclusive_scan<<<1, 1, 0, stream>>>(dev_moe_counts.get(), dev_cum_moe_counts.get(), n_as); + CUDA_CHECK(cudaGetLastError()); + + CUDA_CHECK(cudaMemcpyAsync(dev_row_offsets.get(), dev_cum_moe_counts.get(), n_as*sizeof(int32_t), cudaMemcpyDeviceToDevice, stream)); + + k_moe_row_scatter<<>>( + ids_dev, ids->nb[0], ids->nb[1], n_ids, n_rows, n_as, dev_row_offsets.get(), dev_row_mapping.get()); + CUDA_CHECK(cudaGetLastError()); + + int32_t has_invalid_ids = 0; + CUDA_CHECK(cudaMemcpyAsync(moe_counts.data(), dev_moe_counts.get(), n_as*sizeof(int), cudaMemcpyDeviceToHost, stream)); + CUDA_CHECK(cudaMemcpyAsync(cum_moe_counts.data(), dev_cum_moe_counts.get(), (n_as + 1)*sizeof(int), cudaMemcpyDeviceToHost, stream)); + CUDA_CHECK(cudaMemcpyAsync(&has_invalid_ids, dev_has_invalid_ids.get(), sizeof(int32_t), cudaMemcpyDeviceToHost, stream)); CUDA_CHECK(cudaStreamSynchronize(stream)); - std::vector rmapping(ids->ne[1]*n_ids); - moe_counts.resize(n_as, 0); - cum_moe_counts.resize(n_as + 1); - - bool is_ser = false; - for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) { - for (int64_t id = 0; id < n_ids; id++) { - const int32_t row_id_i = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]); - if (row_id_i >= 0 && row_id_i < n_as) ++moe_counts[row_id_i]; - else is_ser = true; - } - } - cum_moe_counts[0] = 0; - for (int i = 0; i < (int)n_as; ++i) { - cum_moe_counts[i+1] = cum_moe_counts[i] + moe_counts[i]; - } - - dev_row_mapping.alloc(cum_moe_counts[n_as]); - - for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) { - for (int64_t id = 0; id < n_ids; id++) { - const int32_t row_id_i = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]); - if (row_id_i >= 0 && row_id_i < n_as) { - rmapping[cum_moe_counts[row_id_i]++] = {(int)id, (int)iid1}; - } - } - } - - for (int i = 0; i < (int)n_as; ++i) cum_moe_counts[i] -= moe_counts[i]; - - CUDA_CHECK(cudaMemcpyAsync(dev_row_mapping.get(), rmapping.data(), - cum_moe_counts[n_as]*sizeof(mmid_row_mapping), cudaMemcpyHostToDevice, stream)); - //CUDA_CHECK(cudaStreamSynchronize(stream)); - - return is_ser; + return has_invalid_ids != 0; } static bool ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * next) {