mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[CK_TILE] add more stride for layernorm to support un-continuous Tensor (#1650)
* [CK_TILE] add more stride for layernorm to support un-continuous Tensor * align CK coding style * extend strides to layernrom expample * clang-format...
This commit is contained in:
@@ -28,7 +28,10 @@ struct Layernorm2dFwdHostArgs
|
||||
|
||||
index_t m;
|
||||
index_t n;
|
||||
index_t stride; // row_stride
|
||||
index_t x_stride; // x row_stride
|
||||
index_t xr_stride; // x residule row stride
|
||||
index_t y_stride; // y row stride
|
||||
index_t yr_stride; // y residule row stride
|
||||
};
|
||||
|
||||
// TODO: Extract some type to wrapper class
|
||||
@@ -93,7 +96,10 @@ struct Layernorm2dFwd
|
||||
|
||||
index_t m;
|
||||
index_t n;
|
||||
index_t stride; // row_stride
|
||||
index_t x_stride; // x row_stride
|
||||
index_t xr_stride; // x residule row stride
|
||||
index_t y_stride; // y row stride
|
||||
index_t yr_stride; // y residule row stride
|
||||
};
|
||||
using Hargs = Layernorm2dFwdHostArgs;
|
||||
|
||||
@@ -112,7 +118,10 @@ struct Layernorm2dFwd
|
||||
hargs.epsilon,
|
||||
hargs.m,
|
||||
hargs.n,
|
||||
hargs.stride};
|
||||
hargs.x_stride,
|
||||
hargs.xr_stride,
|
||||
hargs.y_stride,
|
||||
hargs.yr_stride};
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto GridSize(const Hargs& hargs)
|
||||
@@ -182,7 +191,7 @@ struct Layernorm2dFwd
|
||||
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
|
||||
static_cast<const XDataType*>(kargs.p_x),
|
||||
make_tuple(kargs.m, kargs.n),
|
||||
make_tuple(kargs.stride, 1),
|
||||
make_tuple(kargs.x_stride, 1),
|
||||
number<Vector_N>{},
|
||||
number<1>{});
|
||||
|
||||
@@ -201,7 +210,7 @@ struct Layernorm2dFwd
|
||||
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
|
||||
static_cast<const XResidualDataType*>(kargs.p_x_residual),
|
||||
make_tuple(kargs.m, kargs.n),
|
||||
make_tuple(kargs.stride, 1),
|
||||
make_tuple(kargs.xr_stride, 1),
|
||||
number<Vector_N>{},
|
||||
number<1>{});
|
||||
|
||||
@@ -250,7 +259,7 @@ struct Layernorm2dFwd
|
||||
auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
|
||||
static_cast<YDataType*>(kargs.p_y),
|
||||
make_tuple(kargs.m, kargs.n),
|
||||
make_tuple(kargs.stride, 1),
|
||||
make_tuple(kargs.y_stride, 1),
|
||||
number<Vector_N>{},
|
||||
number<1>{});
|
||||
|
||||
@@ -266,7 +275,7 @@ struct Layernorm2dFwd
|
||||
auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
|
||||
static_cast<YResidualDataType*>(kargs.p_y_residual),
|
||||
make_tuple(kargs.m, kargs.n),
|
||||
make_tuple(kargs.stride, 1),
|
||||
make_tuple(kargs.yr_stride, 1),
|
||||
number<Vector_N>{},
|
||||
number<1>{});
|
||||
|
||||
|
||||
Reference in New Issue
Block a user