use int64_t as expert stride to avoid overflow

This commit is contained in:
Feng Shijie
2025-08-21 06:58:55 +00:00
parent 9fbcc8f8a4
commit 85976b0b87
3 changed files with 19 additions and 18 deletions

View File

@@ -119,16 +119,16 @@ __global__ void moe_gemm_kernel(const ck_tile::index_t* p_sorted_token_ids_,
? gather_token_id * strideA + k
: k * strideA + gather_token_id;
int b_index =
expert_id * N * K + ((std::is_same_v<LayoutB, tensor_layout::gemm::ColumnMajor>)
? col * strideB + k
: k * strideB + col);
int b_index_up;
long b_index =
long(expert_id) * N * K +
((std::is_same_v<LayoutB, tensor_layout::gemm::ColumnMajor>) ? col * strideB + k
: k * strideB + col);
long b_index_up;
if constexpr(MoeGemmKind == 1)
b_index_up =
expert_id * N * K + ((std::is_same_v<LayoutB, tensor_layout::gemm::ColumnMajor>)
? (col + problem_N) * strideB + k
: k * strideB + col + problem_N);
b_index_up = long(expert_id) * N * K +
((std::is_same_v<LayoutB, tensor_layout::gemm::ColumnMajor>)
? (col + problem_N) * strideB + k
: k * strideB + col + problem_N);
AccDataType v_a;
AccDataType v_b;