From 2b1c154ad41914438b332c3147ed94929ecc944a Mon Sep 17 00:00:00 2001 From: Max Podkorytov <4273004+tenpercent@users.noreply.github.com> Date: Mon, 2 Mar 2026 00:54:14 -0800 Subject: [PATCH] [CK_TILE] Fix CShuffleEpilogue test to use correct GEMM accumulator distribution (#4518) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary The test was using LDS distribution to create the accumulator tile, but CShuffleEpilogue expects the GEMM accumulator distribution that BlockGemm produces. This mismatch caused incorrect data permutation. ## Changes - Use WarpGemmDispatcher to get correct accumulator distribution encoding - Load test input from host-initialized global memory for deterministic verification - Shard tests by data type (FP16, FP8) with gfx950-specific FP8 tests - Extract scale tests into separate target for better organization - Implement exact permutation verification (all unique values appear once) - Reduce tile size from 256x256 to 128x128 to fit in unique fp16 range - Add parameterized test configurations for various warp layouts and MFMA types ## Test plan - [x] Run new cshuffle epilogue tests 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --------- Co-authored-by: Claude Co-authored-by: systems-assistant[bot] --- test/ck_tile/epilogue/CMakeLists.txt | 15 +- .../epilogue/test_cshuffle_epilogue.cpp | 121 ------- .../test_cshuffle_epilogue_common.hpp | 142 ++++++++ .../epilogue/test_cshuffle_epilogue_fp16.cpp | 43 +++ .../epilogue/test_cshuffle_epilogue_fp8.cpp | 48 +++ .../test_cshuffle_epilogue_fp8_gfx950.cpp | 49 +++ .../epilogue/test_cshuffle_epilogue_scale.cpp | 118 +++++++ .../epilogue/test_cshuffle_epilogue_util.hpp | 314 ++++++++++++++---- 8 files changed, 660 insertions(+), 190 deletions(-) delete mode 100644 test/ck_tile/epilogue/test_cshuffle_epilogue.cpp create mode 100644 test/ck_tile/epilogue/test_cshuffle_epilogue_common.hpp create mode 100644 test/ck_tile/epilogue/test_cshuffle_epilogue_fp16.cpp create mode 100644 test/ck_tile/epilogue/test_cshuffle_epilogue_fp8.cpp create mode 100644 test/ck_tile/epilogue/test_cshuffle_epilogue_fp8_gfx950.cpp create mode 100644 test/ck_tile/epilogue/test_cshuffle_epilogue_scale.cpp diff --git a/test/ck_tile/epilogue/CMakeLists.txt b/test/ck_tile/epilogue/CMakeLists.txt index 2b3ffe33cc..b408d79509 100644 --- a/test/ck_tile/epilogue/CMakeLists.txt +++ b/test/ck_tile/epilogue/CMakeLists.txt @@ -1,4 +1,17 @@ # Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT -add_gtest_executable(test_ck_tile_cshuffle_epilogue test_cshuffle_epilogue.cpp) +add_gtest_executable(test_ck_tile_cshuffle_epilogue_fp16 test_cshuffle_epilogue_fp16.cpp) +add_gtest_executable(test_ck_tile_cshuffle_epilogue_fp8 test_cshuffle_epilogue_fp8.cpp) +add_gtest_executable(test_ck_tile_cshuffle_epilogue_scale test_cshuffle_epilogue_scale.cpp) + +if(CK_USE_OCP_FP8) + target_compile_options(test_ck_tile_cshuffle_epilogue_fp8 PRIVATE -DCK_TILE_USE_OCP_FP8) +endif() + +if(GPU_TARGETS MATCHES "gfx950") + add_gtest_executable(test_ck_tile_cshuffle_epilogue_fp8_gfx950 test_cshuffle_epilogue_fp8_gfx950.cpp) + if(CK_USE_OCP_FP8) + target_compile_options(test_ck_tile_cshuffle_epilogue_fp8_gfx950 PRIVATE -DCK_TILE_USE_OCP_FP8) + endif() +endif() diff --git a/test/ck_tile/epilogue/test_cshuffle_epilogue.cpp b/test/ck_tile/epilogue/test_cshuffle_epilogue.cpp deleted file mode 100644 index 9fbe883e32..0000000000 --- a/test/ck_tile/epilogue/test_cshuffle_epilogue.cpp +++ /dev/null @@ -1,121 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include "test_cshuffle_epilogue_util.hpp" -#include -#include - -using namespace ck_tile; - -class CShuffleEpilogueTest : public ::testing::Test -{ - protected: - void SetUp() override {} -}; - -TEST_F(CShuffleEpilogueTest, BasicHalfTest) -{ - // Basic test configuration with half_t data types - using ADataType = ck_tile::half_t; - using BDataType = ck_tile::half_t; - using AccDataType = float; - using ODataType = ck_tile::half_t; - - constexpr index_t kMPerBlock = 256; - constexpr index_t kNPerBlock = 256; - constexpr index_t MWave = 2; - constexpr index_t NWave = 2; - constexpr index_t MPerXdl = 32; - constexpr index_t NPerXdl = 32; - constexpr index_t KPerXdl = 8; - - using TestProblem = SimpleCShuffleEpilogueProblem; - - auto result = run_cshuffle_epilogue_test(ScaleType::None); - EXPECT_FLOAT_EQ(result[0], 2.0F) << "Basic CShuffleEpilogue test failed"; -} - -TEST_F(CShuffleEpilogueTest, BasicHalfTestWithScale) -{ - // Basic test configuration with half_t data types - using ADataType = ck_tile::half_t; - using BDataType = ck_tile::half_t; - using AccDataType = float; - using ODataType = ck_tile::half_t; - - constexpr index_t kMPerBlock = 256; - constexpr index_t kNPerBlock = 256; - constexpr index_t MWave = 2; - constexpr index_t NWave = 2; - constexpr index_t MPerXdl = 32; - constexpr index_t NPerXdl = 32; - constexpr index_t KPerXdl = 8; - - using TestProblem = SimpleCShuffleEpilogueProblem; - - auto result = - run_cshuffle_epilogue_test(ScaleType::RowCol); - EXPECT_FLOAT_EQ(result[0], 2.0F) << "RowCol CShuffleEpilogue test failed: first element not 2"; - EXPECT_FLOAT_EQ(result[1], 4.0F) - << "RowCol CShuffleEpilogue test failed: second element not 2*2"; -} - -TEST_F(CShuffleEpilogueTest, BasicHalfTestWithTensorScale) -{ - // Basic test configuration with half_t data types - using ADataType = ck_tile::half_t; - using BDataType = ck_tile::half_t; - using AccDataType = float; - using ODataType = ck_tile::half_t; - - constexpr index_t kMPerBlock = 256; - constexpr index_t kNPerBlock = 256; - constexpr index_t MWave = 2; - constexpr index_t NWave = 2; - constexpr index_t MPerXdl = 32; - constexpr index_t NPerXdl = 32; - constexpr index_t KPerXdl = 8; - - using TestProblem = SimpleCShuffleEpilogueProblem; - - auto result = - run_cshuffle_epilogue_test(ScaleType::Tensor); - EXPECT_FLOAT_EQ(result[0], 4.0F) - << "TensorScale CShuffleEpilogue test failed: first element not 2*2=4"; -} - -int main(int argc, char** argv) -{ - ::testing::InitGoogleTest(&argc, argv); - return RUN_ALL_TESTS(); -} diff --git a/test/ck_tile/epilogue/test_cshuffle_epilogue_common.hpp b/test/ck_tile/epilogue/test_cshuffle_epilogue_common.hpp new file mode 100644 index 0000000000..d791df263b --- /dev/null +++ b/test/ck_tile/epilogue/test_cshuffle_epilogue_common.hpp @@ -0,0 +1,142 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/** + * CShuffleEpilogue Test Infrastructure + * + * File organization: + * - test_cshuffle_epilogue_common.hpp: TileConfig template, verification helpers, + * typed test suite definition (this file) + * - test_cshuffle_epilogue_util.hpp: Kernel templates, launch helpers, test runners + * - test_cshuffle_epilogue_fp16.cpp: FP16 tile configurations + * - test_cshuffle_epilogue_fp8.cpp: FP8 tile configurations (standard) + * - test_cshuffle_epilogue_fp8_gfx950.cpp: FP8 configurations for gfx950 + * - test_cshuffle_epilogue_scale.cpp: RowCol and Tensor scaling tests + */ + +#pragma once + +#include "test_cshuffle_epilogue_util.hpp" +#include +#include +#include +#include + +// TileConfig defines a test configuration for CShuffleEpilogue. +// - ODataType_: The output data type written to global memory +// - MfmaDataType_: The data type used for MFMA instruction selection (determines valid KPerXdl) +// Defaults to ODataType_ but can differ (e.g., FP8 MFMA tiles with FP16 output +// to avoid FP8 range limitations in test verification) +template +struct TileConfig +{ + using DataType = ODataType_; + using MfmaDataType = MfmaDataType_; + static constexpr ck_tile::index_t kMPerBlock = MPerBlock_; + static constexpr ck_tile::index_t kNPerBlock = NPerBlock_; + static constexpr ck_tile::index_t MWave = MWave_; + static constexpr ck_tile::index_t NWave = NWave_; + static constexpr ck_tile::index_t MPerXdl = MPerXdl_; + static constexpr ck_tile::index_t NPerXdl = NPerXdl_; + static constexpr ck_tile::index_t KPerXdl = KPerXdl_; +}; + +// Helper to construct SimpleCShuffleEpilogueProblem from TileConfig +// Uses MfmaDataType for MFMA input types (A/B) and DataType for output +template +using MakeProblem = ck_tile::SimpleCShuffleEpilogueProblem; + +// Verification helper: check that output contains valid data from the epilogue shuffle. +// The C-shuffle epilogue loads thread-local values and writes them to output through LDS. +// We verify: correct output size, no NaN values, no unwritten zeros, and at least +// kBlockSize unique values (one per thread). +template +void verify_permutation_output(const std::vector& sorted_vals) +{ + constexpr size_t expected_size = static_cast(kMPerBlock * kNPerBlock); + + ASSERT_EQ(sorted_vals.size(), expected_size) << "Output size mismatch"; + + // Verify no NaN values + for(size_t i = 0; i < sorted_vals.size(); ++i) + { + ASSERT_FALSE(std::isnan(sorted_vals[i])) << "NaN at index " << i; + } + + // Count unique values using bit-exact comparison (sorted fp32 values from fp16 should be + // distinct) + size_t num_unique = 1; + for(size_t i = 1; i < sorted_vals.size(); ++i) + { + if(ck_tile::bit_cast(sorted_vals[i]) != + ck_tile::bit_cast(sorted_vals[i - 1])) + { + ++num_unique; + } + } + + // Verify exact permutation: all input values should appear exactly once in output + EXPECT_EQ(num_unique, expected_size) << "Expected exact permutation with " << expected_size + << " unique values, got " << num_unique; +} + +// Type-parameterized test fixture +template +class CShuffleEpilogueTypedTest : public ::testing::Test +{ +}; + +TYPED_TEST_SUITE_P(CShuffleEpilogueTypedTest); + +TYPED_TEST_P(CShuffleEpilogueTypedTest, BasicTest) +{ + using Config = TypeParam; + using DataType = typename Config::DataType; + + constexpr ck_tile::index_t kMPerBlock = Config::kMPerBlock; + constexpr ck_tile::index_t kNPerBlock = Config::kNPerBlock; + + using TestProblem = MakeProblem; + constexpr ck_tile::index_t kBlockSize = TestProblem::kBlockSize; + + auto host_output = ck_tile::run_cshuffle_epilogue_test( + ck_tile::ScaleType::None); + + // Convert output to sorted vector and verify using existing helper + auto output_vals = ck_tile::convert_and_sort_output(host_output); + verify_permutation_output(output_vals); +} + +REGISTER_TYPED_TEST_SUITE_P(CShuffleEpilogueTypedTest, BasicTest); + +// Allow this test suite to be included without instantiation (e.g., in scale tests) +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(CShuffleEpilogueTypedTest); + +// Macro to instantiate typed test suites with suppressed clang warnings +// NOLINTNEXTLINE(cppcoreguidelines-macro-usage) +#define CK_INSTANTIATE_TYPED_TEST_SUITE(Prefix, Suite, Types) \ + _Pragma("clang diagnostic push") \ + _Pragma("clang diagnostic ignored \"-Wused-but-marked-unused\"") \ + INSTANTIATE_TYPED_TEST_SUITE_P(Prefix, Suite, Types); \ + _Pragma("clang diagnostic pop") diff --git a/test/ck_tile/epilogue/test_cshuffle_epilogue_fp16.cpp b/test/ck_tile/epilogue/test_cshuffle_epilogue_fp16.cpp new file mode 100644 index 0000000000..12bb987347 --- /dev/null +++ b/test/ck_tile/epilogue/test_cshuffle_epilogue_fp16.cpp @@ -0,0 +1,43 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "test_cshuffle_epilogue_common.hpp" + +using namespace ck_tile; + +// Half precision test configurations (128x128 = 16384 elements fits in unique fp16 range) +using HalfConfig_128x128_2x2x1_32x32x8 = TileConfig; +using HalfConfig_128x128_1x4x1_16x16x16 = TileConfig; +using HalfConfig_128x128_2x2x1_16x16x16 = TileConfig; +using HalfConfig_128x128_4x1x1_16x16x16 = TileConfig; +using HalfConfig_128x128_2x2x1_32x32x16 = TileConfig; + +using HalfTestTypes = ::testing::Types; + +CK_INSTANTIATE_TYPED_TEST_SUITE(FP16, CShuffleEpilogueTypedTest, HalfTestTypes) + +// Global test environment to check for wave32 devices +class Wave32CheckEnvironment : public ::testing::Environment +{ + public: + void SetUp() override + { + int warp_size = 0; + hipError_t err = hipDeviceGetAttribute(&warp_size, hipDeviceAttributeWarpSize, 0); + if(err == hipSuccess && warp_size == 32) + { + GTEST_SKIP() << "CShuffleEpilogue tests not supported on wave32 devices"; + } + } +}; + +int main(int argc, char** argv) +{ + ::testing::InitGoogleTest(&argc, argv); + ::testing::AddGlobalTestEnvironment(new Wave32CheckEnvironment); + return RUN_ALL_TESTS(); +} diff --git a/test/ck_tile/epilogue/test_cshuffle_epilogue_fp8.cpp b/test/ck_tile/epilogue/test_cshuffle_epilogue_fp8.cpp new file mode 100644 index 0000000000..36f986a747 --- /dev/null +++ b/test/ck_tile/epilogue/test_cshuffle_epilogue_fp8.cpp @@ -0,0 +1,48 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "test_cshuffle_epilogue_common.hpp" + +using namespace ck_tile; + +// FP8 MFMA tile configurations with half_t output +// Using half_t output avoids FP8 range limitations while testing FP8-specific tile sizes +using FP8Config_128x128_2x2x1_16x16x16 = TileConfig; +using FP8Config_128x128_1x4x1_16x16x16 = TileConfig; +using FP8Config_128x128_4x1x1_16x16x16 = TileConfig; +using FP8Config_128x128_2x2x1_32x32x16 = TileConfig; +using FP8Config_128x128_2x2x1_16x16x32 = TileConfig; +using FP8Config_128x128_2x2x1_32x32x32 = TileConfig; +using FP8Config_128x128_2x2x1_16x16x64 = TileConfig; + +using FP8TestTypes = ::testing::Types; + +CK_INSTANTIATE_TYPED_TEST_SUITE(FP8, CShuffleEpilogueTypedTest, FP8TestTypes) + +// Global test environment to check for wave32 devices +class Wave32CheckEnvironment : public ::testing::Environment +{ + public: + void SetUp() override + { + int warp_size = 0; + hipError_t err = hipDeviceGetAttribute(&warp_size, hipDeviceAttributeWarpSize, 0); + if(err == hipSuccess && warp_size == 32) + { + GTEST_SKIP() << "CShuffleEpilogue tests not supported on wave32 devices"; + } + } +}; + +int main(int argc, char** argv) +{ + ::testing::InitGoogleTest(&argc, argv); + ::testing::AddGlobalTestEnvironment(new Wave32CheckEnvironment); + return RUN_ALL_TESTS(); +} diff --git a/test/ck_tile/epilogue/test_cshuffle_epilogue_fp8_gfx950.cpp b/test/ck_tile/epilogue/test_cshuffle_epilogue_fp8_gfx950.cpp new file mode 100644 index 0000000000..f5c988ef54 --- /dev/null +++ b/test/ck_tile/epilogue/test_cshuffle_epilogue_fp8_gfx950.cpp @@ -0,0 +1,49 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "test_cshuffle_epilogue_common.hpp" + +using namespace ck_tile; + +// FP8 MFMA tile configurations for gfx950-specific tile sizes with half_t output +// Using half_t output avoids FP8 range limitations while testing FP8-specific tile sizes +// 2x2 warp layout +using FP8Config_128x128_2x2x1_16x16x128 = TileConfig; +using FP8Config_128x128_2x2x1_32x32x64 = TileConfig; +// 1x4 warp layout +using FP8Config_128x128_1x4x1_16x16x128 = TileConfig; +using FP8Config_128x128_1x4x1_32x32x64 = TileConfig; +// 4x1 warp layout +using FP8Config_128x128_4x1x1_16x16x128 = TileConfig; +using FP8Config_128x128_4x1x1_32x32x64 = TileConfig; + +using FP8Gfx950TestTypes = ::testing::Types; + +CK_INSTANTIATE_TYPED_TEST_SUITE(FP8Gfx950, CShuffleEpilogueTypedTest, FP8Gfx950TestTypes) + +// Global test environment to check for wave32 devices +class Wave32CheckEnvironment : public ::testing::Environment +{ + public: + void SetUp() override + { + int warp_size = 0; + hipError_t err = hipDeviceGetAttribute(&warp_size, hipDeviceAttributeWarpSize, 0); + if(err == hipSuccess && warp_size == 32) + { + GTEST_SKIP() << "CShuffleEpilogue tests not supported on wave32 devices"; + } + } +}; + +int main(int argc, char** argv) +{ + ::testing::InitGoogleTest(&argc, argv); + ::testing::AddGlobalTestEnvironment(new Wave32CheckEnvironment); + return RUN_ALL_TESTS(); +} diff --git a/test/ck_tile/epilogue/test_cshuffle_epilogue_scale.cpp b/test/ck_tile/epilogue/test_cshuffle_epilogue_scale.cpp new file mode 100644 index 0000000000..dd57ea64fa --- /dev/null +++ b/test/ck_tile/epilogue/test_cshuffle_epilogue_scale.cpp @@ -0,0 +1,118 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "test_cshuffle_epilogue_common.hpp" + +using namespace ck_tile; + +namespace { +constexpr float kScaleEpsilon = 0.001F; +constexpr float kTestScaleFactor = 2.0F; +constexpr ck_tile::index_t kScaledColIndex = 1; +} // namespace + +// Half precision test configuration for scale tests (128x128 fits in unique fp16 range) +using HalfConfig = TileConfig; +using ScaleTestProblem = MakeProblem; + +class CShuffleEpilogueScaleTest : public ::testing::Test +{ +}; + +TEST_F(CShuffleEpilogueScaleTest, HalfTestWithRowColScale) +{ + // Run both unscaled and scaled tests + auto results = run_scale_comparison_test(); + + // With RowCol scaling, column kScaledColIndex is scaled by kTestScaleFactor + // while other columns are scaled by kIdentityScale. + // Verify scaling behavior for the first MPerXdl * MWave rows. + const index_t rows_to_check = + std::min(HalfConfig::kMPerBlock, HalfConfig::MPerXdl * HalfConfig::MWave); + + constexpr index_t kUnscaledCol = 0; + constexpr index_t kScaledCol = kScaledColIndex; + + size_t col0_unchanged_count = 0; + size_t col1_scaled_count = 0; + + for(index_t row = 0; row < rows_to_check; ++row) + { + const size_t col0_idx = static_cast(row * HalfConfig::kNPerBlock + kUnscaledCol); + const size_t col1_idx = static_cast(row * HalfConfig::kNPerBlock + kScaledCol); + + const auto unscaled_col0 = type_convert(results.first.mData[col0_idx]); + const auto scaled_col0 = type_convert(results.second.mData[col0_idx]); + const auto unscaled_col1 = type_convert(results.first.mData[col1_idx]); + const auto scaled_col1 = type_convert(results.second.mData[col1_idx]); + + // Count rows where column 0 is unchanged (scale = kIdentityScale) + if(std::abs(scaled_col0 - unscaled_col0) < kScaleEpsilon) + { + col0_unchanged_count++; + } + + // Count rows where column 1 is scaled by kTestScaleFactor + const float expected_scaled = unscaled_col1 * kTestScaleFactor; + if(std::abs(scaled_col1 - expected_scaled) < kScaleEpsilon) + { + col1_scaled_count++; + } + } + + // All rows must have correct scaling + EXPECT_EQ(col0_unchanged_count, static_cast(rows_to_check)) + << "RowCol: not all rows have unchanged col0"; + EXPECT_EQ(col1_scaled_count, static_cast(rows_to_check)) + << "RowCol: not all rows have scaled col1"; +} + +TEST_F(CShuffleEpilogueScaleTest, HalfTestWithTensorScale) +{ + // Run both unscaled and scaled tests + auto results = run_scale_comparison_test(); + + // Convert both to sorted vectors using helper + auto unscaled_vals = convert_and_sort_output(results.first); + auto scaled_vals = convert_and_sort_output(results.second); + + // With Tensor scaling (m_scale=kTestScaleFactor, n_scale=kIdentityScale), + // all values should be scaled by kTestScaleFactor + EXPECT_EQ(unscaled_vals.size(), scaled_vals.size()) << "Tensor scale: output sizes differ"; + + for(size_t i = 0; i < unscaled_vals.size(); ++i) + { + const float expected = unscaled_vals[i] * kTestScaleFactor; + EXPECT_NEAR(scaled_vals[i], expected, kScaleEpsilon) + << "Tensor scale: sorted scaled[" << i << "]=" << scaled_vals[i] << " should be " + << kTestScaleFactor << "x " << unscaled_vals[i]; + } +} + +// Global test environment to check for wave32 devices +class Wave32CheckEnvironment : public ::testing::Environment +{ + public: + void SetUp() override + { + int warp_size = 0; + hipError_t err = hipDeviceGetAttribute(&warp_size, hipDeviceAttributeWarpSize, 0); + if(err == hipSuccess && warp_size == 32) + { + GTEST_SKIP() << "CShuffleEpilogue tests not supported on wave32 devices"; + } + } +}; + +int main(int argc, char** argv) +{ + ::testing::InitGoogleTest(&argc, argv); + ::testing::AddGlobalTestEnvironment(new Wave32CheckEnvironment); + return RUN_ALL_TESTS(); +} diff --git a/test/ck_tile/epilogue/test_cshuffle_epilogue_util.hpp b/test/ck_tile/epilogue/test_cshuffle_epilogue_util.hpp index 0572115201..2c258b5bb9 100644 --- a/test/ck_tile/epilogue/test_cshuffle_epilogue_util.hpp +++ b/test/ck_tile/epilogue/test_cshuffle_epilogue_util.hpp @@ -4,16 +4,18 @@ #pragma once #include "ck_tile/core.hpp" +#include "ck_tile/host/device_memory.hpp" #include "ck_tile/host/hip_check_error.hpp" +#include "ck_tile/host/host_tensor.hpp" #include "ck_tile/ops/epilogue/cshuffle_epilogue.hpp" #include "ck_tile/ops/elementwise.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" +#include +#include #include -#include -#include -#include +#include #include #include @@ -28,11 +30,14 @@ enum class ScaleType // Simple test kernel to invoke the CShuffleEpilogue template -__global__ void test_cshuffle_epilogue_kernel(typename Problem::ODataType* __restrict__ output_data, - float* m_scale, - float* n_scale) +__global__ void +test_cshuffle_epilogue_kernel(const typename Problem::AccDataType* __restrict__ input_data, + typename Problem::ODataType* __restrict__ output_data, + float* m_scale, + float* n_scale) { - using Epilogue = CShuffleEpilogue; + using Epilogue = CShuffleEpilogue; + using AccDataType = typename Epilogue::AccDataType; static_assert(Problem::kMPerBlock <= M && Problem::kNPerBlock <= N, "Block size must fit in tensor dimensions"); @@ -40,15 +45,55 @@ __global__ void test_cshuffle_epilogue_kernel(typename Problem::ODataType* __res // Allocate shared memory for epilogue __shared__ char smem[Epilogue::GetSmemSize()]; - // Create accumulator tile - constexpr auto lds_distribution_encode = - make_static_tile_distribution(Epilogue::MakeLdsDistributionEncode()); - auto acc_tile = - make_static_distributed_tensor(lds_distribution_encode); + // Create accumulator tile with GEMM accumulator distribution (matches BlockGemm) + using WG = ck_tile::WarpGemmDispatcher; - // Fill acc_tile with a simple pattern - auto& acc_buffer = acc_tile.get_thread_buffer(); - acc_buffer[0] = 2.0F; + constexpr index_t MIterPerWarp = Problem::kMPerBlock / (Problem::MWave * Problem::MPerXdl); + constexpr index_t NIterPerWarp = Problem::kNPerBlock / (Problem::NWave * Problem::NPerXdl); + + constexpr auto c_block_outer_dstr_encoding = ck_tile::tile_distribution_encoding< + ck_tile::sequence<>, + ck_tile::tuple, + ck_tile::sequence>, + ck_tile::tuple>, + ck_tile::tuple>, + ck_tile::sequence<1, 2>, + ck_tile::sequence<0, 0>>{}; + + constexpr auto acc_distribution_encode = ck_tile::detail::make_embed_tile_distribution_encoding( + c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{}); + + constexpr auto acc_distribution = make_static_tile_distribution(acc_distribution_encode); + auto acc_tile = make_static_distributed_tensor(acc_distribution); + + // Create input tensor view for loading from global memory + // Note: cast away const since buffer_view infrastructure doesn't support const pointers, + // but the input_data is only read, never written + // Use runtime values for dimensions to avoid issues with constant buffer size types + constexpr index_t kMPerBlock = Problem::kMPerBlock; + constexpr index_t kNPerBlock = Problem::kNPerBlock; + auto input_tensor_view = make_naive_tensor_view( + const_cast(input_data), + make_tuple(kMPerBlock, kNPerBlock), + make_tuple(kNPerBlock, 1), // row-major strides + number<1>{}, + number<1>{}); + + // Create tile window using the correct accumulator distribution + auto input_tile_window = + make_tile_window(input_tensor_view, + make_tuple(number{}, number{}), + {0, 0}, + acc_distribution); // Use GEMM acc distribution, not LDS distribution + + // Load input data from global memory into acc_tile + load_tile(acc_tile, input_tile_window); // Create output tensor view auto output_tensor_view = @@ -123,83 +168,216 @@ using SimpleCShuffleEpilogueProblem = false // isCTransposed >; +// Launch kernel with RowCol scaling +template +void launch_kernel_with_rowcol_scale(const typename Problem::AccDataType* device_input, + typename Problem::ODataType* device_output, + dim3 gridSize, + dim3 blockSize) +{ + HostTensor h_m_scale({M}); + HostTensor h_n_scale({N}); + for(index_t i = 0; i < M; ++i) + { + h_m_scale.mData[i] = 1.0F; + } + for(index_t i = 0; i < N; ++i) + { + h_n_scale.mData[i] = 1.0F; + } + h_n_scale.mData[1] = 2.0F; + + DeviceMem m_scale_buf(h_m_scale.get_element_space_size_in_bytes()); + DeviceMem n_scale_buf(h_n_scale.get_element_space_size_in_bytes()); + m_scale_buf.ToDevice(h_m_scale.data()); + n_scale_buf.ToDevice(h_n_scale.data()); + + test_cshuffle_epilogue_kernel + <<>>(device_input, + device_output, + static_cast(m_scale_buf.GetDeviceBuffer()), + static_cast(n_scale_buf.GetDeviceBuffer())); + HIP_CHECK_ERROR(hipGetLastError()); + HIP_CHECK_ERROR(hipDeviceSynchronize()); +} + +// Launch kernel with Tensor scaling +template +void launch_kernel_with_tensor_scale(const typename Problem::AccDataType* device_input, + typename Problem::ODataType* device_output, + dim3 gridSize, + dim3 blockSize) +{ + HostTensor h_m_scale({1}); + HostTensor h_n_scale({1}); + h_m_scale.mData[0] = 2.0F; + h_n_scale.mData[0] = 1.0F; + + DeviceMem m_scale_buf(h_m_scale.get_element_space_size_in_bytes()); + DeviceMem n_scale_buf(h_n_scale.get_element_space_size_in_bytes()); + m_scale_buf.ToDevice(h_m_scale.data()); + n_scale_buf.ToDevice(h_n_scale.data()); + + test_cshuffle_epilogue_kernel + <<>>(device_input, + device_output, + static_cast(m_scale_buf.GetDeviceBuffer()), + static_cast(n_scale_buf.GetDeviceBuffer())); + HIP_CHECK_ERROR(hipGetLastError()); + HIP_CHECK_ERROR(hipDeviceSynchronize()); +} + +// Launch kernel without scaling +template +void launch_kernel_without_scale(const typename Problem::AccDataType* device_input, + typename Problem::ODataType* device_output, + dim3 gridSize, + dim3 blockSize) +{ + test_cshuffle_epilogue_kernel + <<>>(device_input, device_output, nullptr, nullptr); + HIP_CHECK_ERROR(hipGetLastError()); + HIP_CHECK_ERROR(hipDeviceSynchronize()); +} + +/// Generate N unique fp16 bit patterns from the normal range. +/// Uses positive normals (0x0400-0x7BFF) first, then negative normals (0x8400-0xFBFF). +/// Static asserts if N > 61440 (max unique normal fp16 values). +template +constexpr std::array generate_fp16_bit_patterns() +{ + static_assert(N <= 61440, "N exceeds available unique normal fp16 values"); + + std::array result{}; + constexpr uint16_t kPosStart = 0x0400; + constexpr uint16_t kNegStart = 0x8400; + constexpr size_t kMaxPositiveNormals = 30720; + + for(size_t i = 0; i < N; ++i) + { + result[i] = (i < kMaxPositiveNormals) + ? static_cast(kPosStart + i) + : static_cast(kNegStart + (i - kMaxPositiveNormals)); + } + return result; +} + +/// Convert fp16 bit patterns to float values. +/// Performs: uint16_t -> half_t (bit_cast) -> float +template +std::array convert_fp16_bits(const std::array& bits) +{ + std::array result; + for(size_t i = 0; i < N; ++i) + { + half_t h = bit_cast(bits[i]); + result[i] = type_convert(h); + } + return result; +} + +/// Generate unique fp16 values as a HostTensor for permutation testing. +/// Uses layered architecture: bit patterns -> type conversion -> HostTensor. +template +HostTensor generate_unique_fp16_input() +{ + constexpr size_t N = static_cast(Rows * Cols); + + constexpr auto bits = generate_fp16_bit_patterns(); + auto values = convert_fp16_bits(bits); + + HostTensor host_input({Rows, Cols}); + for(index_t m = 0; m < Rows; ++m) + { + for(index_t n = 0; n < Cols; ++n) + { + host_input(m, n) = static_cast(values[static_cast(m * Cols + n)]); + } + } + return host_input; +} + template auto run_cshuffle_epilogue_test(ScaleType scale = ScaleType::None) { - using ODataType = typename Problem::ODataType; + using AccDataType = typename Problem::AccDataType; + using ODataType = typename Problem::ODataType; constexpr index_t kMPerBlock = Problem::kMPerBlock; constexpr index_t kNPerBlock = Problem::kNPerBlock; - index_t kBlockSize = ck_tile::is_wave32() ? Problem::kBlockSize / 2 : Problem::kBlockSize; + const index_t kBlockSize = ck_tile::is_wave32() ? Problem::kBlockSize / 2 : Problem::kBlockSize; std::cout << "Running CShuffleEpilogue test with M=" << M << ", N=" << N << ", MPerBlock=" << kMPerBlock << ", NPerBlock=" << kNPerBlock << ", BlockSize=" << kBlockSize << std::endl; - // Allocate host memory - const size_t output_size = M * N; + HostTensor host_input = + generate_unique_fp16_input(); - std::vector host_output(output_size, static_cast(0)); + // Allocate device input and copy from host + DeviceMem device_input_buf(host_input.get_element_space_size_in_bytes()); + device_input_buf.ToDevice(host_input.data()); + auto* device_input = static_cast(device_input_buf.GetDeviceBuffer()); - // Allocate device memory - ODataType* device_output; + // Allocate host output memory + HostTensor host_output({M, N}); + host_output.SetZero(); - HIP_CHECK_ERROR(hipMalloc(&device_output, output_size * sizeof(ODataType))); + // Allocate device output memory + DeviceMem device_output_buf(host_output.get_element_space_size_in_bytes()); + device_output_buf.ToDevice(host_output.data()); + ODataType* device_output = static_cast(device_output_buf.GetDeviceBuffer()); - HIP_CHECK_ERROR(hipMemcpy( - device_output, host_output.data(), output_size * sizeof(ODataType), hipMemcpyHostToDevice)); - - // Launch kernel + // Launch kernel with appropriate scale configuration dim3 gridSize(1, 1, 1); dim3 blockSize(kBlockSize, 1, 1); - if(scale == ScaleType::RowCol) + switch(scale) { - float* m_scale; - float* n_scale; - std::vector h_m_scale(M, 1.0F); - std::vector h_n_scale(N, 1.0F); - h_n_scale[1] = 2.0F; // multiply one col only with 2 - HIP_CHECK_ERROR(hipMalloc(&m_scale, M * sizeof(float))); - HIP_CHECK_ERROR(hipMalloc(&n_scale, N * sizeof(float))); - HIP_CHECK_ERROR( - hipMemcpy(m_scale, h_m_scale.data(), M * sizeof(float), hipMemcpyHostToDevice)); - HIP_CHECK_ERROR( - hipMemcpy(n_scale, h_n_scale.data(), N * sizeof(float), hipMemcpyHostToDevice)); - test_cshuffle_epilogue_kernel - <<>>(device_output, m_scale, n_scale); + case ScaleType::RowCol: + launch_kernel_with_rowcol_scale( + device_input, device_output, gridSize, blockSize); + break; + case ScaleType::Tensor: + launch_kernel_with_tensor_scale( + device_input, device_output, gridSize, blockSize); + break; + case ScaleType::None: + launch_kernel_without_scale( + device_input, device_output, gridSize, blockSize); + break; } - else if(scale == ScaleType::Tensor) - { - float* m_scale; - float* n_scale; - std::vector h_m_scale(1, 2.0F); - std::vector h_n_scale(1, 1.0F); - HIP_CHECK_ERROR(hipMalloc(&m_scale, sizeof(float))); - HIP_CHECK_ERROR(hipMalloc(&n_scale, sizeof(float))); - HIP_CHECK_ERROR(hipMemcpy(m_scale, h_m_scale.data(), sizeof(float), hipMemcpyHostToDevice)); - HIP_CHECK_ERROR(hipMemcpy(n_scale, h_n_scale.data(), sizeof(float), hipMemcpyHostToDevice)); - test_cshuffle_epilogue_kernel - <<>>(device_output, m_scale, n_scale); - } - else - { - test_cshuffle_epilogue_kernel - <<>>(device_output, nullptr, nullptr); - } - - // Check for kernel launch errors - HIP_CHECK_ERROR(hipGetLastError()); - HIP_CHECK_ERROR(hipDeviceSynchronize()); // Copy results back - HIP_CHECK_ERROR(hipMemcpy( - host_output.data(), device_output, output_size * sizeof(ODataType), hipMemcpyDeviceToHost)); - - // Cleanup - HIP_CHECK_ERROR(hipFree(device_output)); + device_output_buf.FromDevice(host_output.data()); return host_output; } +// Convert output values to sorted float vector for verification +// Uses float as intermediate to preserve precision for floating-point comparison +template +std::vector convert_and_sort_output(const HostTensor& output) +{ + std::vector result; + result.reserve(output.get_element_size()); + for(size_t i = 0; i < output.get_element_size(); ++i) + { + result.push_back(type_convert(output.mData[i])); + } + std::sort(result.begin(), result.end()); + return result; +} + +// Run both unscaled and scaled tests for comparison +// Returns pair of (unscaled_output, scaled_output) host tensors +template +auto run_scale_comparison_test() +{ + auto unscaled_output = run_cshuffle_epilogue_test(ScaleType::None); + auto scaled_output = run_cshuffle_epilogue_test(ScaleMode); + + return std::make_pair(std::move(unscaled_output), std::move(scaled_output)); +} + } // namespace ck_tile