mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-05 20:55:59 +00:00
try to remove c_shuffle_lds
This commit is contained in:
@@ -13,7 +13,6 @@
|
||||
#include "flatmm_basic.hpp"
|
||||
#include <type_traits>
|
||||
|
||||
|
||||
template <typename T>
|
||||
constexpr const char* DataTypeToString()
|
||||
{
|
||||
@@ -63,6 +62,24 @@ auto shuffle_b(const ck_tile::HostTensor<T>& t)
|
||||
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
|
||||
}
|
||||
|
||||
template <typename FlatmmConfig, typename T>
|
||||
auto shuffle_b_v1(const ck_tile::HostTensor<T>& t)
|
||||
{
|
||||
assert(t.get_lengths().size() == 2);
|
||||
int n_ = t.get_lengths()[1];
|
||||
int k_ = t.get_lengths()[0];
|
||||
constexpr int divisor = FlatmmConfig::N_Warp_Tile == 32 ? 2 : 4;
|
||||
constexpr int NShuffleStride = FlatmmConfig::N_Tile / FlatmmConfig::N_Warp_Tile;
|
||||
FlatmmConfig::N_ ck_tile::HostTensor<T> t_view({n_ / FlatmmConfig::N_Tile,
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
NShuffleStride,
|
||||
k_ / FlatmmConfig::K_Warp_Tile,
|
||||
divisor,
|
||||
FlatmmConfig::K_Warp_Tile / divisor});
|
||||
std::copy(t.begin(), t.end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {0, 2, 3, 4, 1, 5});
|
||||
}
|
||||
|
||||
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
|
||||
auto calculate_rtol_atol(const ck_tile::index_t K,
|
||||
const ck_tile::index_t kbatch,
|
||||
@@ -84,7 +101,6 @@ auto calculate_rtol_atol(const ck_tile::index_t K,
|
||||
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
|
||||
}
|
||||
|
||||
|
||||
template <typename FlatmmConfig,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
@@ -215,39 +231,39 @@ float flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>& args,
|
||||
|
||||
if(s.flush_cache_)
|
||||
{
|
||||
std::cout << "Flushing cache..." << std::endl;
|
||||
static constexpr ck_tile::index_t APackedSize =
|
||||
std::is_same_v<BDataType, ck_tile::pk_int4_t> ? 2 : 1;
|
||||
static constexpr ck_tile::index_t BPackedSize =
|
||||
std::is_same_v<BDataType, ck_tile::pk_int4_t> ? 2 : 1;
|
||||
// std::cout << "Flushing cache..." << std::endl;
|
||||
// static constexpr ck_tile::index_t APackedSize =
|
||||
// std::is_same_v<BDataType, ck_tile::pk_int4_t> ? 2 : 1;
|
||||
// static constexpr ck_tile::index_t BPackedSize =
|
||||
// std::is_same_v<BDataType, ck_tile::pk_int4_t> ? 2 : 1;
|
||||
|
||||
ck_tile::HostTensor<ADataType> a_m(ck_tile::host_tensor_descriptor(
|
||||
args.M, args.K, args.stride_A, is_row_major(ALayout{})));
|
||||
ck_tile::HostTensor<BDataType> b_n(ck_tile::host_tensor_descriptor(
|
||||
args.K, args.N, args.stride_B, is_row_major(BLayout{})));
|
||||
// ck_tile::HostTensor<ADataType> a_m(ck_tile::host_tensor_descriptor(
|
||||
// args.M, args.K, args.stride_A, is_row_major(ALayout{})));
|
||||
// ck_tile::HostTensor<BDataType> b_n(ck_tile::host_tensor_descriptor(
|
||||
// args.K, args.N, args.stride_B, is_row_major(BLayout{})));
|
||||
|
||||
auto size_a_buffer = a_m.get_element_space_size_in_bytes() / APackedSize;
|
||||
auto size_b_buffer = b_n.get_element_space_size_in_bytes() / BPackedSize;
|
||||
// auto size_a_buffer = a_m.get_element_space_size_in_bytes() / APackedSize;
|
||||
// auto size_b_buffer = b_n.get_element_space_size_in_bytes() / BPackedSize;
|
||||
|
||||
ck_tile::RotatingMemWrapper<ADataType, BDataType> rotating_mem(
|
||||
kargs.a_ptr, kargs.b_ptr, s.rotating_count_, size_a_buffer, size_b_buffer);
|
||||
rotating_mem.Print();
|
||||
// ck_tile::RotatingMemWrapper<ADataType, BDataType> rotating_mem(
|
||||
// kargs.a_ptr, kargs.b_ptr, s.rotating_count_, size_a_buffer, size_b_buffer);
|
||||
// rotating_mem.Print();
|
||||
|
||||
auto run_flush_cache = [&]() {
|
||||
// flush icache
|
||||
ck_tile::flush_icache();
|
||||
// rotating mem
|
||||
rotating_mem.Next();
|
||||
// clear c mem
|
||||
if(args.k_batch > 1)
|
||||
hipGetErrorString(hipMemsetAsync(
|
||||
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_));
|
||||
};
|
||||
ave_time = ck_tile::launch_kernel_preprocess(
|
||||
s,
|
||||
run_flush_cache,
|
||||
ck_tile::make_kernel<blocks.x, FlatmmConfig::kBlockPerCu>(
|
||||
Kernel{}, grids, blocks, 0, kargs));
|
||||
// auto run_flush_cache = [&]() {
|
||||
// // flush icache
|
||||
// ck_tile::flush_icache();
|
||||
// // rotating mem
|
||||
// rotating_mem.Next();
|
||||
// // clear c mem
|
||||
// if(args.k_batch > 1)
|
||||
// hipGetErrorString(hipMemsetAsync(
|
||||
// args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_));
|
||||
// };
|
||||
// ave_time = ck_tile::launch_kernel_preprocess(
|
||||
// s,
|
||||
// run_flush_cache,
|
||||
// ck_tile::make_kernel<blocks.x, FlatmmConfig::kBlockPerCu>(
|
||||
// Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -269,10 +285,10 @@ float flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>& args,
|
||||
}
|
||||
else
|
||||
{
|
||||
Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::atomic_add>{});
|
||||
// Run(has_hot_loop_,
|
||||
// tail_number_,
|
||||
// ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
// ck_tile::memory_operation_enum::atomic_add>{});
|
||||
}
|
||||
};
|
||||
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
|
||||
@@ -400,14 +416,14 @@ int run_flatmm_example(int argc, char* argv[])
|
||||
{
|
||||
if(data_type == "fp16")
|
||||
{
|
||||
run_flatmm_example_with_layouts<ck_tile::half_t, FlatmmConfig<ck_tile::half_t>>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(data_type == "bf16")
|
||||
{
|
||||
run_flatmm_example_with_layouts<ck_tile::bf16_t, FlatmmConfig<ck_tile::bf16_t>>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
// run_flatmm_example_with_layouts<ck_tile::half_t, FlatmmConfig<ck_tile::half_t>>(
|
||||
// argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
// else if(data_type == "bf16")
|
||||
// {
|
||||
// run_flatmm_example_with_layouts<ck_tile::bf16_t, FlatmmConfig<ck_tile::bf16_t>>(
|
||||
// argc, argv, Row{}, Col{}, Row{});
|
||||
// }
|
||||
else if(data_type == "fp8")
|
||||
{
|
||||
if(scale_opt == 0)
|
||||
@@ -417,23 +433,25 @@ int run_flatmm_example(int argc, char* argv[])
|
||||
}
|
||||
else
|
||||
{
|
||||
run_flatmm_example_with_layouts<ck_tile::fp8_t, FlatmmConfig<ck_tile::fp8_t>, 1, 1>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
}
|
||||
else if(data_type == "bf8")
|
||||
{
|
||||
if(scale_opt == 0)
|
||||
{
|
||||
run_flatmm_example_with_layouts<ck_tile::bf8_t, FlatmmConfig<ck_tile::bf8_t>>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
run_flatmm_example_with_layouts<ck_tile::bf8_t, FlatmmConfig<ck_tile::bf8_t>, 1, 1>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
// run_flatmm_example_with_layouts<ck_tile::fp8_t, FlatmmConfig<ck_tile::fp8_t>, 1,
|
||||
// 1>(
|
||||
// argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
}
|
||||
// else if(data_type == "bf8")
|
||||
// {
|
||||
// if(scale_opt == 0)
|
||||
// {
|
||||
// run_flatmm_example_with_layouts<ck_tile::bf8_t, FlatmmConfig<ck_tile::bf8_t>>(
|
||||
// argc, argv, Row{}, Col{}, Row{});
|
||||
// }
|
||||
// else
|
||||
// {
|
||||
// run_flatmm_example_with_layouts<ck_tile::bf8_t, FlatmmConfig<ck_tile::bf8_t>, 1,
|
||||
// 1>(
|
||||
// argc, argv, Row{}, Col{}, Row{});
|
||||
// }
|
||||
// }
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data_type!");
|
||||
@@ -459,18 +477,18 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
return !run_flatmm_example<FlatmmConfig16>(argc, argv);
|
||||
}
|
||||
else if(warp_tile == 1)
|
||||
{
|
||||
return !run_flatmm_example<FlatmmConfig32>(argc, argv);
|
||||
}
|
||||
else if(warp_tile == 2)
|
||||
{
|
||||
return !run_flatmm_example<FlatmmConfig16_950>(argc, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
return !run_flatmm_example<FlatmmConfig32_950>(argc, argv);
|
||||
}
|
||||
// else if(warp_tile == 1)
|
||||
// {
|
||||
// return !run_flatmm_example<FlatmmConfig32>(argc, argv);
|
||||
// }
|
||||
// else if(warp_tile == 2)
|
||||
// {
|
||||
// return !run_flatmm_example<FlatmmConfig16_950>(argc, argv);
|
||||
// }
|
||||
// else
|
||||
// {
|
||||
// return !run_flatmm_example<FlatmmConfig32_950>(argc, argv);
|
||||
// }
|
||||
}
|
||||
catch(const std::runtime_error& e)
|
||||
{
|
||||
|
||||
@@ -93,7 +93,7 @@ int run_flatmm_example_with_layouts(int argc,
|
||||
per_channel_scale_dev_buf.ToDevice(per_channel_scale.data());
|
||||
|
||||
// do pre-shuffle
|
||||
ck_tile::HostTensor<BDataType> b_shuffle_host = shuffle_b<FlatmmConfig>(b_origin_host);
|
||||
ck_tile::HostTensor<BDataType> b_shuffle_host = shuffle_b_v1<FlatmmConfig>(b_origin_host);
|
||||
ck_tile::DeviceMem b_shuffle_dev_buf(b_shuffle_host.get_element_space_size_in_bytes());
|
||||
b_shuffle_dev_buf.ToDevice(b_shuffle_host.data());
|
||||
|
||||
@@ -133,7 +133,7 @@ int run_flatmm_example_with_layouts(int argc,
|
||||
|
||||
if(arg_parser.get_int("v") == 1)
|
||||
{
|
||||
if (ScaleGranularityM == -1 || ScaleGranularityN == -1)
|
||||
if(ScaleGranularityM == -1 || ScaleGranularityN == -1)
|
||||
throw std::runtime_error("ScaleAB is not supported for CPU verification!\n");
|
||||
ck_tile::HostTensor<CDataType> c_ref_host(
|
||||
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
|
||||
|
||||
@@ -257,6 +257,84 @@ struct CShuffleEpilogue
|
||||
}
|
||||
|
||||
template <typename ODramWindow, typename OAccTile, typename DsDramWindows>
|
||||
CK_TILE_DEVICE auto operator()(ODramWindow& out_dram_window,
|
||||
const OAccTile& o_acc_tile,
|
||||
const DsDramWindows& ds_dram_windows,
|
||||
void* p_smem)
|
||||
{
|
||||
constexpr int MRepeat = kMPerBlock / MPerIterationShuffle;
|
||||
constexpr int NRepeat = kNPerBlock / NPerIterationShuffle;
|
||||
|
||||
static_assert(MPerXdl == 16);
|
||||
|
||||
constexpr int kM0 = MWave;
|
||||
constexpr int kM2 = 4;
|
||||
constexpr int kM1 = MPerXdl / kM2;
|
||||
|
||||
constexpr int kN0 = NWave;
|
||||
constexpr int kN1 = NPerXdl;
|
||||
constexpr int kN2 = NRepeat;
|
||||
|
||||
using IntrThreadShuffleEncode =
|
||||
tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<kM0, kM1, kM2>, sequence<kN0, kN1, kN2>>,
|
||||
tuple<sequence<1, 2>, sequence<1, 2>>,
|
||||
tuple<sequence<0, 0>, sequence<1, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<2, 2>>;
|
||||
constexpr auto dram_tile_distribution =
|
||||
make_static_tile_distribution(IntrThreadShuffleEncode{});
|
||||
|
||||
auto d_dram_windows = generate_tuple(
|
||||
[&](auto idx) {
|
||||
return make_tile_window(ds_dram_windows[idx], dram_tile_distribution);
|
||||
},
|
||||
number<NumDTensor>{});
|
||||
|
||||
constexpr auto c_warp_y_lengths =
|
||||
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
|
||||
auto shuffle_acc = make_static_distributed_tensor<AccDataType>(dram_tile_distribution);
|
||||
auto c_out_tensor = make_static_distributed_tensor<ODataType>(dram_tile_distribution);
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto mIter) {
|
||||
shuffle_acc.get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, 0>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, NRepeat>{}, c_warp_y_lengths));
|
||||
|
||||
static_for<0, NRepeat, 1>{}([&](auto n_idx) {
|
||||
c_out_tensor.get_thread_buffer()[n_idx + 0 * c_warp_y_lengths.product()] =
|
||||
type_convert<ODataType>(
|
||||
shuffle_acc.get_thread_buffer()[n_idx * c_warp_y_lengths.product() + 0]);
|
||||
c_out_tensor.get_thread_buffer()[n_idx + 1 * c_warp_y_lengths.product()] =
|
||||
type_convert<ODataType>(
|
||||
shuffle_acc.get_thread_buffer()[n_idx * c_warp_y_lengths.product() + 1]);
|
||||
c_out_tensor.get_thread_buffer()[n_idx + 2 * c_warp_y_lengths.product()] =
|
||||
type_convert<ODataType>(
|
||||
shuffle_acc.get_thread_buffer()[n_idx * c_warp_y_lengths.product() + 2]);
|
||||
c_out_tensor.get_thread_buffer()[n_idx + 3 * c_warp_y_lengths.product()] =
|
||||
type_convert<ODataType>(
|
||||
shuffle_acc.get_thread_buffer()[n_idx * c_warp_y_lengths.product() + 3]);
|
||||
});
|
||||
|
||||
if constexpr(MemoryOperation == memory_operation_enum::set)
|
||||
{
|
||||
store_tile(out_dram_window, c_out_tensor);
|
||||
}
|
||||
else
|
||||
{
|
||||
update_tile(out_dram_window, c_out_tensor);
|
||||
}
|
||||
move_tile_window(out_dram_window, {number<MPerXdl * MWave>{}, number<0>{}});
|
||||
|
||||
static_for<0, NumDTensor, 1>{}([&](auto idx) {
|
||||
move_tile_window(d_dram_windows[idx], {number<MPerXdl * MWave>{}, number<0>{}});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
template <class, typename ODramWindow, typename OAccTile, typename DsDramWindows>
|
||||
CK_TILE_DEVICE auto operator()(ODramWindow& out_dram_window,
|
||||
const OAccTile& o_acc_tile,
|
||||
const DsDramWindows& ds_dram_windows,
|
||||
|
||||
Reference in New Issue
Block a user