mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-19 20:40:07 +00:00
Reorganize project folders (#6)
This commit is contained in:
5
test/ck_tile/CMakeLists.txt
Normal file
5
test/ck_tile/CMakeLists.txt
Normal file
@@ -0,0 +1,5 @@
|
||||
add_subdirectory(image_to_column)
|
||||
add_subdirectory(gemm)
|
||||
add_subdirectory(batched_gemm)
|
||||
add_subdirectory(grouped_gemm)
|
||||
add_subdirectory(data_type)
|
||||
4
test/ck_tile/batched_gemm/CMakeLists.txt
Normal file
4
test/ck_tile/batched_gemm/CMakeLists.txt
Normal file
@@ -0,0 +1,4 @@
|
||||
# Currently ck_tile is only built on gfx9
|
||||
if(GPU_TARGETS MATCHES "gfx9")
|
||||
add_gtest_executable(test_ck_tile_batched_gemm test_batched_gemm.cpp)
|
||||
endif()
|
||||
29
test/ck_tile/batched_gemm/test_batched_gemm.cpp
Normal file
29
test/ck_tile/batched_gemm/test_batched_gemm.cpp
Normal file
@@ -0,0 +1,29 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <tuple>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "test_batched_gemm_util.hpp"
|
||||
|
||||
using F16 = ck_tile::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
// clang-format off
|
||||
using KernelTypes = ::testing::Types<
|
||||
// ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType
|
||||
// std::tuple< Row, Row, Row, F16, F16, F32, F16>,
|
||||
//std::tuple< Col, Row, Row, F16, F16, F32, F16>,
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16>//,
|
||||
//std::tuple< Col, Col, Row, F16, F16, F32, F16>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileBatchedGemm, KernelTypes);
|
||||
|
||||
#include "test_batched_gemm_ut_cases.inc"
|
||||
9
test/ck_tile/batched_gemm/test_batched_gemm_ut_cases.inc
Normal file
9
test/ck_tile/batched_gemm/test_batched_gemm_ut_cases.inc
Normal file
@@ -0,0 +1,9 @@
|
||||
#pragma once
|
||||
|
||||
TYPED_TEST(TestCkTileBatchedGemm, Basic)
|
||||
{
|
||||
constexpr int M = 256;
|
||||
constexpr int N = 256;
|
||||
constexpr int K = 512;
|
||||
this->Run(M, N, K);
|
||||
}
|
||||
281
test/ck_tile/batched_gemm/test_batched_gemm_util.hpp
Normal file
281
test/ck_tile/batched_gemm/test_batched_gemm_util.hpp
Normal file
@@ -0,0 +1,281 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
#pragma once
|
||||
|
||||
#include <sstream>
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestCkTileBatchedGemm : public ::testing::Test
|
||||
{
|
||||
protected:
|
||||
using ALayout = std::tuple_element_t<0, Tuple>;
|
||||
using BLayout = std::tuple_element_t<1, Tuple>;
|
||||
using CLayout = std::tuple_element_t<2, Tuple>;
|
||||
using ADataType = std::tuple_element_t<3, Tuple>;
|
||||
using BDataType = std::tuple_element_t<4, Tuple>;
|
||||
using AccDataType = std::tuple_element_t<5, Tuple>;
|
||||
using CDataType = std::tuple_element_t<6, Tuple>;
|
||||
|
||||
template <typename ALayout, typename BLayout, typename CLayout>
|
||||
void invoke_batched_gemm(const ck_tile::BatchedGemmHostArgs& args,
|
||||
const ck_tile::stream_config& s)
|
||||
{
|
||||
constexpr ck_tile::index_t M_Tile = 256;
|
||||
constexpr ck_tile::index_t N_Tile = 256;
|
||||
constexpr ck_tile::index_t K_Tile = 64;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp = 2;
|
||||
constexpr ck_tile::index_t N_Warp = 2;
|
||||
constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
|
||||
constexpr bool DoubleSmemBuffer = false;
|
||||
|
||||
constexpr bool kPadM = false;
|
||||
constexpr bool kPadN = false;
|
||||
constexpr bool kPadK = false;
|
||||
|
||||
constexpr bool TransposeC = false;
|
||||
|
||||
constexpr int kBlockPerCu = 1;
|
||||
constexpr ck_tile::index_t TileParitionerGroupNum = 8;
|
||||
constexpr ck_tile::index_t TileParitionerM01 = 4;
|
||||
|
||||
using GemmShape =
|
||||
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
|
||||
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
|
||||
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
|
||||
using TilePartitioner = ck_tile::
|
||||
GemmSpatiallyLocalTilePartitioner<GemmShape, TileParitionerGroupNum, TileParitionerM01>;
|
||||
|
||||
using Traits = ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
|
||||
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<kPadM,
|
||||
kPadN,
|
||||
kPadK,
|
||||
DoubleSmemBuffer,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
TransposeC>;
|
||||
using GemmPipelineProblem =
|
||||
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
|
||||
|
||||
using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmPipelineProblem>;
|
||||
|
||||
const ck_tile::index_t k_grain = args.k_batch * K_Tile;
|
||||
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * K_Tile;
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
|
||||
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
float ave_time{0};
|
||||
|
||||
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>;
|
||||
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<UniversalGemmProblem>;
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
CLayout,
|
||||
GemmPipelineProblem::kBlockSize,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
M_Warp,
|
||||
N_Warp,
|
||||
M_Warp_Tile,
|
||||
N_Warp_Tile,
|
||||
K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC>>;
|
||||
using Kernel = ck_tile::BatchedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch, args.batch_count);
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
|
||||
}
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
|
||||
<< "shape: " << GemmShape::GetName() << '\n'
|
||||
<< "problem: " << GemmPipelineProblem::GetName() << '\n'
|
||||
<< "pipeline: " << GemmPipeline::GetName() << '\n'
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z
|
||||
<< "}" << std::endl;
|
||||
}
|
||||
|
||||
ave_time = ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
return ave_time;
|
||||
};
|
||||
|
||||
if(has_hot_loop)
|
||||
{
|
||||
if(tail_num == ck_tile::TailNumber::Full)
|
||||
{
|
||||
Run(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
std::ostringstream err;
|
||||
err << "For compute pipeline tail number should always be Full, but have \""
|
||||
<< tail_num << "\" which is not supported! PrefetchStages: "
|
||||
<< BaseGemmPipeline::PrefetchStages << "\n File: " << __FILE__ << ":"
|
||||
<< __LINE__ << ", in function: " << __func__;
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
std::ostringstream err;
|
||||
err << "Num K loop must be larger than number of prefetech stages."
|
||||
<< "\n PrefetchStages: " << BaseGemmPipeline::PrefetchStages
|
||||
<< "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
void Run(const int M,
|
||||
const int N,
|
||||
const int K,
|
||||
int StrideA = 512,
|
||||
int StrideB = 512,
|
||||
int StrideC = 256,
|
||||
const int BatchStrideA = 131072,
|
||||
const int BatchStrideB = 131072,
|
||||
const int BatchStrideC = 65536,
|
||||
const int BatchCount = 8)
|
||||
{
|
||||
using namespace ck_tile::literals;
|
||||
|
||||
auto f_host_tensor_descriptor = [](std::size_t batch_count_,
|
||||
std::size_t row,
|
||||
std::size_t col,
|
||||
std::size_t stride,
|
||||
std::size_t batch_stride,
|
||||
auto layout) {
|
||||
if constexpr(std::is_same_v<decltype(layout), ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return ck_tile::HostTensorDescriptor({batch_count_, row, col},
|
||||
{batch_stride, stride, 1_uz});
|
||||
}
|
||||
else
|
||||
{
|
||||
return ck_tile::HostTensorDescriptor({batch_count_, row, col},
|
||||
{batch_stride, 1_uz, stride});
|
||||
}
|
||||
};
|
||||
|
||||
auto f_get_default_stride =
|
||||
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
|
||||
if(stride == 0)
|
||||
{
|
||||
// give a chance if stride is zero, return a default packed stride
|
||||
if constexpr(std::is_same_v<decltype(layout),
|
||||
ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return col;
|
||||
}
|
||||
else
|
||||
{
|
||||
return row;
|
||||
}
|
||||
}
|
||||
else
|
||||
return stride;
|
||||
};
|
||||
|
||||
StrideA = f_get_default_stride(M, K, StrideA, ALayout{});
|
||||
StrideB = f_get_default_stride(K, N, StrideB, BLayout{});
|
||||
StrideC = f_get_default_stride(M, N, StrideC, CLayout{});
|
||||
|
||||
ck_tile::HostTensor<ADataType> a_m_k(
|
||||
f_host_tensor_descriptor(BatchCount, M, K, StrideA, BatchStrideA, ALayout{}));
|
||||
ck_tile::HostTensor<BDataType> b_k_n(
|
||||
f_host_tensor_descriptor(BatchCount, K, N, StrideB, BatchStrideB, BLayout{}));
|
||||
ck_tile::HostTensor<CDataType> c_m_n_dev_result(
|
||||
f_host_tensor_descriptor(BatchCount, M, N, StrideC, BatchStrideC, CLayout{}));
|
||||
|
||||
ck_tile::FillUniformDistribution<ADataType>{-5.f, 5.f}(a_m_k);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-5.f, 5.f}(b_k_n);
|
||||
|
||||
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes());
|
||||
|
||||
a_m_k_dev_buf.ToDevice(a_m_k.data());
|
||||
b_k_n_dev_buf.ToDevice(b_k_n.data());
|
||||
c_m_n_dev_buf.SetZero();
|
||||
c_m_n_dev_result.SetZero();
|
||||
|
||||
ck_tile::BatchedGemmHostArgs args;
|
||||
args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer();
|
||||
args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer();
|
||||
args.c_ptr = c_m_n_dev_buf.GetDeviceBuffer();
|
||||
args.k_batch = 1;
|
||||
args.M = M;
|
||||
args.N = N;
|
||||
args.K = K;
|
||||
args.stride_A = StrideA;
|
||||
args.stride_B = StrideB;
|
||||
args.stride_C = StrideC;
|
||||
args.batch_stride_A = BatchStrideA;
|
||||
args.batch_stride_B = BatchStrideB;
|
||||
args.batch_stride_C = BatchStrideC;
|
||||
args.batch_count = BatchCount;
|
||||
|
||||
invoke_batched_gemm<ALayout, BLayout, CLayout>(args,
|
||||
ck_tile::stream_config{nullptr, false});
|
||||
|
||||
std::cout << "Run kernel with M =" << M << " N =" << N << " K =" << K
|
||||
<< " StrideA =" << StrideA << " StrideB =" << StrideB << " StrideC =" << StrideC
|
||||
<< " BatchStrideA =" << BatchStrideA << " BatchStrideB =" << BatchStrideB
|
||||
<< " BatchStrideC =" << BatchStrideC << " BatchCount =" << BatchCount
|
||||
<< std::endl;
|
||||
|
||||
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
|
||||
bool pass = true;
|
||||
|
||||
ck_tile::HostTensor<CDataType> c_m_n_host_ref(
|
||||
f_host_tensor_descriptor(BatchCount, M, N, StrideC, BatchStrideC, CLayout{}));
|
||||
c_m_n_host_ref.SetZero();
|
||||
|
||||
const auto b_n_k = b_k_n.transpose({0, 2, 1});
|
||||
ck_tile::reference_batched_gemm<ADataType, BDataType, AccDataType, CDataType>(
|
||||
a_m_k, b_n_k, c_m_n_host_ref);
|
||||
|
||||
pass = ck_tile::check_err(c_m_n_dev_result, c_m_n_host_ref);
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
};
|
||||
4
test/ck_tile/data_type/CMakeLists.txt
Normal file
4
test/ck_tile/data_type/CMakeLists.txt
Normal file
@@ -0,0 +1,4 @@
|
||||
# Currently ck_tile is only built on gfx9
|
||||
if(GPU_TARGETS MATCHES "gfx9")
|
||||
add_gtest_executable(test_ck_tile_pk_int4 test_pk_int4.cpp)
|
||||
endif()
|
||||
65
test/ck_tile/data_type/test_pk_int4.cpp
Normal file
65
test/ck_tile/data_type/test_pk_int4.cpp
Normal file
@@ -0,0 +1,65 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
using ck_tile::bf16_t;
|
||||
using ck_tile::bf16x2_t;
|
||||
using ck_tile::fp16x2_t;
|
||||
using ck_tile::fp32x2_t;
|
||||
using ck_tile::half_t;
|
||||
using ck_tile::pk_int4_t;
|
||||
|
||||
TEST(PackedInt4, ConvertToFloat)
|
||||
{
|
||||
#ifdef CK_TILE_USE_PK4_LAYOUT_SHUFFLE
|
||||
constexpr float first_input_val = 7.f;
|
||||
constexpr float second_input_val = -1.f;
|
||||
#else
|
||||
constexpr float first_input_val = -1.f;
|
||||
constexpr float second_input_val = 7.f;
|
||||
#endif
|
||||
uint8_t data = 0b11110111; // {-1, 7}
|
||||
pk_int4_t in = ck_tile::bit_cast<int8_t>(data);
|
||||
fp32x2_t out = ck_tile::pk_int4_t_to_fp32x2_t(in);
|
||||
|
||||
EXPECT_EQ(out.x, first_input_val);
|
||||
EXPECT_EQ(out.y, second_input_val);
|
||||
}
|
||||
|
||||
TEST(PackedInt4, ConvertToHalf)
|
||||
{
|
||||
#ifdef CK_TILE_USE_PK4_LAYOUT_SHUFFLE
|
||||
const half_t first_input_val = ck_tile::type_convert<half_t>(7.f);
|
||||
const half_t second_input_val = ck_tile::type_convert<half_t>(-1.f);
|
||||
#else
|
||||
const half_t first_input_val = ck_tile::type_convert<half_t>(-1.f);
|
||||
const half_t second_input_val = ck_tile::type_convert<half_t>(7.f);
|
||||
#endif
|
||||
uint8_t data = 0b11110111; // {-1, 7}
|
||||
pk_int4_t in = ck_tile::bit_cast<int8_t>(data);
|
||||
fp16x2_t out = ck_tile::pk_int4_t_to_halfx2_t(in);
|
||||
|
||||
EXPECT_EQ(out.x, first_input_val);
|
||||
EXPECT_EQ(out.y, second_input_val);
|
||||
}
|
||||
|
||||
TEST(PackedInt4, ConvertToBHalf)
|
||||
{
|
||||
#ifdef CK_TILE_USE_PK4_LAYOUT_SHUFFLE
|
||||
const bf16_t first_input_val = ck_tile::type_convert<bf16_t>(7.f);
|
||||
const bf16_t second_input_val = ck_tile::type_convert<bf16_t>(-1.f);
|
||||
#else
|
||||
const bf16_t first_input_val = ck_tile::type_convert<bf16_t>(-1.f);
|
||||
const bf16_t second_input_val = ck_tile::type_convert<bf16_t>(7.f);
|
||||
#endif
|
||||
uint8_t data = 0b11110111; // {-1, 7}
|
||||
pk_int4_t in = ck_tile::bit_cast<int8_t>(data);
|
||||
bf16x2_t out = ck_tile::pk_int4_t_to_bfloat16x2_t(in);
|
||||
|
||||
EXPECT_EQ(out.x, first_input_val);
|
||||
EXPECT_EQ(out.y, second_input_val);
|
||||
}
|
||||
28
test/ck_tile/gemm/CMakeLists.txt
Normal file
28
test/ck_tile/gemm/CMakeLists.txt
Normal file
@@ -0,0 +1,28 @@
|
||||
# Currently ck_tile is only built on gfx94/gfx95
|
||||
set(EXAMPLE_GEMM_COMPILE_OPTIONS "")
|
||||
if(CK_USE_OCP_FP8)
|
||||
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
|
||||
endif()
|
||||
set(EXAMPLE_GEMM_COMPILE_COMPUTE_V4_OPTIONS "")
|
||||
if(CK_USE_OCP_FP8)
|
||||
list(APPEND EXAMPLE_GEMM_COMPILE_COMPUTE_V4_OPTIONS -DCK_TILE_USE_OCP_FP8)
|
||||
endif()
|
||||
list(APPEND EXAMPLE_GEMM_COMPILE_COMPUTE_V4_OPTIONS
|
||||
-mllvm
|
||||
-enable-noalias-to-md-conversion=0
|
||||
)
|
||||
|
||||
if(CK_USE_OCP_FP8)
|
||||
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
|
||||
if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95")
|
||||
add_gtest_executable(test_ck_tile_gemm_pipeline_mem test_gemm_pipeline_mem.cpp)
|
||||
add_gtest_executable(test_ck_tile_gemm_pipeline_compv3 test_gemm_pipeline_compv3.cpp)
|
||||
add_gtest_executable(test_ck_tile_gemm_pipeline_compv4 test_gemm_pipeline_compv4.cpp)
|
||||
|
||||
target_compile_options(test_ck_tile_gemm_pipeline_mem PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
target_compile_options(test_ck_tile_gemm_pipeline_compv3 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
target_compile_options(test_ck_tile_gemm_pipeline_compv4 PRIVATE ${EXAMPLE_GEMM_COMPILE_COMPUTE_V4_OPTIONS})
|
||||
else()
|
||||
message("Skipping ck_tile_gemm tests for current target")
|
||||
endif()
|
||||
endif()
|
||||
16
test/ck_tile/gemm/test_gemm_pipeline_compv3.cpp
Normal file
16
test/ck_tile/gemm/test_gemm_pipeline_compv3.cpp
Normal file
@@ -0,0 +1,16 @@
|
||||
#include "test_gemm_pipeline_kernel_types.hpp"
|
||||
#include "test_gemm_pipeline_util.hpp"
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
template <typename T>
|
||||
class TestCkTileGemmPipelineCompV3 : public TestCkTileGemmPipeline<T>
|
||||
{
|
||||
};
|
||||
|
||||
#define TEST_SUITE_NAME TestCkTileGemmPipelineCompV3
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileGemmPipelineCompV3, KernelTypesCompV3);
|
||||
|
||||
#include "test_gemm_pipeline_ut_cases.inc"
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
16
test/ck_tile/gemm/test_gemm_pipeline_compv4.cpp
Normal file
16
test/ck_tile/gemm/test_gemm_pipeline_compv4.cpp
Normal file
@@ -0,0 +1,16 @@
|
||||
#include "test_gemm_pipeline_kernel_types.hpp"
|
||||
#include "test_gemm_pipeline_util.hpp"
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
template <typename T>
|
||||
class TestCkTileGemmPipelineCompV4 : public TestCkTileGemmPipeline<T>
|
||||
{
|
||||
};
|
||||
|
||||
#define TEST_SUITE_NAME TestCkTileGemmPipelineCompV4
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileGemmPipelineCompV4, KernelTypesCompV4);
|
||||
|
||||
#include "test_gemm_pipeline_ut_cases.inc"
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
62
test/ck_tile/gemm/test_gemm_pipeline_kernel_types.hpp
Normal file
62
test/ck_tile/gemm/test_gemm_pipeline_kernel_types.hpp
Normal file
@@ -0,0 +1,62 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <tuple>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "test_gemm_pipeline_util.hpp"
|
||||
|
||||
using F16 = ck_tile::half_t;
|
||||
using F32 = float;
|
||||
using F8 = ck_tile::fp8_t;
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using Intrawave = ck_tile::integral_constant<ck_tile::GemmPipelineScheduler,
|
||||
ck_tile::GemmPipelineScheduler::Intrawave>;
|
||||
using Interwave = ck_tile::integral_constant<ck_tile::GemmPipelineScheduler,
|
||||
ck_tile::GemmPipelineScheduler::Interwave>;
|
||||
using Mem = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::Mem>;
|
||||
using CompV3 = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::CompV3>;
|
||||
using CompV4 = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::CompV4>;
|
||||
|
||||
// clang-format off
|
||||
using KernelTypesMem = ::testing::Types<
|
||||
std::tuple< Row, Row, Row, F16, F16, F32, F16, Intrawave, Mem>,
|
||||
std::tuple< Row, Row, Row, F16, F16, F32, F16, Interwave, Mem>,
|
||||
std::tuple< Row, Row, Row, F8, F8, F32, F16, Interwave, Mem>,
|
||||
std::tuple< Row, Row, Row, F8, F8, F32, F16, Intrawave, Mem>,
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave, Mem>,
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16, Interwave, Mem>,
|
||||
std::tuple< Row, Col, Row, F8, F8, F32, F16, Interwave, Mem>,
|
||||
std::tuple< Row, Col, Row, F8, F8, F32, F16, Intrawave, Mem>,
|
||||
std::tuple< Col, Row, Row, F16, F16, F32, F16, Intrawave, Mem>,
|
||||
std::tuple< Col, Row, Row, F16, F16, F32, F16, Interwave, Mem>,
|
||||
std::tuple< Col, Row, Row, F8, F8, F32, F16, Intrawave, Mem>,
|
||||
std::tuple< Col, Row, Row, F8, F8, F32, F16, Interwave, Mem>,
|
||||
std::tuple< Col, Col, Row, F16, F16, F32, F16, Intrawave, Mem>,
|
||||
std::tuple< Col, Col, Row, F16, F16, F32, F16, Interwave, Mem>,
|
||||
std::tuple< Col, Col, Row, F8, F8, F32, F16, Intrawave, Mem>,
|
||||
std::tuple< Col, Col, Row, F8, F8, F32, F16, Interwave, Mem>
|
||||
>;
|
||||
|
||||
using KernelTypesCompV3 = ::testing::Types<
|
||||
std::tuple< Row, Row, Row, F16, F16, F32, F16, Intrawave, CompV3>,
|
||||
std::tuple< Row, Row, Row, F8, F8, F32, F16, Intrawave, CompV3>,
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave, CompV3>,
|
||||
std::tuple< Row, Col, Row, F8, F8, F32, F16, Intrawave, CompV3>,
|
||||
std::tuple< Col, Row, Row, F16, F16, F32, F16, Intrawave, CompV3>,
|
||||
std::tuple< Col, Row, Row, F8, F8, F32, F16, Intrawave, CompV3>,
|
||||
std::tuple< Col, Col, Row, F16, F16, F32, F16, Intrawave, CompV3>,
|
||||
std::tuple< Col, Col, Row, F8, F8, F32, F16, Intrawave, CompV3>
|
||||
>;
|
||||
|
||||
using KernelTypesCompV4 = ::testing::Types<
|
||||
std::tuple< Row, Row, Row, F16, F16, F32, F16, Intrawave, CompV4>,
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave, CompV4>,
|
||||
std::tuple< Col, Row, Row, F16, F16, F32, F16, Intrawave, CompV4>,
|
||||
std::tuple< Col, Col, Row, F16, F16, F32, F16, Intrawave, CompV4>
|
||||
>;
|
||||
|
||||
// clang-format on
|
||||
16
test/ck_tile/gemm/test_gemm_pipeline_mem.cpp
Normal file
16
test/ck_tile/gemm/test_gemm_pipeline_mem.cpp
Normal file
@@ -0,0 +1,16 @@
|
||||
#include "test_gemm_pipeline_kernel_types.hpp"
|
||||
#include "test_gemm_pipeline_util.hpp"
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
template <typename T>
|
||||
class TestCkTileGemmPipelineMem : public TestCkTileGemmPipeline<T>
|
||||
{
|
||||
};
|
||||
|
||||
#define TEST_SUITE_NAME TestCkTileGemmPipelineMem
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileGemmPipelineMem, KernelTypesMem);
|
||||
|
||||
#include "test_gemm_pipeline_ut_cases.inc"
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
103
test/ck_tile/gemm/test_gemm_pipeline_ut_cases.inc
Normal file
103
test/ck_tile/gemm/test_gemm_pipeline_ut_cases.inc
Normal file
@@ -0,0 +1,103 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#ifndef TEST_GEMM_PIPELINE_UT_CASES_INC
|
||||
#define TEST_GEMM_PIPELINE_UT_CASES_INC
|
||||
|
||||
TYPED_TEST(TEST_SUITE_NAME, SmallM)
|
||||
{
|
||||
std::vector<int> Ms{1, 2, 3, 4, 5, 6};
|
||||
constexpr int N = 1024;
|
||||
constexpr int K = 320;
|
||||
|
||||
for(int M : Ms)
|
||||
{
|
||||
if constexpr(std::is_same_v<typename TestFixture::ALayout,
|
||||
ck_tile::tensor_layout::gemm::ColumnMajor>)
|
||||
{
|
||||
EXPECT_THROW((this->Run(M, N, K)), std::runtime_error);
|
||||
}
|
||||
else
|
||||
{
|
||||
this->Run(M, N, K);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TYPED_TEST(TEST_SUITE_NAME, MidLargeM)
|
||||
{
|
||||
std::vector<int> Ms{127, 255, 312, 799, 1573};
|
||||
constexpr int N = 1024;
|
||||
constexpr int K = 320;
|
||||
constexpr int VecLoadSize = (std::is_same_v<typename TestFixture::ADataType, ck_tile::fp8_t> ||
|
||||
std::is_same_v<typename TestFixture::ADataType, ck_tile::bf8_t>)
|
||||
? 16
|
||||
: 8;
|
||||
|
||||
for(int M : Ms)
|
||||
{
|
||||
if constexpr(std::is_same_v<typename TestFixture::ALayout,
|
||||
ck_tile::tensor_layout::gemm::ColumnMajor>)
|
||||
{
|
||||
// TODO: Can we anyhow deduce used vector load size?
|
||||
if(M % VecLoadSize == 0)
|
||||
{
|
||||
this->Run(M, N, K);
|
||||
}
|
||||
else
|
||||
{
|
||||
EXPECT_THROW((this->Run(M, N, K)), std::runtime_error);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
this->Run(M, N, K);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TYPED_TEST(TEST_SUITE_NAME, PaddK)
|
||||
{
|
||||
std::vector<int> Ms{128};
|
||||
constexpr int N = 1024;
|
||||
constexpr int K = 432;
|
||||
|
||||
for(int M : Ms)
|
||||
this->Run(M, N, K);
|
||||
}
|
||||
|
||||
TYPED_TEST(TEST_SUITE_NAME, Regular)
|
||||
{
|
||||
std::vector<int> Ms{512};
|
||||
constexpr int N = 1024;
|
||||
constexpr int K = 512;
|
||||
|
||||
for(int M : Ms)
|
||||
this->Run(M, N, K);
|
||||
}
|
||||
|
||||
TYPED_TEST(TEST_SUITE_NAME, LargeMatrix)
|
||||
{
|
||||
constexpr int M = 2048;
|
||||
constexpr int N = 2048;
|
||||
constexpr int K = 2048;
|
||||
|
||||
this->Run(M, N, K);
|
||||
}
|
||||
|
||||
TYPED_TEST(TEST_SUITE_NAME, NotSupportedArgument)
|
||||
{
|
||||
constexpr int M = 512;
|
||||
constexpr int N = 1025;
|
||||
constexpr int K = 513;
|
||||
|
||||
constexpr bool PadM = false;
|
||||
constexpr bool PadN = false;
|
||||
constexpr bool PadK = false;
|
||||
|
||||
EXPECT_THROW((this->template Run<PadM, PadN, PadK>(M, N, K)), std::runtime_error);
|
||||
}
|
||||
|
||||
#endif
|
||||
458
test/ck_tile/gemm/test_gemm_pipeline_util.hpp
Normal file
458
test/ck_tile/gemm/test_gemm_pipeline_util.hpp
Normal file
@@ -0,0 +1,458 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
#pragma once
|
||||
|
||||
#include <sstream>
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
|
||||
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
|
||||
auto calculate_rtol_atol(const ck_tile::index_t K,
|
||||
const ck_tile::index_t kbatch,
|
||||
const float max_accumulated_value)
|
||||
{
|
||||
using ComputeType =
|
||||
std::conditional_t<sizeof(ADataType) < sizeof(BDataType), ADataType, BDataType>;
|
||||
// Calculate thresholds
|
||||
const auto rtol = ck_tile::get_relative_threshold<ComputeType, CDataType, AccDataType>(
|
||||
ck_tile::integer_divide_ceil(K, kbatch));
|
||||
const auto atol = ck_tile::get_absolute_threshold<ComputeType, CDataType, AccDataType>(
|
||||
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch));
|
||||
// Calculate error due to split_k accumulation
|
||||
const auto rtol_split_k =
|
||||
ck_tile::get_relative_threshold<CDataType, CDataType, CDataType>(kbatch);
|
||||
const auto atol_split_k = ck_tile::get_absolute_threshold<CDataType, CDataType, CDataType>(
|
||||
max_accumulated_value, kbatch);
|
||||
// Use higher threshold
|
||||
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
|
||||
}
|
||||
|
||||
enum struct GemmPipelineType
|
||||
{
|
||||
Mem,
|
||||
CompV3,
|
||||
CompV4
|
||||
};
|
||||
|
||||
template <GemmPipelineType PT, typename Problem>
|
||||
struct GemmPipelineTypeSelector;
|
||||
|
||||
template <typename Problem>
|
||||
struct GemmPipelineTypeSelector<GemmPipelineType::Mem, Problem>
|
||||
{
|
||||
using base_pipeline = ck_tile::BaseGemmPipelineAgBgCrMem<Problem>;
|
||||
using pipeline = ck_tile::GemmPipelineAgBgCrMem<Problem>;
|
||||
};
|
||||
|
||||
template <typename Problem>
|
||||
struct GemmPipelineTypeSelector<GemmPipelineType::CompV3, Problem>
|
||||
{
|
||||
using base_pipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<Problem>;
|
||||
using pipeline = ck_tile::GemmPipelineAgBgCrCompV3<Problem>;
|
||||
};
|
||||
|
||||
template <typename Problem>
|
||||
struct GemmPipelineTypeSelector<GemmPipelineType::CompV4, Problem>
|
||||
{
|
||||
using base_pipeline = ck_tile::BaseGemmPipelineAgBgCrCompV4<Problem>;
|
||||
using pipeline = ck_tile::GemmPipelineAgBgCrCompV4<Problem>;
|
||||
};
|
||||
|
||||
template <typename Tuple>
|
||||
class TestCkTileGemmPipeline : public ::testing::Test
|
||||
{
|
||||
protected:
|
||||
using ALayout = std::tuple_element_t<0, Tuple>;
|
||||
using BLayout = std::tuple_element_t<1, Tuple>;
|
||||
using CLayout = std::tuple_element_t<2, Tuple>;
|
||||
using ADataType = std::tuple_element_t<3, Tuple>;
|
||||
using BDataType = std::tuple_element_t<4, Tuple>;
|
||||
using AccDataType = std::tuple_element_t<5, Tuple>;
|
||||
using CDataType = std::tuple_element_t<6, Tuple>;
|
||||
static constexpr auto Scheduler = std::tuple_element_t<7, Tuple>::value;
|
||||
static constexpr auto PipelineType = std::tuple_element_t<8, Tuple>::value;
|
||||
// TODO: expose tile size through test t-param ?
|
||||
|
||||
template <bool PadM, bool PadN, bool PadK>
|
||||
void invoke_gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
|
||||
{
|
||||
// TODO: This should be parameterized in tests
|
||||
constexpr ck_tile::index_t M_Tile = 256;
|
||||
constexpr ck_tile::index_t N_Tile = 256;
|
||||
constexpr ck_tile::index_t K_Tile = (PipelineType == GemmPipelineType::CompV4) ? 32 : 64;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp = 2;
|
||||
constexpr ck_tile::index_t N_Warp = 2;
|
||||
constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
|
||||
constexpr bool kPadM = PadM;
|
||||
constexpr bool kPadN = PadN;
|
||||
constexpr bool kPadK = PadK;
|
||||
|
||||
constexpr bool DoubleSmemBuffer = (PipelineType == GemmPipelineType::CompV4) ? true : false;
|
||||
|
||||
// TODO: For now - but this should also be a test parameter
|
||||
constexpr bool TransposeC = false;
|
||||
|
||||
constexpr int kBlockPerCu = 1;
|
||||
constexpr ck_tile::index_t TileParitionerGroupNum = 8;
|
||||
constexpr ck_tile::index_t TileParitionerM01 = 4;
|
||||
|
||||
// ===============================================
|
||||
|
||||
using GemmShape =
|
||||
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
|
||||
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
|
||||
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
|
||||
using TilePartitioner = ck_tile::
|
||||
GemmSpatiallyLocalTilePartitioner<GemmShape, TileParitionerGroupNum, TileParitionerM01>;
|
||||
|
||||
using Traits = ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
|
||||
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<kPadM,
|
||||
kPadN,
|
||||
kPadK,
|
||||
DoubleSmemBuffer,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
TransposeC>;
|
||||
|
||||
using GemmPipelineProblem =
|
||||
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
|
||||
|
||||
using BaseGemmPipeline =
|
||||
typename GemmPipelineTypeSelector<PipelineType, GemmPipelineProblem>::base_pipeline;
|
||||
|
||||
const ck_tile::index_t k_grain = args.k_batch * K_Tile;
|
||||
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * K_Tile;
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
|
||||
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
Scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>;
|
||||
|
||||
using GemmPipeline =
|
||||
typename GemmPipelineTypeSelector<PipelineType, UniversalGemmProblem>::pipeline;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
CLayout,
|
||||
GemmPipeline::BlockSize,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
M_Warp,
|
||||
N_Warp,
|
||||
M_Warp_Tile,
|
||||
N_Warp_Tile,
|
||||
K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC>>;
|
||||
|
||||
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
|
||||
}
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args:"
|
||||
<< " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z
|
||||
<< "}" << std::endl;
|
||||
}
|
||||
|
||||
ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
};
|
||||
|
||||
if(has_hot_loop)
|
||||
{
|
||||
if constexpr(PipelineType == GemmPipelineType::CompV3)
|
||||
{
|
||||
if(tail_num == ck_tile::TailNumber::Full)
|
||||
{
|
||||
Run(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber,
|
||||
ck_tile::TailNumber::Full>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
std::ostringstream err;
|
||||
err << "For compute pipeline tail number should always be Full, but have \""
|
||||
<< tail_num << "\" which is not supported! PrefetchStages: "
|
||||
<< BaseGemmPipeline::PrefetchStages << "\n File: " << __FILE__ << ":"
|
||||
<< __LINE__ << ", in function: " << __func__;
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(PipelineType == GemmPipelineType::Mem)
|
||||
{
|
||||
// Tail pipeline One to Seven
|
||||
if(tail_num == ck_tile::TailNumber::One)
|
||||
{
|
||||
Run(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber,
|
||||
ck_tile::TailNumber::One>{});
|
||||
}
|
||||
else if(tail_num == ck_tile::TailNumber::Full)
|
||||
{
|
||||
Run(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber,
|
||||
ck_tile::TailNumber::Full>{});
|
||||
}
|
||||
|
||||
if constexpr(BaseGemmPipeline::PrefetchStages > 2)
|
||||
{
|
||||
if(tail_num == ck_tile::TailNumber::Two)
|
||||
{
|
||||
Run(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber,
|
||||
ck_tile::TailNumber::Two>{});
|
||||
}
|
||||
}
|
||||
if constexpr(BaseGemmPipeline::PrefetchStages > 3)
|
||||
{
|
||||
if(tail_num == ck_tile::TailNumber::Three)
|
||||
{
|
||||
Run(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber,
|
||||
ck_tile::TailNumber::Three>{});
|
||||
}
|
||||
}
|
||||
if constexpr(BaseGemmPipeline::PrefetchStages > 4)
|
||||
{
|
||||
if(tail_num == ck_tile::TailNumber::Four)
|
||||
{
|
||||
Run(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber,
|
||||
ck_tile::TailNumber::Four>{});
|
||||
}
|
||||
}
|
||||
if constexpr(BaseGemmPipeline::PrefetchStages > 5)
|
||||
{
|
||||
if(tail_num == ck_tile::TailNumber::Five)
|
||||
{
|
||||
Run(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber,
|
||||
ck_tile::TailNumber::Five>{});
|
||||
}
|
||||
}
|
||||
if constexpr(BaseGemmPipeline::PrefetchStages > 6)
|
||||
{
|
||||
if(tail_num == ck_tile::TailNumber::Six)
|
||||
{
|
||||
Run(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber,
|
||||
ck_tile::TailNumber::Six>{});
|
||||
}
|
||||
}
|
||||
if constexpr(BaseGemmPipeline::PrefetchStages > 7)
|
||||
{
|
||||
if(tail_num == ck_tile::TailNumber::Seven)
|
||||
{
|
||||
Run(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber,
|
||||
ck_tile::TailNumber::Seven>{});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(PipelineType == GemmPipelineType::CompV4)
|
||||
{
|
||||
if(tail_num == ck_tile::TailNumber::Three)
|
||||
{
|
||||
Run(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber,
|
||||
ck_tile::TailNumber::Three>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
Run(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber,
|
||||
ck_tile::TailNumber::Two>{});
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// Tail number always Full - #PrefetchStages
|
||||
if(tail_num == ck_tile::TailNumber::Full)
|
||||
{
|
||||
Run(ck_tile::bool_constant<false>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
std::ostringstream err;
|
||||
err << "When there's no hot loop, this tail number \"" << tail_num
|
||||
<< "\" is not supported! " << __FILE__ << ":" << __LINE__
|
||||
<< ", in function: " << __func__;
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
std::vector<int> k_batches_;
|
||||
|
||||
void SetUp() override
|
||||
{
|
||||
if constexpr(PipelineType == GemmPipelineType::CompV4)
|
||||
{
|
||||
// Only do k_batch = 1 when pipeline is CompV4
|
||||
k_batches_ = {1};
|
||||
}
|
||||
else
|
||||
{
|
||||
// Otherwise, use k_batch = 1 and 2
|
||||
k_batches_ = {1, 2};
|
||||
}
|
||||
}
|
||||
|
||||
template <bool PadM = true, bool PadN = true, bool PadK = true>
|
||||
void Run(const int M,
|
||||
const int N,
|
||||
const int K,
|
||||
const int StrideA = 0,
|
||||
const int StrideB = 0,
|
||||
const int StrideC = 0)
|
||||
{
|
||||
for(auto kb : k_batches_)
|
||||
{
|
||||
RunSingle<PadM, PadN, PadK>(M, N, K, StrideA, StrideB, StrideC, kb);
|
||||
}
|
||||
}
|
||||
|
||||
template <bool PadM, bool PadN, bool PadK>
|
||||
void RunSingle(const int M,
|
||||
const int N,
|
||||
const int K,
|
||||
const int StrideA,
|
||||
const int StrideB,
|
||||
const int StrideC,
|
||||
int kbatch = 1)
|
||||
{
|
||||
using namespace ck_tile::literals;
|
||||
|
||||
auto f_host_tensor_descriptor = [](std::size_t row,
|
||||
std::size_t col,
|
||||
std::size_t stride,
|
||||
auto layout) {
|
||||
if constexpr(std::is_same_v<decltype(layout), ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return ck_tile::HostTensorDescriptor({row, col}, {stride, 1_uz});
|
||||
}
|
||||
else
|
||||
{
|
||||
return ck_tile::HostTensorDescriptor({row, col}, {1_uz, stride});
|
||||
}
|
||||
};
|
||||
|
||||
auto f_get_default_stride =
|
||||
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
|
||||
if(stride == 0)
|
||||
{
|
||||
// give a chance if stride is zero, return a default packed stride
|
||||
if constexpr(std::is_same_v<decltype(layout),
|
||||
ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return col;
|
||||
}
|
||||
else
|
||||
{
|
||||
return row;
|
||||
}
|
||||
}
|
||||
else
|
||||
return stride;
|
||||
};
|
||||
|
||||
std::size_t stride_A = f_get_default_stride(M, K, StrideA, ALayout{});
|
||||
std::size_t stride_B = f_get_default_stride(K, N, StrideB, BLayout{});
|
||||
std::size_t stride_C = f_get_default_stride(M, N, StrideC, CLayout{});
|
||||
|
||||
ck_tile::HostTensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, stride_A, ALayout{}));
|
||||
ck_tile::HostTensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, stride_B, BLayout{}));
|
||||
ck_tile::HostTensor<CDataType> c_m_n_dev_result(
|
||||
f_host_tensor_descriptor(M, N, stride_C, CLayout{}));
|
||||
|
||||
ck_tile::FillUniformDistributionIntegerValue<ADataType>{-5, 5}(a_m_k);
|
||||
ck_tile::FillUniformDistributionIntegerValue<BDataType>{-5, 5}(b_k_n);
|
||||
|
||||
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes());
|
||||
|
||||
a_m_k_dev_buf.ToDevice(a_m_k.data());
|
||||
b_k_n_dev_buf.ToDevice(b_k_n.data());
|
||||
c_m_n_dev_buf.SetZero();
|
||||
c_m_n_dev_result.SetZero();
|
||||
|
||||
ck_tile::GemmHostArgs args;
|
||||
args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer();
|
||||
args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer();
|
||||
args.c_ptr = c_m_n_dev_buf.GetDeviceBuffer();
|
||||
args.k_batch = kbatch;
|
||||
args.M = M;
|
||||
args.N = N;
|
||||
args.K = K;
|
||||
args.stride_A = stride_A;
|
||||
args.stride_B = stride_B;
|
||||
args.stride_C = stride_C;
|
||||
|
||||
invoke_gemm<PadM, PadN, PadK>(args, ck_tile::stream_config{nullptr, false});
|
||||
|
||||
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
|
||||
bool pass = true;
|
||||
|
||||
ck_tile::HostTensor<CDataType> c_m_n_host_ref(
|
||||
f_host_tensor_descriptor(M, N, stride_C, CLayout{}));
|
||||
c_m_n_host_ref.SetZero();
|
||||
|
||||
ck_tile::reference_gemm<ADataType, BDataType, AccDataType, CDataType>(
|
||||
a_m_k, b_k_n, c_m_n_host_ref);
|
||||
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
|
||||
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
|
||||
K, kbatch, max_accumulated_value);
|
||||
pass = ck_tile::check_err(c_m_n_dev_result,
|
||||
c_m_n_host_ref,
|
||||
"Error: Incorrect results!",
|
||||
rtol_atol.at(ck_tile::number<0>{}),
|
||||
rtol_atol.at(ck_tile::number<1>{}));
|
||||
std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
|
||||
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
|
||||
<< std::endl;
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
};
|
||||
4
test/ck_tile/grouped_gemm/CMakeLists.txt
Normal file
4
test/ck_tile/grouped_gemm/CMakeLists.txt
Normal file
@@ -0,0 +1,4 @@
|
||||
# Currently ck_tile is only built on gfx9
|
||||
if(GPU_TARGETS MATCHES "gfx9")
|
||||
add_gtest_executable(test_ck_tile_grouped_gemm test_grouped_gemm.cpp)
|
||||
endif()
|
||||
29
test/ck_tile/grouped_gemm/test_grouped_gemm.cpp
Normal file
29
test/ck_tile/grouped_gemm/test_grouped_gemm.cpp
Normal file
@@ -0,0 +1,29 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <tuple>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "test_grouped_gemm_util.hpp"
|
||||
|
||||
using F16 = ck_tile::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
// clang-format off
|
||||
using KernelTypes = ::testing::Types<
|
||||
// ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType
|
||||
// std::tuple< Row, Row, Row, F16, F16, F32, F16>,
|
||||
//std::tuple< Col, Row, Row, F16, F16, F32, F16>,
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16>//,
|
||||
//std::tuple< Col, Col, Row, F16, F16, F32, F16>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileGroupedGemm, KernelTypes);
|
||||
|
||||
#include "test_grouped_gemm_ut_cases.inc"
|
||||
25
test/ck_tile/grouped_gemm/test_grouped_gemm_ut_cases.inc
Normal file
25
test/ck_tile/grouped_gemm/test_grouped_gemm_ut_cases.inc
Normal file
@@ -0,0 +1,25 @@
|
||||
#pragma once
|
||||
|
||||
TYPED_TEST(TestCkTileGroupedGemm, Basic)
|
||||
{
|
||||
const int group_count = 8;
|
||||
std::vector<int> Ms;
|
||||
std::vector<int> Ns;
|
||||
std::vector<int> Ks;
|
||||
std::vector<int> stride_As;
|
||||
std::vector<int> stride_Bs;
|
||||
std::vector<int> stride_Cs;
|
||||
|
||||
for(int i = 0; i < group_count; i++)
|
||||
{
|
||||
Ms.push_back(256 + 256 * i);
|
||||
Ns.push_back(256 + 512 * i);
|
||||
Ks.push_back(256 + 64 * i);
|
||||
|
||||
stride_As.push_back(Ks[i]);
|
||||
stride_Bs.push_back(Ks[i]);
|
||||
stride_Cs.push_back(Ns[i]);
|
||||
}
|
||||
|
||||
this->Run(Ms, Ns, Ks, stride_As, stride_Bs, stride_Cs, group_count);
|
||||
}
|
||||
325
test/ck_tile/grouped_gemm/test_grouped_gemm_util.hpp
Normal file
325
test/ck_tile/grouped_gemm/test_grouped_gemm_util.hpp
Normal file
@@ -0,0 +1,325 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
#pragma once
|
||||
|
||||
#include <sstream>
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestCkTileGroupedGemm : public ::testing::Test
|
||||
{
|
||||
protected:
|
||||
using ALayout = std::tuple_element_t<0, Tuple>;
|
||||
using BLayout = std::tuple_element_t<1, Tuple>;
|
||||
using CLayout = std::tuple_element_t<2, Tuple>;
|
||||
using ADataType = std::tuple_element_t<3, Tuple>;
|
||||
using BDataType = std::tuple_element_t<4, Tuple>;
|
||||
using AccDataType = std::tuple_element_t<5, Tuple>;
|
||||
using CDataType = std::tuple_element_t<6, Tuple>;
|
||||
|
||||
struct GroupedGemKernelParam
|
||||
{
|
||||
static const bool kPadM = false;
|
||||
static const bool kPadN = false;
|
||||
static const bool kPadK = false;
|
||||
|
||||
static const int kBlockPerCu = 1;
|
||||
static const ck_tile::index_t M_Tile = 128;
|
||||
static const ck_tile::index_t N_Tile = 128;
|
||||
static const ck_tile::index_t K_Tile = 32;
|
||||
|
||||
static const ck_tile::index_t M_Warp = 2;
|
||||
static const ck_tile::index_t N_Warp = 2;
|
||||
static const ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static const ck_tile::index_t M_Warp_Tile = 32;
|
||||
static const ck_tile::index_t N_Warp_Tile = 32;
|
||||
static const ck_tile::index_t K_Warp_Tile = 8;
|
||||
};
|
||||
|
||||
using grouped_gemm_kargs = ck_tile::GemmHostArgs;
|
||||
std::size_t get_workspace_size(const std::vector<grouped_gemm_kargs>& gemm_descs)
|
||||
{
|
||||
return gemm_descs.size() * sizeof(ck_tile::GemmTransKernelArg);
|
||||
}
|
||||
|
||||
template <typename ALayout, typename BLayout, typename CLayout>
|
||||
void invoke_grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
|
||||
const ck_tile::stream_config& s,
|
||||
void* p_workspace_)
|
||||
{
|
||||
constexpr bool DoubleSmemBuffer = false;
|
||||
constexpr bool TransposeC = false;
|
||||
|
||||
constexpr ck_tile::index_t TileParitionerGroupNum = 8;
|
||||
constexpr ck_tile::index_t TileParitionerM01 = 4;
|
||||
|
||||
using GemmShape =
|
||||
ck_tile::TileGemmShape<ck_tile::sequence<GroupedGemKernelParam::M_Tile,
|
||||
GroupedGemKernelParam::N_Tile,
|
||||
GroupedGemKernelParam::K_Tile>,
|
||||
ck_tile::sequence<GroupedGemKernelParam::M_Warp,
|
||||
GroupedGemKernelParam::N_Warp,
|
||||
GroupedGemKernelParam::K_Warp>,
|
||||
ck_tile::sequence<GroupedGemKernelParam::M_Warp_Tile,
|
||||
GroupedGemKernelParam::N_Warp_Tile,
|
||||
GroupedGemKernelParam::K_Warp_Tile>>;
|
||||
using TilePartitioner = ck_tile::
|
||||
GemmSpatiallyLocalTilePartitioner<GemmShape, TileParitionerGroupNum, TileParitionerM01>;
|
||||
|
||||
using Traits = ck_tile::TileGemmTraits<GroupedGemKernelParam::kPadM,
|
||||
GroupedGemKernelParam::kPadN,
|
||||
GroupedGemKernelParam::kPadK,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout>;
|
||||
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<GroupedGemKernelParam::kPadM,
|
||||
GroupedGemKernelParam::kPadN,
|
||||
GroupedGemKernelParam::kPadK,
|
||||
DoubleSmemBuffer,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
TransposeC>;
|
||||
using GemmPipelineProblem =
|
||||
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
|
||||
|
||||
using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmPipelineProblem>;
|
||||
|
||||
const ck_tile::index_t k_grain = gemm_descs[0].k_batch * GroupedGemKernelParam::K_Tile;
|
||||
const ck_tile::index_t K_split =
|
||||
(gemm_descs[0].K + k_grain - 1) / k_grain * GroupedGemKernelParam::K_Tile;
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
|
||||
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
float ave_time{0};
|
||||
|
||||
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>;
|
||||
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<UniversalGemmProblem>;
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
CLayout,
|
||||
GemmPipelineProblem::kBlockSize,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GroupedGemKernelParam::M_Warp,
|
||||
GroupedGemKernelParam::N_Warp,
|
||||
GroupedGemKernelParam::M_Warp_Tile,
|
||||
GroupedGemKernelParam::N_Warp_Tile,
|
||||
GroupedGemKernelParam::K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC>>;
|
||||
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKargs(gemm_descs);
|
||||
|
||||
const dim3 grids = Kernel::GridSize(gemm_descs);
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
ck_tile::hip_check_error(hipMemcpyWithStream(p_workspace_,
|
||||
kargs.data(),
|
||||
get_workspace_size(gemm_descs),
|
||||
hipMemcpyHostToDevice,
|
||||
s.stream_id_));
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel: " << Kernel::GetName() << " with args:"
|
||||
<< " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z
|
||||
<< "}" << std::endl;
|
||||
}
|
||||
|
||||
ave_time = ck_tile::launch_kernel(
|
||||
s,
|
||||
ck_tile::make_kernel<blocks.x, GroupedGemKernelParam::kBlockPerCu>(
|
||||
Kernel{},
|
||||
grids,
|
||||
blocks,
|
||||
0,
|
||||
ck_tile::cast_pointer_to_constant_address_space(p_workspace_),
|
||||
gemm_descs.size()));
|
||||
return ave_time;
|
||||
};
|
||||
|
||||
if(has_hot_loop)
|
||||
{
|
||||
if(tail_num == ck_tile::TailNumber::Full)
|
||||
{
|
||||
Run(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
std::ostringstream err;
|
||||
err << "For compute pipeline tail number should always be Full, but have \""
|
||||
<< tail_num << "\" which is not supported! PrefetchStages: "
|
||||
<< BaseGemmPipeline::PrefetchStages << "\n File: " << __FILE__ << ":"
|
||||
<< __LINE__ << ", in function: " << __func__;
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
std::ostringstream err;
|
||||
err << "Num K loop must be larger than number of prefetech stages."
|
||||
<< "\n PrefetchStages: " << BaseGemmPipeline::PrefetchStages
|
||||
<< "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
void Run(const std::vector<int>& Ms,
|
||||
const std::vector<int>& Ns,
|
||||
const std::vector<int>& Ks,
|
||||
std::vector<int>& stride_As,
|
||||
std::vector<int>& stride_Bs,
|
||||
std::vector<int>& stride_Cs,
|
||||
const int group_count = 16)
|
||||
{
|
||||
using namespace ck_tile::literals;
|
||||
auto f_host_tensor_descriptor = [](std::size_t row,
|
||||
std::size_t col,
|
||||
std::size_t stride,
|
||||
auto layout) {
|
||||
if constexpr(std::is_same_v<decltype(layout), ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return ck_tile::HostTensorDescriptor({row, col}, {stride, 1_uz});
|
||||
}
|
||||
else
|
||||
{
|
||||
return ck_tile::HostTensorDescriptor({row, col}, {1_uz, stride});
|
||||
}
|
||||
};
|
||||
|
||||
auto f_get_default_stride =
|
||||
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
|
||||
if(stride == 0)
|
||||
{
|
||||
if constexpr(std::is_same_v<decltype(layout),
|
||||
ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return col;
|
||||
}
|
||||
else
|
||||
{
|
||||
return row;
|
||||
}
|
||||
}
|
||||
else
|
||||
return stride;
|
||||
};
|
||||
|
||||
std::vector<ck_tile::HostTensor<ADataType>> a_m_k_tensors;
|
||||
std::vector<ck_tile::HostTensor<BDataType>> b_k_n_tensors;
|
||||
std::vector<ck_tile::HostTensor<CDataType>> c_m_n_tensors;
|
||||
|
||||
a_m_k_tensors.reserve(group_count);
|
||||
b_k_n_tensors.reserve(group_count);
|
||||
c_m_n_tensors.reserve(group_count);
|
||||
|
||||
std::vector<std::unique_ptr<ck_tile::DeviceMem>> a_m_k_dev_buf;
|
||||
std::vector<std::unique_ptr<ck_tile::DeviceMem>> b_k_n_dev_buf;
|
||||
std::vector<std::unique_ptr<ck_tile::DeviceMem>> c_m_n_dev_buf;
|
||||
|
||||
a_m_k_dev_buf.reserve(group_count);
|
||||
b_k_n_dev_buf.reserve(group_count);
|
||||
c_m_n_dev_buf.reserve(group_count);
|
||||
|
||||
std::vector<grouped_gemm_kargs> gemm_descs;
|
||||
gemm_descs.reserve(group_count);
|
||||
|
||||
for(int i = 0; i < group_count; ++i)
|
||||
{
|
||||
const ck_tile::index_t M = Ms[i];
|
||||
const ck_tile::index_t N = Ns[i];
|
||||
const ck_tile::index_t K = Ks[i];
|
||||
|
||||
stride_As[i] = f_get_default_stride(M, N, stride_As[i], ALayout{});
|
||||
stride_Bs[i] = f_get_default_stride(K, N, stride_Bs[i], BLayout{});
|
||||
stride_Cs[i] = f_get_default_stride(M, N, stride_Cs[i], CLayout{});
|
||||
|
||||
a_m_k_tensors.push_back(ck_tile::HostTensor<ADataType>(
|
||||
f_host_tensor_descriptor(M, K, stride_As[i], ALayout{})));
|
||||
b_k_n_tensors.push_back(ck_tile::HostTensor<BDataType>(
|
||||
f_host_tensor_descriptor(K, N, stride_Bs[i], BLayout{})));
|
||||
c_m_n_tensors.push_back(ck_tile::HostTensor<CDataType>(
|
||||
f_host_tensor_descriptor(M, N, stride_Cs[i], CLayout{})));
|
||||
|
||||
std::cout << "gemm[" << i << "]"
|
||||
<< " a_m_k: " << a_m_k_tensors[i].mDesc
|
||||
<< " b_k_n: " << b_k_n_tensors[i].mDesc
|
||||
<< " c_m_n: " << c_m_n_tensors[i].mDesc << std::endl;
|
||||
|
||||
ck_tile::FillUniformDistribution<ADataType>{-5.f, 5.f}(a_m_k_tensors[i]);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-5.f, 5.f}(b_k_n_tensors[i]);
|
||||
|
||||
a_m_k_dev_buf.push_back(std::make_unique<ck_tile::DeviceMem>(
|
||||
a_m_k_tensors[i].get_element_space_size_in_bytes()));
|
||||
b_k_n_dev_buf.push_back(std::make_unique<ck_tile::DeviceMem>(
|
||||
b_k_n_tensors[i].get_element_space_size_in_bytes()));
|
||||
c_m_n_dev_buf.push_back(std::make_unique<ck_tile::DeviceMem>(
|
||||
c_m_n_tensors[i].get_element_space_size_in_bytes()));
|
||||
|
||||
a_m_k_dev_buf[i]->ToDevice(a_m_k_tensors[i].data());
|
||||
b_k_n_dev_buf[i]->ToDevice(b_k_n_tensors[i].data());
|
||||
c_m_n_dev_buf[i]->SetZero();
|
||||
c_m_n_tensors[i].SetZero();
|
||||
|
||||
const void* p_a = a_m_k_dev_buf[i]->GetDeviceBuffer();
|
||||
const void* p_b = b_k_n_dev_buf[i]->GetDeviceBuffer();
|
||||
void* p_c = c_m_n_dev_buf[i]->GetDeviceBuffer();
|
||||
|
||||
// TODO add support for kbatch > 1
|
||||
static constexpr ck_tile::index_t k_batch = 1;
|
||||
gemm_descs.push_back(
|
||||
{p_a, p_b, p_c, k_batch, M, N, K, stride_As[i], stride_Bs[i], stride_Cs[i]});
|
||||
}
|
||||
|
||||
ck_tile::DeviceMem gemm_workspace;
|
||||
gemm_workspace.Realloc(get_workspace_size(gemm_descs));
|
||||
|
||||
invoke_grouped_gemm<ALayout, BLayout, CLayout>(
|
||||
gemm_descs, ck_tile::stream_config{nullptr, false}, gemm_workspace.GetDeviceBuffer());
|
||||
|
||||
for(int i = 0; i < group_count; i++)
|
||||
{
|
||||
c_m_n_dev_buf[i]->FromDevice(c_m_n_tensors[i].data());
|
||||
}
|
||||
|
||||
bool pass{true};
|
||||
for(int i = 0; i < group_count; ++i)
|
||||
{
|
||||
ck_tile::HostTensor<CDataType> c_m_n_host_ref(
|
||||
f_host_tensor_descriptor(Ms[i], Ns[i], stride_Cs[i], CLayout{}));
|
||||
c_m_n_host_ref.SetZero();
|
||||
ck_tile::reference_gemm<ADataType, BDataType, AccDataType, CDataType>(
|
||||
a_m_k_tensors[i], b_k_n_tensors[i], c_m_n_host_ref);
|
||||
pass &= ck_tile::check_err(c_m_n_tensors[i], c_m_n_host_ref);
|
||||
}
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
};
|
||||
4
test/ck_tile/image_to_column/CMakeLists.txt
Normal file
4
test/ck_tile/image_to_column/CMakeLists.txt
Normal file
@@ -0,0 +1,4 @@
|
||||
# Currently ck_tile is only built on gfx9
|
||||
if(GPU_TARGETS MATCHES "gfx9")
|
||||
add_gtest_executable(test_tile_image_to_column test_tile_image_to_column.cpp)
|
||||
endif()
|
||||
142
test/ck_tile/image_to_column/test_tile_image_to_column.cpp
Normal file
142
test/ck_tile/image_to_column/test_tile_image_to_column.cpp
Normal file
@@ -0,0 +1,142 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <algorithm>
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/ops/image_to_column.hpp"
|
||||
|
||||
// Host API implementation
|
||||
template <typename DataType>
|
||||
class TestCkTileImageToColumn : public ::testing::Test
|
||||
{
|
||||
static constexpr ck_tile::index_t VectorSize = 1;
|
||||
static constexpr ck_tile::index_t NDimSpatial = 2;
|
||||
|
||||
protected:
|
||||
void Run(const ck_tile::conv::ConvParam conv_params)
|
||||
{
|
||||
|
||||
using ImLayout = ck_tile::tensor_layout::convolution::NHWGC;
|
||||
|
||||
const auto G = conv_params.G_;
|
||||
const auto N = conv_params.N_;
|
||||
const auto C = conv_params.C_;
|
||||
|
||||
const ck_tile::long_index_t NDoHoWo =
|
||||
N * std::accumulate(conv_params.output_spatial_lengths_.begin(),
|
||||
std::next(conv_params.output_spatial_lengths_.begin(), NDimSpatial),
|
||||
1,
|
||||
std::multiplies<>());
|
||||
|
||||
const ck_tile::long_index_t CZYX =
|
||||
C * std::accumulate(conv_params.filter_spatial_lengths_.begin(),
|
||||
std::next(conv_params.filter_spatial_lengths_.begin(), NDimSpatial),
|
||||
1,
|
||||
std::multiplies<>());
|
||||
|
||||
const auto in_desc =
|
||||
ck_tile::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<ImLayout>(
|
||||
conv_params);
|
||||
const auto out_desc = ck_tile::HostTensorDescriptor({G, NDoHoWo, CZYX});
|
||||
|
||||
// host verify
|
||||
ck_tile::HostTensor<DataType> in(in_desc);
|
||||
ck_tile::HostTensor<DataType> out_device(out_desc);
|
||||
ck_tile::HostTensor<DataType> out_host(out_desc);
|
||||
|
||||
std::cout << "input: " << in.mDesc << std::endl;
|
||||
std::cout << "output: " << out_device.mDesc << std::endl;
|
||||
|
||||
ck_tile::FillUniformDistributionIntegerValue<DataType>{-5.f, 5.f}(in);
|
||||
|
||||
ck_tile::DeviceMem in_device_buf(in.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem out_device_buf(out_device.get_element_space_size_in_bytes());
|
||||
|
||||
in_device_buf.ToDevice(in.data());
|
||||
|
||||
using thread_tile = ck_tile::sequence<4, 4>;
|
||||
using warp_tile = ck_tile::sequence<8, 128>;
|
||||
using block_tile = ck_tile::sequence<32, 128>;
|
||||
|
||||
using Shape = ck_tile::TileImageToColumnShape<thread_tile, warp_tile, block_tile>;
|
||||
|
||||
using PipelineProblem = ck_tile::BlockImageToColumnProblem<DataType,
|
||||
DataType,
|
||||
Shape,
|
||||
NDimSpatial,
|
||||
VectorSize,
|
||||
VectorSize>;
|
||||
|
||||
using Kernel = ck_tile::ImageToColumn<PipelineProblem>;
|
||||
|
||||
auto kargs = Kernel::MakeKargs(
|
||||
in_device_buf.GetDeviceBuffer(),
|
||||
out_device_buf.GetDeviceBuffer(),
|
||||
G,
|
||||
N,
|
||||
C,
|
||||
ck_tile::to_array<ck_tile::long_index_t, NDimSpatial>(
|
||||
conv_params.input_spatial_lengths_),
|
||||
ck_tile::to_array<ck_tile::long_index_t, NDimSpatial>(
|
||||
conv_params.filter_spatial_lengths_),
|
||||
ck_tile::to_array<ck_tile::long_index_t, NDimSpatial>(
|
||||
conv_params.output_spatial_lengths_),
|
||||
ck_tile::to_array<ck_tile::long_index_t, NDimSpatial + 3>(in_desc.get_strides()),
|
||||
ck_tile::to_array<ck_tile::long_index_t, 3>(out_desc.get_strides()),
|
||||
ck_tile::to_array<ck_tile::long_index_t, NDimSpatial>(conv_params.conv_filter_strides_),
|
||||
ck_tile::to_array<ck_tile::long_index_t, NDimSpatial>(
|
||||
conv_params.conv_filter_dilations_),
|
||||
ck_tile::to_array<ck_tile::long_index_t, NDimSpatial>(conv_params.input_left_pads_),
|
||||
ck_tile::to_array<ck_tile::long_index_t, NDimSpatial>(conv_params.input_right_pads_));
|
||||
|
||||
const dim3 grids = Kernel::GridSize(
|
||||
kargs.N * kargs.output_spatial_lengths[0] * kargs.output_spatial_lengths[1],
|
||||
kargs.filter_spatial_lengths[0] * kargs.filter_spatial_lengths[1] * kargs.C,
|
||||
kargs.G);
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
constexpr ck_tile::index_t kBlockPerCu = 2;
|
||||
|
||||
ck_tile::launch_kernel(
|
||||
ck_tile::stream_config{},
|
||||
ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
|
||||
// reference
|
||||
ck_tile::reference_im2col<DataType, DataType, NDimSpatial>(in, out_host, conv_params);
|
||||
|
||||
out_device_buf.FromDevice(out_device.data());
|
||||
bool pass = ck_tile::check_err(out_device, out_host);
|
||||
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
};
|
||||
|
||||
class TestCkTileImageToColumnFloat : public TestCkTileImageToColumn<float>
|
||||
{
|
||||
};
|
||||
|
||||
class TestCkTileImageToColumnHalf : public TestCkTileImageToColumn<ck_tile::half_t>
|
||||
{
|
||||
};
|
||||
|
||||
TEST_F(TestCkTileImageToColumnFloat, TestCorrectness)
|
||||
{
|
||||
this->Run({2, 2, 4, 1, 192, {3, 3}, {28, 28}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
|
||||
this->Run({2, 2, 64, 1, 64, {3, 3}, {14, 14}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
|
||||
this->Run({2, 1, 64, 1, 64, {1, 1}, {7, 7}, {3, 3}, {1, 1}, {0, 0}, {0, 0}});
|
||||
this->Run({2, 1, 64, 1, 64, {1, 1}, {3, 3}, {1, 1}, {1, 1}, {0, 0}, {0, 0}});
|
||||
this->Run({2, 2, 64, 1, 64, {3, 3}, {28, 28}, {2, 2}, {2, 2}, {1, 1}, {1, 1}});
|
||||
}
|
||||
|
||||
TEST_F(TestCkTileImageToColumnHalf, TestCorrectness)
|
||||
{
|
||||
this->Run({2, 2, 4, 1, 192, {3, 3}, {28, 28}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
|
||||
this->Run({2, 2, 64, 1, 64, {3, 3}, {14, 14}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
|
||||
this->Run({2, 1, 64, 1, 64, {1, 1}, {7, 7}, {3, 3}, {1, 1}, {0, 0}, {0, 0}});
|
||||
this->Run({2, 1, 64, 1, 64, {1, 1}, {3, 3}, {1, 1}, {1, 1}, {0, 0}, {0, 0}});
|
||||
this->Run({2, 2, 64, 1, 64, {3, 3}, {28, 28}, {2, 2}, {2, 2}, {1, 1}, {1, 1}});
|
||||
}
|
||||
Reference in New Issue
Block a user