MHC V3 with gemm pipeline

This commit is contained in:
Damien Lejeune
2026-02-05 17:11:09 +00:00
parent 43a5678fdf
commit 053aed9402
5 changed files with 83 additions and 78 deletions

View File

@@ -48,9 +48,9 @@ int main()
d_phi_mem.ToDevice(h_phi.data());
d_output_mem.ToDevice(h_output.data());
// Define block shape
using BlockShape = ck_tile::Generic2dBlockShape<ck_tile::sequence<1, 256>,
ck_tile::sequence<1, 256>,
// Define block shape - must match BlockGemmShape thread count (2 warps × 64 = 128 threads)
using BlockShape = ck_tile::Generic2dBlockShape<ck_tile::sequence<1, 128>,
ck_tile::sequence<1, 128>,
ck_tile::sequence<1, 1>>;
using Problem = ck_tile::MHCProblem<float, float, float, BlockShape>;

View File

@@ -49,9 +49,9 @@ int main()
d_phi_mem.ToDevice(h_phi.data());
d_output_mem.ToDevice(h_output.data());
// Define block shape
using BlockShape = ck_tile::Generic2dBlockShape<ck_tile::sequence<1, 256>,
ck_tile::sequence<1, 256>,
// Define block shape - must match BlockGemmShape thread count (2 warps × 64 = 128 threads)
using BlockShape = ck_tile::Generic2dBlockShape<ck_tile::sequence<1, 128>,
ck_tile::sequence<1, 128>,
ck_tile::sequence<1, 1>>;
using Problem = ck_tile::MHCProblem<float, float, float, BlockShape>;

View File

@@ -49,9 +49,9 @@ int main()
d_phi_mem.ToDevice(h_phi.data());
d_output_mem.ToDevice(h_output.data());
// Define block shape
using BlockShape = ck_tile::Generic2dBlockShape<ck_tile::sequence<1, 256>,
ck_tile::sequence<1, 256>,
// Define block shape - must match BlockGemmShape thread count (2 warps × 64 = 128 threads)
using BlockShape = ck_tile::Generic2dBlockShape<ck_tile::sequence<1, 128>,
ck_tile::sequence<1, 128>,
ck_tile::sequence<1, 1>>;
using Problem = ck_tile::MHCProblem<float, float, float, BlockShape>;