mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 05:01:25 +00:00
Grouped conv fwd v3 fix for SplitN an G > 1 (#2038)
* Grouped conv fwd v3 fix for SplitN an G > 1 * Remove int8 large test * Retore int8 test
This commit is contained in:
@@ -79,15 +79,12 @@ __global__ void
|
||||
[[maybe_unused]] const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
[[maybe_unused]] const ComputePtrOffset compute_ptr_offset_of_groups,
|
||||
[[maybe_unused]] const ComputePtrOffset compute_ptr_offset_of_n,
|
||||
[[maybe_unused]] const index_t groups_count)
|
||||
[[maybe_unused]] const ComputePtrOffset compute_ptr_offset_of_n)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
|
||||
// offset base pointer for each work-group
|
||||
const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(gridDim.y / groups_count);
|
||||
const index_t& num_blocks_per_n = groups_count;
|
||||
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y / num_blocks_per_batch);
|
||||
const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.y / num_blocks_per_n);
|
||||
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y);
|
||||
const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z);
|
||||
|
||||
const long_index_t a_batch_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetAPtrOffset(g_idx));
|
||||
@@ -141,15 +138,12 @@ __global__ void
|
||||
[[maybe_unused]] const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
[[maybe_unused]] const ComputePtrOffset compute_ptr_offset_of_groups,
|
||||
[[maybe_unused]] const ComputePtrOffset compute_ptr_offset_of_n,
|
||||
[[maybe_unused]] const index_t groups_count)
|
||||
[[maybe_unused]] const ComputePtrOffset compute_ptr_offset_of_n)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
|
||||
// offset base pointer for each work-group
|
||||
const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(gridDim.y / groups_count);
|
||||
const index_t& num_blocks_per_n = groups_count;
|
||||
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y / num_blocks_per_batch);
|
||||
const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.y / num_blocks_per_n);
|
||||
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y);
|
||||
const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z);
|
||||
|
||||
const long_index_t a_batch_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetAPtrOffset(g_idx));
|
||||
@@ -766,7 +760,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
std::tie(gdx, gdy, gdz) =
|
||||
GridwiseGemm::CalculateGridSize(GemmM, GemmN, I1 /*arg.KBatch*/);
|
||||
|
||||
gdy *= arg.num_group_ * num_workgroups_per_Conv_N;
|
||||
gdy = arg.num_group_;
|
||||
gdz = num_workgroups_per_Conv_N;
|
||||
|
||||
index_t K_split = (GemmK + KPerBlock - 1) / KPerBlock * KPerBlock;
|
||||
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
|
||||
@@ -820,8 +815,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.compute_ptr_offset_of_groups_,
|
||||
arg.compute_ptr_offset_of_n_,
|
||||
arg.num_group_);
|
||||
arg.compute_ptr_offset_of_n_);
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -836,8 +830,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.compute_ptr_offset_of_groups_,
|
||||
arg.compute_ptr_offset_of_n_,
|
||||
arg.num_group_);
|
||||
arg.compute_ptr_offset_of_n_);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user