mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 17:26:00 +00:00
update example code
This commit is contained in:
@@ -60,41 +60,6 @@ float invoke_mx_gemm(ck_tile::DeviceMem& a_dev_buf,
|
||||
scale_m,
|
||||
scale_n);
|
||||
|
||||
// using GemmShape = ck_tile::TileGemmShape<
|
||||
// ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
|
||||
// ck_tile::sequence<GemmConfig::M_Warp, GemmConfig::N_Warp, GemmConfig::K_Warp>,
|
||||
// ck_tile::sequence<GemmConfig::M_Warp_Tile,
|
||||
// GemmConfig::N_Warp_Tile,
|
||||
// GemmConfig::K_Warp_Tile>>;
|
||||
|
||||
// using TilePartitioner =
|
||||
// ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
|
||||
// GemmConfig::TileParitionerGroupNum,
|
||||
// GemmConfig::TileParitionerM01>;
|
||||
|
||||
// using MXGemmTraits = ck_tile::TileGemmUniversalTraits<GemmConfig::kPadM,
|
||||
// GemmConfig::kPadN,
|
||||
// GemmConfig::kPadK,
|
||||
// GemmConfig::DoubleSmemBuffer,
|
||||
// ALayout,
|
||||
// BLayout,
|
||||
// CLayout,
|
||||
// GemmConfig::TransposeC,
|
||||
// GemmConfig::UseStructuredSparsity,
|
||||
// UsePersistentKernel,
|
||||
// GemmConfig::NumWaveGroups,
|
||||
// false>;
|
||||
|
||||
// using MXPipelineProblem = MXGemmPipelineProblem<ADataType,
|
||||
// BDataType,
|
||||
// AccDataType,
|
||||
// GemmShape,
|
||||
// MXGemmTraits,
|
||||
// GemmConfig::Scheduler>;
|
||||
|
||||
// // Use the new MX comp_async pipeline with MX scaling support
|
||||
// using MXGemmPipeline = ck_tile::MXGemmPipelineAgBgCrCompAsync<MXPipelineProblem>;
|
||||
|
||||
// Simplified invocation - comp_async handles hot loop and tail internally
|
||||
auto invoke_splitk_path = [&](auto split_k_) {
|
||||
return mx_gemm_calc<GemmConfig,
|
||||
@@ -154,10 +119,7 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("repeat", "100", "number of iterations to benchmark the kernel")
|
||||
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
|
||||
.insert("split_k", "1", "splitK value")
|
||||
.insert("init", "0", "0:random, 1:constant(1)")
|
||||
.insert("warp_tile",
|
||||
"0",
|
||||
"0: 16x16, 1: 32x32, 2: 16x16x128 (950 only), 3: 32x32x64 (950 only)");
|
||||
.insert("init", "0", "0:random, 1:constant(1)");
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
@@ -66,16 +66,16 @@ struct MxGemmConfig
|
||||
static constexpr int TileParitionerM01 = 4;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
static constexpr ck_tile::index_t NumWaveGroups = 1;
|
||||
static constexpr bool DoubleSmemBuffer = true; // comp_async uses double buffer
|
||||
static constexpr bool DoubleSmemBuffer = false; // comp_async uses double buffer
|
||||
|
||||
static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp;
|
||||
static constexpr bool TiledMMAPermuteN = false;
|
||||
};
|
||||
struct MXfp4_GemmConfig16 : MxGemmConfig
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 256;
|
||||
static constexpr ck_tile::index_t M_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Tile = 64;
|
||||
static constexpr ck_tile::index_t K_Tile = 512;
|
||||
};
|
||||
|
||||
// GEMM config with 16x16 warp tile
|
||||
|
||||
@@ -69,6 +69,12 @@ int run_mx_gemm_with_layouts(int argc,
|
||||
ck_tile::FillConstant<ScaleType>{ScaleType(1.f)}(scale_a_host);
|
||||
ck_tile::FillConstant<ScaleType>{ScaleType(1.f)}(scale_b_host);
|
||||
break;
|
||||
case 2:
|
||||
ck_tile::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_host);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_host);
|
||||
ck_tile::FillConstant<ScaleType>{ScaleType(1.f)}(scale_a_host);
|
||||
ck_tile::FillConstant<ScaleType>{ScaleType(1.f)}(scale_b_host);
|
||||
break;
|
||||
}
|
||||
|
||||
ck_tile::DeviceMem a_dev_buf(a_host.get_element_space_size_in_bytes());
|
||||
|
||||
Reference in New Issue
Block a user