Files
Aviral Goel d85f065b15 chore(copyright): update copyright header for example directory (#3273)
* chore(copyright): update copyright header for codegen directory

* chore(copyright): update copyright header for example directory
2025-11-24 18:02:41 -08:00

422 lines
13 KiB
C++

// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "permute.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/utility/json_dump.hpp"
#include <array>
#include <cstring>
#include <functional>
#include <numeric>
#include <ostream>
#include <string>
#include <tuple>
#include <utility>
#include <vector>
#ifdef PERMUTE_USE_ALTERNATIVE_IMPL
#include "alternative_impl/matrix_core_swizzle.hpp"
#endif
namespace detail {
template <int bytes>
struct to_integer_type;
template <>
struct to_integer_type<4>
{
using type = int32_t;
};
template <>
struct to_integer_type<2>
{
using type = int16_t;
};
template <>
struct to_integer_type<1>
{
using type = int8_t;
};
} // namespace detail
template <int bytes>
using to_integer_type = typename detail::to_integer_type<bytes>::type;
// host API (shoule come from codegen)
float permute(permute_traits t, permute_args a, const ck_tile::stream_config& s)
{
if(t.data_type.compare("fp8") == 0)
{
using DataType = ck_tile::fp8_t;
using PipelineProblem = ck_tile::GenericPermuteProblem<DataType>;
using Kernel = ck_tile::GenericPermute<PipelineProblem>;
auto kargs = Kernel::MakeKargs(a);
const dim3 grids = Kernel::GridSize(a);
const dim3 blocks = Kernel::BlockSize();
float ave_time =
ck_tile::launch_kernel(s, ck_tile::make_kernel<1>(Kernel{}, grids, blocks, 0, kargs));
return ave_time;
}
else if(t.data_type.compare("fp16") == 0)
{
using DataType = ck_tile::half_t;
using PipelineProblem = ck_tile::GenericPermuteProblem<DataType>;
using Kernel = ck_tile::GenericPermute<PipelineProblem>;
auto kargs = Kernel::MakeKargs(a);
const dim3 grids = Kernel::GridSize(a);
const dim3 blocks = Kernel::BlockSize();
float ave_time =
ck_tile::launch_kernel(s, ck_tile::make_kernel<1>(Kernel{}, grids, blocks, 0, kargs));
return ave_time;
}
else if(t.data_type.compare("fp32") == 0)
{
using DataType = float;
using PipelineProblem = ck_tile::GenericPermuteProblem<DataType>;
using Kernel = ck_tile::GenericPermute<PipelineProblem>;
auto kargs = Kernel::MakeKargs(a);
const dim3 grids = Kernel::GridSize(a);
const dim3 blocks = Kernel::BlockSize();
float ave_time =
ck_tile::launch_kernel(s, ck_tile::make_kernel<1>(Kernel{}, grids, blocks, 0, kargs));
return ave_time;
}
return 0;
}
template <typename T>
std::ostream& operator<<(std::ostream& os, const std::vector<T>& v)
{
using size_type = typename std::vector<T>::size_type;
os << "[";
for(size_type idx = 0; idx < v.size(); ++idx)
{
if(0 < idx)
{
os << ", ";
}
os << v[idx];
}
return os << "]";
}
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("v", "1", "weather do CPU validation or not")
.insert("prec", "fp16", "data type. fp8/fp16/fp32 (representing 8/16/32 bit data)")
.insert("shape", "2,3,4", "the shape of the input tensor")
.insert("perm", "2,1,0", "permute perm")
.insert("kname", "0", "t to 1 will print kernel name")
.insert("seed",
"11939",
"random seed used for initializing input tensors. 0 for "
"non-deterministic seed")
.insert("warmup", "5", "number of iterations before benchmark the kernel")
.insert("repeat", "20", "number of iterations to benchmark the kernel")
.insert("json", "0", "0: No Json, 1: Dump Results in Json format")
.insert("jsonfile", "permute.json", "json file name to dump results");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}
// different threshold for different dtype
template <typename DataType>
auto get_elimit(std::string /*init_method*/)
{
double rtol = 1e-3;
double atol = 1e-3;
return ck_tile::make_tuple(rtol, atol);
}
template <>
auto get_elimit<ck_tile::bf16_t>(std::string /*init_method*/)
{
double rtol = 1e-2;
double atol = 1e-2;
return ck_tile::make_tuple(rtol, atol);
}
template <>
auto get_elimit<ck_tile::fp8_t>(std::string init_method)
{
if(init_method == "ui" || init_method == "ni")
{
unsigned max_rounding_point_distance = 0;
double atol = 2e-3;
return ck_tile::make_tuple(max_rounding_point_distance, atol);
}
else
{
unsigned max_rounding_point_distance = 1;
double atol = 0.0625;
return ck_tile::make_tuple(max_rounding_point_distance, atol);
}
}
// "1,2,3,4" -> vector{1,2,3,4}
std::vector<ck_tile::index_t> decode_vec(std::string q_val)
{
#define _S2I_(str_) static_cast<ck_tile::index_t>(std::atoi((str_).c_str()))
std::string::size_type pos = 0;
std::vector<ck_tile::index_t> v;
while(true)
{
auto found = q_val.find(',', pos);
ck_tile::index_t n =
_S2I_(q_val.substr(pos, found == std::string::npos ? found : found - pos));
v.push_back(n);
if(found == std::string::npos)
{
break;
}
pos = found + 1;
}
return v;
#undef _S2I_
}
template <typename DataType>
bool run(const ck_tile::ArgParser& arg_parser)
{
std::string data_type = arg_parser.get_str("prec");
int do_validation = arg_parser.get_int("v");
auto shape = decode_vec(arg_parser.get_str("shape"));
auto perm = decode_vec(arg_parser.get_str("perm"));
int stream_warmup = arg_parser.get_int("warmup");
int stream_repeat = arg_parser.get_int("repeat");
bool kname = arg_parser.get_bool("kname");
int seed = arg_parser.get_int("seed");
assert(shape.size() == perm.size());
ck_tile::index_t rank = perm.size();
if(rank > ck_tile::GenericPermuteHostArgs::kMaxRanks)
{
printf("rank %d permute is not support yet\n", rank);
return false;
}
ck_tile::HostTensor<DataType> x(shape);
ck_tile::FillUniformDistributionIntegerValue<DataType>{-15, 15, seed}(x);
std::vector<ck_tile::index_t> y_shape = [&]() {
std::vector<ck_tile::index_t> tmp(rank, 0);
// std::cout << "@@@@" << tmp << std::endl;
for(int i = 0; i < static_cast<int>(rank); i++)
{
// std::cout << " i:" << i << ", perm:" << perm[i] << ", rak:" <<
// static_cast<int>(rank)
// << std::endl;
tmp[i] = shape[perm[i]];
}
// std::cout << "@@@" << tmp << std::endl;
return tmp;
}();
ck_tile::HostTensor<DataType> y(y_shape);
ck_tile::DeviceMem x_buf(x.get_element_space_size_in_bytes());
ck_tile::DeviceMem y_buf(y.get_element_space_size_in_bytes());
x_buf.ToDevice(x.data());
std::cout << "[" << data_type << "] shape:" << shape << "->" << y_shape << ", permute:" << perm
<< std::flush;
ck_tile::stream_config stream_config{nullptr,
true,
/* log_level = */ (kname ? 1 : 0),
stream_warmup,
stream_repeat};
float ave_time = 0.f;
auto run_permute = [&]() {
permute_traits t;
t.data_type = data_type;
permute_args a;
a.p_src = x_buf.GetDeviceBuffer();
a.p_dst = y_buf.GetDeviceBuffer();
a.rank = rank;
std::copy(shape.begin(), shape.end(), a.shape);
std::copy(perm.begin(), perm.end(), a.perm);
return permute(t, a, stream_config);
};
#if !CK_TILE_USE_WMMA
#ifdef PERMUTE_USE_ALTERNATIVE_IMPL
// batch* n0*n1*n2*k0*k1*k2 -> batch* n0*k0*n1*k1*n2*k2
if((arg_parser.get_str("perm") == std::string("0,1,4,2,5,3,6") ||
arg_parser.get_str("perm") == std::string("0,1,2,4,5,3,6") ||
arg_parser.get_str("perm") == std::string("0,1,3,4,2,5")))
{
if(arg_parser.get_str("perm") == std::string("0,1,3,4,2,5"))
{
// b_nr_kr_kw_nw_kv = 2, // 0,1,3,4,2,5
matrix_core_swizzle_traits t;
t.data_type = data_type;
t.permute = arg_parser.get_str("perm");
matrix_core_swizzle_args a;
a.p_src = x_buf.GetDeviceBuffer();
a.p_dst = y_buf.GetDeviceBuffer();
a.batch = shape[0];
auto nr = shape[1];
auto nw = shape[2];
auto kr = shape[3];
auto kw = shape[4];
auto kv = shape[5];
a.n = nr * nw;
a.k = kr * kw * kv;
if(kv == 8 && kw == 4 && nw == 16 && nr % 4 == 0 && kr % 8 == 0)
{
t.inst = "16x16x16";
std::cout << ", matrix_core_swizzle_waveflatten_" << t.inst << std::flush;
ave_time = matrix_core_swizzle(t, a, stream_config);
}
else if(kv == 8 && kw == 2 && nw == 32 && nr % 4 == 0 && kr % 8 == 0)
{
t.inst = "32x32x8";
std::cout << ", matrix_core_swizzle_waveflatten_" << t.inst << std::flush;
ave_time = matrix_core_swizzle(t, a, stream_config);
}
else
{
ave_time = run_permute();
}
}
else
{
matrix_core_swizzle_traits t;
t.data_type = data_type;
t.permute = arg_parser.get_str("perm");
matrix_core_swizzle_args a;
a.p_src = x_buf.GetDeviceBuffer();
a.p_dst = y_buf.GetDeviceBuffer();
a.batch = shape[0];
a.n = shape[1] * shape[2] * shape[3];
a.k = shape[4] * shape[5] * shape[6];
if(shape[6] == 8 && shape[3] == 32 && shape[5] == 2 && shape[2] == 4 &&
shape[4] % 8 == 0 && shape[1] % 2 == 0)
{
// 32x32x8 inst
// perm=0,1,4,2,5,3,6
// y_shape=*,2x,8x,4,2,32,8 (3,6,16,4,2,32,8)
// shape = *,2x,4,32,8x,2,8 (3,6,4,32,16,2,8)
t.inst = "32x32x8";
std::cout << ", matrix_core_swizzle_" << t.inst << std::flush;
ave_time = matrix_core_swizzle(t, a, stream_config);
}
else if(shape[6] == 8 && shape[3] == 16 && shape[5] == 4 && shape[2] == 4 &&
shape[4] % 4 == 0 && shape[1] % 4 == 0)
{
// 16x16x16 inst
// perm=0,1,4,2,5,3,6
// y_shape=*,4x,4x,4,4,16,8
// shape = *,4x,4,16,4x,4,8 (3,8,4,16,16,4,8)
t.inst = "16x16x16";
std::cout << ", matrix_core_swizzle_" << t.inst << std::flush;
ave_time = matrix_core_swizzle(t, a, stream_config);
}
else
{
ave_time = run_permute();
}
}
}
else
#endif
#endif
{
ave_time = run_permute();
}
std::cout << ", time:" << ave_time << "ms" << std::flush;
bool pass = true;
if(do_validation)
{
reference_permute(x, y, perm);
#if 0
if constexpr (std::is_same_v<float, DataType>){
// using itype = to_integer_type<sizeof(DataType)>;
fflush(stdout);
for(int zz = 0; zz < static_cast<int>(x.get_element_size()); zz++ ) {
printf("%3.0f ", x.mData[zz]);
}
printf("->\n");
for(int zz = 0; zz < static_cast<int>(x.get_element_size()); zz++ ) {
printf("%3.0f ", y.mData[zz]);
}
fflush(stdout);
}
#endif
ck_tile::HostTensor<DataType> y_dev(y.get_lengths());
y_buf.FromDevice(y_dev.data());
pass = std::equal(
y_dev.begin(), y_dev.end(), y.begin(), [&](const DataType& d, const DataType& h) {
using itype = to_integer_type<sizeof(DataType)>;
itype i_d = ck_tile::bit_cast<itype>(d);
itype i_h = ck_tile::bit_cast<itype>(h);
return i_d == i_h;
});
std::cout << ", valid:" << (pass ? "y" : "n") << std::flush;
}
if(arg_parser.get_int("json") == 1)
{
dump_permute_json_results(arg_parser.get_str("jsonfile"), data_type, pass, ave_time, 0, 0);
}
std::cout << std::endl;
return pass;
}
int main(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;
const std::string data_type = arg_parser.get_str("prec");
if(data_type == "fp8")
{
return run<ck_tile::fp8_t>(arg_parser) ? 0 : -2;
}
else if(data_type == "fp16")
{
return run<ck_tile::half_t>(arg_parser) ? 0 : -2;
}
else if(data_type == "fp32")
{
return run<float>(arg_parser) ? 0 : -2;
}
return -3;
}