// SPDX-License-Identifier: MIT // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "ck_tile/core.hpp" #include "ck_tile/host/kernel_launch.hpp" #include "ck_tile/ops/smoothquant.hpp" #include template struct MoeSmoothquantTypeConfig { using XDataType = InputType; using SmoothScaleDataType = float; using YScaleDataType = float; using QYDataType = OutputType; using ComputeDataType = float; }; // runtime args struct moe_smoothquant_args : public ck_tile::MoeSmoothquantHostArgs { }; // this is used to pattern-match internl kernel implementation, not to instantiate kernel template struct moe_smoothquant_traits_ { using InputType = ck_tile::remove_cvref_t; using OutputType = ck_tile::remove_cvref_t; static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= warpSize; static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % warpSize == 0); static constexpr ck_tile::index_t total_warps = (ThreadPerBlock_M_ * ThreadPerBlock_N_) / warpSize; // num of warps along m static constexpr ck_tile::index_t BlockWarps_M = []() { if constexpr(is_warp_per_row) { static_assert(warpSize % ThreadPerBlock_N_ == 0); return total_warps * (warpSize / ThreadPerBlock_N_); } else { // static_assert(warpSize % ThreadPerBlock_M_ == 0); return total_warps / (ThreadPerBlock_N_ / warpSize); } }(); // num of warps along n static constexpr ck_tile::index_t BlockWarps_N = []() { if constexpr(is_warp_per_row) { static_assert(warpSize % ThreadPerBlock_N_ == 0); return 1; } else { static_assert(ThreadPerBlock_N_ % warpSize == 0); return ThreadPerBlock_N_ / warpSize; } }(); static constexpr ck_tile::index_t Repeat_M = Repeat_M_; static constexpr ck_tile::index_t Repeat_N = Repeat_N_; static constexpr ck_tile::index_t Block_M = Repeat_M_ * ThreadPerBlock_M_; static constexpr ck_tile::index_t Block_N = Repeat_N_ * ThreadPerBlock_N_ * Vector_N_; static constexpr ck_tile::index_t Warp_M = ThreadPerBlock_M_ / BlockWarps_M; static constexpr ck_tile::index_t Warp_N = ThreadPerBlock_N_ / BlockWarps_N * Vector_N_; using BlockTile = ck_tile::sequence; using BlockWarps = ck_tile::sequence; using WarpTile = ck_tile::sequence; using Vector = ck_tile::sequence<1, Vector_N_>; using Shape = ck_tile::Generic2dBlockShape; static constexpr bool kPadN = kPadN_; static constexpr bool kTwoPass = kTwoPass_; }; template float moe_smoothquant_(const ck_tile::stream_config& s, moe_smoothquant_args a); // This is the public API, will be generated by script struct moe_smoothquant_traits { std::string in_type; // input type std::string out_type; // output type }; float moe_smoothquant(moe_smoothquant_traits, moe_smoothquant_args, const ck_tile::stream_config&);