mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 17:26:00 +00:00
Add benchmark example
This commit is contained in:
@@ -677,21 +677,15 @@ class TestCkTileMHC : public ::testing::Test
|
||||
d_phi_mem.ToDevice(h_phi.data());
|
||||
d_output_mem.ToDevice(h_output.data());
|
||||
|
||||
// Define block shape - must match BlockGemmShape thread count (2 warps × 64 = 128 threads)
|
||||
// Define block shape - 128 threads (2 warps) to match BlockGemmShape configuration
|
||||
using BlockShape = ck_tile::Generic2dBlockShape<ck_tile::sequence<1, 128>,
|
||||
ck_tile::sequence<1, 128>,
|
||||
ck_tile::sequence<1, 1>>;
|
||||
|
||||
using Problem = ck_tile::MHCProblem<XDataType, ComputeDataType, YDataType, BlockShape>;
|
||||
|
||||
// V3 kernel with 2D tiling
|
||||
constexpr ck_tile::index_t kMTile = 64; // Batch tile
|
||||
constexpr ck_tile::index_t kNTile = 32; // Output tile (exactly covers 24 outputs for n=4)
|
||||
constexpr ck_tile::index_t kKTile =
|
||||
8; // K tile for C dimension (must match BlockGemmShape::kK)
|
||||
|
||||
using KernelV3 = ck_tile::
|
||||
MHCKernelV3<Problem, ck_tile::MHCDefaultPolicy, kMTile, kNTile, kKTile, ActivationFunc>;
|
||||
// V3 kernel - tile sizes automatically derived from Problem::BlockGemmShape
|
||||
using KernelV3 = ck_tile::MHCKernelV3<Problem, ck_tile::MHCDefaultPolicy, ActivationFunc>;
|
||||
|
||||
const ck_tile::index_t kBlockSize = KernelV3::BlockSize();
|
||||
|
||||
|
||||
Reference in New Issue
Block a user