mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-18 03:49:41 +00:00
[CK_TILE] fix example reduces, permute and elementwise on gfx11 & gfx12 (#2810)
1. Refine Reduce2dShape to support both wave32 and wave64
2. Fix example reduce, permute and elementwise on gfx11 and gfx12
---------
Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
[ROCm/composable_kernel commit: 0b9a638f26]
This commit is contained in:
@@ -88,10 +88,9 @@ struct matrix_core_swizzle_kernel
|
||||
using karg = matrix_core_swizzle_host_args;
|
||||
using harg = matrix_core_swizzle_host_args;
|
||||
|
||||
static constexpr int BLOCK_SIZE = BLOCK_SIZE_;
|
||||
static constexpr int WavesPerBlock_N = 4;
|
||||
static constexpr int WavesPerBlock_K = 1;
|
||||
static_assert(WavesPerBlock_N * WavesPerBlock_K * 64 == BLOCK_SIZE);
|
||||
static constexpr int BLOCK_SIZE = BLOCK_SIZE_;
|
||||
static constexpr int WavesPerBlock_N = BLOCK_SIZE / ck_tile::get_warp_size();
|
||||
static constexpr int WavesPerBlock_K = 1;
|
||||
static constexpr int NPerBlock = NPerBlock_;
|
||||
static constexpr int KPerBlock = KPerBlock_;
|
||||
static constexpr matrix_core_permute_style pstyle = pstyle_;
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
#include "permute.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/utility/json_dump.hpp"
|
||||
|
||||
#include <array>
|
||||
#include <cstring>
|
||||
@@ -128,6 +129,7 @@ auto create_args(int argc, char* argv[])
|
||||
"non-deterministic seed")
|
||||
.insert("warmup", "5", "number of iterations before benchmark the kernel")
|
||||
.insert("repeat", "20", "number of iterations to benchmark the kernel")
|
||||
.insert("json", "0", "0: No Json, 1: Dump Results in Json format")
|
||||
.insert("jsonfile", "permute.json", "json file name to dump results");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
@@ -257,6 +259,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
return permute(t, a, stream_config);
|
||||
};
|
||||
#if !CK_TILE_USE_WMMA
|
||||
#ifdef PERMUTE_USE_ALTERNATIVE_IMPL
|
||||
// batch* n0*n1*n2*k0*k1*k2 -> batch* n0*k0*n1*k1*n2*k2
|
||||
if((arg_parser.get_str("perm") == std::string("0,1,4,2,5,3,6") ||
|
||||
@@ -345,6 +348,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
}
|
||||
}
|
||||
else
|
||||
#endif
|
||||
#endif
|
||||
{
|
||||
ave_time = run_permute();
|
||||
|
||||
Reference in New Issue
Block a user