mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 21:21:22 +00:00
Add permuteN optimzization when NRepeat % 2 == 0 on flatmm
This commit is contained in:
@@ -66,18 +66,19 @@ template <typename FlatmmConfig, typename T>
|
||||
auto shuffle_b_v1(const ck_tile::HostTensor<T>& t)
|
||||
{
|
||||
assert(t.get_lengths().size() == 2);
|
||||
int n_ = t.get_lengths()[1];
|
||||
int k_ = t.get_lengths()[0];
|
||||
constexpr int divisor = FlatmmConfig::N_Warp_Tile == 32 ? 2 : 4;
|
||||
constexpr int NShuffleStride = FlatmmConfig::N_Tile / FlatmmConfig::N_Warp_Tile;
|
||||
FlatmmConfig::N_ ck_tile::HostTensor<T> t_view({n_ / FlatmmConfig::N_Tile,
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
NShuffleStride,
|
||||
k_ / FlatmmConfig::K_Warp_Tile,
|
||||
divisor,
|
||||
FlatmmConfig::K_Warp_Tile / divisor});
|
||||
int n_ = t.get_lengths()[1];
|
||||
int k_ = t.get_lengths()[0];
|
||||
constexpr int divisor = FlatmmConfig::N_Warp_Tile == 32 ? 2 : 4;
|
||||
constexpr int NRepeat = FlatmmConfig::N_Tile / FlatmmConfig::N_Warp_Tile / FlatmmConfig::N_Warp;
|
||||
ck_tile::HostTensor<T> t_view({n_ / FlatmmConfig::N_Tile,
|
||||
FlatmmConfig::N_Warp,
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
NRepeat,
|
||||
k_ / FlatmmConfig::K_Warp_Tile,
|
||||
divisor,
|
||||
FlatmmConfig::K_Warp_Tile / divisor});
|
||||
std::copy(t.begin(), t.end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {0, 2, 3, 4, 1, 5});
|
||||
return ck_tile::reference_permute(t_view, {0, 3, 1, 4, 5, 2, 6});
|
||||
}
|
||||
|
||||
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
|
||||
@@ -202,7 +203,10 @@ float flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>& args,
|
||||
FlatmmConfig::K_Warp_Tile,
|
||||
CodegenPipelineProblem::TransposeC,
|
||||
memory_operation,
|
||||
FlatmmConfig::NumWaveGroups>>;
|
||||
FlatmmConfig::NumWaveGroups,
|
||||
false,
|
||||
1,
|
||||
FlatmmConfig::TiledMMAPermuteN>>;
|
||||
|
||||
// ToDo: Will add the codegen part to test different pipeline policies in GEMM.
|
||||
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
|
||||
@@ -231,39 +235,39 @@ float flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>& args,
|
||||
|
||||
if(s.flush_cache_)
|
||||
{
|
||||
// std::cout << "Flushing cache..." << std::endl;
|
||||
// static constexpr ck_tile::index_t APackedSize =
|
||||
// std::is_same_v<BDataType, ck_tile::pk_int4_t> ? 2 : 1;
|
||||
// static constexpr ck_tile::index_t BPackedSize =
|
||||
// std::is_same_v<BDataType, ck_tile::pk_int4_t> ? 2 : 1;
|
||||
std::cout << "Flushing cache..." << std::endl;
|
||||
static constexpr ck_tile::index_t APackedSize =
|
||||
std::is_same_v<BDataType, ck_tile::pk_int4_t> ? 2 : 1;
|
||||
static constexpr ck_tile::index_t BPackedSize =
|
||||
std::is_same_v<BDataType, ck_tile::pk_int4_t> ? 2 : 1;
|
||||
|
||||
// ck_tile::HostTensor<ADataType> a_m(ck_tile::host_tensor_descriptor(
|
||||
// args.M, args.K, args.stride_A, is_row_major(ALayout{})));
|
||||
// ck_tile::HostTensor<BDataType> b_n(ck_tile::host_tensor_descriptor(
|
||||
// args.K, args.N, args.stride_B, is_row_major(BLayout{})));
|
||||
ck_tile::HostTensor<ADataType> a_m(ck_tile::host_tensor_descriptor(
|
||||
args.M, args.K, args.stride_A, is_row_major(ALayout{})));
|
||||
ck_tile::HostTensor<BDataType> b_n(ck_tile::host_tensor_descriptor(
|
||||
args.K, args.N, args.stride_B, is_row_major(BLayout{})));
|
||||
|
||||
// auto size_a_buffer = a_m.get_element_space_size_in_bytes() / APackedSize;
|
||||
// auto size_b_buffer = b_n.get_element_space_size_in_bytes() / BPackedSize;
|
||||
auto size_a_buffer = a_m.get_element_space_size_in_bytes() / APackedSize;
|
||||
auto size_b_buffer = b_n.get_element_space_size_in_bytes() / BPackedSize;
|
||||
|
||||
// ck_tile::RotatingMemWrapper<ADataType, BDataType> rotating_mem(
|
||||
// kargs.a_ptr, kargs.b_ptr, s.rotating_count_, size_a_buffer, size_b_buffer);
|
||||
// rotating_mem.Print();
|
||||
ck_tile::RotatingMemWrapper<ADataType, BDataType> rotating_mem(
|
||||
kargs.a_ptr, kargs.b_ptr, s.rotating_count_, size_a_buffer, size_b_buffer);
|
||||
rotating_mem.Print();
|
||||
|
||||
// auto run_flush_cache = [&]() {
|
||||
// // flush icache
|
||||
// ck_tile::flush_icache();
|
||||
// // rotating mem
|
||||
// rotating_mem.Next();
|
||||
// // clear c mem
|
||||
// if(args.k_batch > 1)
|
||||
// hipGetErrorString(hipMemsetAsync(
|
||||
// args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_));
|
||||
// };
|
||||
// ave_time = ck_tile::launch_kernel_preprocess(
|
||||
// s,
|
||||
// run_flush_cache,
|
||||
// ck_tile::make_kernel<blocks.x, FlatmmConfig::kBlockPerCu>(
|
||||
// Kernel{}, grids, blocks, 0, kargs));
|
||||
auto run_flush_cache = [&]() {
|
||||
// flush icache
|
||||
ck_tile::flush_icache();
|
||||
// rotating mem
|
||||
rotating_mem.Next();
|
||||
// clear c mem
|
||||
if(args.k_batch > 1)
|
||||
hipGetErrorString(hipMemsetAsync(
|
||||
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_));
|
||||
};
|
||||
ave_time = ck_tile::launch_kernel_preprocess(
|
||||
s,
|
||||
run_flush_cache,
|
||||
ck_tile::make_kernel<blocks.x, FlatmmConfig::kBlockPerCu>(
|
||||
Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -285,10 +289,10 @@ float flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>& args,
|
||||
}
|
||||
else
|
||||
{
|
||||
// Run(has_hot_loop_,
|
||||
// tail_number_,
|
||||
// ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
// ck_tile::memory_operation_enum::atomic_add>{});
|
||||
Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::atomic_add>{});
|
||||
}
|
||||
};
|
||||
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
|
||||
@@ -416,14 +420,14 @@ int run_flatmm_example(int argc, char* argv[])
|
||||
{
|
||||
if(data_type == "fp16")
|
||||
{
|
||||
// run_flatmm_example_with_layouts<ck_tile::half_t, FlatmmConfig<ck_tile::half_t>>(
|
||||
// argc, argv, Row{}, Col{}, Row{});
|
||||
run_flatmm_example_with_layouts<ck_tile::half_t, FlatmmConfig<ck_tile::half_t>>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(data_type == "bf16")
|
||||
{
|
||||
run_flatmm_example_with_layouts<ck_tile::bf16_t, FlatmmConfig<ck_tile::bf16_t>>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
// else if(data_type == "bf16")
|
||||
// {
|
||||
// run_flatmm_example_with_layouts<ck_tile::bf16_t, FlatmmConfig<ck_tile::bf16_t>>(
|
||||
// argc, argv, Row{}, Col{}, Row{});
|
||||
// }
|
||||
else if(data_type == "fp8")
|
||||
{
|
||||
if(scale_opt == 0)
|
||||
@@ -433,25 +437,23 @@ int run_flatmm_example(int argc, char* argv[])
|
||||
}
|
||||
else
|
||||
{
|
||||
// run_flatmm_example_with_layouts<ck_tile::fp8_t, FlatmmConfig<ck_tile::fp8_t>, 1,
|
||||
// 1>(
|
||||
// argc, argv, Row{}, Col{}, Row{});
|
||||
run_flatmm_example_with_layouts<ck_tile::fp8_t, FlatmmConfig<ck_tile::fp8_t>, 1, 1>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
}
|
||||
else if(data_type == "bf8")
|
||||
{
|
||||
if(scale_opt == 0)
|
||||
{
|
||||
run_flatmm_example_with_layouts<ck_tile::bf8_t, FlatmmConfig<ck_tile::bf8_t>>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
run_flatmm_example_with_layouts<ck_tile::bf8_t, FlatmmConfig<ck_tile::bf8_t>, 1, 1>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
}
|
||||
// else if(data_type == "bf8")
|
||||
// {
|
||||
// if(scale_opt == 0)
|
||||
// {
|
||||
// run_flatmm_example_with_layouts<ck_tile::bf8_t, FlatmmConfig<ck_tile::bf8_t>>(
|
||||
// argc, argv, Row{}, Col{}, Row{});
|
||||
// }
|
||||
// else
|
||||
// {
|
||||
// run_flatmm_example_with_layouts<ck_tile::bf8_t, FlatmmConfig<ck_tile::bf8_t>, 1,
|
||||
// 1>(
|
||||
// argc, argv, Row{}, Col{}, Row{});
|
||||
// }
|
||||
// }
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data_type!");
|
||||
@@ -477,18 +479,18 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
return !run_flatmm_example<FlatmmConfig16>(argc, argv);
|
||||
}
|
||||
// else if(warp_tile == 1)
|
||||
// {
|
||||
// return !run_flatmm_example<FlatmmConfig32>(argc, argv);
|
||||
// }
|
||||
// else if(warp_tile == 2)
|
||||
// {
|
||||
// return !run_flatmm_example<FlatmmConfig16_950>(argc, argv);
|
||||
// }
|
||||
// else
|
||||
// {
|
||||
// return !run_flatmm_example<FlatmmConfig32_950>(argc, argv);
|
||||
// }
|
||||
else if(warp_tile == 1)
|
||||
{
|
||||
return !run_flatmm_example<FlatmmConfig32>(argc, argv);
|
||||
}
|
||||
else if(warp_tile == 2)
|
||||
{
|
||||
return !run_flatmm_example<FlatmmConfig16_950>(argc, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
return !run_flatmm_example<FlatmmConfig32_950>(argc, argv);
|
||||
}
|
||||
}
|
||||
catch(const std::runtime_error& e)
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user