mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 18:17:44 +00:00
[CK][CK Tile] Fix dram step for KM/KN layouts in V1 pipeline (#5470)
## Motivation Fix v1 pipeline for KM/KN layouts by passing correct step for dram tile window. ## Technical Details - Fix dram step for KM/KN layouts in V1 pipeline - Disable instances which use more threads than warp size in continous dim (not supported in ck tile yet) - Use 1x1 specialization for explicit gemm - Use two stage for vectorsize =1 and sizeof(datatype) ==2 - remove not needed check sinze GetVectorSizeA/B check if vector size is fixed ## Test Plan test_grouped_convnd_bwd_weight_tile ## Test Result passed locally ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. AICK-966
This commit is contained in:
@@ -317,7 +317,7 @@ def parse_bwd_weight_instances(instances, problem_name):
|
||||
gemm_params = device_op_name = instance.split("<")[2].split(">")[1].split(",")
|
||||
args = [param.split(":")[1].strip() for param in gemm_params]
|
||||
|
||||
spec = "Default"
|
||||
spec = "Filter1x1Stride1Pad0"
|
||||
block_size = int(args[0])
|
||||
|
||||
mnk_per_block = args[1].split("x")
|
||||
@@ -450,6 +450,13 @@ def parse_bwd_weight_instances(instances, problem_name):
|
||||
if pipeline_version == "V6":
|
||||
print(f"Skipping instance {instance_id} with V6 since it's not supported yet.")
|
||||
continue
|
||||
if m_per_block > (warp_size * a_scalar_per_vector) or n_per_block > (warp_size * b_scalar_per_vector):
|
||||
print(f"Skipping instance {instance_id} with multiple warps per continous tile dim since it's not supported yet.")
|
||||
continue
|
||||
|
||||
if is_explicit_gemm:
|
||||
if dtype != "float" and c_scalar_per_vector % 2 != 0:
|
||||
is_two_stage_instance = True
|
||||
|
||||
conv = ConvInstanceTemplateParams(
|
||||
spec,
|
||||
|
||||
@@ -449,6 +449,14 @@ struct GemmPipelineAGmemBGmemCRegV1 : public BaseGemmPipelineAGmemBGmemCRegV1<Pr
|
||||
auto&& [bs_copy_dram_window, b_copy_lds_window, b_lds_gemm_window] =
|
||||
Base::GetBWindows(b_dram_block_window_tmp, b_lds_block, b_lds_load_tile_distr);
|
||||
|
||||
using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex;
|
||||
using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex;
|
||||
|
||||
constexpr ADramTileWindowStep a_dram_tile_window_step =
|
||||
is_a_col_major ? make_array(kKPerBlock, 0) : make_array(0, kKPerBlock);
|
||||
constexpr BDramTileWindowStep b_dram_tile_window_step =
|
||||
is_b_row_major ? make_array(kKPerBlock, 0) : make_array(0, kKPerBlock);
|
||||
|
||||
// Block GEMM
|
||||
auto block_gemm = BlockGemm();
|
||||
|
||||
@@ -471,10 +479,10 @@ struct GemmPipelineAGmemBGmemCRegV1 : public BaseGemmPipelineAGmemBGmemCRegV1<Pr
|
||||
// move to 1
|
||||
// Move each A — the enhanced function move_tile_window is executed, which takes a
|
||||
// tuple as input.
|
||||
move_tile_window(as_copy_dram_window, {0, kKPerBlock});
|
||||
move_tile_window(as_copy_dram_window, a_dram_tile_window_step);
|
||||
// Move each B — the enhanced function move_tile_window is executed, which takes a
|
||||
// tuple as input.
|
||||
move_tile_window(bs_copy_dram_window, {0, kKPerBlock});
|
||||
move_tile_window(bs_copy_dram_window, b_dram_tile_window_step);
|
||||
|
||||
// initialize C
|
||||
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
|
||||
@@ -524,8 +532,8 @@ struct GemmPipelineAGmemBGmemCRegV1 : public BaseGemmPipelineAGmemBGmemCRegV1<Pr
|
||||
is_b_load_tr_v);
|
||||
|
||||
// move to i + 2
|
||||
move_tile_window(as_copy_dram_window, {0, kKPerBlock});
|
||||
move_tile_window(bs_copy_dram_window, {0, kKPerBlock});
|
||||
move_tile_window(as_copy_dram_window, a_dram_tile_window_step);
|
||||
move_tile_window(bs_copy_dram_window, b_dram_tile_window_step);
|
||||
|
||||
// LDS write i + 1
|
||||
if constexpr(is_a_col_major && !is_a_load_tr_v())
|
||||
|
||||
@@ -711,11 +711,10 @@ struct UniversalGemmBasePolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution()
|
||||
{
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t VecLoadSize =
|
||||
Problem::FixedVectorSize ? Problem::VectorSizeA : GetVectorSizeA<Problem>();
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t VecLoadSize = GetVectorSizeA<Problem>();
|
||||
constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
|
||||
|
||||
using ALayout = remove_cvref_t<
|
||||
@@ -756,9 +755,7 @@ struct UniversalGemmBasePolicy
|
||||
// since the assumption is that A type is going to be the B LDS type
|
||||
constexpr bool IsBCastPolicyBeforeLDSWrite = IsBCastPolicyBeforeLDSWrite_v<Problem>;
|
||||
constexpr index_t VecLoadSize =
|
||||
IsBCastPolicyBeforeLDSWrite
|
||||
? (Problem::FixedVectorSize ? Problem::VectorSizeA : GetVectorSizeA<Problem>())
|
||||
: (Problem::FixedVectorSize ? Problem::VectorSizeB : GetVectorSizeB<Problem>());
|
||||
IsBCastPolicyBeforeLDSWrite ? GetVectorSizeA<Problem>() : GetVectorSizeB<Problem>();
|
||||
constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
|
||||
using BLayout = remove_cvref_t<
|
||||
std::tuple_element_t<number<0>{}, remove_cvref_t<typename Problem::BsLayoutTuple>>>;
|
||||
|
||||
Reference in New Issue
Block a user