// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT #pragma once #include "ck_tile/core.hpp" #include "ck_tile/host/kernel_launch.hpp" #include "ck_tile/ops/smoothquant.hpp" #include template struct SmoothquantTypeConfig; template <> struct SmoothquantTypeConfig { using XDataType = ck_tile::half_t; using SmoothScaleDataType = float; using YScaleDataType = float; using QYDataType = ck_tile::int8_t; using ComputeDataType = float; }; template <> struct SmoothquantTypeConfig { using XDataType = ck_tile::bf16_t; using SmoothScaleDataType = float; using YScaleDataType = float; using QYDataType = ck_tile::int8_t; using ComputeDataType = float; }; // runtime args struct smoothquant_args : public ck_tile::SmoothquantHostArgs { }; // this is used to pattern-match internl kernel implementation, not to instantiate kernel template struct smoothquant_traits_ { using DataType = ck_tile::remove_cvref_t; static constexpr ck_tile::index_t Repeat_M = Repeat_M_; static constexpr ck_tile::index_t Repeat_N = Repeat_N_; static constexpr ck_tile::index_t Block_M = Repeat_M_ * ThreadPerBlock_M_; static constexpr ck_tile::index_t Block_N = Repeat_N_ * ThreadPerBlock_N_ * Vector_N_; using BlockTile = ck_tile::sequence; using Vector = ck_tile::sequence<1, Vector_N_>; using ThreadPerBlock = ck_tile::sequence; using Shape = ck_tile::Generic2dBlockShape; static constexpr bool kPadN = kPadN_; static constexpr bool kTwoPass = kTwoPass_; }; template float smoothquant_(const ck_tile::stream_config& s, smoothquant_args a); // This is the public API, will be generated by script struct smoothquant_traits { std::string data_type; }; float smoothquant(smoothquant_traits, smoothquant_args, const ck_tile::stream_config&);