Grouped conv fwd v3 fix for SplitN an G > 1 (#2038)

* Grouped conv fwd v3 fix for SplitN an G > 1

* Remove int8 large test

* Retore int8 test
This commit is contained in:
Bartłomiej Kocot
2025-04-01 22:19:35 +02:00
committed by GitHub
parent df32020f93
commit ec742908bd
2 changed files with 14 additions and 18 deletions

View File

@@ -79,15 +79,12 @@ __global__ void
[[maybe_unused]] const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock,
[[maybe_unused]] const ComputePtrOffset compute_ptr_offset_of_groups,
[[maybe_unused]] const ComputePtrOffset compute_ptr_offset_of_n,
[[maybe_unused]] const index_t groups_count)
[[maybe_unused]] const ComputePtrOffset compute_ptr_offset_of_n)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
// offset base pointer for each work-group
const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(gridDim.y / groups_count);
const index_t& num_blocks_per_n = groups_count;
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y / num_blocks_per_batch);
const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.y / num_blocks_per_n);
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y);
const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z);
const long_index_t a_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetAPtrOffset(g_idx));
@@ -141,15 +138,12 @@ __global__ void
[[maybe_unused]] const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock,
[[maybe_unused]] const ComputePtrOffset compute_ptr_offset_of_groups,
[[maybe_unused]] const ComputePtrOffset compute_ptr_offset_of_n,
[[maybe_unused]] const index_t groups_count)
[[maybe_unused]] const ComputePtrOffset compute_ptr_offset_of_n)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
// offset base pointer for each work-group
const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(gridDim.y / groups_count);
const index_t& num_blocks_per_n = groups_count;
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y / num_blocks_per_batch);
const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.y / num_blocks_per_n);
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y);
const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z);
const long_index_t a_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetAPtrOffset(g_idx));
@@ -766,7 +760,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
std::tie(gdx, gdy, gdz) =
GridwiseGemm::CalculateGridSize(GemmM, GemmN, I1 /*arg.KBatch*/);
gdy *= arg.num_group_ * num_workgroups_per_Conv_N;
gdy = arg.num_group_;
gdz = num_workgroups_per_Conv_N;
index_t K_split = (GemmK + KPerBlock - 1) / KPerBlock * KPerBlock;
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
@@ -820,8 +815,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
arg.b_grid_desc_bk0_n_bk1_,
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.compute_ptr_offset_of_groups_,
arg.compute_ptr_offset_of_n_,
arg.num_group_);
arg.compute_ptr_offset_of_n_);
}
else
{
@@ -836,8 +830,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
arg.b_grid_desc_bk0_n_bk1_,
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.compute_ptr_offset_of_groups_,
arg.compute_ptr_offset_of_n_,
arg.num_group_);
arg.compute_ptr_offset_of_n_);
}
};