[CK_TILE] fix example reduces, permute and elementwise on gfx11 & gfx12 (#2810)

1. Refine Reduce2dShape to support both wave32 and wave64
2. Fix example reduce, permute and elementwise on gfx11 and gfx12

---------

Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>

[ROCm/composable_kernel commit: 0b9a638f26]
This commit is contained in:
linqunAMD
2025-09-11 12:41:20 +08:00
committed by GitHub
parent 2eb6cbb6a8
commit eaf1fa7edb
11 changed files with 38 additions and 22 deletions

View File

@@ -343,7 +343,6 @@ float reduce_stage2(const GemmSplitKHostArgs& args, const ck_tile::stream_config
using WarpTile = ck_tile::sequence<32, 128>;
using ThreadTile = ck_tile::sequence<8, 8>;
constexpr ck_tile::index_t kBlockSize = 256;
constexpr ck_tile::index_t kBlockPerCu = 1;
ck_tile::index_t kGridSize = (output_size + BlockTile::at(ck_tile::number<0>{}) - 1) /
@@ -352,7 +351,8 @@ float reduce_stage2(const GemmSplitKHostArgs& args, const ck_tile::stream_config
using Shape = ck_tile::Reduce2dShape<BlockWarps, BlockTile, WarpTile, ThreadTile>;
using Problem =
ck_tile::Reduce2dProblem<CDataType, ComputeDataType, CDataType, Shape, ReduceOp>;
using Kernel = ck_tile::Reduce<Problem>;
using Kernel = ck_tile::Reduce<Problem>;
const ck_tile::index_t kBlockSize = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(reduce_dim_size, workspace_strides))
{
@@ -992,7 +992,11 @@ int main(int argc, char* argv[])
try
{
#if CK_TILE_USE_WMMA
return !run_gemm_example<GemmConfigComputeV3_WMMA>(arg_parser);
#else
return !run_gemm_example<GemmConfigComputeV3>(arg_parser);
#endif
}
catch(const std::runtime_error& e)
{

View File

@@ -88,7 +88,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
// using WarpTile = ck_tile::sequence<1, 512>;
// using Vector = ck_tile::sequence<1, 8>;
constexpr ck_tile::index_t kBlockSize = 256;
constexpr ck_tile::index_t kBlockPerCu = 1;
ck_tile::index_t kept_dim_len_prod = N * C;
ck_tile::index_t kGridSize = (kept_dim_len_prod + BlockTile::at(ck_tile::number<0>{}) - 1) /
@@ -99,8 +98,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
using Porblem =
ck_tile::Reduce2dProblem<XDataType, ComputeDataType, YDataType, Shape, ReduceOp>;
using Kernel = ck_tile::Reduce<Porblem>;
using Kernel = ck_tile::Reduce<Porblem>;
const ck_tile::index_t kBlockSize = Kernel::BlockSize();
// Create input tensor shape and strides
auto input_shape =
ck_tile::make_tuple(problem_shape[0], problem_shape[1], problem_shape[2], problem_shape[3]);

View File

@@ -88,10 +88,9 @@ struct matrix_core_swizzle_kernel
using karg = matrix_core_swizzle_host_args;
using harg = matrix_core_swizzle_host_args;
static constexpr int BLOCK_SIZE = BLOCK_SIZE_;
static constexpr int WavesPerBlock_N = 4;
static constexpr int WavesPerBlock_K = 1;
static_assert(WavesPerBlock_N * WavesPerBlock_K * 64 == BLOCK_SIZE);
static constexpr int BLOCK_SIZE = BLOCK_SIZE_;
static constexpr int WavesPerBlock_N = BLOCK_SIZE / ck_tile::get_warp_size();
static constexpr int WavesPerBlock_K = 1;
static constexpr int NPerBlock = NPerBlock_;
static constexpr int KPerBlock = KPerBlock_;
static constexpr matrix_core_permute_style pstyle = pstyle_;

View File

@@ -3,6 +3,7 @@
#include "permute.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/utility/json_dump.hpp"
#include <array>
#include <cstring>
@@ -128,6 +129,7 @@ auto create_args(int argc, char* argv[])
"non-deterministic seed")
.insert("warmup", "5", "number of iterations before benchmark the kernel")
.insert("repeat", "20", "number of iterations to benchmark the kernel")
.insert("json", "0", "0: No Json, 1: Dump Results in Json format")
.insert("jsonfile", "permute.json", "json file name to dump results");
bool result = arg_parser.parse(argc, argv);
@@ -257,6 +259,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
return permute(t, a, stream_config);
};
#if !CK_TILE_USE_WMMA
#ifdef PERMUTE_USE_ALTERNATIVE_IMPL
// batch* n0*n1*n2*k0*k1*k2 -> batch* n0*k0*n1*k1*n2*k2
if((arg_parser.get_str("perm") == std::string("0,1,4,2,5,3,6") ||
@@ -345,6 +348,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
}
}
else
#endif
#endif
{
ave_time = run_permute();

View File

@@ -137,8 +137,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
// This is often a multiple of the wavefront size, 64 on CDNA.
// Here, it's explicitly set to 512. This should be consistent with Shape::kBlockSize.
// Shape::kBlockSize would be BlockWarps * warpSize (e.g., 8 * 64 = 512).
constexpr ck_tile::index_t kBlockSize =
ck_tile::get_warp_size() * BlockWarps::at(ck_tile::number<0>{});
const ck_tile::index_t kBlockSize = Kernel::BlockSize();
// kBlockPerCu: Hint for how many workgroups can be scheduled per Compute Unit (CU).
// This can influence occupancy and performance.

View File

@@ -84,8 +84,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
for(auto d : problem_shape)
total_elements *= d;
constexpr ck_tile::index_t kBlockSize =
ck_tile::get_warp_size() * BlockWarps::at(ck_tile::number<0>{});
const ck_tile::index_t kBlockSize = Kernel::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = 2;

View File

@@ -89,8 +89,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::index_t total_elements = M * N;
constexpr ck_tile::index_t kBlockSize =
ck_tile::get_warp_size() * BlockWarps::at(ck_tile::number<0>{});
const ck_tile::index_t kBlockSize = Kernel::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = 1;
constexpr ck_tile::index_t elements_per_block = BlockTile::at(ck_tile::number<0>{});
ck_tile::index_t kGridSize = (total_elements + elements_per_block - 1) / elements_per_block;

View File

@@ -78,8 +78,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
for(auto d : shape)
total_elements *= d;
constexpr ck_tile::index_t kBlockSize =
ck_tile::get_warp_size() * BlockWarps::at(ck_tile::number<0>{});
const ck_tile::index_t kBlockSize = Kernel::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = 1;
constexpr ck_tile::index_t elements_per_block = BlockTile::at(ck_tile::number<0>{});

View File

@@ -21,7 +21,10 @@ struct ElementWiseKernel
using ElementWiseOperation = ck_tile::remove_cvref_t<typename Problem::ElementWiseOperation>;
static constexpr index_t kBlockSize = Problem::BlockShape::kBlockSize;
CK_TILE_HOST static constexpr auto BlockSize()
{
return is_wave32() ? kBlockSize / 2 : kBlockSize;
}
template <typename... XDataType, typename Dims>
CK_TILE_DEVICE void operator()(const Dims lens,
const Dims input_strides,

View File

@@ -26,6 +26,10 @@ struct Reduce
using YDataType = ck_tile::remove_cvref_t<typename Problem::YDataType>;
static constexpr index_t kBlockSize = Problem::BlockShape::BlockSize;
CK_TILE_HOST static constexpr auto BlockSize()
{
return is_wave32() ? kBlockSize / 2 : kBlockSize;
}
private:
// Helper function to calculate optimal vector size for input tensor

View File

@@ -25,11 +25,18 @@ struct Reduce2dShape
static constexpr index_t WarpPerBlock_M = BlockWarps::at(number<0>{});
static constexpr index_t WarpPerBlock_N = BlockWarps::at(number<1>{});
static constexpr index_t ThreadPerWarp_M = Warp_M / ThreadTile_M;
static constexpr index_t ThreadPerWarp_N = Warp_N / ThreadTile_N;
static constexpr index_t RepeatInWarp =
Warp_M * Warp_N / ThreadTile_M / ThreadTile_N / ck_tile::get_warp_size();
static constexpr index_t RepeatInWarp_M =
(Warp_M / ThreadTile_M > Warp_N / ThreadTile_N) ? RepeatInWarp : 1;
static constexpr index_t RepeatInWarp_N =
(Warp_M / ThreadTile_M > Warp_N / ThreadTile_N) ? 1 : RepeatInWarp;
static constexpr index_t Repeat_M = Block_M / (WarpPerBlock_M * Warp_M);
static constexpr index_t Repeat_N = Block_N / (WarpPerBlock_N * Warp_N);
static constexpr index_t ThreadPerWarp_M = Warp_M / ThreadTile_M / RepeatInWarp_M;
static constexpr index_t ThreadPerWarp_N = Warp_N / ThreadTile_N / RepeatInWarp_N;
static constexpr index_t Repeat_M = Block_M * RepeatInWarp_M / (WarpPerBlock_M * Warp_M);
static constexpr index_t Repeat_N = Block_N * RepeatInWarp_N / (WarpPerBlock_N * Warp_N);
static constexpr index_t BlockSize =
ck_tile::get_warp_size() * reduce_on_sequence(BlockWarps{}, multiplies{}, number<1>{});