cuda: build MoE row mapping on device in mul_mat_id

This commit is contained in:
yurko
2026-02-06 13:52:33 +00:00
parent 9fbb50481e
commit 89e9ecfa84

View File

@@ -2321,54 +2321,127 @@ static __global__ void k_quick_add(uint32_t n_per_row, const float * src1, const
}
}
static __global__ void k_moe_row_count(
const char * __restrict__ ids,
size_t nb0, size_t nb1,
int64_t n_ids, int64_t n_rows,
int32_t n_as,
int32_t * __restrict__ moe_counts,
int32_t * __restrict__ has_invalid_ids) {
const int64_t idx = (int64_t) blockIdx.x * blockDim.x + threadIdx.x;
const int64_t ne = n_ids * n_rows;
if (idx >= ne) {
return;
}
const int64_t iid1 = idx / n_ids;
const int64_t id = idx - iid1*n_ids;
const int32_t row_id_i = *(const int32_t *) (ids + iid1*nb1 + id*nb0);
if ((uint32_t) row_id_i < (uint32_t) n_as) {
atomicAdd(moe_counts + row_id_i, 1);
} else {
atomicExch(has_invalid_ids, 1);
}
}
static __global__ void k_moe_row_exclusive_scan(
const int32_t * __restrict__ moe_counts,
int32_t * __restrict__ cum_moe_counts,
int32_t n_as) {
if (blockIdx.x != 0 || threadIdx.x != 0) {
return;
}
int32_t sum = 0;
cum_moe_counts[0] = 0;
for (int i = 0; i < n_as; ++i) {
sum += moe_counts[i];
cum_moe_counts[i + 1] = sum;
}
}
static __global__ void k_moe_row_scatter(
const char * __restrict__ ids,
size_t nb0, size_t nb1,
int64_t n_ids, int64_t n_rows,
int32_t n_as,
int32_t * __restrict__ row_offsets,
mmid_row_mapping * __restrict__ row_mapping) {
const int64_t idx = (int64_t) blockIdx.x * blockDim.x + threadIdx.x;
const int64_t ne = n_ids * n_rows;
if (idx >= ne) {
return;
}
const int64_t iid1 = idx / n_ids;
const int64_t id = idx - iid1*n_ids;
const int32_t row_id_i = *(const int32_t *) (ids + iid1*nb1 + id*nb0);
if ((uint32_t) row_id_i < (uint32_t) n_as) {
const int32_t dst_idx = atomicAdd(row_offsets + row_id_i, 1);
row_mapping[dst_idx] = { (int32_t) id, (int32_t) iid1 };
}
}
static inline bool prepare_row_mappigs(ggml_backend_cuda_context& ctx, int64_t n_as, int64_t n_ids,
const ggml_tensor * ids, std::vector<int>& moe_counts, std::vector<int>& cum_moe_counts,
ggml_cuda_pool_alloc<mmid_row_mapping>& dev_row_mapping) {
GGML_ASSERT(moe_counts.empty() && cum_moe_counts.empty());
GGML_ASSERT(ids->type == GGML_TYPE_I32);
GGML_ASSERT(n_as <= (int64_t) std::numeric_limits<int32_t>::max());
GGML_ASSERT(n_ids <= (int64_t) std::numeric_limits<int32_t>::max());
GGML_ASSERT(ids->ne[1] <= (int64_t) std::numeric_limits<int32_t>::max());
auto stream = ctx.stream();
std::vector<char> ids_host(ggml_nbytes(ids));
const char * ids_dev = (const char *) ids->data;
CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids_dev, ggml_nbytes(ids), cudaMemcpyDeviceToHost, stream));
const int64_t n_rows = ids->ne[1];
const int64_t n_entries = n_rows*n_ids;
moe_counts.resize(n_as, 0);
cum_moe_counts.resize(n_as + 1, 0);
if (n_entries == 0 || n_as == 0) {
return false;
}
// Build row mapping fully on-device to avoid per-call ids D2H/H2D round-trips.
dev_row_mapping.alloc(n_entries);
ggml_cuda_pool_alloc<int32_t> dev_moe_counts(ctx.pool(), n_as);
ggml_cuda_pool_alloc<int32_t> dev_cum_moe_counts(ctx.pool(), n_as + 1);
ggml_cuda_pool_alloc<int32_t> dev_row_offsets(ctx.pool(), n_as);
ggml_cuda_pool_alloc<int32_t> dev_has_invalid_ids(ctx.pool(), 1);
CUDA_CHECK(cudaMemsetAsync(dev_moe_counts.get(), 0, n_as*sizeof(int32_t), stream));
CUDA_CHECK(cudaMemsetAsync(dev_has_invalid_ids.get(), 0, sizeof(int32_t), stream));
constexpr int block_size = 256;
const dim3 grid_dims((n_entries + block_size - 1) / block_size, 1, 1);
k_moe_row_count<<<grid_dims, block_size, 0, stream>>>(
ids_dev, ids->nb[0], ids->nb[1], n_ids, n_rows, n_as, dev_moe_counts.get(), dev_has_invalid_ids.get());
CUDA_CHECK(cudaGetLastError());
k_moe_row_exclusive_scan<<<1, 1, 0, stream>>>(dev_moe_counts.get(), dev_cum_moe_counts.get(), n_as);
CUDA_CHECK(cudaGetLastError());
CUDA_CHECK(cudaMemcpyAsync(dev_row_offsets.get(), dev_cum_moe_counts.get(), n_as*sizeof(int32_t), cudaMemcpyDeviceToDevice, stream));
k_moe_row_scatter<<<grid_dims, block_size, 0, stream>>>(
ids_dev, ids->nb[0], ids->nb[1], n_ids, n_rows, n_as, dev_row_offsets.get(), dev_row_mapping.get());
CUDA_CHECK(cudaGetLastError());
int32_t has_invalid_ids = 0;
CUDA_CHECK(cudaMemcpyAsync(moe_counts.data(), dev_moe_counts.get(), n_as*sizeof(int), cudaMemcpyDeviceToHost, stream));
CUDA_CHECK(cudaMemcpyAsync(cum_moe_counts.data(), dev_cum_moe_counts.get(), (n_as + 1)*sizeof(int), cudaMemcpyDeviceToHost, stream));
CUDA_CHECK(cudaMemcpyAsync(&has_invalid_ids, dev_has_invalid_ids.get(), sizeof(int32_t), cudaMemcpyDeviceToHost, stream));
CUDA_CHECK(cudaStreamSynchronize(stream));
std::vector<mmid_row_mapping> rmapping(ids->ne[1]*n_ids);
moe_counts.resize(n_as, 0);
cum_moe_counts.resize(n_as + 1);
bool is_ser = false;
for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
for (int64_t id = 0; id < n_ids; id++) {
const int32_t row_id_i = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
if (row_id_i >= 0 && row_id_i < n_as) ++moe_counts[row_id_i];
else is_ser = true;
}
}
cum_moe_counts[0] = 0;
for (int i = 0; i < (int)n_as; ++i) {
cum_moe_counts[i+1] = cum_moe_counts[i] + moe_counts[i];
}
dev_row_mapping.alloc(cum_moe_counts[n_as]);
for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
for (int64_t id = 0; id < n_ids; id++) {
const int32_t row_id_i = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
if (row_id_i >= 0 && row_id_i < n_as) {
rmapping[cum_moe_counts[row_id_i]++] = {(int)id, (int)iid1};
}
}
}
for (int i = 0; i < (int)n_as; ++i) cum_moe_counts[i] -= moe_counts[i];
CUDA_CHECK(cudaMemcpyAsync(dev_row_mapping.get(), rmapping.data(),
cum_moe_counts[n_as]*sizeof(mmid_row_mapping), cudaMemcpyHostToDevice, stream));
//CUDA_CHECK(cudaStreamSynchronize(stream));
return is_ser;
return has_invalid_ids != 0;
}
static bool ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * next) {