mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-06 15:54:31 +00:00
refine example of rmsnorm
This commit is contained in:
@@ -4,23 +4,6 @@
|
||||
#include "ck_tile/ops/rmsnorm2d.hpp"
|
||||
#include <cstring>
|
||||
|
||||
// different threshold for different dtype
|
||||
template <typename DataType>
|
||||
auto get_elimit()
|
||||
{
|
||||
double rtol = 1e-2;
|
||||
double atol = 1e-2;
|
||||
return ck_tile::make_tuple(rtol, atol);
|
||||
}
|
||||
|
||||
template <>
|
||||
auto get_elimit<ck_tile::bf16_t>()
|
||||
{
|
||||
double rtol = 1e-2;
|
||||
double atol = 1e-2;
|
||||
return ck_tile::make_tuple(rtol, atol);
|
||||
}
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
@@ -29,10 +12,9 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("stride", "-1", "stride per row, if -1 then equal to n")
|
||||
.insert("e", "1e-5", "epsilon")
|
||||
.insert("v", "1", "cpu validation or not")
|
||||
.insert("kname", "1", "print kernel name or not")
|
||||
.insert("prec", "fp16", "precision")
|
||||
.insert("warmup", "5", "cold iter")
|
||||
.insert("repeat", "20", "hot iter");
|
||||
.insert("warmup", "0", "cold iter")
|
||||
.insert("repeat", "1", "hot iter");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
@@ -48,7 +30,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
stride = n;
|
||||
float epsilon = arg_parser.get_float("e");
|
||||
std::string data_type = arg_parser.get_str("prec");
|
||||
int kname = arg_parser.get_int("kname");
|
||||
int do_validation = arg_parser.get_int("v");
|
||||
int warmup = arg_parser.get_int("warmup");
|
||||
int repeat = arg_parser.get_int("repeat");
|
||||
@@ -81,15 +62,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
x_buf.ToDevice(x_host.data());
|
||||
gamma_buf.ToDevice(gamma_host.data());
|
||||
|
||||
std::cout << "[" << data_type << "]"
|
||||
<< " m:" << m << ", n:" << n << ", stride:" << stride << std::flush;
|
||||
|
||||
constexpr bool kTwoPass = true;
|
||||
|
||||
using BlockWarps = ck_tile::sequence<4, 1>;
|
||||
using BlockTile = ck_tile::sequence<8, 64>;
|
||||
using WarpTile = ck_tile::sequence<2, 64>;
|
||||
using Vector = ck_tile::sequence<1, 2>;
|
||||
using BlockWarps = ck_tile::sequence<2, 2>;
|
||||
using BlockTile = ck_tile::sequence<2, 256>;
|
||||
using WarpTile = ck_tile::sequence<1, 64>;
|
||||
using Vector = ck_tile::sequence<1, 1>;
|
||||
|
||||
using Shape = ck_tile::Rmsnorm2dShape<BlockTile, BlockWarps, WarpTile, Vector>;
|
||||
using Problem = ck_tile::Rmsnorm2dFwdPipelineProblem<XDataType,
|
||||
@@ -98,7 +76,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
YDataType,
|
||||
InvRmsDataType,
|
||||
Shape,
|
||||
false, // kPadN
|
||||
true, // kPadN
|
||||
false, // kSaveInvRms
|
||||
kTwoPass>;
|
||||
|
||||
@@ -121,7 +99,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
const dim3 grids = Kernel::GridSize(args);
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = 1;
|
||||
auto s = ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat};
|
||||
auto s = ck_tile::stream_config{nullptr, true, 0, warmup, repeat};
|
||||
|
||||
ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
@@ -140,7 +118,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
y_buf.FromDevice(y_host_dev.data());
|
||||
|
||||
auto [rtol, atol] = get_elimit<DataType>();
|
||||
auto [rtol, atol] = ck_tile::make_tuple(1e-3, 1e-3);
|
||||
if(stride == n)
|
||||
{
|
||||
pass = ck_tile::check_err(
|
||||
@@ -163,7 +141,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
}
|
||||
}
|
||||
|
||||
std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl;
|
||||
std::cout << "[" << data_type << "]" << " m:" << m << ", n:" << n << ", stride:" << stride
|
||||
<< ", valid:" << (pass ? "y" : "n") << std::flush << std::endl;
|
||||
}
|
||||
|
||||
return pass;
|
||||
|
||||
Reference in New Issue
Block a user