[CK][CK Tile] Fix dram step for KM/KN layouts in V1 pipeline (#5470)

## Motivation

Fix v1 pipeline for KM/KN layouts by passing correct step for dram tile
window.

## Technical Details

- Fix dram step for KM/KN layouts in V1 pipeline
- Disable instances which use more threads than warp size in continous
dim (not supported in ck tile yet)
- Use 1x1 specialization for explicit gemm
- Use two stage for vectorsize =1 and sizeof(datatype) ==2
- remove not needed check sinze GetVectorSizeA/B check if vector size is
fixed

## Test Plan

test_grouped_convnd_bwd_weight_tile

## Test Result

passed locally

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.

AICK-966
This commit is contained in:
Bartłomiej Kocot
2026-03-19 12:59:44 +01:00
committed by GitHub
parent 7a8410498d
commit b90e64e600
3 changed files with 25 additions and 13 deletions

View File

@@ -317,7 +317,7 @@ def parse_bwd_weight_instances(instances, problem_name):
gemm_params = device_op_name = instance.split("<")[2].split(">")[1].split(",")
args = [param.split(":")[1].strip() for param in gemm_params]
spec = "Default"
spec = "Filter1x1Stride1Pad0"
block_size = int(args[0])
mnk_per_block = args[1].split("x")
@@ -450,6 +450,13 @@ def parse_bwd_weight_instances(instances, problem_name):
if pipeline_version == "V6":
print(f"Skipping instance {instance_id} with V6 since it's not supported yet.")
continue
if m_per_block > (warp_size * a_scalar_per_vector) or n_per_block > (warp_size * b_scalar_per_vector):
print(f"Skipping instance {instance_id} with multiple warps per continous tile dim since it's not supported yet.")
continue
if is_explicit_gemm:
if dtype != "float" and c_scalar_per_vector % 2 != 0:
is_two_stage_instance = True
conv = ConvInstanceTemplateParams(
spec,

View File

@@ -449,6 +449,14 @@ struct GemmPipelineAGmemBGmemCRegV1 : public BaseGemmPipelineAGmemBGmemCRegV1<Pr
auto&& [bs_copy_dram_window, b_copy_lds_window, b_lds_gemm_window] =
Base::GetBWindows(b_dram_block_window_tmp, b_lds_block, b_lds_load_tile_distr);
using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex;
using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex;
constexpr ADramTileWindowStep a_dram_tile_window_step =
is_a_col_major ? make_array(kKPerBlock, 0) : make_array(0, kKPerBlock);
constexpr BDramTileWindowStep b_dram_tile_window_step =
is_b_row_major ? make_array(kKPerBlock, 0) : make_array(0, kKPerBlock);
// Block GEMM
auto block_gemm = BlockGemm();
@@ -471,10 +479,10 @@ struct GemmPipelineAGmemBGmemCRegV1 : public BaseGemmPipelineAGmemBGmemCRegV1<Pr
// move to 1
// Move each A — the enhanced function move_tile_window is executed, which takes a
// tuple as input.
move_tile_window(as_copy_dram_window, {0, kKPerBlock});
move_tile_window(as_copy_dram_window, a_dram_tile_window_step);
// Move each B — the enhanced function move_tile_window is executed, which takes a
// tuple as input.
move_tile_window(bs_copy_dram_window, {0, kKPerBlock});
move_tile_window(bs_copy_dram_window, b_dram_tile_window_step);
// initialize C
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
@@ -524,8 +532,8 @@ struct GemmPipelineAGmemBGmemCRegV1 : public BaseGemmPipelineAGmemBGmemCRegV1<Pr
is_b_load_tr_v);
// move to i + 2
move_tile_window(as_copy_dram_window, {0, kKPerBlock});
move_tile_window(bs_copy_dram_window, {0, kKPerBlock});
move_tile_window(as_copy_dram_window, a_dram_tile_window_step);
move_tile_window(bs_copy_dram_window, b_dram_tile_window_step);
// LDS write i + 1
if constexpr(is_a_col_major && !is_a_load_tr_v())

View File

@@ -711,11 +711,10 @@ struct UniversalGemmBasePolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution()
{
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t VecLoadSize =
Problem::FixedVectorSize ? Problem::VectorSizeA : GetVectorSizeA<Problem>();
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t VecLoadSize = GetVectorSizeA<Problem>();
constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
using ALayout = remove_cvref_t<
@@ -756,9 +755,7 @@ struct UniversalGemmBasePolicy
// since the assumption is that A type is going to be the B LDS type
constexpr bool IsBCastPolicyBeforeLDSWrite = IsBCastPolicyBeforeLDSWrite_v<Problem>;
constexpr index_t VecLoadSize =
IsBCastPolicyBeforeLDSWrite
? (Problem::FixedVectorSize ? Problem::VectorSizeA : GetVectorSizeA<Problem>())
: (Problem::FixedVectorSize ? Problem::VectorSizeB : GetVectorSizeB<Problem>());
IsBCastPolicyBeforeLDSWrite ? GetVectorSizeA<Problem>() : GetVectorSizeB<Problem>();
constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
using BLayout = remove_cvref_t<
std::tuple_element_t<number<0>{}, remove_cvref_t<typename Problem::BsLayoutTuple>>>;