Also this barely moves the needle

This commit is contained in:
Iwan Kawrakow
2025-08-28 18:20:18 +03:00
parent 37ef1d3719
commit 486f1adc1e
3 changed files with 66 additions and 22 deletions

View File

@@ -2758,11 +2758,6 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
return false;
}
std::vector<char> ids_host(ggml_nbytes(ids));
const char * ids_dev = (const char *) ids->data;
CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids_dev, ggml_nbytes(ids), cudaMemcpyDeviceToHost, stream));
CUDA_CHECK(cudaStreamSynchronize(stream));
ggml_tensor src0_1_row = *src0_1;
ggml_tensor src0_2_row = *src0_2;
ggml_tensor src1_row = *src1;
@@ -2836,20 +2831,19 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
bool first = false; //true;
ggml_cuda_pool_alloc<mmid_row_mapping> dev_row_mapping(ctx.pool());
std::vector<int> moe_counts, cum_moe_counts;
const int64_t ne_get_rows = ne12 * n_ids;
ggml_cuda_pool_alloc<mmid_row_mapping> dev_row_mapping(ctx.pool(), ne_get_rows + (n_as + 2)/2);
bool is_ser = prepare_row_mappigs(ctx, n_as, n_ids, ids, moe_counts, cum_moe_counts, dev_row_mapping);
if (is_ser) {
if (fuse_down) {
CUDA_CHECK(cudaMemsetAsync(next->data, 0, ggml_nbytes(next), stream));
} else {
CUDA_CHECK(cudaMemsetAsync(dst->data, 0, ggml_nbytes(dst), stream));
}
}
compute_row_ids2((const int32_t *)ids->data, dev_row_mapping.get(), (int32_t *)(dev_row_mapping.get() + ne_get_rows),
ne02, ne12, n_ids, ne11, nb11, nb12, ids->nb[1], stream);
std::vector<int> cum_moe_counts(n_as + 1);
CUDA_CHECK(cudaMemcpyAsync(cum_moe_counts.data(), dev_row_mapping.get() + ne_get_rows, (n_as+1)*sizeof(int),
cudaMemcpyDeviceToHost, stream));
CUDA_CHECK(cudaStreamSynchronize(stream));
for (int64_t i02 = 0; i02 < n_as; i02++) {
int64_t num_src1_rows = moe_counts[i02];
int64_t num_src1_rows = cum_moe_counts[i02+1] - cum_moe_counts[i02];
if (num_src1_rows == 0) continue;
size_t mapping_offset = cum_moe_counts[i02];

View File

@@ -24,11 +24,13 @@ struct mmq_ids_helper_store {
};
static_assert(sizeof(mmq_ids_helper_store) == 4, "unexpected size for mmq_ids_helper_store");
struct mmid_row_mapping { int32_t i1, i2; };
// Helper function for mul_mat_id, converts ids to a more convenient format.
// ids_src1 describes how to permute the flattened column indices of src1 in order to get a compact src1 tensor sorted by expert.
// ids_dst describes the same mapping but for the dst tensor.
// The upper and lower bounds for the ith expert in the compact src1 tensor are stored in expert_bounds[i:i+1].
template <int n_expert_used_template>
template <int n_expert_used_template, int type = 0>
__launch_bounds__(ggml_cuda_get_physical_warp_size(), 1)
static __global__ void mmq_ids_helper(
const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds,
@@ -103,9 +105,12 @@ static __global__ void mmq_ids_helper(
const mmq_ids_helper_store store_it = store[itc];
const int it = store_it.it();
const int iex_used = store_it.iex_used();
//ids_src1[nex_prev + itc] = it*sis1 + iex_used % nchannels_y;
ids_src1[nex_prev + itc] = it;
ids_dst [nex_prev + itc] = it*n_expert_used + iex_used;
if constexpr (type == 0) {
ids_src1[nex_prev + itc] = it;
ids_dst [nex_prev + itc] = it*n_expert_used + iex_used;
} else {
((mmid_row_mapping *)ids_src1)[nex_prev + itc] = {iex_used, it};
}
}
if (threadIdx.x != 0) {
@@ -121,7 +126,7 @@ static __global__ void mmq_ids_helper(
expert_bounds[gridDim.x] = nex_prev + it_compact;
}
template <int n_expert_used_template>
template <int n_expert_used_template, int type = 0>
static void launch_mmq_ids_helper(
const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds,
const int n_experts, const int n_tokens, const int n_expert_used_var, const int nchannels_y, const int si1, const int sis1, cudaStream_t stream) {
@@ -136,7 +141,7 @@ static void launch_mmq_ids_helper(
const dim3 num_blocks(n_experts, 1, 1);
const dim3 block_size(warp_size, 1, 1);
const size_t nbytes_shared = n_tokens*sizeof(mmq_ids_helper_store);
mmq_ids_helper<n_expert_used_template><<<num_blocks, block_size, nbytes_shared, stream>>>
mmq_ids_helper<n_expert_used_template, type><<<num_blocks, block_size, nbytes_shared, stream>>>
(ids, ids_src1, ids_dst, expert_bounds, n_tokens, n_expert_used_var, nchannels_y, si1, sis1);
}
@@ -314,6 +319,46 @@ void compute_row_ids(const int32_t * ids, int32_t * ids_src1, int32_t * ids_dst,
CUDA_CHECK(cudaGetLastError());
}
void compute_row_ids2(const int32_t * ids, mmid_row_mapping * rmapping, int32_t * expert_bounds,
int64_t ne02, int64_t ne12, int64_t n_expert_used, int64_t ne11, int64_t nb11, int64_t nb12, int64_t nb21,
cudaStream_t stream) {
const int si1 = nb21 / sizeof(int);
const int sis1 = nb12 / nb11;
switch (n_expert_used) {
case 2:
launch_mmq_ids_helper< 2, 1> (ids, (int32_t *)rmapping, nullptr, expert_bounds,
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
break;
case 4:
launch_mmq_ids_helper< 4, 1> (ids, (int32_t *)rmapping, nullptr, expert_bounds,
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
break;
case 6:
launch_mmq_ids_helper< 6, 1> (ids, (int32_t *)rmapping, nullptr, expert_bounds,
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
break;
case 8:
launch_mmq_ids_helper< 8, 1> (ids, (int32_t *)rmapping, nullptr, expert_bounds,
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
break;
case 16:
launch_mmq_ids_helper<16, 1> (ids, (int32_t *)rmapping, nullptr, expert_bounds,
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
break;
case 32:
launch_mmq_ids_helper<32, 1> (ids, (int32_t *)rmapping, nullptr, expert_bounds,
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
break;
default:
launch_mmq_ids_helper< 0, 1> (ids, (int32_t *)rmapping, nullptr, expert_bounds,
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
break;
}
CUDA_CHECK(cudaGetLastError());
}
void ggml_cuda_mul_mat_q_id(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1,
const ggml_tensor * ids_tensor, ggml_tensor * dst, char * ids_data, char * src1_quantized_data, bool is_next) {
GGML_ASSERT( src1->type == GGML_TYPE_F32);

View File

@@ -9,4 +9,9 @@ void ggml_cuda_mul_mat_q_id(
void compute_row_ids(const int32_t * ids, int32_t * ids_src1, int32_t * ids_dst, int32_t * expert_bounds,
int64_t ne02, int64_t ne12, int64_t n_expert_used, int64_t ne11, int64_t nb11, int64_t nb12, int64_t nb21, cudaStream_t stream);
struct mmid_row_mapping;
void compute_row_ids2(const int32_t * ids, mmid_row_mapping * rmapping, int32_t * expert_bounds,
int64_t ne02, int64_t ne12, int64_t n_expert_used, int64_t ne11, int64_t nb11, int64_t nb12, int64_t nb21, cudaStream_t stream);
bool ggml_cuda_can_use_mmq_id(enum ggml_type type, int cc, int64_t ne11);