mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
Ck tile/smoothquant out stride (#1742)
* add ck_tile/smoothquant out stride parameter * Remove the default stride value --------- Co-authored-by: so <a.com>
This commit is contained in:
@@ -19,7 +19,8 @@ struct SmoothquantHostArgs
|
||||
|
||||
index_t m;
|
||||
index_t n;
|
||||
index_t stride; // row_stride
|
||||
index_t x_stride; // input row_stride
|
||||
index_t y_stride; // output row_stride
|
||||
};
|
||||
|
||||
// TODO: Extract some type to wrapper class
|
||||
@@ -58,14 +59,21 @@ struct Smoothquant
|
||||
|
||||
index_t m;
|
||||
index_t n;
|
||||
index_t stride; // row_stride
|
||||
index_t x_stride; // input row_stride
|
||||
index_t y_stride; // out row_stride
|
||||
};
|
||||
using Hargs = SmoothquantHostArgs;
|
||||
|
||||
CK_TILE_HOST static constexpr Kargs MakeKargs(const Hargs& hargs)
|
||||
{
|
||||
return Kargs{
|
||||
hargs.p_x, hargs.p_xscale, hargs.p_yscale, hargs.p_qy, hargs.m, hargs.n, hargs.stride};
|
||||
return Kargs{hargs.p_x,
|
||||
hargs.p_xscale,
|
||||
hargs.p_yscale,
|
||||
hargs.p_qy,
|
||||
hargs.m,
|
||||
hargs.n,
|
||||
hargs.x_stride,
|
||||
hargs.y_stride};
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto GridSize(const Hargs& hargs)
|
||||
@@ -116,7 +124,7 @@ struct Smoothquant
|
||||
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
|
||||
static_cast<const XDataType*>(kargs.p_x),
|
||||
make_tuple(kargs.m, kargs.n),
|
||||
make_tuple(kargs.stride, 1),
|
||||
make_tuple(kargs.x_stride, 1),
|
||||
number<Vector_N>{},
|
||||
number<1>{});
|
||||
|
||||
@@ -157,7 +165,7 @@ struct Smoothquant
|
||||
auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
|
||||
static_cast<QYDataType*>(kargs.p_qy),
|
||||
make_tuple(kargs.m, kargs.n),
|
||||
make_tuple(kargs.stride, 1),
|
||||
make_tuple(kargs.y_stride, 1),
|
||||
number<Vector_N>{},
|
||||
number<1>{});
|
||||
|
||||
|
||||
Reference in New Issue
Block a user