[CK_TILE] Refine Generic2dBlockShape to fix ck_tile example 2,10,11,14 on rdna3 and 4 (#2795)

BlockWarps, WarpTile in Generic2dBlockShape are wave size dependent, it causes mangled name mismatch between host and device side.

Solution: Replace them with ThreadPerBlock and move BlockWarps, WarpTile calculation into Generic2dBlockShape
This commit is contained in:
linqunAMD
2025-09-10 08:29:20 +08:00
committed by GitHub
parent df4ee556d6
commit c254f3d7b4
14 changed files with 103 additions and 453 deletions

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -35,43 +35,69 @@ namespace ck_tile {
+-----------+-----------+-----------+-----------+-----------+
// clang-format on
*/
template <typename BlockTile_, // block size, seq<M, N>
typename WarpPerBlock_, // num warps along seq<M, N>
typename WarpTile_, // warp size, seq<M, N>
typename Vector_> // contiguous pixels(vector size) along seq<M, N>)>
template <typename BlockTile_, // block size, seq<M, N>
typename ThreadPerBlock_, // num threads along seq<M, N>
typename Vector_> // contiguous pixels(vector size) along seq<M, N>)>
struct Generic2dBlockShape
{
// block size
static constexpr index_t Block_M = BlockTile_::at(number<0>{});
static constexpr index_t Block_N = BlockTile_::at(number<1>{});
// num warps along seq<M, N>, within each block
static constexpr index_t WarpPerBlock_M = WarpPerBlock_::at(number<0>{});
static constexpr index_t WarpPerBlock_N = WarpPerBlock_::at(number<1>{});
// warp size
static constexpr index_t Warp_M = WarpTile_::at(number<0>{});
static constexpr index_t Warp_N = WarpTile_::at(number<1>{});
static_assert(Block_M % (WarpPerBlock_M * Warp_M) == 0);
static_assert(Block_N % (WarpPerBlock_N * Warp_N) == 0);
// repeat of each thread along seq<M, N>
static constexpr index_t Repeat_M = Block_M / (WarpPerBlock_M * Warp_M);
static constexpr index_t Repeat_N = Block_N / (WarpPerBlock_N * Warp_N);
static constexpr index_t Block_M = BlockTile_::at(number<0>{});
static constexpr index_t Block_N = BlockTile_::at(number<1>{});
static constexpr index_t ThreadPerBlock_M = ThreadPerBlock_::at(number<0>{});
static constexpr index_t ThreadPerBlock_N = ThreadPerBlock_::at(number<1>{});
static constexpr index_t BlockSize = ThreadPerBlock_M * ThreadPerBlock_N;
// vector size along seq<M, N>
static constexpr index_t Vector_M = Vector_::at(number<0>{});
static constexpr index_t Vector_N = Vector_::at(number<1>{});
static constexpr bool is_warp_per_row = ThreadPerBlock_N <= get_warp_size();
static_assert((ThreadPerBlock_M * ThreadPerBlock_N) % get_warp_size() == 0);
static constexpr index_t total_warps = (ThreadPerBlock_M * ThreadPerBlock_N) / get_warp_size();
// num warps along seq<M, N>, within each block
static constexpr index_t WarpPerBlock_M = []() {
if constexpr(is_warp_per_row)
{
static_assert(get_warp_size() % ThreadPerBlock_N == 0);
return total_warps * (get_warp_size() / ThreadPerBlock_N);
}
else
{
// static_assert(ck_tile::get_warp_size() % ThreadPerBlock_M_ == 0);
return total_warps / (ThreadPerBlock_N / get_warp_size());
}
}();
// num of warps along n
static constexpr index_t WarpPerBlock_N = []() {
if constexpr(is_warp_per_row)
{
static_assert(get_warp_size() % ThreadPerBlock_N == 0);
return 1;
}
else
{
static_assert(ThreadPerBlock_N % get_warp_size() == 0);
return ThreadPerBlock_N / get_warp_size();
}
}();
// warp size
static constexpr index_t Warp_M = ThreadPerBlock_M / WarpPerBlock_M * Vector_M;
static constexpr index_t Warp_N = ThreadPerBlock_N / WarpPerBlock_N * Vector_N;
static_assert(Warp_M % Vector_M == 0);
static_assert(Warp_N % Vector_N == 0);
// num of threads along seq<M, N>, within each warp
static constexpr index_t ThreadPerWarp_M = Warp_M / Vector_M;
static constexpr index_t ThreadPerWarp_N = Warp_N / Vector_N;
static constexpr index_t ThreadPerBlock_M = Block_M / Repeat_M / Vector_M;
static constexpr index_t ThreadPerBlock_N = Block_N / Repeat_N / Vector_N;
static_assert(Block_M % (WarpPerBlock_M * Warp_M) == 0);
static_assert(Block_N % (WarpPerBlock_N * Warp_N) == 0);
static constexpr index_t BlockSize = ThreadPerBlock_M * ThreadPerBlock_N;
// repeat of each thread along seq<M, N>
static constexpr index_t Repeat_M = Block_M / (WarpPerBlock_M * Warp_M);
static constexpr index_t Repeat_N = Block_N / (WarpPerBlock_N * Warp_N);
// num of threads along seq<M, N>, within each warp
static constexpr index_t ThreadPerWarp_M = Warp_M / Vector_M;
static constexpr index_t ThreadPerWarp_N = Warp_N / Vector_N;
};
} // namespace ck_tile