mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-21 13:44:10 +00:00
cuda: build MoE row mapping on device in mul_mat_id
This commit is contained in:
@@ -2321,54 +2321,127 @@ static __global__ void k_quick_add(uint32_t n_per_row, const float * src1, const
|
||||
}
|
||||
}
|
||||
|
||||
static __global__ void k_moe_row_count(
|
||||
const char * __restrict__ ids,
|
||||
size_t nb0, size_t nb1,
|
||||
int64_t n_ids, int64_t n_rows,
|
||||
int32_t n_as,
|
||||
int32_t * __restrict__ moe_counts,
|
||||
int32_t * __restrict__ has_invalid_ids) {
|
||||
const int64_t idx = (int64_t) blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const int64_t ne = n_ids * n_rows;
|
||||
|
||||
if (idx >= ne) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int64_t iid1 = idx / n_ids;
|
||||
const int64_t id = idx - iid1*n_ids;
|
||||
const int32_t row_id_i = *(const int32_t *) (ids + iid1*nb1 + id*nb0);
|
||||
|
||||
if ((uint32_t) row_id_i < (uint32_t) n_as) {
|
||||
atomicAdd(moe_counts + row_id_i, 1);
|
||||
} else {
|
||||
atomicExch(has_invalid_ids, 1);
|
||||
}
|
||||
}
|
||||
|
||||
static __global__ void k_moe_row_exclusive_scan(
|
||||
const int32_t * __restrict__ moe_counts,
|
||||
int32_t * __restrict__ cum_moe_counts,
|
||||
int32_t n_as) {
|
||||
if (blockIdx.x != 0 || threadIdx.x != 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
int32_t sum = 0;
|
||||
cum_moe_counts[0] = 0;
|
||||
for (int i = 0; i < n_as; ++i) {
|
||||
sum += moe_counts[i];
|
||||
cum_moe_counts[i + 1] = sum;
|
||||
}
|
||||
}
|
||||
|
||||
static __global__ void k_moe_row_scatter(
|
||||
const char * __restrict__ ids,
|
||||
size_t nb0, size_t nb1,
|
||||
int64_t n_ids, int64_t n_rows,
|
||||
int32_t n_as,
|
||||
int32_t * __restrict__ row_offsets,
|
||||
mmid_row_mapping * __restrict__ row_mapping) {
|
||||
const int64_t idx = (int64_t) blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const int64_t ne = n_ids * n_rows;
|
||||
|
||||
if (idx >= ne) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int64_t iid1 = idx / n_ids;
|
||||
const int64_t id = idx - iid1*n_ids;
|
||||
const int32_t row_id_i = *(const int32_t *) (ids + iid1*nb1 + id*nb0);
|
||||
|
||||
if ((uint32_t) row_id_i < (uint32_t) n_as) {
|
||||
const int32_t dst_idx = atomicAdd(row_offsets + row_id_i, 1);
|
||||
row_mapping[dst_idx] = { (int32_t) id, (int32_t) iid1 };
|
||||
}
|
||||
}
|
||||
|
||||
static inline bool prepare_row_mappigs(ggml_backend_cuda_context& ctx, int64_t n_as, int64_t n_ids,
|
||||
const ggml_tensor * ids, std::vector<int>& moe_counts, std::vector<int>& cum_moe_counts,
|
||||
ggml_cuda_pool_alloc<mmid_row_mapping>& dev_row_mapping) {
|
||||
|
||||
GGML_ASSERT(moe_counts.empty() && cum_moe_counts.empty());
|
||||
GGML_ASSERT(ids->type == GGML_TYPE_I32);
|
||||
GGML_ASSERT(n_as <= (int64_t) std::numeric_limits<int32_t>::max());
|
||||
GGML_ASSERT(n_ids <= (int64_t) std::numeric_limits<int32_t>::max());
|
||||
GGML_ASSERT(ids->ne[1] <= (int64_t) std::numeric_limits<int32_t>::max());
|
||||
|
||||
auto stream = ctx.stream();
|
||||
|
||||
std::vector<char> ids_host(ggml_nbytes(ids));
|
||||
const char * ids_dev = (const char *) ids->data;
|
||||
CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids_dev, ggml_nbytes(ids), cudaMemcpyDeviceToHost, stream));
|
||||
const int64_t n_rows = ids->ne[1];
|
||||
const int64_t n_entries = n_rows*n_ids;
|
||||
|
||||
moe_counts.resize(n_as, 0);
|
||||
cum_moe_counts.resize(n_as + 1, 0);
|
||||
|
||||
if (n_entries == 0 || n_as == 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Build row mapping fully on-device to avoid per-call ids D2H/H2D round-trips.
|
||||
dev_row_mapping.alloc(n_entries);
|
||||
|
||||
ggml_cuda_pool_alloc<int32_t> dev_moe_counts(ctx.pool(), n_as);
|
||||
ggml_cuda_pool_alloc<int32_t> dev_cum_moe_counts(ctx.pool(), n_as + 1);
|
||||
ggml_cuda_pool_alloc<int32_t> dev_row_offsets(ctx.pool(), n_as);
|
||||
ggml_cuda_pool_alloc<int32_t> dev_has_invalid_ids(ctx.pool(), 1);
|
||||
|
||||
CUDA_CHECK(cudaMemsetAsync(dev_moe_counts.get(), 0, n_as*sizeof(int32_t), stream));
|
||||
CUDA_CHECK(cudaMemsetAsync(dev_has_invalid_ids.get(), 0, sizeof(int32_t), stream));
|
||||
|
||||
constexpr int block_size = 256;
|
||||
const dim3 grid_dims((n_entries + block_size - 1) / block_size, 1, 1);
|
||||
k_moe_row_count<<<grid_dims, block_size, 0, stream>>>(
|
||||
ids_dev, ids->nb[0], ids->nb[1], n_ids, n_rows, n_as, dev_moe_counts.get(), dev_has_invalid_ids.get());
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
k_moe_row_exclusive_scan<<<1, 1, 0, stream>>>(dev_moe_counts.get(), dev_cum_moe_counts.get(), n_as);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
CUDA_CHECK(cudaMemcpyAsync(dev_row_offsets.get(), dev_cum_moe_counts.get(), n_as*sizeof(int32_t), cudaMemcpyDeviceToDevice, stream));
|
||||
|
||||
k_moe_row_scatter<<<grid_dims, block_size, 0, stream>>>(
|
||||
ids_dev, ids->nb[0], ids->nb[1], n_ids, n_rows, n_as, dev_row_offsets.get(), dev_row_mapping.get());
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
int32_t has_invalid_ids = 0;
|
||||
CUDA_CHECK(cudaMemcpyAsync(moe_counts.data(), dev_moe_counts.get(), n_as*sizeof(int), cudaMemcpyDeviceToHost, stream));
|
||||
CUDA_CHECK(cudaMemcpyAsync(cum_moe_counts.data(), dev_cum_moe_counts.get(), (n_as + 1)*sizeof(int), cudaMemcpyDeviceToHost, stream));
|
||||
CUDA_CHECK(cudaMemcpyAsync(&has_invalid_ids, dev_has_invalid_ids.get(), sizeof(int32_t), cudaMemcpyDeviceToHost, stream));
|
||||
CUDA_CHECK(cudaStreamSynchronize(stream));
|
||||
|
||||
std::vector<mmid_row_mapping> rmapping(ids->ne[1]*n_ids);
|
||||
moe_counts.resize(n_as, 0);
|
||||
cum_moe_counts.resize(n_as + 1);
|
||||
|
||||
bool is_ser = false;
|
||||
for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
|
||||
for (int64_t id = 0; id < n_ids; id++) {
|
||||
const int32_t row_id_i = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
|
||||
if (row_id_i >= 0 && row_id_i < n_as) ++moe_counts[row_id_i];
|
||||
else is_ser = true;
|
||||
}
|
||||
}
|
||||
cum_moe_counts[0] = 0;
|
||||
for (int i = 0; i < (int)n_as; ++i) {
|
||||
cum_moe_counts[i+1] = cum_moe_counts[i] + moe_counts[i];
|
||||
}
|
||||
|
||||
dev_row_mapping.alloc(cum_moe_counts[n_as]);
|
||||
|
||||
for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
|
||||
for (int64_t id = 0; id < n_ids; id++) {
|
||||
const int32_t row_id_i = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
|
||||
if (row_id_i >= 0 && row_id_i < n_as) {
|
||||
rmapping[cum_moe_counts[row_id_i]++] = {(int)id, (int)iid1};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < (int)n_as; ++i) cum_moe_counts[i] -= moe_counts[i];
|
||||
|
||||
CUDA_CHECK(cudaMemcpyAsync(dev_row_mapping.get(), rmapping.data(),
|
||||
cum_moe_counts[n_as]*sizeof(mmid_row_mapping), cudaMemcpyHostToDevice, stream));
|
||||
//CUDA_CHECK(cudaStreamSynchronize(stream));
|
||||
|
||||
return is_ser;
|
||||
return has_invalid_ids != 0;
|
||||
}
|
||||
|
||||
static bool ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * next) {
|
||||
|
||||
Reference in New Issue
Block a user