mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-02 20:51:23 +00:00
Tests for CK Tile Flatmm and MOE Smoothquant (#2458)
* CK tile tests for flatmm using example * MOE smoothquant draft tests * fix create_arg default index to zero for MOE smoothquant * revert MOE smoothquant changes * code clean up * Add back MOE smoothquant changes * Add MOE smoothquant cases for different precisions and update cmake * clean up comments * Update flamm cmake * revert change made to moe_smoothquant smoke_test.sh EXE path * remove unecessary comment in MOE smoothquant cmakelist * comment out adding moe_smoothquant subdirectory for now due to bugs with GPU core dump issue on gfx942 and gfx90a * Clean up run_test_case function in MOE smootquant tests * update copyright and licensing on files * Remove flatmm test dir since tests should be done as weighted preshuffle gemm * Add flamm smoke test cases to weighted preshuffle gemm gtests * remove blank line from CMakeLists --------- Co-authored-by: root <root@ctr-ubbsmc16.amd.com> Co-authored-by: Thomas Ning <Thomas.Ning@amd.com>
This commit is contained in:
104
test/ck_tile/moe_smoothquant/moe_smoothquant.hpp
Normal file
104
test/ck_tile/moe_smoothquant/moe_smoothquant.hpp
Normal file
@@ -0,0 +1,104 @@
|
||||
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/ops/smoothquant.hpp"
|
||||
#include <string>
|
||||
|
||||
template <typename InputType, typename OutputType>
|
||||
struct MoeSmoothquantTypeConfig
|
||||
{
|
||||
using XDataType = InputType;
|
||||
using SmoothScaleDataType = float;
|
||||
using YScaleDataType = float;
|
||||
using QYDataType = OutputType;
|
||||
using ComputeDataType = float;
|
||||
};
|
||||
|
||||
// runtime args
|
||||
struct moe_smoothquant_args : public ck_tile::MoeSmoothquantHostArgs
|
||||
{
|
||||
};
|
||||
|
||||
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
|
||||
template <typename InputType_,
|
||||
typename OutputType_,
|
||||
ck_tile::index_t Repeat_M_, // each thread repeat along M
|
||||
ck_tile::index_t Repeat_N_, // each thread repeat along N
|
||||
ck_tile::index_t ThreadPerBlock_M_, // num threads along M
|
||||
ck_tile::index_t ThreadPerBlock_N_, // num threads along N
|
||||
ck_tile::index_t Vector_N_, // vector size along N
|
||||
bool kPadN_,
|
||||
bool kTwoPass_>
|
||||
struct moe_smoothquant_traits_
|
||||
{
|
||||
using InputType = ck_tile::remove_cvref_t<InputType_>;
|
||||
using OutputType = ck_tile::remove_cvref_t<OutputType_>;
|
||||
|
||||
static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= ck_tile::get_warp_size();
|
||||
static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % ck_tile::get_warp_size() == 0);
|
||||
static constexpr ck_tile::index_t total_warps =
|
||||
(ThreadPerBlock_M_ * ThreadPerBlock_N_) / ck_tile::get_warp_size();
|
||||
|
||||
// num of warps along m
|
||||
static constexpr ck_tile::index_t BlockWarps_M = []() {
|
||||
if constexpr(is_warp_per_row)
|
||||
{
|
||||
static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0);
|
||||
return total_warps * (ck_tile::get_warp_size() / ThreadPerBlock_N_);
|
||||
}
|
||||
else
|
||||
{
|
||||
// static_assert(ck_tile::get_warp_size() % ThreadPerBlock_M_ == 0);
|
||||
return total_warps / (ThreadPerBlock_N_ / ck_tile::get_warp_size());
|
||||
}
|
||||
}();
|
||||
|
||||
// num of warps along n
|
||||
static constexpr ck_tile::index_t BlockWarps_N = []() {
|
||||
if constexpr(is_warp_per_row)
|
||||
{
|
||||
static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0);
|
||||
return 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(ThreadPerBlock_N_ % ck_tile::get_warp_size() == 0);
|
||||
return ThreadPerBlock_N_ / ck_tile::get_warp_size();
|
||||
}
|
||||
}();
|
||||
|
||||
static constexpr ck_tile::index_t Repeat_M = Repeat_M_;
|
||||
static constexpr ck_tile::index_t Repeat_N = Repeat_N_;
|
||||
|
||||
static constexpr ck_tile::index_t Block_M = Repeat_M_ * ThreadPerBlock_M_;
|
||||
static constexpr ck_tile::index_t Block_N = Repeat_N_ * ThreadPerBlock_N_ * Vector_N_;
|
||||
|
||||
static constexpr ck_tile::index_t Warp_M = ThreadPerBlock_M_ / BlockWarps_M;
|
||||
static constexpr ck_tile::index_t Warp_N = ThreadPerBlock_N_ / BlockWarps_N * Vector_N_;
|
||||
|
||||
using BlockTile = ck_tile::sequence<Block_M, Block_N>;
|
||||
using BlockWarps = ck_tile::sequence<BlockWarps_M, BlockWarps_N>;
|
||||
using WarpTile = ck_tile::sequence<Warp_M, Warp_N>;
|
||||
using Vector = ck_tile::sequence<1, Vector_N_>;
|
||||
|
||||
using Shape = ck_tile::Generic2dBlockShape<BlockTile, BlockWarps, WarpTile, Vector>;
|
||||
|
||||
static constexpr bool kPadN = kPadN_;
|
||||
static constexpr bool kTwoPass = kTwoPass_;
|
||||
};
|
||||
|
||||
template <typename Traits_>
|
||||
float moe_smoothquant_(const ck_tile::stream_config& s, moe_smoothquant_args a);
|
||||
|
||||
// This is the public API, will be generated by script
|
||||
struct moe_smoothquant_traits
|
||||
{
|
||||
std::string in_type; // input type
|
||||
std::string out_type; // output type
|
||||
};
|
||||
|
||||
float moe_smoothquant(moe_smoothquant_traits, moe_smoothquant_args, const ck_tile::stream_config&);
|
||||
Reference in New Issue
Block a user