mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-02 13:17:36 +00:00
Merge branch 'develop' into ck_tile/gemm_blockscale_abquant
This commit is contained in:
@@ -261,6 +261,7 @@ add_subdirectory(gemm_multiply_multiply_wp)
|
||||
add_subdirectory(gemm_split_k)
|
||||
add_subdirectory(gemm_universal)
|
||||
add_subdirectory(gemm_universal_preshuffle)
|
||||
add_subdirectory(gemm_ab_scale)
|
||||
add_subdirectory(gemm_b_scale)
|
||||
add_subdirectory(gemm_universal_streamk)
|
||||
add_subdirectory(gemm_reduce)
|
||||
@@ -310,3 +311,4 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx12")
|
||||
endif()
|
||||
add_subdirectory(position_embedding)
|
||||
add_subdirectory(scatter_gather)
|
||||
add_subdirectory(util)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
|
||||
0
test/ck_tile/gemm_block_scale/CMakeLists.txt
Normal file → Executable file
0
test/ck_tile/gemm_block_scale/CMakeLists.txt
Normal file → Executable file
@@ -1,5 +1,5 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
|
||||
@@ -138,8 +138,10 @@ class TestCkTileGemmQuantBase : public ::testing::Test
|
||||
const ck_tile::index_t kbatch,
|
||||
const float max_accumulated_value)
|
||||
{
|
||||
using ComputeType =
|
||||
std::conditional_t<sizeof(ADataType_) < sizeof(BDataType_), ADataType_, BDataType_>;
|
||||
using ComputeType = std::conditional_t<
|
||||
std::is_same_v<BDataType_, ck_tile::pk_fp4_raw_t>,
|
||||
ADataType_,
|
||||
std::conditional_t<sizeof(ADataType_) < sizeof(BDataType_), ADataType_, BDataType_>>;
|
||||
// Calculate thresholds
|
||||
const auto rtol = ck_tile::get_relative_threshold<ComputeType, CDataType_, AccDataType_>(
|
||||
ck_tile::integer_divide_ceil(K, kbatch));
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
@@ -16,9 +16,12 @@ using FP8 = ck_tile::fp8_t;
|
||||
using BF8 = ck_tile::bf8_t;
|
||||
using Half = ck_tile::half_t;
|
||||
using PkInt4 = ck_tile::pk_int4_t;
|
||||
using BF16 = ck_tile::bf16_t;
|
||||
using UInt8 = ck_tile::pk_fp4_raw_t;
|
||||
using BQuantGrouped = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::BQuantGrouped>;
|
||||
using GroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
using GroupSize64 = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 64>>;
|
||||
using GroupSize32 = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 32>>;
|
||||
|
||||
// 2d block sizes for BQuant
|
||||
using GroupSize2D8N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
|
||||
@@ -42,6 +45,9 @@ using BQuantTypes = ::testing::Types<
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize64>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize64>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigBase, GroupSize64>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF16, UInt8, UInt8, BF16, BQuantGrouped, GemmConfigMxFp4, GroupSize64>,
|
||||
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF16, UInt8, UInt8, BF16, BQuantGrouped, GemmConfigMxFp4, GroupSize32>,
|
||||
|
||||
// 2d cases with grouping also on the n axis
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D8N>,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
|
||||
@@ -60,6 +60,13 @@ struct GemmConfigPrefill : public GemmConfigBase
|
||||
static constexpr ck_tile::index_t K_Tile = 128;
|
||||
};
|
||||
|
||||
struct GemmConfigMxFp4 : public GemmConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 128;
|
||||
};
|
||||
|
||||
struct GemmConfigPreshuffleQuant : public GemmConfigBase
|
||||
{
|
||||
static constexpr bool PreshuffleQuant = true;
|
||||
@@ -403,7 +410,8 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGem
|
||||
void run_test_with_validation(ck_tile::index_t M, ck_tile::index_t N, ck_tile::index_t K)
|
||||
{
|
||||
const ck_tile::index_t stride_A = K;
|
||||
const ck_tile::index_t stride_B = K;
|
||||
const ck_tile::index_t stride_B =
|
||||
std::is_same_v<BDataType, ck_tile::pk_fp4_raw_t> ? (K / 2) : K;
|
||||
const ck_tile::index_t stride_C = N;
|
||||
|
||||
// BQuant uses block/grouped quantization for B matrix
|
||||
@@ -414,15 +422,27 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGem
|
||||
// Generate test data
|
||||
ck_tile::HostTensor<ADataType> a_m_k(
|
||||
ck_tile::host_tensor_descriptor(M, K, stride_A, this->is_row_major(ALayout{})));
|
||||
ck_tile::HostTensor<BDataType> b_k_n(
|
||||
ck_tile::host_tensor_descriptor(K, N, stride_B, this->is_row_major(BLayout{})));
|
||||
ck_tile::HostTensor<BDataType> b_k_n(ck_tile::host_tensor_descriptor(
|
||||
std::is_same_v<BDataType, ck_tile::pk_fp4_raw_t> ? K / 2 : K,
|
||||
N,
|
||||
stride_B,
|
||||
this->is_row_major(BLayout{})));
|
||||
ck_tile::HostTensor<QDataType> bq_bqk_bqn(
|
||||
ck_tile::host_tensor_descriptor(BQK, BQN, stride_BQ, this->is_row_major(BQLayout{})));
|
||||
|
||||
// Initialize data with random values
|
||||
ck_tile::FillUniformDistribution<ADataType>{-0.5f, 0.5f}(a_m_k);
|
||||
ck_tile::FillUniformDistribution<BDataType>{0.f, 1.f}(b_k_n);
|
||||
ck_tile::FillUniformDistribution<QDataType>{-1.0f, 1.0f}(bq_bqk_bqn);
|
||||
if constexpr(std::is_same_v<BDataType, ck_tile::pk_fp4_raw_t>)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<BDataType>{-5.0f, 5.0f}(b_k_n);
|
||||
ck_tile::FillUniformDistribution<QDataType>{125.f, 130.f}(bq_bqk_bqn);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::FillUniformDistribution<BDataType>{0.f, 1.f}(b_k_n);
|
||||
ck_tile::FillUniformDistribution<QDataType>{-1.0f, 1.0f}(bq_bqk_bqn);
|
||||
}
|
||||
|
||||
// Allocate device memory
|
||||
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size() * sizeof(ADataType));
|
||||
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size() * sizeof(BDataType));
|
||||
@@ -501,13 +521,22 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGem
|
||||
c_m_n_host_ref.SetZero();
|
||||
|
||||
// Run reference BQuant implementation
|
||||
ck_tile::reference_gemm_quant<ADataType,
|
||||
QDataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
QuantGroupSize,
|
||||
false>(a_m_k, bq_bqk_bqn, b_k_n, c_m_n_host_ref);
|
||||
if constexpr(std::is_same_v<BDataType, ck_tile::pk_fp4_raw_t>)
|
||||
ck_tile::reference_mxfp4gemm_quant<ADataType,
|
||||
QDataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
QuantGroupSize,
|
||||
false>(a_m_k, bq_bqk_bqn, b_k_n, c_m_n_host_ref);
|
||||
else
|
||||
ck_tile::reference_gemm_quant<ADataType,
|
||||
QDataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
QuantGroupSize,
|
||||
false>(a_m_k, bq_bqk_bqn, b_k_n, c_m_n_host_ref);
|
||||
|
||||
// Get device result
|
||||
ck_tile::HostTensor<CDataType> c_m_n_dev_result(
|
||||
@@ -580,33 +609,37 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGem
|
||||
has_hot_loop_v,
|
||||
tail_number_v>;
|
||||
|
||||
using GemmPipeline =
|
||||
std::conditional_t<PreshuffleB == false,
|
||||
ck_tile::BQuantGemmPipelineAgBgCrCompV3<PipelineProblem>,
|
||||
ck_tile::WPQuantBPipelineAgBgCrV2<PipelineProblem>>;
|
||||
using GemmPipeline = std::conditional_t<
|
||||
PreshuffleB == false,
|
||||
std::conditional_t<std::is_same_v<BDataType, ck_tile::pk_fp4_raw_t>,
|
||||
ck_tile::MxFp4GemmPipelineAgBgCrCompV3<PipelineProblem>,
|
||||
ck_tile::BQuantGemmPipelineAgBgCrCompV3<PipelineProblem>>,
|
||||
ck_tile::WPQuantBPipelineAgBgCrV2<PipelineProblem>>;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ck_tile::tuple<>,
|
||||
CLayout,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
Base::M_Warp,
|
||||
Base::N_Warp,
|
||||
Base::M_Warp_Tile,
|
||||
Base::N_Warp_Tile,
|
||||
Base::K_Warp_Tile,
|
||||
false, // transpose_c
|
||||
ck_tile::memory_operation_enum::set,
|
||||
1,
|
||||
false,
|
||||
1,
|
||||
TiledMMAPermuteN>>;
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
|
||||
ADataType,
|
||||
std::conditional_t<std::is_same_v<BDataType, ck_tile::pk_fp4_raw_t>,
|
||||
ADataType,
|
||||
BDataType>,
|
||||
ck_tile::tuple<>,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ck_tile::tuple<>,
|
||||
CLayout,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
Base::M_Warp,
|
||||
Base::N_Warp,
|
||||
Base::M_Warp_Tile,
|
||||
Base::N_Warp_Tile,
|
||||
Base::K_Warp_Tile,
|
||||
false, // transpose_c
|
||||
ck_tile::memory_operation_enum::set,
|
||||
1,
|
||||
false,
|
||||
1,
|
||||
TiledMMAPermuteN>>;
|
||||
|
||||
using Kernel = ck_tile::QuantGemmKernel<TilePartitioner,
|
||||
GemmPipeline,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck_tile/host/fill.hpp"
|
||||
#include "ck_tile/host/joinable_thread.hpp"
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx95")
|
||||
add_gtest_executable(test_ck_tile_wg_16x16x128_fp4 test_f32_16x16x128_fp4.cpp)
|
||||
endif()
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include "ck_tile/host.hpp"
|
||||
|
||||
9
test/gemm_ab_scale/CMakeLists.txt
Normal file
9
test/gemm_ab_scale/CMakeLists.txt
Normal file
@@ -0,0 +1,9 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx9[45]|gfx12")
|
||||
add_gtest_executable(test_gemm_ab_scale test_gemm_ab_scale.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_gemm_ab_scale PRIVATE utility device_gemm_ab_scale_instance)
|
||||
endif()
|
||||
endif()
|
||||
236
test/gemm_ab_scale/test_gemm_ab_scale.cpp
Normal file
236
test/gemm_ab_scale/test_gemm_ab_scale.cpp
Normal file
@@ -0,0 +1,236 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <tuple>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "test_gemm_ab_scale_util.hpp"
|
||||
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F32 = float;
|
||||
using F8 = ck::f8_t;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename X, typename Y>
|
||||
struct tuple_concat;
|
||||
|
||||
template <typename... Xs, typename... Ys>
|
||||
struct tuple_concat<std::tuple<Xs...>, std::tuple<Ys...>>
|
||||
{
|
||||
using type = std::tuple<Xs..., Ys...>;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGemmABScale_MK_NK : public ck::test::TestGemmABScale<
|
||||
typename tuple_concat<std::tuple<Row, Col, Row>, Tuple>::type>
|
||||
{
|
||||
};
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGemmABScale_MK_KN : public ck::test::TestGemmABScale<
|
||||
typename tuple_concat<std::tuple<Row, Row, Row>, Tuple>::type>
|
||||
{
|
||||
};
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGemmABScale_KM_KN : public ck::test::TestGemmABScale<
|
||||
typename tuple_concat<std::tuple<Col, Row, Row>, Tuple>::type>
|
||||
{
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
using KernelTypes = ::testing::Types<
|
||||
// ADataType, BDataType, ComputeDataType, EDataType
|
||||
std::tuple< F8, F32, F8, F32, F8, BF16>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
TYPED_TEST_SUITE(TestGemmABScale_MK_NK, KernelTypes);
|
||||
TYPED_TEST_SUITE(TestGemmABScale_MK_KN, KernelTypes);
|
||||
TYPED_TEST_SUITE(TestGemmABScale_KM_KN, KernelTypes);
|
||||
|
||||
// Row Col
|
||||
TYPED_TEST(TestGemmABScale_MK_NK, SmallM)
|
||||
{
|
||||
std::vector<int> Ms{1, 2, 3, 4, 5, 6};
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 1024;
|
||||
|
||||
constexpr int StrideA = K;
|
||||
constexpr int StrideB = K;
|
||||
constexpr int StrideC = N;
|
||||
|
||||
for(int M : Ms)
|
||||
this->Run(M, N, K, StrideA, StrideB, StrideC);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmABScale_MK_NK, SmallMPadK)
|
||||
{
|
||||
std::vector<int> Ms{1, 2, 3, 4, 5, 6};
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 704;
|
||||
|
||||
constexpr int StrideA = K;
|
||||
constexpr int StrideB = K;
|
||||
constexpr int StrideC = N;
|
||||
|
||||
for(int M : Ms)
|
||||
this->Run(M, N, K, StrideA, StrideB, StrideC);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmABScale_MK_NK, MidLargeM)
|
||||
{
|
||||
std::vector<int> Ms{127, 255, 312, 799, 1573};
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 1024;
|
||||
|
||||
constexpr int StrideA = K;
|
||||
constexpr int StrideB = K;
|
||||
constexpr int StrideC = N;
|
||||
|
||||
for(int M : Ms)
|
||||
this->Run(M, N, K, StrideA, StrideB, StrideC);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmABScale_MK_NK, Regular)
|
||||
{
|
||||
std::vector<int> Ms{512};
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 1024;
|
||||
|
||||
constexpr int StrideA = K;
|
||||
constexpr int StrideB = K;
|
||||
constexpr int StrideE = N;
|
||||
|
||||
for(int M : Ms)
|
||||
this->Run(M, N, K, StrideA, StrideB, StrideE);
|
||||
}
|
||||
|
||||
// Row Row
|
||||
TYPED_TEST(TestGemmABScale_MK_KN, SmallM)
|
||||
{
|
||||
std::vector<int> Ms{1, 2, 3, 4, 5, 6};
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 1024;
|
||||
|
||||
constexpr int StrideA = K;
|
||||
constexpr int StrideB = N;
|
||||
constexpr int StrideC = N;
|
||||
|
||||
for(int M : Ms)
|
||||
this->Run(M, N, K, StrideA, StrideB, StrideC);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmABScale_MK_KN, SmallMPadK)
|
||||
{
|
||||
std::vector<int> Ms{1, 2, 3, 4, 5, 6};
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 704;
|
||||
|
||||
constexpr int StrideA = K;
|
||||
constexpr int StrideB = N;
|
||||
constexpr int StrideC = N;
|
||||
|
||||
for(int M : Ms)
|
||||
this->Run(M, N, K, StrideA, StrideB, StrideC);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmABScale_MK_KN, MidLargeM)
|
||||
{
|
||||
std::vector<int> Ms{127, 255, 312, 799, 1573};
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 1024;
|
||||
|
||||
constexpr int StrideA = K;
|
||||
constexpr int StrideB = N;
|
||||
constexpr int StrideC = N;
|
||||
|
||||
for(int M : Ms)
|
||||
this->Run(M, N, K, StrideA, StrideB, StrideC);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmABScale_MK_KN, Regular)
|
||||
{
|
||||
std::vector<int> Ms{512};
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 1024;
|
||||
|
||||
constexpr int StrideA = K;
|
||||
constexpr int StrideB = N;
|
||||
constexpr int StrideE = N;
|
||||
|
||||
for(int M : Ms)
|
||||
this->Run(M, N, K, StrideA, StrideB, StrideE);
|
||||
}
|
||||
|
||||
// Col Row
|
||||
TYPED_TEST(TestGemmABScale_KM_KN, SmallM)
|
||||
{
|
||||
std::vector<int> Ms{16, 32};
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 1024;
|
||||
|
||||
constexpr int StrideB = N;
|
||||
constexpr int StrideC = N;
|
||||
|
||||
for(int M : Ms)
|
||||
{
|
||||
int StrideA = M;
|
||||
this->Run(M, N, K, StrideA, StrideB, StrideC);
|
||||
}
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmABScale_KM_KN, SmallMPadK)
|
||||
{
|
||||
std::vector<int> Ms{16, 32};
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 704;
|
||||
|
||||
constexpr int StrideB = N;
|
||||
constexpr int StrideC = N;
|
||||
|
||||
for(int M : Ms)
|
||||
{
|
||||
int StrideA = M;
|
||||
this->Run(M, N, K, StrideA, StrideB, StrideC);
|
||||
}
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmABScale_KM_KN, MidLargeM)
|
||||
{
|
||||
std::vector<int> Ms{128, 256};
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 1024;
|
||||
|
||||
constexpr int StrideB = N;
|
||||
constexpr int StrideC = N;
|
||||
|
||||
for(int M : Ms)
|
||||
{
|
||||
int StrideA = M;
|
||||
this->Run(M, N, K, StrideA, StrideB, StrideC);
|
||||
}
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmABScale_KM_KN, Regular)
|
||||
{
|
||||
std::vector<int> Ms{512};
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 1024;
|
||||
|
||||
constexpr int StrideB = N;
|
||||
constexpr int StrideE = N;
|
||||
|
||||
for(int M : Ms)
|
||||
{
|
||||
int StrideA = M;
|
||||
this->Run(M, N, K, StrideA, StrideB, StrideE);
|
||||
}
|
||||
}
|
||||
102
test/gemm_ab_scale/test_gemm_ab_scale_util.hpp
Normal file
102
test/gemm_ab_scale/test_gemm_ab_scale_util.hpp
Normal file
@@ -0,0 +1,102 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <sstream>
|
||||
#include <tuple>
|
||||
#include <vector>
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "include/ck/utility/data_type.hpp"
|
||||
#include "profiler/profile_gemm_ab_scale_impl.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace test {
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGemmABScale : public testing::Test
|
||||
{
|
||||
using F32 = float;
|
||||
|
||||
protected:
|
||||
using ALayout = std::tuple_element_t<0, Tuple>;
|
||||
using BLayout = std::tuple_element_t<1, Tuple>;
|
||||
using ELayout = std::tuple_element_t<2, Tuple>;
|
||||
using A0DataType = std::tuple_element_t<3, Tuple>;
|
||||
using A1DataType = std::tuple_element_t<4, Tuple>;
|
||||
using B0DataType = std::tuple_element_t<5, Tuple>;
|
||||
using B1DataType = std::tuple_element_t<6, Tuple>;
|
||||
using ComputeDataType = std::tuple_element_t<7, Tuple>;
|
||||
using EDataType = std::tuple_element_t<8, Tuple>;
|
||||
|
||||
public:
|
||||
static constexpr ck::index_t ScaleBlockM = 1;
|
||||
static constexpr ck::index_t ScaleBlockN = 128;
|
||||
static constexpr ck::index_t ScaleBlockK = 128;
|
||||
static constexpr bool verify_ = true;
|
||||
static constexpr int init_method_ = 1;
|
||||
static constexpr bool log_ = false;
|
||||
static constexpr bool bench_ = false;
|
||||
std::vector<int> k_batches_;
|
||||
|
||||
void SetUp() override { k_batches_ = {1, 2}; }
|
||||
|
||||
void Run(const int M,
|
||||
const int N,
|
||||
const int K,
|
||||
const int StrideA,
|
||||
const int StrideB,
|
||||
const int StrideE)
|
||||
{
|
||||
for(auto kb : k_batches_)
|
||||
{
|
||||
RunSingle(M, N, K, StrideA, StrideB, StrideE, kb);
|
||||
}
|
||||
}
|
||||
|
||||
void RunSingle(const int M,
|
||||
const int N,
|
||||
const int K,
|
||||
const int StrideA,
|
||||
const int StrideB,
|
||||
const int StrideE,
|
||||
int kbatch = 1,
|
||||
int n_warmup = 1,
|
||||
int n_iter = 10)
|
||||
{
|
||||
bool pass = ck::profiler::profile_gemm_ab_scale_impl<A0DataType,
|
||||
A1DataType,
|
||||
B0DataType,
|
||||
B1DataType,
|
||||
ComputeDataType,
|
||||
F32,
|
||||
EDataType,
|
||||
ScaleBlockM,
|
||||
ScaleBlockN,
|
||||
ScaleBlockK,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ELayout>(verify_,
|
||||
init_method_,
|
||||
log_,
|
||||
bench_,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideE,
|
||||
kbatch,
|
||||
n_warmup,
|
||||
n_iter);
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace test
|
||||
} // namespace ck
|
||||
@@ -2,8 +2,8 @@
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx9[45]|gfx12")
|
||||
add_gtest_executable(test_gemm_blockscale_wp_xdl_fp8 test_gemm_blockscale_wp_xdl_fp8.cpp)
|
||||
add_gtest_executable(test_gemm_blockscale_wp_fp8 test_gemm_blockscale_wp_fp8.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_gemm_blockscale_wp_xdl_fp8 PRIVATE utility device_gemm_blockscale_wp_instance)
|
||||
target_link_libraries(test_gemm_blockscale_wp_fp8 PRIVATE utility device_gemm_blockscale_wp_instance)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
7
test/util/CMakeLists.txt
Normal file
7
test/util/CMakeLists.txt
Normal file
@@ -0,0 +1,7 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
add_gtest_executable(unit_sequence unit_sequence.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(unit_sequence PRIVATE utility)
|
||||
endif()
|
||||
684
test/util/unit_sequence.cpp
Normal file
684
test/util/unit_sequence.cpp
Normal file
@@ -0,0 +1,684 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include "ck/utility/sequence.hpp"
|
||||
#include "ck/utility/functional.hpp"
|
||||
|
||||
using namespace ck;
|
||||
|
||||
// Test basic Sequence construction and properties
|
||||
TEST(Sequence, BasicConstruction)
|
||||
{
|
||||
using Seq = Sequence<1, 2, 3, 4, 5>;
|
||||
EXPECT_EQ(Seq::Size(), 5);
|
||||
EXPECT_EQ(Seq::mSize, 5);
|
||||
}
|
||||
|
||||
TEST(Sequence, EmptySequence)
|
||||
{
|
||||
using Seq = Sequence<>;
|
||||
EXPECT_EQ(Seq::Size(), 0);
|
||||
EXPECT_EQ(Seq::mSize, 0);
|
||||
}
|
||||
|
||||
// Test At() method
|
||||
TEST(Sequence, AtRuntime)
|
||||
{
|
||||
using Seq = Sequence<10, 20, 30, 40>;
|
||||
EXPECT_EQ(Seq::At(0), 10);
|
||||
EXPECT_EQ(Seq::At(1), 20);
|
||||
EXPECT_EQ(Seq::At(2), 30);
|
||||
EXPECT_EQ(Seq::At(3), 40);
|
||||
}
|
||||
|
||||
TEST(Sequence, AtCompileTime)
|
||||
{
|
||||
using Seq = Sequence<10, 20, 30, 40>;
|
||||
EXPECT_EQ(Seq::At(Number<0>{}), 10);
|
||||
EXPECT_EQ(Seq::At(Number<1>{}), 20);
|
||||
EXPECT_EQ(Seq::At(Number<2>{}), 30);
|
||||
EXPECT_EQ(Seq::At(Number<3>{}), 40);
|
||||
}
|
||||
|
||||
TEST(Sequence, OperatorBracket)
|
||||
{
|
||||
constexpr auto seq = Sequence<5, 10, 15>{};
|
||||
EXPECT_EQ(seq[Number<0>{}], 5);
|
||||
EXPECT_EQ(seq[Number<1>{}], 10);
|
||||
EXPECT_EQ(seq[Number<2>{}], 15);
|
||||
}
|
||||
|
||||
// Test Front() and Back()
|
||||
TEST(Sequence, FrontBack)
|
||||
{
|
||||
using Seq = Sequence<100, 200, 300>;
|
||||
EXPECT_EQ(Seq::Front(), 100);
|
||||
EXPECT_EQ(Seq::Back(), 300);
|
||||
}
|
||||
|
||||
TEST(Sequence, FrontBackSingleElement)
|
||||
{
|
||||
using Seq = Sequence<42>;
|
||||
EXPECT_EQ(Seq::Front(), 42);
|
||||
EXPECT_EQ(Seq::Back(), 42);
|
||||
}
|
||||
|
||||
// Test PushFront and PushBack
|
||||
TEST(Sequence, PushFront)
|
||||
{
|
||||
using Seq = Sequence<2, 3, 4>;
|
||||
using Result = decltype(Seq::PushFront(Sequence<1>{}));
|
||||
using Expected = Sequence<1, 2, 3, 4>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
TEST(Sequence, PushFrontNumbers)
|
||||
{
|
||||
using Seq = Sequence<3, 4>;
|
||||
using Result = decltype(Seq::PushFront(Number<1>{}, Number<2>{}));
|
||||
using Expected = Sequence<1, 2, 3, 4>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
TEST(Sequence, PushBack)
|
||||
{
|
||||
using Seq = Sequence<1, 2, 3>;
|
||||
using Result = decltype(Seq::PushBack(Sequence<4, 5>{}));
|
||||
using Expected = Sequence<1, 2, 3, 4, 5>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
TEST(Sequence, PushBackNumbers)
|
||||
{
|
||||
using Seq = Sequence<1, 2>;
|
||||
using Result = decltype(Seq::PushBack(Number<3>{}, Number<4>{}));
|
||||
using Expected = Sequence<1, 2, 3, 4>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
// Test PopFront and PopBack
|
||||
TEST(Sequence, PopFront)
|
||||
{
|
||||
using Seq = Sequence<1, 2, 3, 4>;
|
||||
using Result = decltype(Seq::PopFront());
|
||||
using Expected = Sequence<2, 3, 4>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
TEST(Sequence, PopBack)
|
||||
{
|
||||
using Seq = Sequence<1, 2, 3, 4>;
|
||||
using Result = decltype(Seq::PopBack());
|
||||
using Expected = Sequence<1, 2, 3>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
// Test Extract
|
||||
TEST(Sequence, ExtractByNumbers)
|
||||
{
|
||||
using Seq = Sequence<10, 20, 30, 40, 50>;
|
||||
using Result = decltype(Seq::Extract(Number<0>{}, Number<2>{}, Number<4>{}));
|
||||
using Expected = Sequence<10, 30, 50>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
TEST(Sequence, ExtractBySequence)
|
||||
{
|
||||
using Seq = Sequence<10, 20, 30, 40, 50>;
|
||||
using Result = decltype(Seq::Extract(Sequence<1, 3>{}));
|
||||
using Expected = Sequence<20, 40>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
// Test Modify
|
||||
TEST(Sequence, Modify)
|
||||
{
|
||||
using Seq = Sequence<1, 2, 3, 4>;
|
||||
using Result = decltype(Seq::Modify(Number<2>{}, Number<99>{}));
|
||||
using Expected = Sequence<1, 2, 99, 4>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
// Test Transform
|
||||
TEST(Sequence, Transform)
|
||||
{
|
||||
using Seq = Sequence<1, 2, 3, 4>;
|
||||
auto double_it = [](auto x) { return 2 * x; };
|
||||
using Result = decltype(Seq::Transform(double_it));
|
||||
using Expected = Sequence<2, 4, 6, 8>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
// Test Reverse
|
||||
TEST(Sequence, Reverse)
|
||||
{
|
||||
using Seq = Sequence<1, 2, 3, 4, 5>;
|
||||
using Result = decltype(Seq::Reverse());
|
||||
using Expected = Sequence<5, 4, 3, 2, 1>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
TEST(Sequence, ReverseSingleElement)
|
||||
{
|
||||
using Seq = Sequence<42>;
|
||||
using Result = decltype(Seq::Reverse());
|
||||
using Expected = Sequence<42>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
// Test ReorderGivenNew2Old
|
||||
TEST(Sequence, ReorderGivenNew2Old)
|
||||
{
|
||||
using Seq = Sequence<10, 20, 30, 40>;
|
||||
using Result = decltype(Seq::ReorderGivenNew2Old(Sequence<3, 1, 2, 0>{}));
|
||||
using Expected = Sequence<40, 20, 30, 10>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
// Test ReorderGivenOld2New
|
||||
TEST(Sequence, ReorderGivenOld2New)
|
||||
{
|
||||
using Seq = Sequence<10, 20, 30, 40>;
|
||||
using Result = decltype(Seq::ReorderGivenOld2New(Sequence<3, 1, 2, 0>{}));
|
||||
using Expected = Sequence<40, 20, 30, 10>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
// Test arithmetic_sequence_gen
|
||||
TEST(SequenceGen, ArithmeticSequence)
|
||||
{
|
||||
using Result = typename arithmetic_sequence_gen<0, 5, 1>::type;
|
||||
using Expected = Sequence<0, 1, 2, 3, 4>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
TEST(SequenceGen, ArithmeticSequenceWithIncrement)
|
||||
{
|
||||
using Result = typename arithmetic_sequence_gen<0, 10, 2>::type;
|
||||
using Expected = Sequence<0, 2, 4, 6, 8>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
TEST(SequenceGen, ArithmeticSequenceNegativeIncrement)
|
||||
{
|
||||
using Result = typename arithmetic_sequence_gen<10, 5, -1>::type;
|
||||
using Expected = Sequence<10, 9, 8, 7, 6>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
TEST(SequenceGen, ArithmeticSequenceEmpty)
|
||||
{
|
||||
using Result = typename arithmetic_sequence_gen<5, 5, 1>::type;
|
||||
using Expected = Sequence<>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
// Test uniform_sequence_gen
|
||||
TEST(SequenceGen, UniformSequence)
|
||||
{
|
||||
using Result = typename uniform_sequence_gen<5, 42>::type;
|
||||
using Expected = Sequence<42, 42, 42, 42, 42>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
TEST(SequenceGen, UniformSequenceZeroSize)
|
||||
{
|
||||
using Result = typename uniform_sequence_gen<0, 42>::type;
|
||||
using Expected = Sequence<>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
// Test make_index_sequence
|
||||
TEST(SequenceGen, MakeIndexSequence)
|
||||
{
|
||||
using Result = make_index_sequence<5>;
|
||||
using Expected = Sequence<0, 1, 2, 3, 4>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
TEST(SequenceGen, MakeIndexSequenceZero)
|
||||
{
|
||||
using Result = make_index_sequence<0>;
|
||||
using Expected = Sequence<>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
// Test sequence_merge
|
||||
TEST(SequenceMerge, MergeTwoSequences)
|
||||
{
|
||||
using Seq1 = Sequence<1, 2, 3>;
|
||||
using Seq2 = Sequence<4, 5, 6>;
|
||||
using Result = typename sequence_merge<Seq1, Seq2>::type;
|
||||
using Expected = Sequence<1, 2, 3, 4, 5, 6>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
TEST(SequenceMerge, MergeMultipleSequences)
|
||||
{
|
||||
using Seq1 = Sequence<1, 2>;
|
||||
using Seq2 = Sequence<3, 4>;
|
||||
using Seq3 = Sequence<5, 6>;
|
||||
using Result = typename sequence_merge<Seq1, Seq2, Seq3>::type;
|
||||
using Expected = Sequence<1, 2, 3, 4, 5, 6>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
TEST(SequenceMerge, MergeSingleSequence)
|
||||
{
|
||||
using Seq = Sequence<1, 2, 3>;
|
||||
using Result = typename sequence_merge<Seq>::type;
|
||||
using Expected = Sequence<1, 2, 3>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
// Test sequence_split
|
||||
TEST(SequenceSplit, SplitInMiddle)
|
||||
{
|
||||
using Seq = Sequence<1, 2, 3, 4, 5, 6>;
|
||||
using Split = sequence_split<Seq, 3>;
|
||||
using ExpectedLeft = Sequence<1, 2, 3>;
|
||||
using ExpectedRight = Sequence<4, 5, 6>;
|
||||
EXPECT_TRUE((is_same<typename Split::left_type, ExpectedLeft>::value));
|
||||
EXPECT_TRUE((is_same<typename Split::right_type, ExpectedRight>::value));
|
||||
}
|
||||
|
||||
TEST(SequenceSplit, SplitAtBeginning)
|
||||
{
|
||||
using Seq = Sequence<1, 2, 3, 4>;
|
||||
using Split = sequence_split<Seq, 0>;
|
||||
using ExpectedLeft = Sequence<>;
|
||||
using ExpectedRight = Sequence<1, 2, 3, 4>;
|
||||
EXPECT_TRUE((is_same<typename Split::left_type, ExpectedLeft>::value));
|
||||
EXPECT_TRUE((is_same<typename Split::right_type, ExpectedRight>::value));
|
||||
}
|
||||
|
||||
TEST(SequenceSplit, SplitAtEnd)
|
||||
{
|
||||
using Seq = Sequence<1, 2, 3, 4>;
|
||||
using Split = sequence_split<Seq, 4>;
|
||||
using ExpectedLeft = Sequence<1, 2, 3, 4>;
|
||||
using ExpectedRight = Sequence<>;
|
||||
EXPECT_TRUE((is_same<typename Split::left_type, ExpectedLeft>::value));
|
||||
EXPECT_TRUE((is_same<typename Split::right_type, ExpectedRight>::value));
|
||||
}
|
||||
|
||||
// Test sequence_sort
|
||||
TEST(SequenceSort, SortAscending)
|
||||
{
|
||||
using Seq = Sequence<5, 2, 8, 1, 9>;
|
||||
using Result = typename sequence_sort<Seq, math::less<index_t>>::type;
|
||||
using Expected = Sequence<1, 2, 5, 8, 9>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
TEST(SequenceSort, SortDescending)
|
||||
{
|
||||
// Create a greater-than comparator
|
||||
struct greater
|
||||
{
|
||||
__host__ __device__ constexpr bool operator()(index_t x, index_t y) const { return x > y; }
|
||||
};
|
||||
using Seq = Sequence<5, 2, 8, 1, 9>;
|
||||
using Result = typename sequence_sort<Seq, greater>::type;
|
||||
using Expected = Sequence<9, 8, 5, 2, 1>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
TEST(SequenceSort, SortAlreadySorted)
|
||||
{
|
||||
using Seq = Sequence<1, 2, 3, 4, 5>;
|
||||
using Result = typename sequence_sort<Seq, math::less<index_t>>::type;
|
||||
using Expected = Sequence<1, 2, 3, 4, 5>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
TEST(SequenceSort, SortWithDuplicates)
|
||||
{
|
||||
using Seq = Sequence<3, 1, 4, 1, 5, 9, 2, 6, 5>;
|
||||
using Result = typename sequence_sort<Seq, math::less<index_t>>::type;
|
||||
using Expected = Sequence<1, 1, 2, 3, 4, 5, 5, 6, 9>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
TEST(SequenceSort, SortEmptySequence)
|
||||
{
|
||||
using Seq = Sequence<>;
|
||||
using Result = typename sequence_sort<Seq, math::less<index_t>>::type;
|
||||
using Expected = Sequence<>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
TEST(SequenceSort, SortSingleElement)
|
||||
{
|
||||
using Seq = Sequence<42>;
|
||||
using Result = typename sequence_sort<Seq, math::less<index_t>>::type;
|
||||
using Expected = Sequence<42>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
// Test sequence_unique_sort
|
||||
TEST(SequenceUniqueSort, UniqueSort)
|
||||
{
|
||||
using Seq = Sequence<3, 1, 4, 1, 5, 9, 2, 6, 5>;
|
||||
using Result =
|
||||
typename sequence_unique_sort<Seq, math::less<index_t>, math::equal<index_t>>::type;
|
||||
using Expected = Sequence<1, 2, 3, 4, 5, 6, 9>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
TEST(SequenceUniqueSort, UniqueSortNoDuplicates)
|
||||
{
|
||||
using Seq = Sequence<5, 2, 8, 1, 9>;
|
||||
using Result =
|
||||
typename sequence_unique_sort<Seq, math::less<index_t>, math::equal<index_t>>::type;
|
||||
using Expected = Sequence<1, 2, 5, 8, 9>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
TEST(SequenceUniqueSort, UniqueSortAllSame)
|
||||
{
|
||||
using Seq = Sequence<5, 5, 5, 5>;
|
||||
using Result =
|
||||
typename sequence_unique_sort<Seq, math::less<index_t>, math::equal<index_t>>::type;
|
||||
using Expected = Sequence<5>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
// Test is_valid_sequence_map
|
||||
TEST(SequenceMap, ValidMap)
|
||||
{
|
||||
using Map = Sequence<0, 1, 2, 3>;
|
||||
EXPECT_TRUE((is_valid_sequence_map<Map>::value));
|
||||
}
|
||||
|
||||
TEST(SequenceMap, ValidMapPermuted)
|
||||
{
|
||||
using Map = Sequence<2, 0, 3, 1>;
|
||||
EXPECT_TRUE((is_valid_sequence_map<Map>::value));
|
||||
}
|
||||
|
||||
TEST(SequenceMap, InvalidMapDuplicate)
|
||||
{
|
||||
using Map = Sequence<0, 1, 1, 3>;
|
||||
EXPECT_FALSE((is_valid_sequence_map<Map>::value));
|
||||
}
|
||||
|
||||
TEST(SequenceMap, InvalidMapMissing)
|
||||
{
|
||||
using Map = Sequence<0, 1, 3, 4>;
|
||||
EXPECT_FALSE((is_valid_sequence_map<Map>::value));
|
||||
}
|
||||
|
||||
// Test sequence_map_inverse
|
||||
// Note: sequence_map_inverse inverts a mapping where Map[i] = j means old position i maps to new
|
||||
// position j The inverse gives us new position i came from old position inverse[i]
|
||||
TEST(SequenceMapInverse, InverseMap)
|
||||
{
|
||||
// Map = <2, 0, 3, 1> means: old[0]->new[2], old[1]->new[0], old[2]->new[3], old[3]->new[1]
|
||||
// Inverse should be: new[0]<-old[1], new[1]<-old[3], new[2]<-old[0], new[3]<-old[2]
|
||||
using Map = Sequence<2, 0, 3, 1>;
|
||||
using Result = typename sequence_map_inverse<Map>::type;
|
||||
// Verify by checking that Map[Result[i]] == i for all i
|
||||
EXPECT_EQ((Map::At(Number<Result::At(Number<0>{})>{}) == 0), true);
|
||||
EXPECT_EQ((Map::At(Number<Result::At(Number<1>{})>{}) == 1), true);
|
||||
EXPECT_EQ((Map::At(Number<Result::At(Number<2>{})>{}) == 2), true);
|
||||
EXPECT_EQ((Map::At(Number<Result::At(Number<3>{})>{}) == 3), true);
|
||||
}
|
||||
|
||||
TEST(SequenceMapInverse, InverseIdentityMap)
|
||||
{
|
||||
using Map = Sequence<0, 1, 2, 3>;
|
||||
using Result = typename sequence_map_inverse<Map>::type;
|
||||
// Verify by checking that Map[Result[i]] == i for all i (same as the other test)
|
||||
EXPECT_EQ((Map::At(Number<Result::At(Number<0>{})>{}) == 0), true);
|
||||
EXPECT_EQ((Map::At(Number<Result::At(Number<1>{})>{}) == 1), true);
|
||||
EXPECT_EQ((Map::At(Number<Result::At(Number<2>{})>{}) == 2), true);
|
||||
EXPECT_EQ((Map::At(Number<Result::At(Number<3>{})>{}) == 3), true);
|
||||
}
|
||||
|
||||
// Test sequence operators
|
||||
TEST(SequenceOperators, Equality)
|
||||
{
|
||||
constexpr auto seq1 = Sequence<1, 2, 3>{};
|
||||
constexpr auto seq2 = Sequence<1, 2, 3>{};
|
||||
constexpr auto seq3 = Sequence<1, 2, 4>{};
|
||||
EXPECT_TRUE(seq1 == seq2);
|
||||
EXPECT_FALSE(seq1 == seq3);
|
||||
}
|
||||
|
||||
TEST(SequenceOperators, Addition)
|
||||
{
|
||||
using Seq1 = Sequence<1, 2, 3>;
|
||||
using Seq2 = Sequence<4, 5, 6>;
|
||||
using Result = decltype(Seq1{} + Seq2{});
|
||||
using Expected = Sequence<5, 7, 9>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
TEST(SequenceOperators, Subtraction)
|
||||
{
|
||||
using Seq1 = Sequence<10, 20, 30>;
|
||||
using Seq2 = Sequence<1, 2, 3>;
|
||||
using Result = decltype(Seq1{} - Seq2{});
|
||||
using Expected = Sequence<9, 18, 27>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
TEST(SequenceOperators, Multiplication)
|
||||
{
|
||||
using Seq1 = Sequence<2, 3, 4>;
|
||||
using Seq2 = Sequence<5, 6, 7>;
|
||||
using Result = decltype(Seq1{} * Seq2{});
|
||||
using Expected = Sequence<10, 18, 28>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
TEST(SequenceOperators, Division)
|
||||
{
|
||||
using Seq1 = Sequence<10, 20, 30>;
|
||||
using Seq2 = Sequence<2, 4, 5>;
|
||||
using Result = decltype(Seq1{} / Seq2{});
|
||||
using Expected = Sequence<5, 5, 6>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
TEST(SequenceOperators, Modulo)
|
||||
{
|
||||
using Seq1 = Sequence<10, 20, 30>;
|
||||
using Seq2 = Sequence<3, 7, 8>;
|
||||
using Result = decltype(Seq1{} % Seq2{});
|
||||
using Expected = Sequence<1, 6, 6>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
TEST(SequenceOperators, AdditionWithNumber)
|
||||
{
|
||||
using Seq = Sequence<1, 2, 3>;
|
||||
using Result = decltype(Seq{} + Number<10>{});
|
||||
using Expected = Sequence<11, 12, 13>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
TEST(SequenceOperators, SubtractionWithNumber)
|
||||
{
|
||||
using Seq = Sequence<10, 20, 30>;
|
||||
using Result = decltype(Seq{} - Number<5>{});
|
||||
using Expected = Sequence<5, 15, 25>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
TEST(SequenceOperators, MultiplicationWithNumber)
|
||||
{
|
||||
using Seq = Sequence<2, 3, 4>;
|
||||
using Result = decltype(Seq{} * Number<3>{});
|
||||
using Expected = Sequence<6, 9, 12>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
TEST(SequenceOperators, DivisionWithNumber)
|
||||
{
|
||||
using Seq = Sequence<10, 20, 30>;
|
||||
using Result = decltype(Seq{} / Number<5>{});
|
||||
using Expected = Sequence<2, 4, 6>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
TEST(SequenceOperators, NumberAddition)
|
||||
{
|
||||
using Seq = Sequence<1, 2, 3>;
|
||||
using Result = decltype(Number<10>{} + Seq{});
|
||||
using Expected = Sequence<11, 12, 13>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
TEST(SequenceOperators, NumberMultiplication)
|
||||
{
|
||||
using Seq = Sequence<2, 3, 4>;
|
||||
using Result = decltype(Number<3>{} * Seq{});
|
||||
using Expected = Sequence<6, 9, 12>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
// Test helper functions
|
||||
TEST(SequenceHelpers, MergeSequences)
|
||||
{
|
||||
using Seq1 = Sequence<1, 2>;
|
||||
using Seq2 = Sequence<3, 4>;
|
||||
using Seq3 = Sequence<5, 6>;
|
||||
using Result = decltype(merge_sequences(Seq1{}, Seq2{}, Seq3{}));
|
||||
using Expected = Sequence<1, 2, 3, 4, 5, 6>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
TEST(SequenceHelpers, TransformSequencesSingle)
|
||||
{
|
||||
auto double_it = [](auto x) { return 2 * x; };
|
||||
using Seq = Sequence<1, 2, 3>;
|
||||
using Result = decltype(transform_sequences(double_it, Seq{}));
|
||||
using Expected = Sequence<2, 4, 6>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
TEST(SequenceHelpers, TransformSequencesTwo)
|
||||
{
|
||||
auto add = [](auto x, auto y) { return x + y; };
|
||||
using Seq1 = Sequence<1, 2, 3>;
|
||||
using Seq2 = Sequence<4, 5, 6>;
|
||||
using Result = decltype(transform_sequences(add, Seq1{}, Seq2{}));
|
||||
using Expected = Sequence<5, 7, 9>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
TEST(SequenceHelpers, TransformSequencesThree)
|
||||
{
|
||||
auto add3 = [](auto x, auto y, auto z) { return x + y + z; };
|
||||
using Seq1 = Sequence<1, 2, 3>;
|
||||
using Seq2 = Sequence<4, 5, 6>;
|
||||
using Seq3 = Sequence<7, 8, 9>;
|
||||
using Result = decltype(transform_sequences(add3, Seq1{}, Seq2{}, Seq3{}));
|
||||
using Expected = Sequence<12, 15, 18>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
TEST(SequenceHelpers, ReduceOnSequence)
|
||||
{
|
||||
auto add = [](auto x, auto y) { return x + y; };
|
||||
constexpr auto seq = Sequence<1, 2, 3, 4, 5>{};
|
||||
constexpr auto result = reduce_on_sequence(seq, add, Number<0>{});
|
||||
EXPECT_EQ(result, 15);
|
||||
}
|
||||
|
||||
TEST(SequenceHelpers, SequenceAnyOf)
|
||||
{
|
||||
auto is_even = [](auto x) { return x % 2 == 0; };
|
||||
constexpr auto seq1 = Sequence<1, 3, 5, 7>{};
|
||||
constexpr auto seq2 = Sequence<1, 3, 4, 7>{};
|
||||
EXPECT_FALSE(sequence_any_of(seq1, is_even));
|
||||
EXPECT_TRUE(sequence_any_of(seq2, is_even));
|
||||
}
|
||||
|
||||
TEST(SequenceHelpers, SequenceAllOf)
|
||||
{
|
||||
auto is_positive = [](auto x) { return x > 0; };
|
||||
constexpr auto seq1 = Sequence<1, 2, 3, 4>{};
|
||||
constexpr auto seq2 = Sequence<1, -2, 3, 4>{};
|
||||
EXPECT_TRUE(sequence_all_of(seq1, is_positive));
|
||||
EXPECT_FALSE(sequence_all_of(seq2, is_positive));
|
||||
}
|
||||
|
||||
// Test scan operations
|
||||
TEST(SequenceScan, ReverseInclusiveScan)
|
||||
{
|
||||
using Seq = Sequence<1, 2, 3, 4>;
|
||||
using Result =
|
||||
decltype(reverse_inclusive_scan_sequence(Seq{}, math::plus<index_t>{}, Number<0>{}));
|
||||
using Expected = Sequence<10, 9, 7, 4>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
TEST(SequenceScan, ReverseExclusiveScan)
|
||||
{
|
||||
using Seq = Sequence<1, 2, 3, 4>;
|
||||
using Result =
|
||||
decltype(reverse_exclusive_scan_sequence(Seq{}, math::plus<index_t>{}, Number<0>{}));
|
||||
using Expected = Sequence<9, 7, 4, 0>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
TEST(SequenceScan, InclusiveScan)
|
||||
{
|
||||
using Seq = Sequence<1, 2, 3, 4>;
|
||||
using Result = decltype(inclusive_scan_sequence(Seq{}, math::plus<index_t>{}, Number<0>{}));
|
||||
using Expected = Sequence<1, 3, 6, 10>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
// Test pick and modify operations
|
||||
TEST(SequencePickModify, PickElementsByIds)
|
||||
{
|
||||
using Seq = Sequence<10, 20, 30, 40, 50>;
|
||||
using Ids = Sequence<0, 2, 4>;
|
||||
using Result = decltype(pick_sequence_elements_by_ids(Seq{}, Ids{}));
|
||||
using Expected = Sequence<10, 30, 50>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
TEST(SequencePickModify, PickElementsByMask)
|
||||
{
|
||||
using Seq = Sequence<10, 20, 30, 40, 50>;
|
||||
using Mask = Sequence<1, 0, 1, 0, 1>;
|
||||
using Result = decltype(pick_sequence_elements_by_mask(Seq{}, Mask{}));
|
||||
using Expected = Sequence<10, 30, 50>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
TEST(SequencePickModify, ModifyElementsByIds)
|
||||
{
|
||||
using Seq = Sequence<10, 20, 30, 40, 50>;
|
||||
using Values = Sequence<99, 88>;
|
||||
using Ids = Sequence<1, 3>;
|
||||
using Result = decltype(modify_sequence_elements_by_ids(Seq{}, Values{}, Ids{}));
|
||||
using Expected = Sequence<10, 99, 30, 88, 50>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
// Test sequence_reduce
|
||||
TEST(SequenceReduce, ReduceTwoSequences)
|
||||
{
|
||||
using Seq1 = Sequence<1, 2, 3>;
|
||||
using Seq2 = Sequence<4, 5, 6>;
|
||||
using Result = typename sequence_reduce<math::plus<index_t>, Seq1, Seq2>::type;
|
||||
using Expected = Sequence<5, 7, 9>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
TEST(SequenceReduce, ReduceMultipleSequences)
|
||||
{
|
||||
using Seq1 = Sequence<1, 2>;
|
||||
using Seq2 = Sequence<3, 4>;
|
||||
using Seq3 = Sequence<5, 6>;
|
||||
using Result = typename sequence_reduce<math::plus<index_t>, Seq1, Seq2, Seq3>::type;
|
||||
using Expected = Sequence<9, 12>;
|
||||
EXPECT_TRUE((is_same<Result, Expected>::value));
|
||||
}
|
||||
Reference in New Issue
Block a user