Add benchmark example

This commit is contained in:
Damien Lejeune
2026-02-06 14:55:13 +00:00
parent 804a9d488c
commit ec1e8ec58e
6 changed files with 231 additions and 28 deletions

View File

@@ -677,21 +677,15 @@ class TestCkTileMHC : public ::testing::Test
d_phi_mem.ToDevice(h_phi.data());
d_output_mem.ToDevice(h_output.data());
// Define block shape - must match BlockGemmShape thread count (2 warps × 64 = 128 threads)
// Define block shape - 128 threads (2 warps) to match BlockGemmShape configuration
using BlockShape = ck_tile::Generic2dBlockShape<ck_tile::sequence<1, 128>,
ck_tile::sequence<1, 128>,
ck_tile::sequence<1, 1>>;
using Problem = ck_tile::MHCProblem<XDataType, ComputeDataType, YDataType, BlockShape>;
// V3 kernel with 2D tiling
constexpr ck_tile::index_t kMTile = 64; // Batch tile
constexpr ck_tile::index_t kNTile = 32; // Output tile (exactly covers 24 outputs for n=4)
constexpr ck_tile::index_t kKTile =
8; // K tile for C dimension (must match BlockGemmShape::kK)
using KernelV3 = ck_tile::
MHCKernelV3<Problem, ck_tile::MHCDefaultPolicy, kMTile, kNTile, kKTile, ActivationFunc>;
// V3 kernel - tile sizes automatically derived from Problem::BlockGemmShape
using KernelV3 = ck_tile::MHCKernelV3<Problem, ck_tile::MHCDefaultPolicy, ActivationFunc>;
const ck_tile::index_t kBlockSize = KernelV3::BlockSize();