mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[CK_TILE] add more stride for layernorm to support un-continuous Tensor (#1650)
* [CK_TILE] add more stride for layernorm to support un-continuous Tensor * align CK coding style * extend strides to layernrom expample * clang-format...
This commit is contained in:
@@ -25,7 +25,10 @@ auto create_args(int argc, char* argv[])
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("m", "3328", "m dimension")
|
||||
.insert("n", "4096", "n dimension")
|
||||
.insert("stride", "-1", "stride per row, if -1 then equal to n")
|
||||
.insert("x_stride", "-1", "x row_stride, if -1 then equal to n")
|
||||
.insert("xr_stride", "-1", "x residule row_stride, if -1 then equal to n")
|
||||
.insert("y_stride", "-1", "y row_stride, if -1 then equal to n")
|
||||
.insert("yr_stride", "-1", "y residule row_stride, if -1 then equal to n")
|
||||
.insert("e", "1e-5", "epsilon")
|
||||
.insert("save_mv", "0", "save mean/variance(invstd) or not. set to 1 in training case")
|
||||
.insert("v", "1", "cpu validation or not")
|
||||
@@ -54,11 +57,20 @@ template <typename InDataType,
|
||||
bool SaveMeanVar>
|
||||
bool run(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
ck_tile::index_t m = arg_parser.get_int("m");
|
||||
ck_tile::index_t n = arg_parser.get_int("n");
|
||||
ck_tile::index_t stride = arg_parser.get_int("stride");
|
||||
if(stride < 0)
|
||||
stride = n;
|
||||
ck_tile::index_t m = arg_parser.get_int("m");
|
||||
ck_tile::index_t n = arg_parser.get_int("n");
|
||||
ck_tile::index_t x_stride = arg_parser.get_int("x_stride");
|
||||
if(x_stride < 0)
|
||||
x_stride = n;
|
||||
ck_tile::index_t xr_stride = arg_parser.get_int("xr_stride");
|
||||
if(xr_stride < 0)
|
||||
xr_stride = n;
|
||||
ck_tile::index_t y_stride = arg_parser.get_int("y_stride");
|
||||
if(y_stride < 0)
|
||||
y_stride = n;
|
||||
ck_tile::index_t yr_stride = arg_parser.get_int("yr_stride");
|
||||
if(yr_stride < 0)
|
||||
yr_stride = n;
|
||||
float epsilon = arg_parser.get_float("e");
|
||||
std::string prec_i = arg_parser.get_str("prec_i");
|
||||
std::string prec_o = arg_parser.get_str("prec_o");
|
||||
@@ -89,7 +101,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
return false;
|
||||
}
|
||||
|
||||
assert(stride >= n);
|
||||
assert(x_stride >= n);
|
||||
|
||||
using TypeConfig = LayerNormTypeConfig<InDataType, OutDataType, XScaleDataType, YScaleDataType>;
|
||||
|
||||
@@ -108,15 +120,15 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
using ComputeDataType = typename TypeConfig::ComputeDataType;
|
||||
|
||||
// host verify
|
||||
ck_tile::HostTensor<XDataType> x_host({m, n}, {stride, 1});
|
||||
ck_tile::HostTensor<XDataType> x_host({m, n}, {x_stride, 1});
|
||||
ck_tile::HostTensor<GammaDataType> gamma_host({n});
|
||||
ck_tile::HostTensor<BetaDataType> beta_host({n});
|
||||
|
||||
ck_tile::HostTensor<XResidualDataType> x_residual_host({m, n}, {stride, 1});
|
||||
ck_tile::HostTensor<YResidualDataType> y_residual_host({m, n}, {stride, 1});
|
||||
ck_tile::HostTensor<XResidualDataType> x_residual_host({m, n}, {xr_stride, 1});
|
||||
ck_tile::HostTensor<YResidualDataType> y_residual_host({m, n}, {yr_stride, 1});
|
||||
|
||||
ck_tile::HostTensor<YDataType> y_host_ref({m, n}, {stride, 1});
|
||||
ck_tile::HostTensor<YDataType> y_host_dev({m, n}, {stride, 1});
|
||||
ck_tile::HostTensor<YDataType> y_host_ref({m, n}, {y_stride, 1});
|
||||
ck_tile::HostTensor<YDataType> y_host_dev({m, n}, {y_stride, 1});
|
||||
|
||||
ck_tile::HostTensor<MeanDataType> mean_host_ref({m});
|
||||
ck_tile::HostTensor<InvStdDataType> invStd_host_ref({m});
|
||||
@@ -162,7 +174,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
}();
|
||||
|
||||
std::cout << "[" << prec_str << "]"
|
||||
<< " m:" << m << ", n:" << n << ", stride:" << stride << std::flush;
|
||||
<< " m:" << m << ", n:" << n << ", x_stride:" << x_stride
|
||||
<< ", xr_stride:" << xr_stride << ", y_stride:" << y_stride
|
||||
<< ", yr_stride:" << yr_stride << std::flush;
|
||||
|
||||
layernorm2d_fwd_traits traits{
|
||||
prec_i, prec_o, prec_sx, prec_sy, SaveMeanVar, fused_add, fused_quant};
|
||||
@@ -182,7 +196,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
epsilon,
|
||||
m,
|
||||
n,
|
||||
stride};
|
||||
x_stride, // x row_stride
|
||||
xr_stride, // x residule row stride
|
||||
y_stride, // y row stride
|
||||
yr_stride}; // y residule row stride
|
||||
|
||||
float ave_time = layernorm2d_fwd(
|
||||
traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat});
|
||||
@@ -285,7 +302,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
y_buf.FromDevice(y_host_dev.data());
|
||||
|
||||
ck_tile::HostTensor<YResidualDataType> y_residual_host_dev({m, n}, {stride, 1});
|
||||
ck_tile::HostTensor<YResidualDataType> y_residual_host_dev({m, n}, {yr_stride, 1});
|
||||
if(fused_add == 1)
|
||||
{
|
||||
y_residual_buf.FromDevice(y_residual_host_dev.data());
|
||||
@@ -293,7 +310,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
auto [rtol, atol] = get_elimit<InDataType>();
|
||||
|
||||
if(stride == n)
|
||||
if(x_stride == n)
|
||||
{
|
||||
pass = ck_tile::check_err(
|
||||
y_host_dev, y_host_ref, std::string("OUT Error: Incorrect results!"), rtol, atol);
|
||||
@@ -310,10 +327,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
for(int i_r = 0; i_r < m; i_r++)
|
||||
{
|
||||
std::vector<YDataType> y_host_dev_row(y_host_dev.begin() + i_r * stride,
|
||||
y_host_dev.begin() + i_r * stride + n);
|
||||
std::vector<YDataType> y_host_ref_row(y_host_ref.begin() + i_r * stride,
|
||||
y_host_ref.begin() + i_r * stride + n);
|
||||
std::vector<YDataType> y_host_dev_row(y_host_dev.begin() + i_r * y_stride,
|
||||
y_host_dev.begin() + i_r * y_stride + n);
|
||||
std::vector<YDataType> y_host_ref_row(y_host_ref.begin() + i_r * y_stride,
|
||||
y_host_ref.begin() + i_r * y_stride + n);
|
||||
pass &= ck_tile::check_err(y_host_dev_row,
|
||||
y_host_ref_row,
|
||||
std::string("OUT[") + std::to_string(i_r) +
|
||||
@@ -323,10 +340,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
if(fused_add == 1)
|
||||
{
|
||||
std::vector<YResidualDataType> y_residual_host_dev_row(
|
||||
y_residual_host_dev.begin() + i_r * stride,
|
||||
y_residual_host_dev.begin() + i_r * stride + n);
|
||||
y_residual_host_dev.begin() + i_r * yr_stride,
|
||||
y_residual_host_dev.begin() + i_r * yr_stride + n);
|
||||
std::vector<YResidualDataType> y_residual_host_ref_row(
|
||||
x_host.begin() + i_r * stride, x_host.begin() + i_r * stride + n);
|
||||
x_host.begin() + i_r * yr_stride, x_host.begin() + i_r * yr_stride + n);
|
||||
pass &= ck_tile::check_err(y_residual_host_dev_row,
|
||||
y_residual_host_ref_row,
|
||||
std::string("ADD[") + std::to_string(i_r) +
|
||||
|
||||
Reference in New Issue
Block a user