Files
composable_kernel/example/ck_tile/14_moe_smoothquant/moe_smoothquant.hpp
Satyanvesh Dittakavi 4c57157d50 Do not use warpSize as compile time constant as it is removed (#2320)
* Do not use warpSize as compile time constant as it is removed

* Update tile_image_to_column_shape.hpp

update warpSize usage.

* clean-up all use of warpSize, make sure code builds

* fix

---------

Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
Co-authored-by: illsilin <Illia.Silin@amd.com>
Co-authored-by: Bartlomiej Kocot <barkocot@amd.com>
2025-06-17 11:54:30 -07:00

105 lines
3.6 KiB
C++

// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/smoothquant.hpp"
#include <string>
template <typename InputType, typename OutputType>
struct MoeSmoothquantTypeConfig
{
using XDataType = InputType;
using SmoothScaleDataType = float;
using YScaleDataType = float;
using QYDataType = OutputType;
using ComputeDataType = float;
};
// runtime args
struct moe_smoothquant_args : public ck_tile::MoeSmoothquantHostArgs
{
};
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
template <typename InputType_,
typename OutputType_,
ck_tile::index_t Repeat_M_, // each thread repeat along M
ck_tile::index_t Repeat_N_, // each thread repeat along N
ck_tile::index_t ThreadPerBlock_M_, // num threads along M
ck_tile::index_t ThreadPerBlock_N_, // num threads along N
ck_tile::index_t Vector_N_, // vector size along N
bool kPadN_,
bool kTwoPass_>
struct moe_smoothquant_traits_
{
using InputType = ck_tile::remove_cvref_t<InputType_>;
using OutputType = ck_tile::remove_cvref_t<OutputType_>;
static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= WarpSize;
static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % WarpSize == 0);
static constexpr ck_tile::index_t total_warps =
(ThreadPerBlock_M_ * ThreadPerBlock_N_) / WarpSize;
// num of warps along m
static constexpr ck_tile::index_t BlockWarps_M = []() {
if constexpr(is_warp_per_row)
{
static_assert(WarpSize % ThreadPerBlock_N_ == 0);
return total_warps * (WarpSize / ThreadPerBlock_N_);
}
else
{
// static_assert(WarpSize % ThreadPerBlock_M_ == 0);
return total_warps / (ThreadPerBlock_N_ / WarpSize);
}
}();
// num of warps along n
static constexpr ck_tile::index_t BlockWarps_N = []() {
if constexpr(is_warp_per_row)
{
static_assert(WarpSize % ThreadPerBlock_N_ == 0);
return 1;
}
else
{
static_assert(ThreadPerBlock_N_ % WarpSize == 0);
return ThreadPerBlock_N_ / WarpSize;
}
}();
static constexpr ck_tile::index_t Repeat_M = Repeat_M_;
static constexpr ck_tile::index_t Repeat_N = Repeat_N_;
static constexpr ck_tile::index_t Block_M = Repeat_M_ * ThreadPerBlock_M_;
static constexpr ck_tile::index_t Block_N = Repeat_N_ * ThreadPerBlock_N_ * Vector_N_;
static constexpr ck_tile::index_t Warp_M = ThreadPerBlock_M_ / BlockWarps_M;
static constexpr ck_tile::index_t Warp_N = ThreadPerBlock_N_ / BlockWarps_N * Vector_N_;
using BlockTile = ck_tile::sequence<Block_M, Block_N>;
using BlockWarps = ck_tile::sequence<BlockWarps_M, BlockWarps_N>;
using WarpTile = ck_tile::sequence<Warp_M, Warp_N>;
using Vector = ck_tile::sequence<1, Vector_N_>;
using Shape = ck_tile::Generic2dBlockShape<BlockTile, BlockWarps, WarpTile, Vector>;
static constexpr bool kPadN = kPadN_;
static constexpr bool kTwoPass = kTwoPass_;
};
template <typename Traits_>
float moe_smoothquant_(const ck_tile::stream_config& s, moe_smoothquant_args a);
// This is the public API, will be generated by script
struct moe_smoothquant_traits
{
std::string in_type; // input type
std::string out_type; // output type
};
float moe_smoothquant(moe_smoothquant_traits, moe_smoothquant_args, const ck_tile::stream_config&);