mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-17 03:19:48 +00:00
[CK][CK Tile] Fix dram step for KM/KN layouts in V1 pipeline (#5470)
## Motivation Fix v1 pipeline for KM/KN layouts by passing correct step for dram tile window. ## Technical Details - Fix dram step for KM/KN layouts in V1 pipeline - Disable instances which use more threads than warp size in continous dim (not supported in ck tile yet) - Use 1x1 specialization for explicit gemm - Use two stage for vectorsize =1 and sizeof(datatype) ==2 - remove not needed check sinze GetVectorSizeA/B check if vector size is fixed ## Test Plan test_grouped_convnd_bwd_weight_tile ## Test Result passed locally ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. AICK-966
This commit is contained in:
@@ -449,6 +449,14 @@ struct GemmPipelineAGmemBGmemCRegV1 : public BaseGemmPipelineAGmemBGmemCRegV1<Pr
|
||||
auto&& [bs_copy_dram_window, b_copy_lds_window, b_lds_gemm_window] =
|
||||
Base::GetBWindows(b_dram_block_window_tmp, b_lds_block, b_lds_load_tile_distr);
|
||||
|
||||
using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex;
|
||||
using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex;
|
||||
|
||||
constexpr ADramTileWindowStep a_dram_tile_window_step =
|
||||
is_a_col_major ? make_array(kKPerBlock, 0) : make_array(0, kKPerBlock);
|
||||
constexpr BDramTileWindowStep b_dram_tile_window_step =
|
||||
is_b_row_major ? make_array(kKPerBlock, 0) : make_array(0, kKPerBlock);
|
||||
|
||||
// Block GEMM
|
||||
auto block_gemm = BlockGemm();
|
||||
|
||||
@@ -471,10 +479,10 @@ struct GemmPipelineAGmemBGmemCRegV1 : public BaseGemmPipelineAGmemBGmemCRegV1<Pr
|
||||
// move to 1
|
||||
// Move each A — the enhanced function move_tile_window is executed, which takes a
|
||||
// tuple as input.
|
||||
move_tile_window(as_copy_dram_window, {0, kKPerBlock});
|
||||
move_tile_window(as_copy_dram_window, a_dram_tile_window_step);
|
||||
// Move each B — the enhanced function move_tile_window is executed, which takes a
|
||||
// tuple as input.
|
||||
move_tile_window(bs_copy_dram_window, {0, kKPerBlock});
|
||||
move_tile_window(bs_copy_dram_window, b_dram_tile_window_step);
|
||||
|
||||
// initialize C
|
||||
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
|
||||
@@ -524,8 +532,8 @@ struct GemmPipelineAGmemBGmemCRegV1 : public BaseGemmPipelineAGmemBGmemCRegV1<Pr
|
||||
is_b_load_tr_v);
|
||||
|
||||
// move to i + 2
|
||||
move_tile_window(as_copy_dram_window, {0, kKPerBlock});
|
||||
move_tile_window(bs_copy_dram_window, {0, kKPerBlock});
|
||||
move_tile_window(as_copy_dram_window, a_dram_tile_window_step);
|
||||
move_tile_window(bs_copy_dram_window, b_dram_tile_window_step);
|
||||
|
||||
// LDS write i + 1
|
||||
if constexpr(is_a_col_major && !is_a_load_tr_v())
|
||||
|
||||
@@ -711,11 +711,10 @@ struct UniversalGemmBasePolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution()
|
||||
{
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t VecLoadSize =
|
||||
Problem::FixedVectorSize ? Problem::VectorSizeA : GetVectorSizeA<Problem>();
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t VecLoadSize = GetVectorSizeA<Problem>();
|
||||
constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
|
||||
|
||||
using ALayout = remove_cvref_t<
|
||||
@@ -756,9 +755,7 @@ struct UniversalGemmBasePolicy
|
||||
// since the assumption is that A type is going to be the B LDS type
|
||||
constexpr bool IsBCastPolicyBeforeLDSWrite = IsBCastPolicyBeforeLDSWrite_v<Problem>;
|
||||
constexpr index_t VecLoadSize =
|
||||
IsBCastPolicyBeforeLDSWrite
|
||||
? (Problem::FixedVectorSize ? Problem::VectorSizeA : GetVectorSizeA<Problem>())
|
||||
: (Problem::FixedVectorSize ? Problem::VectorSizeB : GetVectorSizeB<Problem>());
|
||||
IsBCastPolicyBeforeLDSWrite ? GetVectorSizeA<Problem>() : GetVectorSizeB<Problem>();
|
||||
constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
|
||||
using BLayout = remove_cvref_t<
|
||||
std::tuple_element_t<number<0>{}, remove_cvref_t<typename Problem::BsLayoutTuple>>>;
|
||||
|
||||
Reference in New Issue
Block a user