refine example of rmsnorm

This commit is contained in:
rocking
2024-10-24 08:56:54 +00:00
parent a50ec83d03
commit e4a169dd47

View File

@@ -4,23 +4,6 @@
#include "ck_tile/ops/rmsnorm2d.hpp"
#include <cstring>
// different threshold for different dtype
template <typename DataType>
auto get_elimit()
{
double rtol = 1e-2;
double atol = 1e-2;
return ck_tile::make_tuple(rtol, atol);
}
template <>
auto get_elimit<ck_tile::bf16_t>()
{
double rtol = 1e-2;
double atol = 1e-2;
return ck_tile::make_tuple(rtol, atol);
}
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
@@ -29,10 +12,9 @@ auto create_args(int argc, char* argv[])
.insert("stride", "-1", "stride per row, if -1 then equal to n")
.insert("e", "1e-5", "epsilon")
.insert("v", "1", "cpu validation or not")
.insert("kname", "1", "print kernel name or not")
.insert("prec", "fp16", "precision")
.insert("warmup", "5", "cold iter")
.insert("repeat", "20", "hot iter");
.insert("warmup", "0", "cold iter")
.insert("repeat", "1", "hot iter");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
@@ -48,7 +30,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
stride = n;
float epsilon = arg_parser.get_float("e");
std::string data_type = arg_parser.get_str("prec");
int kname = arg_parser.get_int("kname");
int do_validation = arg_parser.get_int("v");
int warmup = arg_parser.get_int("warmup");
int repeat = arg_parser.get_int("repeat");
@@ -81,15 +62,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
x_buf.ToDevice(x_host.data());
gamma_buf.ToDevice(gamma_host.data());
std::cout << "[" << data_type << "]"
<< " m:" << m << ", n:" << n << ", stride:" << stride << std::flush;
constexpr bool kTwoPass = true;
using BlockWarps = ck_tile::sequence<4, 1>;
using BlockTile = ck_tile::sequence<8, 64>;
using WarpTile = ck_tile::sequence<2, 64>;
using Vector = ck_tile::sequence<1, 2>;
using BlockWarps = ck_tile::sequence<2, 2>;
using BlockTile = ck_tile::sequence<2, 256>;
using WarpTile = ck_tile::sequence<1, 64>;
using Vector = ck_tile::sequence<1, 1>;
using Shape = ck_tile::Rmsnorm2dShape<BlockTile, BlockWarps, WarpTile, Vector>;
using Problem = ck_tile::Rmsnorm2dFwdPipelineProblem<XDataType,
@@ -98,7 +76,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
YDataType,
InvRmsDataType,
Shape,
false, // kPadN
true, // kPadN
false, // kSaveInvRms
kTwoPass>;
@@ -121,7 +99,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
const dim3 grids = Kernel::GridSize(args);
constexpr dim3 blocks = Kernel::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = 1;
auto s = ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat};
auto s = ck_tile::stream_config{nullptr, true, 0, warmup, repeat};
ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
@@ -140,7 +118,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
y_buf.FromDevice(y_host_dev.data());
auto [rtol, atol] = get_elimit<DataType>();
auto [rtol, atol] = ck_tile::make_tuple(1e-3, 1e-3);
if(stride == n)
{
pass = ck_tile::check_err(
@@ -163,7 +141,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
}
}
std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl;
std::cout << "[" << data_type << "]" << " m:" << m << ", n:" << n << ", stride:" << stride
<< ", valid:" << (pass ? "y" : "n") << std::flush << std::endl;
}
return pass;