mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-02 12:41:26 +00:00
fixed vector load siz for fp4
This commit is contained in:
@@ -60,40 +60,40 @@ float invoke_mx_gemm(ck_tile::DeviceMem& a_dev_buf,
|
||||
scale_m,
|
||||
scale_n);
|
||||
|
||||
using GemmShape = ck_tile::TileGemmShape<
|
||||
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
|
||||
ck_tile::sequence<GemmConfig::M_Warp, GemmConfig::N_Warp, GemmConfig::K_Warp>,
|
||||
ck_tile::sequence<GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile>>;
|
||||
// using GemmShape = ck_tile::TileGemmShape<
|
||||
// ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
|
||||
// ck_tile::sequence<GemmConfig::M_Warp, GemmConfig::N_Warp, GemmConfig::K_Warp>,
|
||||
// ck_tile::sequence<GemmConfig::M_Warp_Tile,
|
||||
// GemmConfig::N_Warp_Tile,
|
||||
// GemmConfig::K_Warp_Tile>>;
|
||||
|
||||
using TilePartitioner =
|
||||
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
|
||||
GemmConfig::TileParitionerGroupNum,
|
||||
GemmConfig::TileParitionerM01>;
|
||||
// using TilePartitioner =
|
||||
// ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
|
||||
// GemmConfig::TileParitionerGroupNum,
|
||||
// GemmConfig::TileParitionerM01>;
|
||||
|
||||
using MXGemmTraits = ck_tile::TileGemmUniversalTraits<GemmConfig::kPadM,
|
||||
GemmConfig::kPadN,
|
||||
GemmConfig::kPadK,
|
||||
GemmConfig::DoubleSmemBuffer,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
GemmConfig::TransposeC,
|
||||
GemmConfig::UseStructuredSparsity,
|
||||
UsePersistentKernel,
|
||||
GemmConfig::NumWaveGroups,
|
||||
false>;
|
||||
// using MXGemmTraits = ck_tile::TileGemmUniversalTraits<GemmConfig::kPadM,
|
||||
// GemmConfig::kPadN,
|
||||
// GemmConfig::kPadK,
|
||||
// GemmConfig::DoubleSmemBuffer,
|
||||
// ALayout,
|
||||
// BLayout,
|
||||
// CLayout,
|
||||
// GemmConfig::TransposeC,
|
||||
// GemmConfig::UseStructuredSparsity,
|
||||
// UsePersistentKernel,
|
||||
// GemmConfig::NumWaveGroups,
|
||||
// false>;
|
||||
|
||||
using MXPipelineProblem = MXGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
MXGemmTraits,
|
||||
GemmConfig::Scheduler>;
|
||||
// using MXPipelineProblem = MXGemmPipelineProblem<ADataType,
|
||||
// BDataType,
|
||||
// AccDataType,
|
||||
// GemmShape,
|
||||
// MXGemmTraits,
|
||||
// GemmConfig::Scheduler>;
|
||||
|
||||
// Use the new MX comp_async pipeline with MX scaling support
|
||||
using MXGemmPipeline = ck_tile::MXGemmPipelineAgBgCrCompAsync<MXPipelineProblem>;
|
||||
// // Use the new MX comp_async pipeline with MX scaling support
|
||||
// using MXGemmPipeline = ck_tile::MXGemmPipelineAgBgCrCompAsync<MXPipelineProblem>;
|
||||
|
||||
// Simplified invocation - comp_async handles hot loop and tail internally
|
||||
auto invoke_splitk_path = [&](auto split_k_) {
|
||||
|
||||
Reference in New Issue
Block a user