v1 v2 works with flatmm develop

This commit is contained in:
AviralGoelAMD
2025-07-16 19:27:36 +00:00
parent c1c30b1c18
commit 9b3e700c7d
4 changed files with 39 additions and 4 deletions

View File

@@ -33,6 +33,7 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs<>& args, const ck_tile::stream_c
ck_tile::sequence<FlatmmConfig::M_Warp_Tile,
FlatmmConfig::N_Warp_Tile,
FlatmmConfig::K_Warp_Tile>>;
std::cout << "CodegenFlatmmShape: " << CodegenFlatmmShape::GetName() << std::endl;
using TilePartitioner =
ck_tile::GemmSpatiallyLocalTilePartitioner<CodegenFlatmmShape,
@@ -219,8 +220,10 @@ int run_flatmm_example(int argc, char* argv[])
std::string b_layout = arg_parser.get_str("b_layout");
if(a_layout == "R" && b_layout == "C")
{
if(data_type == "fp16")
{
{
std::cout << "Running with fp16 data type" << std::endl;
run_flatmm_example_with_layouts<ck_tile::half_t, FlatmmConfig<ck_tile::half_t>>(
argc, argv, Row{}, Col{}, Row{});
}
@@ -261,11 +264,13 @@ int main(int argc, char* argv[])
{
int warp_tile = arg_parser.get_int("warp_tile");
if(warp_tile == 0)
{
{
std::cout << "Running with warp tile size 16x16" << std::endl;
return !run_flatmm_example<FlatmmConfig16>(argc, argv);
}
else if(warp_tile == 1)
{
std::cout << "Running with 32x32 tile size" << std::endl;
return !run_flatmm_example<FlatmmConfig32>(argc, argv);
}
else if(warp_tile == 2)