enable fp8 mx gemm too

This commit is contained in:
Sami Remes
2026-01-30 12:43:49 -05:00
parent 771c46aa8b
commit b8cdea5979
2 changed files with 13 additions and 13 deletions

View File

@@ -81,7 +81,7 @@ struct MXfp4_GemmConfig16 : MxGemmConfig
// GEMM config with 16x16 warp tile
struct MXfp8_GemmConfig16 : MxGemmConfig
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 256;
static constexpr ck_tile::index_t M_Tile = 32;
static constexpr ck_tile::index_t N_Tile = 64;
static constexpr ck_tile::index_t K_Tile = 512;
};

View File

@@ -60,8 +60,8 @@ int run_mx_gemm_with_layouts(int argc,
case 0:
ck_tile::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_host);
ck_tile::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_host);
ck_tile::FillUniformDistribution<ScaleType>{-1.f, 1.f}(scale_a_host);
ck_tile::FillUniformDistribution<ScaleType>{-1.f, 1.f}(scale_b_host);
ck_tile::FillUniformDistribution<ScaleType>{1.f, 10.f}(scale_a_host);
ck_tile::FillUniformDistribution<ScaleType>{1.f, 10.f}(scale_b_host);
break;
case 1:
ck_tile::FillConstant<ADataType>{ADataType(1.f)}(a_host);
@@ -160,14 +160,14 @@ int run_mx_gemm_example(int argc, char* argv[])
MXfp4_GemmConfig16,
true>(argc, argv, Row{}, Col{}, Row{});
}
// else if(mx_prec == "fp8" || mx_prec == "fp8xfp8")
// {
// return run_mx_gemm_with_layouts<ck_tile::fp8_t,
// ck_tile::fp8_t,
// float,
// MXfp8_GemmConfig16,
// false>(argc, argv, Row{}, Col{}, Row{});
// }
else if(mx_prec == "fp8" || mx_prec == "fp8xfp8")
{
return run_mx_gemm_with_layouts<ck_tile::fp8_t,
ck_tile::fp8_t,
float,
MXfp8_GemmConfig16,
true>(argc, argv, Row{}, Col{}, Row{});
}
else
{
throw std::runtime_error("Only fp4 is supported currently!");