mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
fix: address GPU memory segmentation fault caused by int32 overflow
This commit is contained in:
@@ -240,7 +240,8 @@ struct DeviceMoeGemm : public DeviceGemmMultipleDSplitKBPreShuffle<ALayout,
|
||||
if(arg_.KBatch > 1)
|
||||
hipGetErrorString(hipMemsetAsync(arg_.p_c_grid,
|
||||
0,
|
||||
arg_.M * arg_.N * sizeof(CDataType),
|
||||
arg.NumTokens * arg.TopK * arg_.N *
|
||||
sizeof(CDataType),
|
||||
stream_config.stream_id_));
|
||||
};
|
||||
|
||||
@@ -256,10 +257,11 @@ struct DeviceMoeGemm : public DeviceGemmMultipleDSplitKBPreShuffle<ALayout,
|
||||
else
|
||||
{
|
||||
if(arg.KBatch > 1)
|
||||
hipGetErrorString(hipMemsetAsync(arg.p_c_grid,
|
||||
0,
|
||||
arg.M * arg.N * sizeof(CDataType),
|
||||
stream_config.stream_id_));
|
||||
hipGetErrorString(
|
||||
hipMemsetAsync(arg.p_c_grid,
|
||||
0,
|
||||
arg.NumTokens * arg.TopK * arg.N * sizeof(CDataType),
|
||||
stream_config.stream_id_));
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
|
||||
|
||||
@@ -1245,9 +1245,10 @@ struct GridwiseMoeGemm
|
||||
}
|
||||
gather_offsets(m0) = static_cast<IndexType>(token_offset) * problem.K;
|
||||
});
|
||||
const IndexType expert_stride =
|
||||
__builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1));
|
||||
const IndexType expert_offset = expert_id * expert_stride / BPackedSize;
|
||||
const long_index_t expert_stride = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(problem.N) * problem.K * (IsInputGemm ? 2 : 1));
|
||||
const IndexType expert_offset =
|
||||
static_cast<long_index_t>(expert_id) * expert_stride / BPackedSize;
|
||||
// N0, K0, Blocksize*KPack
|
||||
const index_t n_block_data_idx_on_grid =
|
||||
__builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave);
|
||||
@@ -1960,9 +1961,10 @@ struct GridwiseMoeGemm
|
||||
}
|
||||
gather_offsets(m0) = static_cast<IndexType>(token_offset) * problem.K;
|
||||
});
|
||||
const IndexType expert_stride =
|
||||
__builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1));
|
||||
const IndexType expert_offset = expert_id * expert_stride / BPackedSize;
|
||||
const long_index_t expert_stride = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(problem.N) * problem.K * (IsInputGemm ? 2 : 1));
|
||||
const IndexType expert_offset =
|
||||
static_cast<long_index_t>(expert_id) * expert_stride / BPackedSize;
|
||||
// N0, K0, Blocksize*KPack
|
||||
const index_t n_block_data_idx_on_grid =
|
||||
__builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave);
|
||||
|
||||
Reference in New Issue
Block a user