mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-02 04:31:25 +00:00
use int64_t as expert stride to avoid overflow
This commit is contained in:
@@ -119,16 +119,16 @@ __global__ void moe_gemm_kernel(const ck_tile::index_t* p_sorted_token_ids_,
|
||||
? gather_token_id * strideA + k
|
||||
: k * strideA + gather_token_id;
|
||||
|
||||
int b_index =
|
||||
expert_id * N * K + ((std::is_same_v<LayoutB, tensor_layout::gemm::ColumnMajor>)
|
||||
? col * strideB + k
|
||||
: k * strideB + col);
|
||||
int b_index_up;
|
||||
long b_index =
|
||||
long(expert_id) * N * K +
|
||||
((std::is_same_v<LayoutB, tensor_layout::gemm::ColumnMajor>) ? col * strideB + k
|
||||
: k * strideB + col);
|
||||
long b_index_up;
|
||||
if constexpr(MoeGemmKind == 1)
|
||||
b_index_up =
|
||||
expert_id * N * K + ((std::is_same_v<LayoutB, tensor_layout::gemm::ColumnMajor>)
|
||||
? (col + problem_N) * strideB + k
|
||||
: k * strideB + col + problem_N);
|
||||
b_index_up = long(expert_id) * N * K +
|
||||
((std::is_same_v<LayoutB, tensor_layout::gemm::ColumnMajor>)
|
||||
? (col + problem_N) * strideB + k
|
||||
: k * strideB + col + problem_N);
|
||||
|
||||
AccDataType v_a;
|
||||
AccDataType v_b;
|
||||
|
||||
Reference in New Issue
Block a user