diff --git a/example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp b/example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp index 8f029c212c..b49c04619d 100644 --- a/example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp +++ b/example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp @@ -25,7 +25,10 @@ auto create_args(int argc, char* argv[]) ck_tile::ArgParser arg_parser; arg_parser.insert("m", "3328", "m dimension") .insert("n", "4096", "n dimension") - .insert("stride", "-1", "stride per row, if -1 then equal to n") + .insert("x_stride", "-1", "x row_stride, if -1 then equal to n") + .insert("xr_stride", "-1", "x residule row_stride, if -1 then equal to n") + .insert("y_stride", "-1", "y row_stride, if -1 then equal to n") + .insert("yr_stride", "-1", "y residule row_stride, if -1 then equal to n") .insert("e", "1e-5", "epsilon") .insert("save_mv", "0", "save mean/variance(invstd) or not. set to 1 in training case") .insert("v", "1", "cpu validation or not") @@ -54,11 +57,20 @@ template bool run(const ck_tile::ArgParser& arg_parser) { - ck_tile::index_t m = arg_parser.get_int("m"); - ck_tile::index_t n = arg_parser.get_int("n"); - ck_tile::index_t stride = arg_parser.get_int("stride"); - if(stride < 0) - stride = n; + ck_tile::index_t m = arg_parser.get_int("m"); + ck_tile::index_t n = arg_parser.get_int("n"); + ck_tile::index_t x_stride = arg_parser.get_int("x_stride"); + if(x_stride < 0) + x_stride = n; + ck_tile::index_t xr_stride = arg_parser.get_int("xr_stride"); + if(xr_stride < 0) + xr_stride = n; + ck_tile::index_t y_stride = arg_parser.get_int("y_stride"); + if(y_stride < 0) + y_stride = n; + ck_tile::index_t yr_stride = arg_parser.get_int("yr_stride"); + if(yr_stride < 0) + yr_stride = n; float epsilon = arg_parser.get_float("e"); std::string prec_i = arg_parser.get_str("prec_i"); std::string prec_o = arg_parser.get_str("prec_o"); @@ -89,7 +101,7 @@ bool run(const ck_tile::ArgParser& arg_parser) return false; } - assert(stride >= n); + assert(x_stride >= n); using TypeConfig = LayerNormTypeConfig; @@ -108,15 +120,15 @@ bool run(const ck_tile::ArgParser& arg_parser) using ComputeDataType = typename TypeConfig::ComputeDataType; // host verify - ck_tile::HostTensor x_host({m, n}, {stride, 1}); + ck_tile::HostTensor x_host({m, n}, {x_stride, 1}); ck_tile::HostTensor gamma_host({n}); ck_tile::HostTensor beta_host({n}); - ck_tile::HostTensor x_residual_host({m, n}, {stride, 1}); - ck_tile::HostTensor y_residual_host({m, n}, {stride, 1}); + ck_tile::HostTensor x_residual_host({m, n}, {xr_stride, 1}); + ck_tile::HostTensor y_residual_host({m, n}, {yr_stride, 1}); - ck_tile::HostTensor y_host_ref({m, n}, {stride, 1}); - ck_tile::HostTensor y_host_dev({m, n}, {stride, 1}); + ck_tile::HostTensor y_host_ref({m, n}, {y_stride, 1}); + ck_tile::HostTensor y_host_dev({m, n}, {y_stride, 1}); ck_tile::HostTensor mean_host_ref({m}); ck_tile::HostTensor invStd_host_ref({m}); @@ -162,7 +174,9 @@ bool run(const ck_tile::ArgParser& arg_parser) }(); std::cout << "[" << prec_str << "]" - << " m:" << m << ", n:" << n << ", stride:" << stride << std::flush; + << " m:" << m << ", n:" << n << ", x_stride:" << x_stride + << ", xr_stride:" << xr_stride << ", y_stride:" << y_stride + << ", yr_stride:" << yr_stride << std::flush; layernorm2d_fwd_traits traits{ prec_i, prec_o, prec_sx, prec_sy, SaveMeanVar, fused_add, fused_quant}; @@ -182,7 +196,10 @@ bool run(const ck_tile::ArgParser& arg_parser) epsilon, m, n, - stride}; + x_stride, // x row_stride + xr_stride, // x residule row stride + y_stride, // y row stride + yr_stride}; // y residule row stride float ave_time = layernorm2d_fwd( traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat}); @@ -285,7 +302,7 @@ bool run(const ck_tile::ArgParser& arg_parser) y_buf.FromDevice(y_host_dev.data()); - ck_tile::HostTensor y_residual_host_dev({m, n}, {stride, 1}); + ck_tile::HostTensor y_residual_host_dev({m, n}, {yr_stride, 1}); if(fused_add == 1) { y_residual_buf.FromDevice(y_residual_host_dev.data()); @@ -293,7 +310,7 @@ bool run(const ck_tile::ArgParser& arg_parser) auto [rtol, atol] = get_elimit(); - if(stride == n) + if(x_stride == n) { pass = ck_tile::check_err( y_host_dev, y_host_ref, std::string("OUT Error: Incorrect results!"), rtol, atol); @@ -310,10 +327,10 @@ bool run(const ck_tile::ArgParser& arg_parser) { for(int i_r = 0; i_r < m; i_r++) { - std::vector y_host_dev_row(y_host_dev.begin() + i_r * stride, - y_host_dev.begin() + i_r * stride + n); - std::vector y_host_ref_row(y_host_ref.begin() + i_r * stride, - y_host_ref.begin() + i_r * stride + n); + std::vector y_host_dev_row(y_host_dev.begin() + i_r * y_stride, + y_host_dev.begin() + i_r * y_stride + n); + std::vector y_host_ref_row(y_host_ref.begin() + i_r * y_stride, + y_host_ref.begin() + i_r * y_stride + n); pass &= ck_tile::check_err(y_host_dev_row, y_host_ref_row, std::string("OUT[") + std::to_string(i_r) + @@ -323,10 +340,10 @@ bool run(const ck_tile::ArgParser& arg_parser) if(fused_add == 1) { std::vector y_residual_host_dev_row( - y_residual_host_dev.begin() + i_r * stride, - y_residual_host_dev.begin() + i_r * stride + n); + y_residual_host_dev.begin() + i_r * yr_stride, + y_residual_host_dev.begin() + i_r * yr_stride + n); std::vector y_residual_host_ref_row( - x_host.begin() + i_r * stride, x_host.begin() + i_r * stride + n); + x_host.begin() + i_r * yr_stride, x_host.begin() + i_r * yr_stride + n); pass &= ck_tile::check_err(y_residual_host_dev_row, y_residual_host_ref_row, std::string("ADD[") + std::to_string(i_r) + diff --git a/include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp b/include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp index f5a214ba57..10218e8084 100644 --- a/include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp +++ b/include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp @@ -28,7 +28,10 @@ struct Layernorm2dFwdHostArgs index_t m; index_t n; - index_t stride; // row_stride + index_t x_stride; // x row_stride + index_t xr_stride; // x residule row stride + index_t y_stride; // y row stride + index_t yr_stride; // y residule row stride }; // TODO: Extract some type to wrapper class @@ -93,7 +96,10 @@ struct Layernorm2dFwd index_t m; index_t n; - index_t stride; // row_stride + index_t x_stride; // x row_stride + index_t xr_stride; // x residule row stride + index_t y_stride; // y row stride + index_t yr_stride; // y residule row stride }; using Hargs = Layernorm2dFwdHostArgs; @@ -112,7 +118,10 @@ struct Layernorm2dFwd hargs.epsilon, hargs.m, hargs.n, - hargs.stride}; + hargs.x_stride, + hargs.xr_stride, + hargs.y_stride, + hargs.yr_stride}; } CK_TILE_HOST static constexpr auto GridSize(const Hargs& hargs) @@ -182,7 +191,7 @@ struct Layernorm2dFwd const auto tmp_ = make_naive_tensor_view( static_cast(kargs.p_x), make_tuple(kargs.m, kargs.n), - make_tuple(kargs.stride, 1), + make_tuple(kargs.x_stride, 1), number{}, number<1>{}); @@ -201,7 +210,7 @@ struct Layernorm2dFwd const auto tmp_ = make_naive_tensor_view( static_cast(kargs.p_x_residual), make_tuple(kargs.m, kargs.n), - make_tuple(kargs.stride, 1), + make_tuple(kargs.xr_stride, 1), number{}, number<1>{}); @@ -250,7 +259,7 @@ struct Layernorm2dFwd auto tmp_ = make_naive_tensor_view( static_cast(kargs.p_y), make_tuple(kargs.m, kargs.n), - make_tuple(kargs.stride, 1), + make_tuple(kargs.y_stride, 1), number{}, number<1>{}); @@ -266,7 +275,7 @@ struct Layernorm2dFwd auto tmp_ = make_naive_tensor_view( static_cast(kargs.p_y_residual), make_tuple(kargs.m, kargs.n), - make_tuple(kargs.stride, 1), + make_tuple(kargs.yr_stride, 1), number{}, number<1>{});