mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
use int64_t as expert stride to avoid overflow
This commit is contained in:
@@ -301,7 +301,7 @@ void shuffle_mxfp4_weight(const IterSrc src, IterDst dst, int experts_cnt, int N
|
||||
{
|
||||
int up_stride = N / 2 / NLane;
|
||||
|
||||
for(int eid = 0; eid < experts_cnt; ++eid)
|
||||
for(long eid = 0; eid < experts_cnt; ++eid)
|
||||
{
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
@@ -319,9 +319,9 @@ void shuffle_mxfp4_weight(const IterSrc src, IterDst dst, int experts_cnt, int N
|
||||
int k1 = tempk / KPack;
|
||||
int k2 = tempk % KPack;
|
||||
|
||||
int outputIndex = eid * N * K_pk + n0_interleave * KPack * NLane * KLane * K0 +
|
||||
k0 * KPack * NLane * KLane + k1 * KPack * NLane + n1 * KPack +
|
||||
k2;
|
||||
long outputIndex = eid * N * K_pk + n0_interleave * KPack * NLane * KLane * K0 +
|
||||
k0 * KPack * NLane * KLane + k1 * KPack * NLane +
|
||||
n1 * KPack + k2;
|
||||
|
||||
dst[outputIndex] = src[eid * N * K_pk + n * K_pk + k];
|
||||
}
|
||||
@@ -330,7 +330,7 @@ void shuffle_mxfp4_weight(const IterSrc src, IterDst dst, int experts_cnt, int N
|
||||
}
|
||||
else
|
||||
{
|
||||
for(int eid = 0; eid < experts_cnt; ++eid)
|
||||
for(long eid = 0; eid < experts_cnt; ++eid)
|
||||
{
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
@@ -344,9 +344,9 @@ void shuffle_mxfp4_weight(const IterSrc src, IterDst dst, int experts_cnt, int N
|
||||
int k1 = tempk / KPack;
|
||||
int k2 = tempk % KPack;
|
||||
|
||||
int outputIndex = eid * N * K_pk + n0 * KPack * NLane * KLane * K0 +
|
||||
k0 * KPack * NLane * KLane + k1 * KPack * NLane + n1 * KPack +
|
||||
k2;
|
||||
long outputIndex = eid * N * K_pk + n0 * KPack * NLane * KLane * K0 +
|
||||
k0 * KPack * NLane * KLane + k1 * KPack * NLane +
|
||||
n1 * KPack + k2;
|
||||
|
||||
dst[outputIndex] = src[eid * N * K_pk + n * K_pk + k];
|
||||
}
|
||||
|
||||
@@ -119,16 +119,16 @@ __global__ void moe_gemm_kernel(const ck_tile::index_t* p_sorted_token_ids_,
|
||||
? gather_token_id * strideA + k
|
||||
: k * strideA + gather_token_id;
|
||||
|
||||
int b_index =
|
||||
expert_id * N * K + ((std::is_same_v<LayoutB, tensor_layout::gemm::ColumnMajor>)
|
||||
? col * strideB + k
|
||||
: k * strideB + col);
|
||||
int b_index_up;
|
||||
long b_index =
|
||||
long(expert_id) * N * K +
|
||||
((std::is_same_v<LayoutB, tensor_layout::gemm::ColumnMajor>) ? col * strideB + k
|
||||
: k * strideB + col);
|
||||
long b_index_up;
|
||||
if constexpr(MoeGemmKind == 1)
|
||||
b_index_up =
|
||||
expert_id * N * K + ((std::is_same_v<LayoutB, tensor_layout::gemm::ColumnMajor>)
|
||||
? (col + problem_N) * strideB + k
|
||||
: k * strideB + col + problem_N);
|
||||
b_index_up = long(expert_id) * N * K +
|
||||
((std::is_same_v<LayoutB, tensor_layout::gemm::ColumnMajor>)
|
||||
? (col + problem_N) * strideB + k
|
||||
: k * strideB + col + problem_N);
|
||||
|
||||
AccDataType v_a;
|
||||
AccDataType v_b;
|
||||
|
||||
@@ -644,7 +644,8 @@ struct MoeFlatmmKernel
|
||||
});
|
||||
|
||||
const SplitKBatchOffset splitk_batch_offset(kargs);
|
||||
const index_t expert_stride = __builtin_amdgcn_readfirstlane(kargs.N * kargs.K);
|
||||
const long_index_t expert_stride =
|
||||
__builtin_amdgcn_readfirstlane(long_index_t(kargs.N) * kargs.K);
|
||||
|
||||
const ADataType* a_ptr =
|
||||
static_cast<const ADataType*>(kargs.a_ptr) + splitk_batch_offset.a_k_split_offset;
|
||||
|
||||
Reference in New Issue
Block a user