fixed vector load siz for fp4

This commit is contained in:
Sami Remes
2026-01-16 12:04:34 -05:00
parent 16ca5cb532
commit f09e10936d
7 changed files with 66 additions and 50 deletions

View File

@@ -60,40 +60,40 @@ float invoke_mx_gemm(ck_tile::DeviceMem& a_dev_buf,
scale_m,
scale_n);
using GemmShape = ck_tile::TileGemmShape<
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
ck_tile::sequence<GemmConfig::M_Warp, GemmConfig::N_Warp, GemmConfig::K_Warp>,
ck_tile::sequence<GemmConfig::M_Warp_Tile,
GemmConfig::N_Warp_Tile,
GemmConfig::K_Warp_Tile>>;
// using GemmShape = ck_tile::TileGemmShape<
// ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
// ck_tile::sequence<GemmConfig::M_Warp, GemmConfig::N_Warp, GemmConfig::K_Warp>,
// ck_tile::sequence<GemmConfig::M_Warp_Tile,
// GemmConfig::N_Warp_Tile,
// GemmConfig::K_Warp_Tile>>;
using TilePartitioner =
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
GemmConfig::TileParitionerGroupNum,
GemmConfig::TileParitionerM01>;
// using TilePartitioner =
// ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
// GemmConfig::TileParitionerGroupNum,
// GemmConfig::TileParitionerM01>;
using MXGemmTraits = ck_tile::TileGemmUniversalTraits<GemmConfig::kPadM,
GemmConfig::kPadN,
GemmConfig::kPadK,
GemmConfig::DoubleSmemBuffer,
ALayout,
BLayout,
CLayout,
GemmConfig::TransposeC,
GemmConfig::UseStructuredSparsity,
UsePersistentKernel,
GemmConfig::NumWaveGroups,
false>;
// using MXGemmTraits = ck_tile::TileGemmUniversalTraits<GemmConfig::kPadM,
// GemmConfig::kPadN,
// GemmConfig::kPadK,
// GemmConfig::DoubleSmemBuffer,
// ALayout,
// BLayout,
// CLayout,
// GemmConfig::TransposeC,
// GemmConfig::UseStructuredSparsity,
// UsePersistentKernel,
// GemmConfig::NumWaveGroups,
// false>;
using MXPipelineProblem = MXGemmPipelineProblem<ADataType,
BDataType,
AccDataType,
GemmShape,
MXGemmTraits,
GemmConfig::Scheduler>;
// using MXPipelineProblem = MXGemmPipelineProblem<ADataType,
// BDataType,
// AccDataType,
// GemmShape,
// MXGemmTraits,
// GemmConfig::Scheduler>;
// Use the new MX comp_async pipeline with MX scaling support
using MXGemmPipeline = ck_tile::MXGemmPipelineAgBgCrCompAsync<MXPipelineProblem>;
// // Use the new MX comp_async pipeline with MX scaling support
// using MXGemmPipeline = ck_tile::MXGemmPipelineAgBgCrCompAsync<MXPipelineProblem>;
// Simplified invocation - comp_async handles hot loop and tail internally
auto invoke_splitk_path = [&](auto split_k_) {

View File

@@ -43,7 +43,7 @@ struct MXGemmHostArgs : ck_tile::UniversalGemmHostArgs<1, 1, 0>
struct MxGemmConfig
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 256;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 512;
static constexpr ck_tile::index_t M_Warp = 1;
@@ -74,7 +74,7 @@ struct MxGemmConfig
struct MXfp4_GemmConfig16 : MxGemmConfig
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 256;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 256;
};

View File

@@ -154,17 +154,17 @@ int run_mx_gemm_example(int argc, char* argv[])
MXfp4_GemmConfig16,
true>(argc, argv, Row{}, Col{}, Row{});
}
else if(mx_prec == "fp8" || mx_prec == "fp8xfp8")
{
return run_mx_gemm_with_layouts<ck_tile::fp8_t,
ck_tile::fp8_t,
float,
MXfp8_GemmConfig16,
false>(argc, argv, Row{}, Col{}, Row{});
}
// else if(mx_prec == "fp8" || mx_prec == "fp8xfp8")
// {
// return run_mx_gemm_with_layouts<ck_tile::fp8_t,
// ck_tile::fp8_t,
// float,
// MXfp8_GemmConfig16,
// false>(argc, argv, Row{}, Col{}, Row{});
// }
else
{
throw std::runtime_error("Only fp4 and fp8 is supported currently!");
throw std::runtime_error("Only fp4 is supported currently!");
}
}
else