mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 13:11:25 +00:00
* Support wave32/wave64 in CK_TILE - Part 1 * remove blocksize in kernel launch * fix build error * fix clang format * fix clang format 2 * fix clang format 3 * fix fmha build error * fix fmha build 2 * fix fmha build 3 * fix build error 4 * address review comment * update change log * replace KernelBlockSize with kBlockSize * fix CI fail * fix clang format * address review comment and rebase code. * fix universal test fail --------- Co-authored-by: Lin, Qun <Quentin.Lin+amdeng@amd.com> Co-authored-by: Thomas Ning <Thomas.Ning@amd.com>
416 lines
18 KiB
C++
416 lines
18 KiB
C++
// SPDX-License-Identifier: MIT
|
|
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
|
#pragma once
|
|
|
|
#include <sstream>
|
|
#include <gtest/gtest.h>
|
|
|
|
#include "ck_tile/core.hpp"
|
|
#include "ck_tile/host.hpp"
|
|
#include "ck_tile/host/kernel_launch.hpp"
|
|
#include "ck_tile/ops/epilogue.hpp"
|
|
#include "ck_tile/ops/gemm.hpp"
|
|
#include "ck_tile/core/numeric/math.hpp"
|
|
|
|
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
|
|
auto calculate_rtol_atol(const ck_tile::index_t K,
|
|
const ck_tile::index_t kbatch,
|
|
const float max_accumulated_value)
|
|
{
|
|
using ComputeType =
|
|
std::conditional_t<sizeof(ADataType) < sizeof(BDataType), ADataType, BDataType>;
|
|
// Calculate thresholds
|
|
const auto rtol = ck_tile::get_relative_threshold<ComputeType, CDataType, AccDataType>(
|
|
ck_tile::integer_divide_ceil(K, kbatch));
|
|
const auto atol = ck_tile::get_absolute_threshold<ComputeType, CDataType, AccDataType>(
|
|
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch));
|
|
// Calculate error due to split_k accumulation
|
|
const auto rtol_split_k =
|
|
ck_tile::get_relative_threshold<CDataType, CDataType, CDataType>(kbatch);
|
|
const auto atol_split_k = ck_tile::get_absolute_threshold<CDataType, CDataType, CDataType>(
|
|
max_accumulated_value, kbatch);
|
|
// Use higher threshold
|
|
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
|
|
}
|
|
|
|
enum struct GemmPipelineType
|
|
{
|
|
Mem,
|
|
CompV3,
|
|
CompV4
|
|
};
|
|
|
|
template <GemmPipelineType PT, typename Problem>
|
|
struct GemmPipelineTypeSelector;
|
|
|
|
template <typename Problem>
|
|
struct GemmPipelineTypeSelector<GemmPipelineType::Mem, Problem>
|
|
{
|
|
using base_pipeline = ck_tile::BaseGemmPipelineAgBgCrMem<Problem>;
|
|
using pipeline = ck_tile::GemmPipelineAgBgCrMem<Problem>;
|
|
|
|
static constexpr auto GetName() { return "GemmPipelineAgBgCrMem"; }
|
|
};
|
|
|
|
template <typename Problem>
|
|
struct GemmPipelineTypeSelector<GemmPipelineType::CompV3, Problem>
|
|
{
|
|
using base_pipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<Problem>;
|
|
using pipeline = ck_tile::GemmPipelineAgBgCrCompV3<Problem>;
|
|
|
|
static constexpr auto GetName() { return "GemmPipelineAgBgCrCompV3"; }
|
|
};
|
|
|
|
template <typename Problem>
|
|
struct GemmPipelineTypeSelector<GemmPipelineType::CompV4, Problem>
|
|
{
|
|
using base_pipeline = ck_tile::BaseGemmPipelineAgBgCrCompV4<Problem>;
|
|
using pipeline = ck_tile::GemmPipelineAgBgCrCompV4<Problem>;
|
|
|
|
static constexpr auto GetName() { return "GemmPipelineAgBgCrCompV4"; }
|
|
};
|
|
|
|
template <typename Tuple, typename Derived>
|
|
class TestCkTileGemmPipeline : public ::testing::Test
|
|
{
|
|
protected:
|
|
using ALayout = std::tuple_element_t<0, Tuple>;
|
|
using BLayout = std::tuple_element_t<1, Tuple>;
|
|
using CLayout = std::tuple_element_t<2, Tuple>;
|
|
using ADataType = std::tuple_element_t<3, Tuple>;
|
|
using BDataType = std::tuple_element_t<4, Tuple>;
|
|
using AccDataType = std::tuple_element_t<5, Tuple>;
|
|
using CDataType = std::tuple_element_t<6, Tuple>;
|
|
static constexpr auto Scheduler = std::tuple_element_t<13, Tuple>::value;
|
|
static constexpr auto PipelineType = std::tuple_element_t<14, Tuple>::value;
|
|
|
|
static constexpr ck_tile::index_t M_Tile = std::tuple_element_t<7, Tuple>{};
|
|
static constexpr ck_tile::index_t N_Tile = std::tuple_element_t<8, Tuple>{};
|
|
static constexpr ck_tile::index_t K_Tile = std::tuple_element_t<9, Tuple>{};
|
|
|
|
static constexpr ck_tile::index_t M_Warp_Tile = std::tuple_element_t<10, Tuple>{};
|
|
static constexpr ck_tile::index_t N_Warp_Tile = std::tuple_element_t<11, Tuple>{};
|
|
static constexpr ck_tile::index_t K_Warp_Tile = std::tuple_element_t<12, Tuple>{};
|
|
|
|
using DsLayout = ck_tile::tuple<>;
|
|
using DsDataType = ck_tile::tuple<>;
|
|
|
|
static constexpr bool Persistent =
|
|
ck_tile::tuple_element_or_default_t<Tuple, 15, std::false_type>::value;
|
|
|
|
template <bool PadM, bool PadN, bool PadK, bool Preshuffle>
|
|
void invoke_gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
|
|
{
|
|
constexpr ck_tile::index_t M_Warp = 2;
|
|
constexpr ck_tile::index_t N_Warp = 2;
|
|
constexpr ck_tile::index_t K_Warp = 1;
|
|
|
|
constexpr bool kPadM = PadM;
|
|
constexpr bool kPadN = PadN;
|
|
constexpr bool kPadK = PadK;
|
|
constexpr bool preshuffle = Preshuffle;
|
|
|
|
constexpr bool DoubleSmemBuffer = (PipelineType == GemmPipelineType::CompV4) ? true : false;
|
|
|
|
// TODO: For now - but this should also be a test parameter
|
|
constexpr bool TransposeC = false;
|
|
|
|
constexpr int kBlockPerCu = 1;
|
|
constexpr ck_tile::index_t TileParitionerGroupNum = 8;
|
|
constexpr ck_tile::index_t TileParitionerM01 = 4;
|
|
|
|
// ===============================================
|
|
|
|
using GemmShape =
|
|
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
|
|
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
|
|
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
|
|
using TilePartitioner = ck_tile::
|
|
GemmSpatiallyLocalTilePartitioner<GemmShape, TileParitionerGroupNum, TileParitionerM01>;
|
|
|
|
using Traits = ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
|
|
static constexpr bool StructuredSparsity = false;
|
|
static constexpr bool NumWaveGroup = 1;
|
|
|
|
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<kPadM,
|
|
kPadN,
|
|
kPadK,
|
|
DoubleSmemBuffer,
|
|
ALayout,
|
|
BLayout,
|
|
CLayout,
|
|
TransposeC,
|
|
StructuredSparsity,
|
|
Persistent,
|
|
NumWaveGroup,
|
|
preshuffle>;
|
|
|
|
using GemmPipelineProblem =
|
|
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
|
|
|
|
using BaseGemmPipeline =
|
|
typename GemmPipelineTypeSelector<PipelineType, GemmPipelineProblem>::base_pipeline;
|
|
|
|
const ck_tile::index_t k_grain = args.k_batch * K_Tile;
|
|
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * K_Tile;
|
|
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
|
|
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
|
|
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
|
|
|
const auto Run = [&](const auto has_hot_loop_,
|
|
const auto tail_number_,
|
|
const auto memory_operation_) {
|
|
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
|
constexpr auto tail_number_v = tail_number_.value;
|
|
constexpr auto memory_operation = memory_operation_.value;
|
|
|
|
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
|
|
BDataType,
|
|
AccDataType,
|
|
GemmShape,
|
|
GemmUniversalTraits,
|
|
Scheduler,
|
|
has_hot_loop_v,
|
|
tail_number_v>;
|
|
|
|
using GemmPipeline =
|
|
typename GemmPipelineTypeSelector<PipelineType, UniversalGemmProblem>::pipeline;
|
|
|
|
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
|
ck_tile::CShuffleEpilogueProblem<ADataType,
|
|
BDataType,
|
|
DsDataType,
|
|
AccDataType,
|
|
CDataType,
|
|
DsLayout,
|
|
CLayout,
|
|
ck_tile::element_wise::PassThrough,
|
|
TilePartitioner::MPerBlock,
|
|
TilePartitioner::NPerBlock,
|
|
M_Warp,
|
|
N_Warp,
|
|
M_Warp_Tile,
|
|
N_Warp_Tile,
|
|
K_Warp_Tile,
|
|
UniversalGemmProblem::TransposeC,
|
|
memory_operation>>;
|
|
|
|
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
|
auto kargs = Kernel::MakeKernelArgs(args);
|
|
|
|
dim3 grids;
|
|
if constexpr(Persistent)
|
|
{
|
|
grids = Kernel::MaxOccupancyGridSize(s);
|
|
}
|
|
else
|
|
{
|
|
grids = Kernel::GridSize(args.M, args.N, args.k_batch);
|
|
}
|
|
dim3 blocks = Kernel::BlockSize();
|
|
|
|
if(!Kernel::IsSupportedArgument(kargs))
|
|
{
|
|
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
|
|
}
|
|
|
|
if(s.log_level_ > 0)
|
|
{
|
|
std::cout << "Launching kernel with args:" << " grid: {" << grids.x << ", "
|
|
<< grids.y << ", " << grids.z << "}" << ", blocks: {" << blocks.x << ", "
|
|
<< blocks.y << ", " << blocks.z << "}" << std::endl;
|
|
}
|
|
|
|
ck_tile::launch_kernel(
|
|
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
|
};
|
|
|
|
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
|
|
if(args.k_batch == 1)
|
|
{
|
|
Run(has_hot_loop_,
|
|
tail_number_,
|
|
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
|
ck_tile::memory_operation_enum::set>{});
|
|
}
|
|
else
|
|
{
|
|
Run(has_hot_loop_,
|
|
tail_number_,
|
|
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
|
ck_tile::memory_operation_enum::atomic_add>{});
|
|
}
|
|
};
|
|
|
|
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
|
|
}
|
|
|
|
template <typename ADataType,
|
|
typename BDataType,
|
|
typename AccDataType,
|
|
ck_tile::index_t M_Warp_Tile,
|
|
ck_tile::index_t N_Warp_Tile,
|
|
ck_tile::index_t K_Warp_Tile>
|
|
bool check_data_type()
|
|
{
|
|
return static_cast<Derived*>(this)
|
|
->template check_data_type_impl<ADataType,
|
|
BDataType,
|
|
AccDataType,
|
|
M_Warp_Tile,
|
|
N_Warp_Tile,
|
|
K_Warp_Tile>();
|
|
}
|
|
|
|
template <typename ADataType,
|
|
typename BDataType,
|
|
typename AccDataType,
|
|
ck_tile::index_t M_Warp_Tile,
|
|
ck_tile::index_t N_Warp_Tile,
|
|
ck_tile::index_t K_Warp_Tile>
|
|
bool check_data_type_impl()
|
|
{
|
|
return true;
|
|
}
|
|
|
|
public:
|
|
std::vector<int> k_batches_;
|
|
|
|
void SetUp() override
|
|
{
|
|
if(!check_data_type<ADataType,
|
|
BDataType,
|
|
AccDataType,
|
|
M_Warp_Tile,
|
|
N_Warp_Tile,
|
|
K_Warp_Tile>())
|
|
{
|
|
GTEST_SKIP() << "Unsupported data type combination for gemm pipeline test.";
|
|
}
|
|
if constexpr(PipelineType == GemmPipelineType::CompV4)
|
|
{
|
|
// Only do k_batch = 1 when pipeline is CompV4
|
|
k_batches_ = {1};
|
|
}
|
|
else
|
|
{
|
|
// Otherwise, use k_batch = 1 and 2
|
|
k_batches_ = {1, 2};
|
|
}
|
|
}
|
|
|
|
template <bool PadM = true, bool PadN = true, bool PadK = true, bool Preshuffle = false>
|
|
void Run(const int M,
|
|
const int N,
|
|
const int K,
|
|
const int StrideA = 0,
|
|
const int StrideB = 0,
|
|
const int StrideC = 0)
|
|
{
|
|
for(auto kb : k_batches_)
|
|
{
|
|
RunSingle<PadM, PadN, PadK, Preshuffle>(M, N, K, StrideA, StrideB, StrideC, kb);
|
|
}
|
|
}
|
|
|
|
template <bool PadM, bool PadN, bool PadK, bool Preshuffle>
|
|
void RunSingle(const int M,
|
|
const int N,
|
|
const int K,
|
|
const int StrideA,
|
|
const int StrideB,
|
|
const int StrideC,
|
|
int kbatch = 1)
|
|
{
|
|
using namespace ck_tile::literals;
|
|
|
|
auto f_host_tensor_descriptor = [](std::size_t row,
|
|
std::size_t col,
|
|
std::size_t stride,
|
|
auto layout) {
|
|
if constexpr(std::is_same_v<decltype(layout), ck_tile::tensor_layout::gemm::RowMajor>)
|
|
{
|
|
return ck_tile::HostTensorDescriptor({row, col}, {stride, 1_uz});
|
|
}
|
|
else
|
|
{
|
|
return ck_tile::HostTensorDescriptor({row, col}, {1_uz, stride});
|
|
}
|
|
};
|
|
|
|
auto f_get_default_stride =
|
|
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
|
|
if(stride == 0)
|
|
{
|
|
// give a chance if stride is zero, return a default packed stride
|
|
if constexpr(std::is_same_v<decltype(layout),
|
|
ck_tile::tensor_layout::gemm::RowMajor>)
|
|
{
|
|
return col;
|
|
}
|
|
else
|
|
{
|
|
return row;
|
|
}
|
|
}
|
|
else
|
|
return stride;
|
|
};
|
|
|
|
ck_tile::index_t stride_A = f_get_default_stride(M, K, StrideA, ALayout{});
|
|
ck_tile::index_t stride_B = f_get_default_stride(K, N, StrideB, BLayout{});
|
|
ck_tile::index_t stride_C = f_get_default_stride(M, N, StrideC, CLayout{});
|
|
|
|
ck_tile::HostTensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, stride_A, ALayout{}));
|
|
ck_tile::HostTensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, stride_B, BLayout{}));
|
|
ck_tile::HostTensor<CDataType> c_m_n_dev_result(
|
|
f_host_tensor_descriptor(M, N, stride_C, CLayout{}));
|
|
|
|
ck_tile::FillUniformDistributionIntegerValue<ADataType>{-5, 5, 11939}(a_m_k);
|
|
ck_tile::FillUniformDistributionIntegerValue<BDataType>{-5, 5, 11940}(b_k_n);
|
|
|
|
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes());
|
|
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
|
|
ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes());
|
|
|
|
a_m_k_dev_buf.ToDevice(a_m_k.data());
|
|
b_k_n_dev_buf.ToDevice(b_k_n.data());
|
|
c_m_n_dev_buf.SetZero();
|
|
c_m_n_dev_result.SetZero();
|
|
|
|
ck_tile::GemmHostArgs args = {a_m_k_dev_buf.GetDeviceBuffer(),
|
|
b_k_n_dev_buf.GetDeviceBuffer(),
|
|
c_m_n_dev_buf.GetDeviceBuffer(),
|
|
kbatch,
|
|
M,
|
|
N,
|
|
K,
|
|
stride_A,
|
|
stride_B,
|
|
stride_C};
|
|
|
|
invoke_gemm<PadM, PadN, PadK, Preshuffle>(args, ck_tile::stream_config{nullptr, false});
|
|
|
|
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
|
|
bool pass = true;
|
|
|
|
ck_tile::HostTensor<CDataType> c_m_n_host_ref(
|
|
f_host_tensor_descriptor(M, N, stride_C, CLayout{}));
|
|
c_m_n_host_ref.SetZero();
|
|
|
|
ck_tile::reference_gemm<ADataType, BDataType, AccDataType, CDataType>(
|
|
a_m_k, b_k_n, c_m_n_host_ref);
|
|
|
|
const float max_accumulated_value =
|
|
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
|
|
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
|
|
K, kbatch, max_accumulated_value);
|
|
pass = ck_tile::check_err(c_m_n_dev_result,
|
|
c_m_n_host_ref,
|
|
"Error: Incorrect results!",
|
|
rtol_atol.at(ck_tile::number<0>{}),
|
|
rtol_atol.at(ck_tile::number<1>{}));
|
|
EXPECT_TRUE(pass);
|
|
}
|
|
};
|