try to remove c_shuffle_lds

This commit is contained in:
sjfeng
2025-07-27 17:24:08 +08:00
parent 1264f4d2ab
commit bfb9f4002f
3 changed files with 167 additions and 71 deletions

View File

@@ -13,7 +13,6 @@
#include "flatmm_basic.hpp"
#include <type_traits>
template <typename T>
constexpr const char* DataTypeToString()
{
@@ -63,6 +62,24 @@ auto shuffle_b(const ck_tile::HostTensor<T>& t)
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
}
template <typename FlatmmConfig, typename T>
auto shuffle_b_v1(const ck_tile::HostTensor<T>& t)
{
assert(t.get_lengths().size() == 2);
int n_ = t.get_lengths()[1];
int k_ = t.get_lengths()[0];
constexpr int divisor = FlatmmConfig::N_Warp_Tile == 32 ? 2 : 4;
constexpr int NShuffleStride = FlatmmConfig::N_Tile / FlatmmConfig::N_Warp_Tile;
FlatmmConfig::N_ ck_tile::HostTensor<T> t_view({n_ / FlatmmConfig::N_Tile,
FlatmmConfig::N_Warp_Tile,
NShuffleStride,
k_ / FlatmmConfig::K_Warp_Tile,
divisor,
FlatmmConfig::K_Warp_Tile / divisor});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 2, 3, 4, 1, 5});
}
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
auto calculate_rtol_atol(const ck_tile::index_t K,
const ck_tile::index_t kbatch,
@@ -84,7 +101,6 @@ auto calculate_rtol_atol(const ck_tile::index_t K,
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
}
template <typename FlatmmConfig,
typename ADataType,
typename BDataType,
@@ -215,39 +231,39 @@ float flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>& args,
if(s.flush_cache_)
{
std::cout << "Flushing cache..." << std::endl;
static constexpr ck_tile::index_t APackedSize =
std::is_same_v<BDataType, ck_tile::pk_int4_t> ? 2 : 1;
static constexpr ck_tile::index_t BPackedSize =
std::is_same_v<BDataType, ck_tile::pk_int4_t> ? 2 : 1;
// std::cout << "Flushing cache..." << std::endl;
// static constexpr ck_tile::index_t APackedSize =
// std::is_same_v<BDataType, ck_tile::pk_int4_t> ? 2 : 1;
// static constexpr ck_tile::index_t BPackedSize =
// std::is_same_v<BDataType, ck_tile::pk_int4_t> ? 2 : 1;
ck_tile::HostTensor<ADataType> a_m(ck_tile::host_tensor_descriptor(
args.M, args.K, args.stride_A, is_row_major(ALayout{})));
ck_tile::HostTensor<BDataType> b_n(ck_tile::host_tensor_descriptor(
args.K, args.N, args.stride_B, is_row_major(BLayout{})));
// ck_tile::HostTensor<ADataType> a_m(ck_tile::host_tensor_descriptor(
// args.M, args.K, args.stride_A, is_row_major(ALayout{})));
// ck_tile::HostTensor<BDataType> b_n(ck_tile::host_tensor_descriptor(
// args.K, args.N, args.stride_B, is_row_major(BLayout{})));
auto size_a_buffer = a_m.get_element_space_size_in_bytes() / APackedSize;
auto size_b_buffer = b_n.get_element_space_size_in_bytes() / BPackedSize;
// auto size_a_buffer = a_m.get_element_space_size_in_bytes() / APackedSize;
// auto size_b_buffer = b_n.get_element_space_size_in_bytes() / BPackedSize;
ck_tile::RotatingMemWrapper<ADataType, BDataType> rotating_mem(
kargs.a_ptr, kargs.b_ptr, s.rotating_count_, size_a_buffer, size_b_buffer);
rotating_mem.Print();
// ck_tile::RotatingMemWrapper<ADataType, BDataType> rotating_mem(
// kargs.a_ptr, kargs.b_ptr, s.rotating_count_, size_a_buffer, size_b_buffer);
// rotating_mem.Print();
auto run_flush_cache = [&]() {
// flush icache
ck_tile::flush_icache();
// rotating mem
rotating_mem.Next();
// clear c mem
if(args.k_batch > 1)
hipGetErrorString(hipMemsetAsync(
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_));
};
ave_time = ck_tile::launch_kernel_preprocess(
s,
run_flush_cache,
ck_tile::make_kernel<blocks.x, FlatmmConfig::kBlockPerCu>(
Kernel{}, grids, blocks, 0, kargs));
// auto run_flush_cache = [&]() {
// // flush icache
// ck_tile::flush_icache();
// // rotating mem
// rotating_mem.Next();
// // clear c mem
// if(args.k_batch > 1)
// hipGetErrorString(hipMemsetAsync(
// args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_));
// };
// ave_time = ck_tile::launch_kernel_preprocess(
// s,
// run_flush_cache,
// ck_tile::make_kernel<blocks.x, FlatmmConfig::kBlockPerCu>(
// Kernel{}, grids, blocks, 0, kargs));
}
else
{
@@ -269,10 +285,10 @@ float flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>& args,
}
else
{
Run(has_hot_loop_,
tail_number_,
ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::atomic_add>{});
// Run(has_hot_loop_,
// tail_number_,
// ck_tile::integral_constant<ck_tile::memory_operation_enum,
// ck_tile::memory_operation_enum::atomic_add>{});
}
};
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
@@ -400,14 +416,14 @@ int run_flatmm_example(int argc, char* argv[])
{
if(data_type == "fp16")
{
run_flatmm_example_with_layouts<ck_tile::half_t, FlatmmConfig<ck_tile::half_t>>(
argc, argv, Row{}, Col{}, Row{});
}
else if(data_type == "bf16")
{
run_flatmm_example_with_layouts<ck_tile::bf16_t, FlatmmConfig<ck_tile::bf16_t>>(
argc, argv, Row{}, Col{}, Row{});
// run_flatmm_example_with_layouts<ck_tile::half_t, FlatmmConfig<ck_tile::half_t>>(
// argc, argv, Row{}, Col{}, Row{});
}
// else if(data_type == "bf16")
// {
// run_flatmm_example_with_layouts<ck_tile::bf16_t, FlatmmConfig<ck_tile::bf16_t>>(
// argc, argv, Row{}, Col{}, Row{});
// }
else if(data_type == "fp8")
{
if(scale_opt == 0)
@@ -417,23 +433,25 @@ int run_flatmm_example(int argc, char* argv[])
}
else
{
run_flatmm_example_with_layouts<ck_tile::fp8_t, FlatmmConfig<ck_tile::fp8_t>, 1, 1>(
argc, argv, Row{}, Col{}, Row{});
}
}
else if(data_type == "bf8")
{
if(scale_opt == 0)
{
run_flatmm_example_with_layouts<ck_tile::bf8_t, FlatmmConfig<ck_tile::bf8_t>>(
argc, argv, Row{}, Col{}, Row{});
}
else
{
run_flatmm_example_with_layouts<ck_tile::bf8_t, FlatmmConfig<ck_tile::bf8_t>, 1, 1>(
argc, argv, Row{}, Col{}, Row{});
// run_flatmm_example_with_layouts<ck_tile::fp8_t, FlatmmConfig<ck_tile::fp8_t>, 1,
// 1>(
// argc, argv, Row{}, Col{}, Row{});
}
}
// else if(data_type == "bf8")
// {
// if(scale_opt == 0)
// {
// run_flatmm_example_with_layouts<ck_tile::bf8_t, FlatmmConfig<ck_tile::bf8_t>>(
// argc, argv, Row{}, Col{}, Row{});
// }
// else
// {
// run_flatmm_example_with_layouts<ck_tile::bf8_t, FlatmmConfig<ck_tile::bf8_t>, 1,
// 1>(
// argc, argv, Row{}, Col{}, Row{});
// }
// }
else
{
throw std::runtime_error("Unsupported data_type!");
@@ -459,18 +477,18 @@ int main(int argc, char* argv[])
{
return !run_flatmm_example<FlatmmConfig16>(argc, argv);
}
else if(warp_tile == 1)
{
return !run_flatmm_example<FlatmmConfig32>(argc, argv);
}
else if(warp_tile == 2)
{
return !run_flatmm_example<FlatmmConfig16_950>(argc, argv);
}
else
{
return !run_flatmm_example<FlatmmConfig32_950>(argc, argv);
}
// else if(warp_tile == 1)
// {
// return !run_flatmm_example<FlatmmConfig32>(argc, argv);
// }
// else if(warp_tile == 2)
// {
// return !run_flatmm_example<FlatmmConfig16_950>(argc, argv);
// }
// else
// {
// return !run_flatmm_example<FlatmmConfig32_950>(argc, argv);
// }
}
catch(const std::runtime_error& e)
{