use persistent

This commit is contained in:
Sami Remes
2026-02-06 18:21:34 +00:00
parent dc4366a876
commit 1622674c9e
2 changed files with 5 additions and 5 deletions

View File

@@ -81,7 +81,7 @@ struct MXfp4_GemmConfig16 : MxGemmConfig
// GEMM config with 16x16 warp tile
struct MXfp8_GemmConfig16 : MxGemmConfig
{
static constexpr ck_tile::index_t M_Tile = 32;
static constexpr ck_tile::index_t M_Tile = 64;
static constexpr ck_tile::index_t N_Tile = 64;
static constexpr ck_tile::index_t K_Tile = 256;
};

View File

@@ -50,8 +50,8 @@ int run_mx_gemm_with_layouts(int argc,
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
// Scale tensors - follow parent matrix layouts for optimal memory access
// A scales: [M, K/32] with A's layout → coalescing follows A's pattern
// B scales: [K/32, N] with B's layout → coalescing follows B's pattern
// A scales: [M, K/32] with A's layout
// B scales: [K/32, N] with B's layout
using ScaleType = ck_tile::e8m0_t;
ck_tile::index_t scale_k_size = K / 32;
@@ -189,7 +189,7 @@ int run_mx_gemm_example(int argc, char* argv[])
ck_tile::pk_fp4_t,
float,
MXfp4_GemmConfig16,
false>(argc, argv, Row{}, Col{}, Row{});
true>(argc, argv, Row{}, Col{}, Row{});
}
else if(mx_prec == "fp8" || mx_prec == "fp8xfp8")
{
@@ -197,7 +197,7 @@ int run_mx_gemm_example(int argc, char* argv[])
ck_tile::fp8_t,
float,
MXfp8_GemmConfig16,
false>(argc, argv, Row{}, Col{}, Row{});
true>(argc, argv, Row{}, Col{}, Row{});
}
else
{