mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[Ck_tile] smoothquant (#1617)
* fix compile error * fix typo of padding * Add smoothquant op * Add smoothquant instance library * refine type * add test script * Re-generate smoothquant.hpp * Always use 'current year' in copyright * use Generic2dBlockShape instead * Add vector = 8 instance back * Find exe path automatically * Simplify the api condition * Remove debugging code * update year * Add blank line between function declaration * explicitly cast return value to dim3 * refine return value * Fix default warmup and repeat value * Add comment * refactor sommthquant cmake * Add README * Fix typo --------- Co-authored-by: Po Yen, Chen <PoYen.Chen@amd.com>
This commit is contained in:
176
include/ck_tile/ops/smoothquant/kernel/smoothquant_kernel.hpp
Normal file
176
include/ck_tile/ops/smoothquant/kernel/smoothquant_kernel.hpp
Normal file
@@ -0,0 +1,176 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// host side args
|
||||
struct SmoothquantHostArgs
|
||||
{
|
||||
const void* p_x; // [m ,n], input, fp16/bf16
|
||||
const void* p_xscale; // [1, n], input, columnwise scale, fp32
|
||||
|
||||
void* p_yscale; // [m, 1], output, rowwise quant scale (amax / 127) of (p_x * p_xscale)
|
||||
void* p_qy; // [m, n], output, p_x * p_xscale / p_yscale
|
||||
|
||||
index_t m;
|
||||
index_t n;
|
||||
index_t stride; // row_stride
|
||||
};
|
||||
|
||||
// TODO: Extract some type to wrapper class
|
||||
template <typename Pipeline_>
|
||||
struct Smoothquant
|
||||
{
|
||||
using Pipeline = remove_cvref_t<Pipeline_>;
|
||||
using Problem = typename Pipeline::Problem;
|
||||
|
||||
using XDataType = remove_cvref_t<typename Problem::XDataType>;
|
||||
using XScaleDataType = remove_cvref_t<typename Problem::XScaleDataType>;
|
||||
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
|
||||
using YScaleDataType = remove_cvref_t<typename Problem::YScaleDataType>;
|
||||
using QYDataType = remove_cvref_t<typename Problem::QYDataType>;
|
||||
|
||||
static constexpr index_t Block_M = Problem::BlockShape::Block_M;
|
||||
static constexpr index_t Block_N = Problem::BlockShape::Block_N;
|
||||
static constexpr bool kPadM = false; // always no need to pad along M
|
||||
static constexpr bool kPadN = Problem::kPadN;
|
||||
static constexpr bool kTwoPass = Problem::kTwoPass;
|
||||
|
||||
static constexpr index_t ThreadPerWarp_N = Problem::BlockShape::ThreadPerWarp_N;
|
||||
static constexpr index_t Vector_N = Problem::BlockShape::Vector_N;
|
||||
static constexpr index_t Repeat_N = Problem::BlockShape::Repeat_N;
|
||||
|
||||
static constexpr auto I0 = number<0>{};
|
||||
static constexpr auto I1 = number<1>{};
|
||||
|
||||
struct Kargs
|
||||
{
|
||||
const void* p_x;
|
||||
const void* p_xscale;
|
||||
|
||||
void* p_yscale;
|
||||
void* p_qy;
|
||||
|
||||
index_t m;
|
||||
index_t n;
|
||||
index_t stride; // row_stride
|
||||
};
|
||||
using Hargs = SmoothquantHostArgs;
|
||||
|
||||
CK_TILE_HOST static constexpr Kargs MakeKargs(const Hargs& hargs)
|
||||
{
|
||||
return Kargs{
|
||||
hargs.p_x, hargs.p_xscale, hargs.p_yscale, hargs.p_qy, hargs.m, hargs.n, hargs.stride};
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto GridSize(const Hargs& hargs)
|
||||
{
|
||||
return dim3(integer_divide_ceil(hargs.m, Block_M));
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize() { return Problem::BlockShape::BlockSize; }
|
||||
|
||||
// clang-format off
|
||||
template <typename T> struct t2s;
|
||||
template <> struct t2s<float> { static constexpr const char * name = "fp32"; };
|
||||
template <> struct t2s<ck_tile::fp16_t> { static constexpr const char * name = "fp16"; };
|
||||
template <> struct t2s<ck_tile::bf16_t> { static constexpr const char * name = "bf16"; };
|
||||
template <> struct t2s<ck_tile::fp8_t> { static constexpr const char * name = "fp8"; };
|
||||
template <> struct t2s<ck_tile::bf8_t> { static constexpr const char * name = "bf8"; };
|
||||
// clang-format on
|
||||
|
||||
// in byte
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return Pipeline::GetSmemSize(); }
|
||||
|
||||
CK_TILE_HOST static std::string GetName()
|
||||
{
|
||||
// clang-format off
|
||||
using S_ = typename Problem::BlockShape;
|
||||
auto surfix = [&] () {
|
||||
std::string n;
|
||||
if (kPadN) n += "_pn";
|
||||
if (kTwoPass) n += "_2p";
|
||||
return n; }();
|
||||
|
||||
#define _SS_ std::string
|
||||
#define _TS_ std::to_string
|
||||
return _SS_("smoothquant_fwd_") + _SS_(t2s<XDataType>::name) + "_" +
|
||||
_TS_(S_::Block_M) + "x" + _TS_(S_::Block_N) + "_" + _TS_(S_::WarpPerBlock_M) + "x" + _TS_(S_::WarpPerBlock_N) + "_" +
|
||||
_TS_(S_::Warp_M) + "x" + _TS_(S_::Warp_N) + "_" + _TS_(S_::Vector_M) + "x" + _TS_(S_::Vector_N) + "_" +
|
||||
_SS_(Pipeline::name) + surfix;
|
||||
#undef _SS_
|
||||
#undef _TS_
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void operator()(Kargs kargs) const
|
||||
{
|
||||
const auto iM = get_block_id() * Block_M;
|
||||
|
||||
const auto x_window = [&]() {
|
||||
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
|
||||
static_cast<const XDataType*>(kargs.p_x),
|
||||
make_tuple(kargs.m, kargs.n),
|
||||
make_tuple(kargs.stride, 1),
|
||||
number<Vector_N>{},
|
||||
number<1>{});
|
||||
|
||||
const auto tmp2_ = pad_tensor_view(
|
||||
tmp_, make_tuple(number<Block_M>{}, number<Block_N>{}), sequence<kPadM, kPadN>{});
|
||||
return make_tile_window(
|
||||
tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
|
||||
}();
|
||||
|
||||
const auto xscale_window = [&]() {
|
||||
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
|
||||
static_cast<const XScaleDataType*>(kargs.p_xscale),
|
||||
make_tuple(kargs.n),
|
||||
make_tuple(1),
|
||||
number<Vector_N>{},
|
||||
number<1>{});
|
||||
|
||||
const auto tmp2_ =
|
||||
pad_tensor_view(tmp_, make_tuple(number<Block_N>{}), sequence<kPadN>{});
|
||||
|
||||
return make_tile_window(tmp2_, make_tuple(number<Block_N>{}), {0});
|
||||
}();
|
||||
|
||||
auto yscale_window = [&]() {
|
||||
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
|
||||
static_cast<YScaleDataType*>(kargs.p_yscale),
|
||||
make_tuple(kargs.m),
|
||||
make_tuple(1),
|
||||
number<1>{});
|
||||
|
||||
const auto tmp2_ =
|
||||
pad_tensor_view(tmp_, make_tuple(number<Block_M>{}), sequence<kPadM>{});
|
||||
|
||||
return make_tile_window(tmp2_, make_tuple(number<Block_M>{}), {iM});
|
||||
}();
|
||||
|
||||
auto qy_window = [&]() {
|
||||
auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
|
||||
static_cast<QYDataType*>(kargs.p_qy),
|
||||
make_tuple(kargs.m, kargs.n),
|
||||
make_tuple(kargs.stride, 1),
|
||||
number<Vector_N>{},
|
||||
number<1>{});
|
||||
|
||||
auto tmp2_ = pad_tensor_view(
|
||||
tmp_, make_tuple(number<Block_M>{}, number<Block_N>{}), sequence<kPadM, kPadN>{});
|
||||
return make_tile_window(
|
||||
tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
|
||||
}();
|
||||
|
||||
__shared__ char smem[GetSmemSize()];
|
||||
|
||||
Pipeline{}(x_window, xscale_window, yscale_window, qy_window, kargs.n, smem);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,95 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/reduce/block/block_reduce2d_problem.hpp"
|
||||
#include "ck_tile/ops/reduce/block/block_reduce2d.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
struct SmoothquantPipelineDefaultPolicy
|
||||
{
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr auto MakeXBlockTileDistribution()
|
||||
{
|
||||
using S = typename Problem::BlockShape;
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<S::Repeat_M, S::WarpPerBlock_M, S::ThreadPerWarp_M, S::Vector_M>,
|
||||
sequence<S::Repeat_N, S::WarpPerBlock_N, S::ThreadPerWarp_N, S::Vector_N>>,
|
||||
tuple<sequence<1, 2>, sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>, sequence<2, 2>>,
|
||||
sequence<1, 1, 2, 2>,
|
||||
sequence<0, 3, 0, 3>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr auto MakeXScaleBlockTileDistribution()
|
||||
{
|
||||
using S = typename Problem::BlockShape;
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<S::WarpPerBlock_M, S::ThreadPerWarp_M>,
|
||||
tuple<sequence<S::Repeat_N, S::WarpPerBlock_N, S::ThreadPerWarp_N, S::Vector_N>>,
|
||||
tuple<sequence<0, 1>, sequence<0, 1>>,
|
||||
tuple<sequence<0, 1>, sequence<1, 2>>,
|
||||
sequence<1, 1>,
|
||||
sequence<0, 3>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetBlockReduce2d()
|
||||
{
|
||||
using P_ = BlockReduce2dProblem<typename Problem::ComputeDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::BlockShape>;
|
||||
return BlockReduce2d<P_>{};
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetBlockReduce2dSync()
|
||||
{
|
||||
using P_ = BlockReduce2dProblem<typename Problem::ComputeDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::BlockShape>;
|
||||
return BlockReduce2dSync<P_>{};
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetBlockReduce2dCrossWarpSync()
|
||||
{
|
||||
using P_ = BlockReduce2dProblem<typename Problem::ComputeDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::BlockShape>;
|
||||
return BlockReduce2dCrossWarpSync<P_>{};
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
{
|
||||
if constexpr(Problem::kNeedCrossWarpSync)
|
||||
{
|
||||
using P_ = BlockReduce2dProblem<typename Problem::XDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::BlockShape>;
|
||||
|
||||
using block_reduce2d = BlockReduce2d<P_>;
|
||||
using x_block_tile =
|
||||
decltype(make_static_distributed_tensor<typename Problem::XDataType>(
|
||||
MakeXBlockTileDistribution<Problem>()));
|
||||
using y_block_tile = decltype(block_reduce2d::template MakeYBlockTile<x_block_tile>());
|
||||
|
||||
return GetBlockReduce2dCrossWarpSync<Problem>().template GetSmemSize<y_block_tile>();
|
||||
}
|
||||
else
|
||||
{
|
||||
return 1; // zero size arrays are an extension
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,94 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_default_policy.hpp"
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Problem_, typename Policy_ = SmoothquantPipelineDefaultPolicy>
|
||||
struct SmoothquantPipelineOnePass
|
||||
{
|
||||
using Problem = ck_tile::remove_cvref_t<Problem_>;
|
||||
using Policy = ck_tile::remove_cvref_t<Policy_>;
|
||||
|
||||
using XDataType = ck_tile::remove_cvref_t<typename Problem::XDataType>;
|
||||
using XScaleDataType = ck_tile::remove_cvref_t<typename Problem::XScaleDataType>;
|
||||
using ComputeDataType = ck_tile::remove_cvref_t<typename Problem::ComputeDataType>;
|
||||
using QYDataType = ck_tile::remove_cvref_t<typename Problem::QYDataType>;
|
||||
using YScaleDataType = ck_tile::remove_cvref_t<typename Problem::YScaleDataType>;
|
||||
|
||||
static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync;
|
||||
static constexpr bool kPadM = false; // TODO - BlockSmoothquantProblem::kPadM
|
||||
static constexpr bool kPadN = Problem::kPadN;
|
||||
|
||||
static constexpr const char* name = []() {
|
||||
if constexpr(kNeedCrossWarpSync)
|
||||
return "bpr_op"; // block per row
|
||||
else
|
||||
return "wpr_op"; // warp per row
|
||||
}();
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
{
|
||||
return Policy::template GetSmemSize<Problem>();
|
||||
}
|
||||
|
||||
template <typename XWindow, typename XScaleWindow, typename QYWindow, typename YScaleWindow>
|
||||
CK_TILE_DEVICE auto operator()(const XWindow& x_window_,
|
||||
const XScaleWindow& xscale_window_,
|
||||
YScaleWindow& yscale_window,
|
||||
QYWindow& qy_window,
|
||||
ck_tile::index_t,
|
||||
void* smem) const
|
||||
{
|
||||
auto x_window =
|
||||
make_tile_window(x_window_, Policy::template MakeXBlockTileDistribution<Problem>());
|
||||
auto xscale_window = make_tile_window(
|
||||
xscale_window_, Policy::template MakeXScaleBlockTileDistribution<Problem>());
|
||||
|
||||
auto reduce_absmax_func = ReduceOp::AbsMax{};
|
||||
auto reduce_max_func = ReduceOp::Max{};
|
||||
auto block_reduce2d = Policy::template GetBlockReduce2d<Problem>();
|
||||
auto block_reduce2d_sync = Policy::template GetBlockReduce2dSync<Problem>();
|
||||
auto block_reduce2d_cross_warp_sync =
|
||||
Policy::template GetBlockReduce2dCrossWarpSync<Problem>();
|
||||
|
||||
const auto x = load_tile(x_window);
|
||||
const auto xscale = load_tile(xscale_window);
|
||||
auto y = tile_elementwise_in(
|
||||
[&](const auto& a, const auto& b) {
|
||||
return type_convert<ComputeDataType>(a) * type_convert<ComputeDataType>(b);
|
||||
},
|
||||
x,
|
||||
xscale);
|
||||
|
||||
// compute absmax, cross-lane->cross-warp
|
||||
auto absmax = block_reduce2d(
|
||||
y, reduce_absmax_func.GetIdentityValue<ComputeDataType>(), reduce_absmax_func);
|
||||
block_reduce2d_sync(absmax, reduce_max_func);
|
||||
block_reduce2d_cross_warp_sync(absmax, smem, reduce_max_func);
|
||||
|
||||
// ex: yscale = absmax / 127 if int8
|
||||
auto yscale = tile_elementwise_in(
|
||||
[&](const auto& v_) {
|
||||
return v_ / type_convert<ComputeDataType>(numeric<QYDataType>::max());
|
||||
},
|
||||
absmax);
|
||||
store_tile(yscale_window, cast_tile<YScaleDataType>(yscale));
|
||||
|
||||
// quantize y to qy
|
||||
auto qy = make_static_distributed_tensor<QYDataType>(y.get_tile_distribution());
|
||||
sweep_tile(qy, [&](auto idx) {
|
||||
constexpr auto i_idx = make_tuple(idx[number<0>{}]);
|
||||
auto qy_ = y[idx] / yscale[i_idx];
|
||||
qy(idx) = saturates<QYDataType>{}(qy_);
|
||||
});
|
||||
store_tile(qy_window, qy);
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,35 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// Y = X * XScale, QY = RowwiseDynamicQuant(Y) = SaturateCast(Y / YScale)
|
||||
template <typename XDataType_,
|
||||
typename XScaleDataType_,
|
||||
typename ComputeDataType_,
|
||||
typename YScaleDataType_,
|
||||
typename QYDataType_,
|
||||
typename BlockShape_,
|
||||
bool kPadN_,
|
||||
bool kTwoPass_>
|
||||
struct SmoothquantPipelineProblem
|
||||
{
|
||||
using XDataType = remove_cvref_t<XDataType_>;
|
||||
using XScaleDataType = remove_cvref_t<XScaleDataType_>;
|
||||
using ComputeDataType = remove_cvref_t<ComputeDataType_>;
|
||||
using YScaleDataType = remove_cvref_t<YScaleDataType_>;
|
||||
using QYDataType = remove_cvref_t<QYDataType_>;
|
||||
using BlockShape = remove_cvref_t<BlockShape_>;
|
||||
|
||||
static constexpr bool kNeedCrossLaneSync = BlockShape::ThreadPerWarp_N > 1;
|
||||
static constexpr bool kNeedCrossWarpSync = BlockShape::WarpPerBlock_N > 1;
|
||||
|
||||
static constexpr bool kPadN = kPadN_;
|
||||
static constexpr bool kTwoPass = kTwoPass_;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,132 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_default_policy.hpp"
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Problem_, typename Policy_ = SmoothquantPipelineDefaultPolicy>
|
||||
struct SmoothquantPipelineTwoPass
|
||||
{
|
||||
using Problem = ck_tile::remove_cvref_t<Problem_>;
|
||||
using Policy = ck_tile::remove_cvref_t<Policy_>;
|
||||
|
||||
using XDataType = ck_tile::remove_cvref_t<typename Problem::XDataType>;
|
||||
using XScaleDataType = ck_tile::remove_cvref_t<typename Problem::XScaleDataType>;
|
||||
using ComputeDataType = ck_tile::remove_cvref_t<typename Problem::ComputeDataType>;
|
||||
using QYDataType = ck_tile::remove_cvref_t<typename Problem::QYDataType>;
|
||||
using YScaleDataType = ck_tile::remove_cvref_t<typename Problem::YScaleDataType>;
|
||||
|
||||
static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync;
|
||||
static constexpr bool kPadM = false; // TODO - BlockSmoothquantProblem::kPadM
|
||||
static constexpr bool kPadN = Problem::kPadN;
|
||||
|
||||
static constexpr const char* name = []() {
|
||||
if constexpr(kNeedCrossWarpSync)
|
||||
return "bpr_tp"; // block per row
|
||||
else
|
||||
return "wpr_tp"; // warp per row
|
||||
}();
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
{
|
||||
return Policy::template GetSmemSize<Problem>();
|
||||
}
|
||||
|
||||
template <typename XWindow, typename XScaleWindow, typename QYWindow, typename YScaleWindow>
|
||||
CK_TILE_DEVICE auto operator()(const XWindow& x_window_,
|
||||
const XScaleWindow& xscale_window_,
|
||||
YScaleWindow& yscale_window,
|
||||
QYWindow& qy_window,
|
||||
ck_tile::index_t row_size,
|
||||
void* smem) const
|
||||
{
|
||||
auto x_window =
|
||||
make_tile_window(x_window_, Policy::template MakeXBlockTileDistribution<Problem>());
|
||||
auto xscale_window = make_tile_window(
|
||||
xscale_window_, Policy::template MakeXScaleBlockTileDistribution<Problem>());
|
||||
|
||||
static constexpr index_t Block_N = Problem::BlockShape::Block_N;
|
||||
index_t num_n_tile_iteration =
|
||||
__builtin_amdgcn_readfirstlane(integer_divide_ceil(row_size, Block_N));
|
||||
|
||||
auto reduce_absmax_func = ReduceOp::AbsMax{};
|
||||
auto reduce_max_func = ReduceOp::Max{};
|
||||
auto block_reduce2d = Policy::template GetBlockReduce2d<Problem>();
|
||||
auto block_reduce2d_sync = Policy::template GetBlockReduce2dSync<Problem>();
|
||||
auto block_reduce2d_cross_warp_sync =
|
||||
Policy::template GetBlockReduce2dCrossWarpSync<Problem>();
|
||||
|
||||
using XTensorType = decltype(cast_tile<ComputeDataType>(load_tile(x_window)));
|
||||
auto absmax = block_reduce2d.template MakeYBlockTile<XTensorType>();
|
||||
set_tile(absmax, reduce_absmax_func.GetIdentityValue<ComputeDataType>());
|
||||
|
||||
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
|
||||
{
|
||||
const auto x = load_tile(x_window);
|
||||
const auto xscale = load_tile(xscale_window);
|
||||
const auto y = tile_elementwise_in(
|
||||
[&](const auto& a, const auto& b) {
|
||||
return type_convert<ComputeDataType>(a) * type_convert<ComputeDataType>(b);
|
||||
},
|
||||
x,
|
||||
xscale);
|
||||
|
||||
block_reduce2d(y, absmax, reduce_absmax_func);
|
||||
|
||||
move_tile_window(x_window, {0, Block_N});
|
||||
move_tile_window(xscale_window, {Block_N});
|
||||
}
|
||||
|
||||
// compute absmax, cross-lane->cross-warp
|
||||
block_reduce2d_sync(absmax, reduce_max_func);
|
||||
block_reduce2d_cross_warp_sync(absmax, smem, reduce_max_func);
|
||||
|
||||
// ex: yscale = absmax / 127 if int8
|
||||
auto yscale = tile_elementwise_in(
|
||||
[&](const auto& v_) {
|
||||
return v_ / type_convert<ComputeDataType>(numeric<QYDataType>::max());
|
||||
},
|
||||
absmax);
|
||||
store_tile(yscale_window, cast_tile<YScaleDataType>(yscale));
|
||||
|
||||
// reverse read x to reuse cache
|
||||
ck_tile::index_t stride_to_right_most_window =
|
||||
row_size % Block_N == 0 ? row_size - Block_N : row_size - row_size % Block_N;
|
||||
|
||||
move_tile_window(x_window, {0, -Block_N});
|
||||
move_tile_window(xscale_window, {-Block_N});
|
||||
move_tile_window(qy_window, {0, stride_to_right_most_window});
|
||||
|
||||
// recompute y and quantize y to qy
|
||||
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
|
||||
{
|
||||
const auto x = load_tile(x_window);
|
||||
const auto xscale = load_tile(xscale_window);
|
||||
const auto y = tile_elementwise_in(
|
||||
[&](const auto& a, const auto& b) {
|
||||
return type_convert<ComputeDataType>(a) * type_convert<ComputeDataType>(b);
|
||||
},
|
||||
x,
|
||||
xscale);
|
||||
|
||||
auto qy = make_static_distributed_tensor<QYDataType>(y.get_tile_distribution());
|
||||
sweep_tile(qy, [&](auto idx) {
|
||||
constexpr auto i_idx = make_tuple(idx[number<0>{}]);
|
||||
auto qy_ = y[idx] / yscale[i_idx];
|
||||
qy(idx) = saturates<QYDataType>{}(qy_);
|
||||
});
|
||||
store_tile(qy_window, qy);
|
||||
|
||||
move_tile_window(x_window, {0, -Block_N});
|
||||
move_tile_window(xscale_window, {0, -Block_N});
|
||||
move_tile_window(qy_window, {0, -Block_N});
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile
|
||||
Reference in New Issue
Block a user