Add permuteN optimzization when NRepeat % 2 == 0 on flatmm

This commit is contained in:
Feng Shijie
2025-07-27 11:57:38 +00:00
parent bfb9f4002f
commit 5473f06461
5 changed files with 228 additions and 104 deletions

View File

@@ -66,18 +66,19 @@ template <typename FlatmmConfig, typename T>
auto shuffle_b_v1(const ck_tile::HostTensor<T>& t)
{
assert(t.get_lengths().size() == 2);
int n_ = t.get_lengths()[1];
int k_ = t.get_lengths()[0];
constexpr int divisor = FlatmmConfig::N_Warp_Tile == 32 ? 2 : 4;
constexpr int NShuffleStride = FlatmmConfig::N_Tile / FlatmmConfig::N_Warp_Tile;
FlatmmConfig::N_ ck_tile::HostTensor<T> t_view({n_ / FlatmmConfig::N_Tile,
FlatmmConfig::N_Warp_Tile,
NShuffleStride,
k_ / FlatmmConfig::K_Warp_Tile,
divisor,
FlatmmConfig::K_Warp_Tile / divisor});
int n_ = t.get_lengths()[1];
int k_ = t.get_lengths()[0];
constexpr int divisor = FlatmmConfig::N_Warp_Tile == 32 ? 2 : 4;
constexpr int NRepeat = FlatmmConfig::N_Tile / FlatmmConfig::N_Warp_Tile / FlatmmConfig::N_Warp;
ck_tile::HostTensor<T> t_view({n_ / FlatmmConfig::N_Tile,
FlatmmConfig::N_Warp,
FlatmmConfig::N_Warp_Tile,
NRepeat,
k_ / FlatmmConfig::K_Warp_Tile,
divisor,
FlatmmConfig::K_Warp_Tile / divisor});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 2, 3, 4, 1, 5});
return ck_tile::reference_permute(t_view, {0, 3, 1, 4, 5, 2, 6});
}
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
@@ -202,7 +203,10 @@ float flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>& args,
FlatmmConfig::K_Warp_Tile,
CodegenPipelineProblem::TransposeC,
memory_operation,
FlatmmConfig::NumWaveGroups>>;
FlatmmConfig::NumWaveGroups,
false,
1,
FlatmmConfig::TiledMMAPermuteN>>;
// ToDo: Will add the codegen part to test different pipeline policies in GEMM.
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
@@ -231,39 +235,39 @@ float flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>& args,
if(s.flush_cache_)
{
// std::cout << "Flushing cache..." << std::endl;
// static constexpr ck_tile::index_t APackedSize =
// std::is_same_v<BDataType, ck_tile::pk_int4_t> ? 2 : 1;
// static constexpr ck_tile::index_t BPackedSize =
// std::is_same_v<BDataType, ck_tile::pk_int4_t> ? 2 : 1;
std::cout << "Flushing cache..." << std::endl;
static constexpr ck_tile::index_t APackedSize =
std::is_same_v<BDataType, ck_tile::pk_int4_t> ? 2 : 1;
static constexpr ck_tile::index_t BPackedSize =
std::is_same_v<BDataType, ck_tile::pk_int4_t> ? 2 : 1;
// ck_tile::HostTensor<ADataType> a_m(ck_tile::host_tensor_descriptor(
// args.M, args.K, args.stride_A, is_row_major(ALayout{})));
// ck_tile::HostTensor<BDataType> b_n(ck_tile::host_tensor_descriptor(
// args.K, args.N, args.stride_B, is_row_major(BLayout{})));
ck_tile::HostTensor<ADataType> a_m(ck_tile::host_tensor_descriptor(
args.M, args.K, args.stride_A, is_row_major(ALayout{})));
ck_tile::HostTensor<BDataType> b_n(ck_tile::host_tensor_descriptor(
args.K, args.N, args.stride_B, is_row_major(BLayout{})));
// auto size_a_buffer = a_m.get_element_space_size_in_bytes() / APackedSize;
// auto size_b_buffer = b_n.get_element_space_size_in_bytes() / BPackedSize;
auto size_a_buffer = a_m.get_element_space_size_in_bytes() / APackedSize;
auto size_b_buffer = b_n.get_element_space_size_in_bytes() / BPackedSize;
// ck_tile::RotatingMemWrapper<ADataType, BDataType> rotating_mem(
// kargs.a_ptr, kargs.b_ptr, s.rotating_count_, size_a_buffer, size_b_buffer);
// rotating_mem.Print();
ck_tile::RotatingMemWrapper<ADataType, BDataType> rotating_mem(
kargs.a_ptr, kargs.b_ptr, s.rotating_count_, size_a_buffer, size_b_buffer);
rotating_mem.Print();
// auto run_flush_cache = [&]() {
// // flush icache
// ck_tile::flush_icache();
// // rotating mem
// rotating_mem.Next();
// // clear c mem
// if(args.k_batch > 1)
// hipGetErrorString(hipMemsetAsync(
// args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_));
// };
// ave_time = ck_tile::launch_kernel_preprocess(
// s,
// run_flush_cache,
// ck_tile::make_kernel<blocks.x, FlatmmConfig::kBlockPerCu>(
// Kernel{}, grids, blocks, 0, kargs));
auto run_flush_cache = [&]() {
// flush icache
ck_tile::flush_icache();
// rotating mem
rotating_mem.Next();
// clear c mem
if(args.k_batch > 1)
hipGetErrorString(hipMemsetAsync(
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_));
};
ave_time = ck_tile::launch_kernel_preprocess(
s,
run_flush_cache,
ck_tile::make_kernel<blocks.x, FlatmmConfig::kBlockPerCu>(
Kernel{}, grids, blocks, 0, kargs));
}
else
{
@@ -285,10 +289,10 @@ float flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>& args,
}
else
{
// Run(has_hot_loop_,
// tail_number_,
// ck_tile::integral_constant<ck_tile::memory_operation_enum,
// ck_tile::memory_operation_enum::atomic_add>{});
Run(has_hot_loop_,
tail_number_,
ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::atomic_add>{});
}
};
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
@@ -416,14 +420,14 @@ int run_flatmm_example(int argc, char* argv[])
{
if(data_type == "fp16")
{
// run_flatmm_example_with_layouts<ck_tile::half_t, FlatmmConfig<ck_tile::half_t>>(
// argc, argv, Row{}, Col{}, Row{});
run_flatmm_example_with_layouts<ck_tile::half_t, FlatmmConfig<ck_tile::half_t>>(
argc, argv, Row{}, Col{}, Row{});
}
else if(data_type == "bf16")
{
run_flatmm_example_with_layouts<ck_tile::bf16_t, FlatmmConfig<ck_tile::bf16_t>>(
argc, argv, Row{}, Col{}, Row{});
}
// else if(data_type == "bf16")
// {
// run_flatmm_example_with_layouts<ck_tile::bf16_t, FlatmmConfig<ck_tile::bf16_t>>(
// argc, argv, Row{}, Col{}, Row{});
// }
else if(data_type == "fp8")
{
if(scale_opt == 0)
@@ -433,25 +437,23 @@ int run_flatmm_example(int argc, char* argv[])
}
else
{
// run_flatmm_example_with_layouts<ck_tile::fp8_t, FlatmmConfig<ck_tile::fp8_t>, 1,
// 1>(
// argc, argv, Row{}, Col{}, Row{});
run_flatmm_example_with_layouts<ck_tile::fp8_t, FlatmmConfig<ck_tile::fp8_t>, 1, 1>(
argc, argv, Row{}, Col{}, Row{});
}
}
else if(data_type == "bf8")
{
if(scale_opt == 0)
{
run_flatmm_example_with_layouts<ck_tile::bf8_t, FlatmmConfig<ck_tile::bf8_t>>(
argc, argv, Row{}, Col{}, Row{});
}
else
{
run_flatmm_example_with_layouts<ck_tile::bf8_t, FlatmmConfig<ck_tile::bf8_t>, 1, 1>(
argc, argv, Row{}, Col{}, Row{});
}
}
// else if(data_type == "bf8")
// {
// if(scale_opt == 0)
// {
// run_flatmm_example_with_layouts<ck_tile::bf8_t, FlatmmConfig<ck_tile::bf8_t>>(
// argc, argv, Row{}, Col{}, Row{});
// }
// else
// {
// run_flatmm_example_with_layouts<ck_tile::bf8_t, FlatmmConfig<ck_tile::bf8_t>, 1,
// 1>(
// argc, argv, Row{}, Col{}, Row{});
// }
// }
else
{
throw std::runtime_error("Unsupported data_type!");
@@ -477,18 +479,18 @@ int main(int argc, char* argv[])
{
return !run_flatmm_example<FlatmmConfig16>(argc, argv);
}
// else if(warp_tile == 1)
// {
// return !run_flatmm_example<FlatmmConfig32>(argc, argv);
// }
// else if(warp_tile == 2)
// {
// return !run_flatmm_example<FlatmmConfig16_950>(argc, argv);
// }
// else
// {
// return !run_flatmm_example<FlatmmConfig32_950>(argc, argv);
// }
else if(warp_tile == 1)
{
return !run_flatmm_example<FlatmmConfig32>(argc, argv);
}
else if(warp_tile == 2)
{
return !run_flatmm_example<FlatmmConfig16_950>(argc, argv);
}
else
{
return !run_flatmm_example<FlatmmConfig32_950>(argc, argv);
}
}
catch(const std::runtime_error& e)
{