#pragma once #include "ck_tile/core.hpp" #include "ck_tile/ops/common.hpp" namespace ck_tile { template struct PerTensorQuant { using Pipeline = remove_cvref_t; using Problem = typename Pipeline::Problem; using XDataType = remove_cvref_t; using ScaleDataType = remove_cvref_t; using ComputeDataType = remove_cvref_t; using QXDataType = remove_cvref_t; static constexpr index_t Block_M = Problem::BlockShape::Block_M; static constexpr index_t Block_N = Problem::BlockShape::Block_N; static constexpr bool kPadM = false; // always no need to pad along M static constexpr bool kPadN = Problem::kPadN; static constexpr index_t ThreadPerWarp_N = Problem::BlockShape::ThreadPerWarp_N; static constexpr index_t Vector_N = Problem::BlockShape::Vector_N; static constexpr index_t Repeat_N = Problem::BlockShape::Repeat_N; struct Kargs { const void* p_x; void* p_scale; void* p_qx; index_t m; index_t n; index_t x_stride; // input row_stride }; CK_TILE_HOST static constexpr Kargs MakeKargs(const Kargs& kargs) { return kargs; } CK_TILE_HOST static constexpr auto GridSize(const Kargs& kargs) { return dim3(integer_divide_ceil(kargs.m, Block_M)); } CK_TILE_HOST static constexpr auto BlockSize() { return Problem::BlockShape::BlockSize; } // clang-format off template struct t2s; template <> struct t2s { static constexpr const char * name = "fp32"; }; template <> struct t2s { static constexpr const char * name = "fp16"; }; template <> struct t2s { static constexpr const char * name = "bf16"; }; template <> struct t2s { static constexpr const char * name = "fp8"; }; template <> struct t2s { static constexpr const char * name = "bf8"; }; // clang-format on // in byte CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return Pipeline::GetSmemSize(); } CK_TILE_HOST static std::string GetName() { // clang-format off using S_ = typename Problem::BlockShape; auto surfix = [&] () { std::string n; if (kPadN) n += "_pn"; return n; }(); #define _SS_ std::string #define _TS_ std::to_string return _SS_("quant_fwd_") + _SS_(t2s::name) + "_" + _TS_(S_::Block_M) + "x" + _TS_(S_::Block_N) + "_" + _TS_(S_::WarpPerBlock_M) + "x" + _TS_(S_::WarpPerBlock_N) + "_" + _TS_(S_::Warp_M) + "x" + _TS_(S_::Warp_N) + "_" + _TS_(S_::Vector_M) + "x" + _TS_(S_::Vector_N) + "_" + _SS_(Pipeline::name) + surfix; #undef _SS_ #undef _TS_ // clang-format on } CK_TILE_DEVICE void operator()(Kargs kargs) const { const auto iM = get_block_id() * Block_M; const auto x_window = [&]() { const auto tmp_ = make_naive_tensor_view( static_cast(kargs.p_x), make_tuple(kargs.m, kargs.n), make_tuple(kargs.x_stride, 1), number{}, number<1>{}); const auto tmp2_ = pad_tensor_view( tmp_, make_tuple(number{}, number{}), sequence{}); return make_tile_window( tmp2_, make_tuple(number{}, number{}), {iM, 0}); }(); auto qx_window = [&]() { auto tmp_ = make_naive_tensor_view( static_cast(kargs.p_qx), make_tuple(kargs.m, kargs.n), make_tuple(kargs.x_stride, 1), number{}, number<1>{}); auto tmp2_ = pad_tensor_view( tmp_, make_tuple(number{}, number{}), sequence{}); return make_tile_window( tmp2_, make_tuple(number{}, number{}), {iM, 0}); }(); __shared__ char smem[GetSmemSize()]; ScaleDataType* scale = static_cast(kargs.p_scale); Pipeline{}(x_window, scale, kargs.n, qx_window, smem); } }; } // namespace ck_tile