// SPDX-License-Identifier: MIT // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "ck_tile/core.hpp" #include "ck_tile/ops/reduce.hpp" namespace ck_tile { template struct DynamicQuantEpilogueTraits { static constexpr bool kPadM = kPadM_; static constexpr bool kPadN = kPadN_; static constexpr bool UseSmoothInputScale = UseSmoothInputScale_; static constexpr bool UseRawStore = UseRawStore_; static constexpr bool UseMax3 = UseMax3_; }; // this epilogue just store out a M*N matrix, row major template struct DynamicQuantEpilogueProblem { using AccDataType = remove_cvref_t; using SmoothScaleDataType = remove_cvref_t; using YScaleDataType = remove_cvref_t; using ODataType = remove_cvref_t; using BlockShape = remove_cvref_t; // can consum generic 2d shape using Traits = remove_cvref_t; }; // TODO: we should put descriptor creation function into policy template struct DynamicQuantEpilogue { using Problem = remove_cvref_t; using AccDataType = remove_cvref_t; using SmoothScaleDataType = remove_cvref_t; using YScaleDataType = remove_cvref_t; using ODataType = remove_cvref_t; using BlockShape = remove_cvref_t; static constexpr bool kPadM = Problem::Traits::kPadM; static constexpr bool kPadN = Problem::Traits::kPadN; static constexpr bool UseRawStore = Problem::Traits::UseRawStore; static constexpr bool UseMax3 = Problem::Traits::UseMax3; CK_TILE_HOST_DEVICE static constexpr auto GetBlockReduce2d() { using P_ = BlockReduce2dProblem; return BlockReduce2d{}; } CK_TILE_HOST_DEVICE static constexpr auto GetBlockReduce2dSync() { using P_ = BlockReduce2dProblem; return BlockReduce2dSync{}; } CK_TILE_HOST_DEVICE static constexpr auto GetBlockReduce2dCrossWarpSync() { using P_ = BlockReduce2dProblem; return BlockReduce2dCrossWarpSync{}; } CK_TILE_DEVICE static constexpr auto MakeSmoothInputScaleTileDistribution() { using S = BlockShape; #if 0 // don't remove this // Note that if we set encoding purposely like this, you will result in compile fail // TODO: sm_scale create local-scratch to accept arbitrary acc input (with same length) return make_static_tile_distribution( tile_distribution_encoding< sequence, tuple>, tuple, sequence<0, 1>>, tuple, sequence<2, 2>>, sequence<0, 1, 1>, sequence<0, 0, 3>>{}); #else return make_static_tile_distribution( tile_distribution_encoding< sequence, tuple>, tuple, sequence<0, 1>>, tuple, sequence<1, 2>>, sequence<1, 1>, sequence<0, 3>>{}); #endif } CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { auto reduce_crosswarp_sync = GetBlockReduce2dCrossWarpSync(); return reduce_crosswarp_sync.GetSmemSize(); } template CK_TILE_DEVICE auto Impl(ODramWindowTmp& o_dram_window_tmp, YScaleWindow& y_scale_window, const OAccTile& o_acc_tile, void* smem) { auto reduce = GetBlockReduce2d(); auto reduce_sync = GetBlockReduce2dSync(); auto reduce_crosswarp_sync = GetBlockReduce2dCrossWarpSync(); auto o_acc_tmp = o_acc_tile; const auto f_absmax = [](auto acc_, auto v_0_) { return max(acc_, abs(v_0_)); }; auto row_absmax = [&]() { constexpr auto y_size_per_row = OAccTile{}.get_tile_distribution().get_ys_to_d_descriptor().get_lengths().at( number<1>{}); if constexpr(UseMax3 && std::is_same_v && y_size_per_row % 2 == 0) { // fast max3+abs implementation const auto f_max3 = [](auto acc_, auto v_0_, auto v_1_) { float rtn; asm volatile("v_max3_f32 %0, %1, abs(%2), abs(%3)" : "=v"(rtn) : "v"(acc_), "v"(v_0_), "v"(v_1_)); return rtn; }; return reduce(o_acc_tmp, type_convert(0), f_max3, sequence<1, 2>{}); } else { return reduce(o_acc_tmp, type_convert(0), f_absmax); } }(); reduce_sync(row_absmax, f_absmax); reduce_crosswarp_sync(row_absmax, smem, f_absmax); // here y_scale is Acc TYpe, need convert to YScale type later auto y_scale = tile_elementwise_in( [&](const auto& v_) { return v_ / type_convert(numeric::max()); }, row_absmax); store_tile(y_scale_window, cast_tile(y_scale)); sweep_tile(o_acc_tmp, [&](auto idx) { constexpr auto row_id = make_tuple(idx[number<0>{}]); o_acc_tmp(idx) = o_acc_tmp[idx] / y_scale(row_id); }); // TODO: this is ugly if constexpr(UseRawStore && (kPadM || kPadN)) { store_tile_raw(o_dram_window_tmp, cast_tile(o_acc_tmp)); buffer_store_fence(); } else { store_tile(o_dram_window_tmp, cast_tile(o_acc_tmp)); } } // TODO: this function assume store out vector size is the same as OAccTile last dimension size // how do we fix this ? // Smooth Dynamic Quant template CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp, const SmoothScaleWindow& sm_scale_window_, YScaleWindow& y_scale_window, const OAccTile& o_acc_tile, void* smem) { const auto sm_scale_window = make_tile_window(sm_scale_window_, MakeSmoothInputScaleTileDistribution()); auto sm_scale = load_tile(sm_scale_window); auto o_acc_tmp = o_acc_tile; sweep_tile(o_acc_tmp, [&](auto idx) { constexpr auto j_idx = make_tuple(idx[number<1>{}]); const auto xs_ = type_convert(sm_scale[j_idx]); o_acc_tmp(idx) = o_acc_tmp(idx) * xs_; }); Impl(o_dram_window_tmp, y_scale_window, o_acc_tmp, smem); } // Dynamic Quant template CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp, YScaleWindow& y_scale_window, const OAccTile& o_acc_tile, void* smem) { Impl(o_dram_window_tmp, y_scale_window, o_acc_tile, smem); } }; } // namespace ck_tile