From e4a169dd47592937ba4e262877ea26b8afee6b20 Mon Sep 17 00:00:00 2001 From: rocking Date: Thu, 24 Oct 2024 08:56:54 +0000 Subject: [PATCH] refine example of rmsnorm --- .../06_rmsnorm2d/example_rmsnorm2d_fwd.cpp | 43 +++++-------------- 1 file changed, 11 insertions(+), 32 deletions(-) diff --git a/example/ck_tile/06_rmsnorm2d/example_rmsnorm2d_fwd.cpp b/example/ck_tile/06_rmsnorm2d/example_rmsnorm2d_fwd.cpp index f8f426b63b..9dd4c8e001 100644 --- a/example/ck_tile/06_rmsnorm2d/example_rmsnorm2d_fwd.cpp +++ b/example/ck_tile/06_rmsnorm2d/example_rmsnorm2d_fwd.cpp @@ -4,23 +4,6 @@ #include "ck_tile/ops/rmsnorm2d.hpp" #include -// different threshold for different dtype -template -auto get_elimit() -{ - double rtol = 1e-2; - double atol = 1e-2; - return ck_tile::make_tuple(rtol, atol); -} - -template <> -auto get_elimit() -{ - double rtol = 1e-2; - double atol = 1e-2; - return ck_tile::make_tuple(rtol, atol); -} - auto create_args(int argc, char* argv[]) { ck_tile::ArgParser arg_parser; @@ -29,10 +12,9 @@ auto create_args(int argc, char* argv[]) .insert("stride", "-1", "stride per row, if -1 then equal to n") .insert("e", "1e-5", "epsilon") .insert("v", "1", "cpu validation or not") - .insert("kname", "1", "print kernel name or not") .insert("prec", "fp16", "precision") - .insert("warmup", "5", "cold iter") - .insert("repeat", "20", "hot iter"); + .insert("warmup", "0", "cold iter") + .insert("repeat", "1", "hot iter"); bool result = arg_parser.parse(argc, argv); return std::make_tuple(result, arg_parser); @@ -48,7 +30,6 @@ bool run(const ck_tile::ArgParser& arg_parser) stride = n; float epsilon = arg_parser.get_float("e"); std::string data_type = arg_parser.get_str("prec"); - int kname = arg_parser.get_int("kname"); int do_validation = arg_parser.get_int("v"); int warmup = arg_parser.get_int("warmup"); int repeat = arg_parser.get_int("repeat"); @@ -81,15 +62,12 @@ bool run(const ck_tile::ArgParser& arg_parser) x_buf.ToDevice(x_host.data()); gamma_buf.ToDevice(gamma_host.data()); - std::cout << "[" << data_type << "]" - << " m:" << m << ", n:" << n << ", stride:" << stride << std::flush; - constexpr bool kTwoPass = true; - using BlockWarps = ck_tile::sequence<4, 1>; - using BlockTile = ck_tile::sequence<8, 64>; - using WarpTile = ck_tile::sequence<2, 64>; - using Vector = ck_tile::sequence<1, 2>; + using BlockWarps = ck_tile::sequence<2, 2>; + using BlockTile = ck_tile::sequence<2, 256>; + using WarpTile = ck_tile::sequence<1, 64>; + using Vector = ck_tile::sequence<1, 1>; using Shape = ck_tile::Rmsnorm2dShape; using Problem = ck_tile::Rmsnorm2dFwdPipelineProblem; @@ -121,7 +99,7 @@ bool run(const ck_tile::ArgParser& arg_parser) const dim3 grids = Kernel::GridSize(args); constexpr dim3 blocks = Kernel::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = 1; - auto s = ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat}; + auto s = ck_tile::stream_config{nullptr, true, 0, warmup, repeat}; ck_tile::launch_kernel( s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); @@ -140,7 +118,7 @@ bool run(const ck_tile::ArgParser& arg_parser) y_buf.FromDevice(y_host_dev.data()); - auto [rtol, atol] = get_elimit(); + auto [rtol, atol] = ck_tile::make_tuple(1e-3, 1e-3); if(stride == n) { pass = ck_tile::check_err( @@ -163,7 +141,8 @@ bool run(const ck_tile::ArgParser& arg_parser) } } - std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl; + std::cout << "[" << data_type << "]" << " m:" << m << ", n:" << n << ", stride:" << stride + << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl; } return pass;