// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT #pragma once #include "ck_tile/core.hpp" #include "ck_tile/host/kernel_launch.hpp" #include "ck_tile/ops/smoothquant.hpp" #include template struct MoeSmoothquantTypeConfig { using XDataType = InputType; using SmoothScaleDataType = float; using YScaleDataType = float; using QYDataType = OutputType; using ComputeDataType = float; }; // runtime args struct moe_smoothquant_args : public ck_tile::MoeSmoothquantHostArgs { }; // this is used to pattern-match internl kernel implementation, not to instantiate kernel template struct moe_smoothquant_traits_ { using InputType = ck_tile::remove_cvref_t; using OutputType = ck_tile::remove_cvref_t; static constexpr ck_tile::index_t Repeat_M = Repeat_M_; static constexpr ck_tile::index_t Repeat_N = Repeat_N_; static constexpr ck_tile::index_t Block_M = Repeat_M_ * ThreadPerBlock_M_; static constexpr ck_tile::index_t Block_N = Repeat_N_ * ThreadPerBlock_N_ * Vector_N_; using BlockTile = ck_tile::sequence; using Vector = ck_tile::sequence<1, Vector_N_>; using ThreadPerBlock = ck_tile::sequence; using Shape = ck_tile::Generic2dBlockShape; static constexpr bool kPadN = kPadN_; static constexpr bool kTwoPass = kTwoPass_; }; template float moe_smoothquant_(const ck_tile::stream_config& s, moe_smoothquant_args a); // This is the public API, will be generated by script struct moe_smoothquant_traits { std::string in_type; // input type std::string out_type; // output type }; float moe_smoothquant(moe_smoothquant_traits, moe_smoothquant_args, const ck_tile::stream_config&);