mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-08 15:30:23 +00:00
v1 v2 works with flatmm develop
This commit is contained in:
@@ -33,6 +33,7 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs<>& args, const ck_tile::stream_c
|
||||
ck_tile::sequence<FlatmmConfig::M_Warp_Tile,
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
FlatmmConfig::K_Warp_Tile>>;
|
||||
std::cout << "CodegenFlatmmShape: " << CodegenFlatmmShape::GetName() << std::endl;
|
||||
|
||||
using TilePartitioner =
|
||||
ck_tile::GemmSpatiallyLocalTilePartitioner<CodegenFlatmmShape,
|
||||
@@ -219,8 +220,10 @@ int run_flatmm_example(int argc, char* argv[])
|
||||
std::string b_layout = arg_parser.get_str("b_layout");
|
||||
if(a_layout == "R" && b_layout == "C")
|
||||
{
|
||||
|
||||
if(data_type == "fp16")
|
||||
{
|
||||
{
|
||||
std::cout << "Running with fp16 data type" << std::endl;
|
||||
run_flatmm_example_with_layouts<ck_tile::half_t, FlatmmConfig<ck_tile::half_t>>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
@@ -261,11 +264,13 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
int warp_tile = arg_parser.get_int("warp_tile");
|
||||
if(warp_tile == 0)
|
||||
{
|
||||
{
|
||||
std::cout << "Running with warp tile size 16x16" << std::endl;
|
||||
return !run_flatmm_example<FlatmmConfig16>(argc, argv);
|
||||
}
|
||||
else if(warp_tile == 1)
|
||||
{
|
||||
std::cout << "Running with 32x32 tile size" << std::endl;
|
||||
return !run_flatmm_example<FlatmmConfig32>(argc, argv);
|
||||
}
|
||||
else if(warp_tile == 2)
|
||||
|
||||
Reference in New Issue
Block a user