mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 05:01:25 +00:00
fixed vector load siz for fp4
This commit is contained in:
@@ -60,40 +60,40 @@ float invoke_mx_gemm(ck_tile::DeviceMem& a_dev_buf,
|
||||
scale_m,
|
||||
scale_n);
|
||||
|
||||
using GemmShape = ck_tile::TileGemmShape<
|
||||
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
|
||||
ck_tile::sequence<GemmConfig::M_Warp, GemmConfig::N_Warp, GemmConfig::K_Warp>,
|
||||
ck_tile::sequence<GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile>>;
|
||||
// using GemmShape = ck_tile::TileGemmShape<
|
||||
// ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
|
||||
// ck_tile::sequence<GemmConfig::M_Warp, GemmConfig::N_Warp, GemmConfig::K_Warp>,
|
||||
// ck_tile::sequence<GemmConfig::M_Warp_Tile,
|
||||
// GemmConfig::N_Warp_Tile,
|
||||
// GemmConfig::K_Warp_Tile>>;
|
||||
|
||||
using TilePartitioner =
|
||||
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
|
||||
GemmConfig::TileParitionerGroupNum,
|
||||
GemmConfig::TileParitionerM01>;
|
||||
// using TilePartitioner =
|
||||
// ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
|
||||
// GemmConfig::TileParitionerGroupNum,
|
||||
// GemmConfig::TileParitionerM01>;
|
||||
|
||||
using MXGemmTraits = ck_tile::TileGemmUniversalTraits<GemmConfig::kPadM,
|
||||
GemmConfig::kPadN,
|
||||
GemmConfig::kPadK,
|
||||
GemmConfig::DoubleSmemBuffer,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
GemmConfig::TransposeC,
|
||||
GemmConfig::UseStructuredSparsity,
|
||||
UsePersistentKernel,
|
||||
GemmConfig::NumWaveGroups,
|
||||
false>;
|
||||
// using MXGemmTraits = ck_tile::TileGemmUniversalTraits<GemmConfig::kPadM,
|
||||
// GemmConfig::kPadN,
|
||||
// GemmConfig::kPadK,
|
||||
// GemmConfig::DoubleSmemBuffer,
|
||||
// ALayout,
|
||||
// BLayout,
|
||||
// CLayout,
|
||||
// GemmConfig::TransposeC,
|
||||
// GemmConfig::UseStructuredSparsity,
|
||||
// UsePersistentKernel,
|
||||
// GemmConfig::NumWaveGroups,
|
||||
// false>;
|
||||
|
||||
using MXPipelineProblem = MXGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
MXGemmTraits,
|
||||
GemmConfig::Scheduler>;
|
||||
// using MXPipelineProblem = MXGemmPipelineProblem<ADataType,
|
||||
// BDataType,
|
||||
// AccDataType,
|
||||
// GemmShape,
|
||||
// MXGemmTraits,
|
||||
// GemmConfig::Scheduler>;
|
||||
|
||||
// Use the new MX comp_async pipeline with MX scaling support
|
||||
using MXGemmPipeline = ck_tile::MXGemmPipelineAgBgCrCompAsync<MXPipelineProblem>;
|
||||
// // Use the new MX comp_async pipeline with MX scaling support
|
||||
// using MXGemmPipeline = ck_tile::MXGemmPipelineAgBgCrCompAsync<MXPipelineProblem>;
|
||||
|
||||
// Simplified invocation - comp_async handles hot loop and tail internally
|
||||
auto invoke_splitk_path = [&](auto split_k_) {
|
||||
|
||||
@@ -43,7 +43,7 @@ struct MXGemmHostArgs : ck_tile::UniversalGemmHostArgs<1, 1, 0>
|
||||
struct MxGemmConfig
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 256;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 512;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
@@ -74,7 +74,7 @@ struct MxGemmConfig
|
||||
struct MXfp4_GemmConfig16 : MxGemmConfig
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 256;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 256;
|
||||
};
|
||||
|
||||
|
||||
@@ -154,17 +154,17 @@ int run_mx_gemm_example(int argc, char* argv[])
|
||||
MXfp4_GemmConfig16,
|
||||
true>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(mx_prec == "fp8" || mx_prec == "fp8xfp8")
|
||||
{
|
||||
return run_mx_gemm_with_layouts<ck_tile::fp8_t,
|
||||
ck_tile::fp8_t,
|
||||
float,
|
||||
MXfp8_GemmConfig16,
|
||||
false>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
// else if(mx_prec == "fp8" || mx_prec == "fp8xfp8")
|
||||
// {
|
||||
// return run_mx_gemm_with_layouts<ck_tile::fp8_t,
|
||||
// ck_tile::fp8_t,
|
||||
// float,
|
||||
// MXfp8_GemmConfig16,
|
||||
// false>(argc, argv, Row{}, Col{}, Row{});
|
||||
// }
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Only fp4 and fp8 is supported currently!");
|
||||
throw std::runtime_error("Only fp4 is supported currently!");
|
||||
}
|
||||
}
|
||||
else
|
||||
|
||||
Reference in New Issue
Block a user