[rocm-libraries] ROCm/rocm-libraries#4518 (commit dd161dc)

[CK_TILE] Fix CShuffleEpilogue test to use correct GEMM
 accumulator distribution (#4518)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

## Summary

The test was using LDS distribution to create the accumulator tile, but
CShuffleEpilogue expects the GEMM accumulator distribution that
BlockGemm produces. This mismatch caused incorrect data permutation.

## Changes

- Use WarpGemmDispatcher to get correct accumulator distribution
encoding
- Load test input from host-initialized global memory for deterministic
verification
- Shard tests by data type (FP16, FP8) with gfx950-specific FP8 tests
- Extract scale tests into separate target for better organization
- Implement exact permutation verification (all unique values appear
once)
- Reduce tile size from 256x256 to 128x128 to fit in unique fp16 range
- Add parameterized test configurations for various warp layouts and
MFMA types

## Test plan

- [x] Run new cshuffle epilogue tests

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
Max Podkorytov
2026-03-02 08:55:05 +00:00
committed by assistant-librarian[bot]
parent 78ae3835a6
commit 0438ab1b79
8 changed files with 660 additions and 190 deletions

View File

@@ -1,4 +1,17 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
add_gtest_executable(test_ck_tile_cshuffle_epilogue test_cshuffle_epilogue.cpp)
add_gtest_executable(test_ck_tile_cshuffle_epilogue_fp16 test_cshuffle_epilogue_fp16.cpp)
add_gtest_executable(test_ck_tile_cshuffle_epilogue_fp8 test_cshuffle_epilogue_fp8.cpp)
add_gtest_executable(test_ck_tile_cshuffle_epilogue_scale test_cshuffle_epilogue_scale.cpp)
if(CK_USE_OCP_FP8)
target_compile_options(test_ck_tile_cshuffle_epilogue_fp8 PRIVATE -DCK_TILE_USE_OCP_FP8)
endif()
if(GPU_TARGETS MATCHES "gfx950")
add_gtest_executable(test_ck_tile_cshuffle_epilogue_fp8_gfx950 test_cshuffle_epilogue_fp8_gfx950.cpp)
if(CK_USE_OCP_FP8)
target_compile_options(test_ck_tile_cshuffle_epilogue_fp8_gfx950 PRIVATE -DCK_TILE_USE_OCP_FP8)
endif()
endif()

View File

@@ -1,121 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "test_cshuffle_epilogue_util.hpp"
#include <gtest/gtest.h>
#include <hip/hip_runtime.h>
using namespace ck_tile;
class CShuffleEpilogueTest : public ::testing::Test
{
protected:
void SetUp() override {}
};
TEST_F(CShuffleEpilogueTest, BasicHalfTest)
{
// Basic test configuration with half_t data types
using ADataType = ck_tile::half_t;
using BDataType = ck_tile::half_t;
using AccDataType = float;
using ODataType = ck_tile::half_t;
constexpr index_t kMPerBlock = 256;
constexpr index_t kNPerBlock = 256;
constexpr index_t MWave = 2;
constexpr index_t NWave = 2;
constexpr index_t MPerXdl = 32;
constexpr index_t NPerXdl = 32;
constexpr index_t KPerXdl = 8;
using TestProblem = SimpleCShuffleEpilogueProblem<ADataType,
BDataType,
AccDataType,
ODataType,
kMPerBlock,
kNPerBlock,
MWave,
NWave,
MPerXdl,
NPerXdl,
KPerXdl>;
auto result = run_cshuffle_epilogue_test<TestProblem, kMPerBlock, kNPerBlock>(ScaleType::None);
EXPECT_FLOAT_EQ(result[0], 2.0F) << "Basic CShuffleEpilogue test failed";
}
TEST_F(CShuffleEpilogueTest, BasicHalfTestWithScale)
{
// Basic test configuration with half_t data types
using ADataType = ck_tile::half_t;
using BDataType = ck_tile::half_t;
using AccDataType = float;
using ODataType = ck_tile::half_t;
constexpr index_t kMPerBlock = 256;
constexpr index_t kNPerBlock = 256;
constexpr index_t MWave = 2;
constexpr index_t NWave = 2;
constexpr index_t MPerXdl = 32;
constexpr index_t NPerXdl = 32;
constexpr index_t KPerXdl = 8;
using TestProblem = SimpleCShuffleEpilogueProblem<ADataType,
BDataType,
AccDataType,
ODataType,
kMPerBlock,
kNPerBlock,
MWave,
NWave,
MPerXdl,
NPerXdl,
KPerXdl>;
auto result =
run_cshuffle_epilogue_test<TestProblem, kMPerBlock, kNPerBlock>(ScaleType::RowCol);
EXPECT_FLOAT_EQ(result[0], 2.0F) << "RowCol CShuffleEpilogue test failed: first element not 2";
EXPECT_FLOAT_EQ(result[1], 4.0F)
<< "RowCol CShuffleEpilogue test failed: second element not 2*2";
}
TEST_F(CShuffleEpilogueTest, BasicHalfTestWithTensorScale)
{
// Basic test configuration with half_t data types
using ADataType = ck_tile::half_t;
using BDataType = ck_tile::half_t;
using AccDataType = float;
using ODataType = ck_tile::half_t;
constexpr index_t kMPerBlock = 256;
constexpr index_t kNPerBlock = 256;
constexpr index_t MWave = 2;
constexpr index_t NWave = 2;
constexpr index_t MPerXdl = 32;
constexpr index_t NPerXdl = 32;
constexpr index_t KPerXdl = 8;
using TestProblem = SimpleCShuffleEpilogueProblem<ADataType,
BDataType,
AccDataType,
ODataType,
kMPerBlock,
kNPerBlock,
MWave,
NWave,
MPerXdl,
NPerXdl,
KPerXdl>;
auto result =
run_cshuffle_epilogue_test<TestProblem, kMPerBlock, kNPerBlock>(ScaleType::Tensor);
EXPECT_FLOAT_EQ(result[0], 4.0F)
<< "TensorScale CShuffleEpilogue test failed: first element not 2*2=4";
}
int main(int argc, char** argv)
{
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}

View File

@@ -0,0 +1,142 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
/**
* CShuffleEpilogue Test Infrastructure
*
* File organization:
* - test_cshuffle_epilogue_common.hpp: TileConfig template, verification helpers,
* typed test suite definition (this file)
* - test_cshuffle_epilogue_util.hpp: Kernel templates, launch helpers, test runners
* - test_cshuffle_epilogue_fp16.cpp: FP16 tile configurations
* - test_cshuffle_epilogue_fp8.cpp: FP8 tile configurations (standard)
* - test_cshuffle_epilogue_fp8_gfx950.cpp: FP8 configurations for gfx950
* - test_cshuffle_epilogue_scale.cpp: RowCol and Tensor scaling tests
*/
#pragma once
#include "test_cshuffle_epilogue_util.hpp"
#include <algorithm>
#include <cmath>
#include <gtest/gtest.h>
#include <vector>
// TileConfig defines a test configuration for CShuffleEpilogue.
// - ODataType_: The output data type written to global memory
// - MfmaDataType_: The data type used for MFMA instruction selection (determines valid KPerXdl)
// Defaults to ODataType_ but can differ (e.g., FP8 MFMA tiles with FP16 output
// to avoid FP8 range limitations in test verification)
template <typename ODataType_,
ck_tile::index_t MPerBlock_,
ck_tile::index_t NPerBlock_,
ck_tile::index_t MWave_,
ck_tile::index_t NWave_,
ck_tile::index_t MPerXdl_,
ck_tile::index_t NPerXdl_,
ck_tile::index_t KPerXdl_,
typename MfmaDataType_ = ODataType_>
struct TileConfig
{
using DataType = ODataType_;
using MfmaDataType = MfmaDataType_;
static constexpr ck_tile::index_t kMPerBlock = MPerBlock_;
static constexpr ck_tile::index_t kNPerBlock = NPerBlock_;
static constexpr ck_tile::index_t MWave = MWave_;
static constexpr ck_tile::index_t NWave = NWave_;
static constexpr ck_tile::index_t MPerXdl = MPerXdl_;
static constexpr ck_tile::index_t NPerXdl = NPerXdl_;
static constexpr ck_tile::index_t KPerXdl = KPerXdl_;
};
// Helper to construct SimpleCShuffleEpilogueProblem from TileConfig
// Uses MfmaDataType for MFMA input types (A/B) and DataType for output
template <typename Config, typename AccDataType = float>
using MakeProblem = ck_tile::SimpleCShuffleEpilogueProblem<typename Config::MfmaDataType,
typename Config::MfmaDataType,
AccDataType,
typename Config::DataType,
Config::kMPerBlock,
Config::kNPerBlock,
Config::MWave,
Config::NWave,
Config::MPerXdl,
Config::NPerXdl,
Config::KPerXdl>;
// Verification helper: check that output contains valid data from the epilogue shuffle.
// The C-shuffle epilogue loads thread-local values and writes them to output through LDS.
// We verify: correct output size, no NaN values, no unwritten zeros, and at least
// kBlockSize unique values (one per thread).
template <typename DataType,
ck_tile::index_t kMPerBlock,
ck_tile::index_t kNPerBlock,
ck_tile::index_t kBlockSize>
void verify_permutation_output(const std::vector<float>& sorted_vals)
{
constexpr size_t expected_size = static_cast<size_t>(kMPerBlock * kNPerBlock);
ASSERT_EQ(sorted_vals.size(), expected_size) << "Output size mismatch";
// Verify no NaN values
for(size_t i = 0; i < sorted_vals.size(); ++i)
{
ASSERT_FALSE(std::isnan(sorted_vals[i])) << "NaN at index " << i;
}
// Count unique values using bit-exact comparison (sorted fp32 values from fp16 should be
// distinct)
size_t num_unique = 1;
for(size_t i = 1; i < sorted_vals.size(); ++i)
{
if(ck_tile::bit_cast<uint32_t>(sorted_vals[i]) !=
ck_tile::bit_cast<uint32_t>(sorted_vals[i - 1]))
{
++num_unique;
}
}
// Verify exact permutation: all input values should appear exactly once in output
EXPECT_EQ(num_unique, expected_size) << "Expected exact permutation with " << expected_size
<< " unique values, got " << num_unique;
}
// Type-parameterized test fixture
template <typename Config>
class CShuffleEpilogueTypedTest : public ::testing::Test
{
};
TYPED_TEST_SUITE_P(CShuffleEpilogueTypedTest);
TYPED_TEST_P(CShuffleEpilogueTypedTest, BasicTest)
{
using Config = TypeParam;
using DataType = typename Config::DataType;
constexpr ck_tile::index_t kMPerBlock = Config::kMPerBlock;
constexpr ck_tile::index_t kNPerBlock = Config::kNPerBlock;
using TestProblem = MakeProblem<Config>;
constexpr ck_tile::index_t kBlockSize = TestProblem::kBlockSize;
auto host_output = ck_tile::run_cshuffle_epilogue_test<TestProblem, kMPerBlock, kNPerBlock>(
ck_tile::ScaleType::None);
// Convert output to sorted vector and verify using existing helper
auto output_vals = ck_tile::convert_and_sort_output(host_output);
verify_permutation_output<DataType, kMPerBlock, kNPerBlock, kBlockSize>(output_vals);
}
REGISTER_TYPED_TEST_SUITE_P(CShuffleEpilogueTypedTest, BasicTest);
// Allow this test suite to be included without instantiation (e.g., in scale tests)
GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(CShuffleEpilogueTypedTest);
// Macro to instantiate typed test suites with suppressed clang warnings
// NOLINTNEXTLINE(cppcoreguidelines-macro-usage)
#define CK_INSTANTIATE_TYPED_TEST_SUITE(Prefix, Suite, Types) \
_Pragma("clang diagnostic push") \
_Pragma("clang diagnostic ignored \"-Wused-but-marked-unused\"") \
INSTANTIATE_TYPED_TEST_SUITE_P(Prefix, Suite, Types); \
_Pragma("clang diagnostic pop")

View File

@@ -0,0 +1,43 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "test_cshuffle_epilogue_common.hpp"
using namespace ck_tile;
// Half precision test configurations (128x128 = 16384 elements fits in unique fp16 range)
using HalfConfig_128x128_2x2x1_32x32x8 = TileConfig<half_t, 128, 128, 2, 2, 32, 32, 8>;
using HalfConfig_128x128_1x4x1_16x16x16 = TileConfig<half_t, 128, 128, 1, 4, 16, 16, 16>;
using HalfConfig_128x128_2x2x1_16x16x16 = TileConfig<half_t, 128, 128, 2, 2, 16, 16, 16>;
using HalfConfig_128x128_4x1x1_16x16x16 = TileConfig<half_t, 128, 128, 4, 1, 16, 16, 16>;
using HalfConfig_128x128_2x2x1_32x32x16 = TileConfig<half_t, 128, 128, 2, 2, 32, 32, 16>;
using HalfTestTypes = ::testing::Types<HalfConfig_128x128_2x2x1_32x32x8,
HalfConfig_128x128_1x4x1_16x16x16,
HalfConfig_128x128_2x2x1_16x16x16,
HalfConfig_128x128_4x1x1_16x16x16,
HalfConfig_128x128_2x2x1_32x32x16>;
CK_INSTANTIATE_TYPED_TEST_SUITE(FP16, CShuffleEpilogueTypedTest, HalfTestTypes)
// Global test environment to check for wave32 devices
class Wave32CheckEnvironment : public ::testing::Environment
{
public:
void SetUp() override
{
int warp_size = 0;
hipError_t err = hipDeviceGetAttribute(&warp_size, hipDeviceAttributeWarpSize, 0);
if(err == hipSuccess && warp_size == 32)
{
GTEST_SKIP() << "CShuffleEpilogue tests not supported on wave32 devices";
}
}
};
int main(int argc, char** argv)
{
::testing::InitGoogleTest(&argc, argv);
::testing::AddGlobalTestEnvironment(new Wave32CheckEnvironment);
return RUN_ALL_TESTS();
}

View File

@@ -0,0 +1,48 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "test_cshuffle_epilogue_common.hpp"
using namespace ck_tile;
// FP8 MFMA tile configurations with half_t output
// Using half_t output avoids FP8 range limitations while testing FP8-specific tile sizes
using FP8Config_128x128_2x2x1_16x16x16 = TileConfig<half_t, 128, 128, 2, 2, 16, 16, 16, fp8_t>;
using FP8Config_128x128_1x4x1_16x16x16 = TileConfig<half_t, 128, 128, 1, 4, 16, 16, 16, fp8_t>;
using FP8Config_128x128_4x1x1_16x16x16 = TileConfig<half_t, 128, 128, 4, 1, 16, 16, 16, fp8_t>;
using FP8Config_128x128_2x2x1_32x32x16 = TileConfig<half_t, 128, 128, 2, 2, 32, 32, 16, fp8_t>;
using FP8Config_128x128_2x2x1_16x16x32 = TileConfig<half_t, 128, 128, 2, 2, 16, 16, 32, fp8_t>;
using FP8Config_128x128_2x2x1_32x32x32 = TileConfig<half_t, 128, 128, 2, 2, 32, 32, 32, fp8_t>;
using FP8Config_128x128_2x2x1_16x16x64 = TileConfig<half_t, 128, 128, 2, 2, 16, 16, 64, fp8_t>;
using FP8TestTypes = ::testing::Types<FP8Config_128x128_2x2x1_16x16x16,
FP8Config_128x128_1x4x1_16x16x16,
FP8Config_128x128_4x1x1_16x16x16,
FP8Config_128x128_2x2x1_32x32x16,
FP8Config_128x128_2x2x1_16x16x32,
FP8Config_128x128_2x2x1_32x32x32,
FP8Config_128x128_2x2x1_16x16x64>;
CK_INSTANTIATE_TYPED_TEST_SUITE(FP8, CShuffleEpilogueTypedTest, FP8TestTypes)
// Global test environment to check for wave32 devices
class Wave32CheckEnvironment : public ::testing::Environment
{
public:
void SetUp() override
{
int warp_size = 0;
hipError_t err = hipDeviceGetAttribute(&warp_size, hipDeviceAttributeWarpSize, 0);
if(err == hipSuccess && warp_size == 32)
{
GTEST_SKIP() << "CShuffleEpilogue tests not supported on wave32 devices";
}
}
};
int main(int argc, char** argv)
{
::testing::InitGoogleTest(&argc, argv);
::testing::AddGlobalTestEnvironment(new Wave32CheckEnvironment);
return RUN_ALL_TESTS();
}

View File

@@ -0,0 +1,49 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "test_cshuffle_epilogue_common.hpp"
using namespace ck_tile;
// FP8 MFMA tile configurations for gfx950-specific tile sizes with half_t output
// Using half_t output avoids FP8 range limitations while testing FP8-specific tile sizes
// 2x2 warp layout
using FP8Config_128x128_2x2x1_16x16x128 = TileConfig<half_t, 128, 128, 2, 2, 16, 16, 128, fp8_t>;
using FP8Config_128x128_2x2x1_32x32x64 = TileConfig<half_t, 128, 128, 2, 2, 32, 32, 64, fp8_t>;
// 1x4 warp layout
using FP8Config_128x128_1x4x1_16x16x128 = TileConfig<half_t, 128, 128, 1, 4, 16, 16, 128, fp8_t>;
using FP8Config_128x128_1x4x1_32x32x64 = TileConfig<half_t, 128, 128, 1, 4, 32, 32, 64, fp8_t>;
// 4x1 warp layout
using FP8Config_128x128_4x1x1_16x16x128 = TileConfig<half_t, 128, 128, 4, 1, 16, 16, 128, fp8_t>;
using FP8Config_128x128_4x1x1_32x32x64 = TileConfig<half_t, 128, 128, 4, 1, 32, 32, 64, fp8_t>;
using FP8Gfx950TestTypes = ::testing::Types<FP8Config_128x128_2x2x1_16x16x128,
FP8Config_128x128_2x2x1_32x32x64,
FP8Config_128x128_1x4x1_16x16x128,
FP8Config_128x128_1x4x1_32x32x64,
FP8Config_128x128_4x1x1_16x16x128,
FP8Config_128x128_4x1x1_32x32x64>;
CK_INSTANTIATE_TYPED_TEST_SUITE(FP8Gfx950, CShuffleEpilogueTypedTest, FP8Gfx950TestTypes)
// Global test environment to check for wave32 devices
class Wave32CheckEnvironment : public ::testing::Environment
{
public:
void SetUp() override
{
int warp_size = 0;
hipError_t err = hipDeviceGetAttribute(&warp_size, hipDeviceAttributeWarpSize, 0);
if(err == hipSuccess && warp_size == 32)
{
GTEST_SKIP() << "CShuffleEpilogue tests not supported on wave32 devices";
}
}
};
int main(int argc, char** argv)
{
::testing::InitGoogleTest(&argc, argv);
::testing::AddGlobalTestEnvironment(new Wave32CheckEnvironment);
return RUN_ALL_TESTS();
}

View File

@@ -0,0 +1,118 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "test_cshuffle_epilogue_common.hpp"
using namespace ck_tile;
namespace {
constexpr float kScaleEpsilon = 0.001F;
constexpr float kTestScaleFactor = 2.0F;
constexpr ck_tile::index_t kScaledColIndex = 1;
} // namespace
// Half precision test configuration for scale tests (128x128 fits in unique fp16 range)
using HalfConfig = TileConfig<half_t, 128, 128, 2, 2, 32, 32, 8>;
using ScaleTestProblem = MakeProblem<HalfConfig>;
class CShuffleEpilogueScaleTest : public ::testing::Test
{
};
TEST_F(CShuffleEpilogueScaleTest, HalfTestWithRowColScale)
{
// Run both unscaled and scaled tests
auto results = run_scale_comparison_test<ScaleTestProblem,
HalfConfig::kMPerBlock,
HalfConfig::kNPerBlock,
ScaleType::RowCol>();
// With RowCol scaling, column kScaledColIndex is scaled by kTestScaleFactor
// while other columns are scaled by kIdentityScale.
// Verify scaling behavior for the first MPerXdl * MWave rows.
const index_t rows_to_check =
std::min(HalfConfig::kMPerBlock, HalfConfig::MPerXdl * HalfConfig::MWave);
constexpr index_t kUnscaledCol = 0;
constexpr index_t kScaledCol = kScaledColIndex;
size_t col0_unchanged_count = 0;
size_t col1_scaled_count = 0;
for(index_t row = 0; row < rows_to_check; ++row)
{
const size_t col0_idx = static_cast<size_t>(row * HalfConfig::kNPerBlock + kUnscaledCol);
const size_t col1_idx = static_cast<size_t>(row * HalfConfig::kNPerBlock + kScaledCol);
const auto unscaled_col0 = type_convert<float>(results.first.mData[col0_idx]);
const auto scaled_col0 = type_convert<float>(results.second.mData[col0_idx]);
const auto unscaled_col1 = type_convert<float>(results.first.mData[col1_idx]);
const auto scaled_col1 = type_convert<float>(results.second.mData[col1_idx]);
// Count rows where column 0 is unchanged (scale = kIdentityScale)
if(std::abs(scaled_col0 - unscaled_col0) < kScaleEpsilon)
{
col0_unchanged_count++;
}
// Count rows where column 1 is scaled by kTestScaleFactor
const float expected_scaled = unscaled_col1 * kTestScaleFactor;
if(std::abs(scaled_col1 - expected_scaled) < kScaleEpsilon)
{
col1_scaled_count++;
}
}
// All rows must have correct scaling
EXPECT_EQ(col0_unchanged_count, static_cast<size_t>(rows_to_check))
<< "RowCol: not all rows have unchanged col0";
EXPECT_EQ(col1_scaled_count, static_cast<size_t>(rows_to_check))
<< "RowCol: not all rows have scaled col1";
}
TEST_F(CShuffleEpilogueScaleTest, HalfTestWithTensorScale)
{
// Run both unscaled and scaled tests
auto results = run_scale_comparison_test<ScaleTestProblem,
HalfConfig::kMPerBlock,
HalfConfig::kNPerBlock,
ScaleType::Tensor>();
// Convert both to sorted vectors using helper
auto unscaled_vals = convert_and_sort_output(results.first);
auto scaled_vals = convert_and_sort_output(results.second);
// With Tensor scaling (m_scale=kTestScaleFactor, n_scale=kIdentityScale),
// all values should be scaled by kTestScaleFactor
EXPECT_EQ(unscaled_vals.size(), scaled_vals.size()) << "Tensor scale: output sizes differ";
for(size_t i = 0; i < unscaled_vals.size(); ++i)
{
const float expected = unscaled_vals[i] * kTestScaleFactor;
EXPECT_NEAR(scaled_vals[i], expected, kScaleEpsilon)
<< "Tensor scale: sorted scaled[" << i << "]=" << scaled_vals[i] << " should be "
<< kTestScaleFactor << "x " << unscaled_vals[i];
}
}
// Global test environment to check for wave32 devices
class Wave32CheckEnvironment : public ::testing::Environment
{
public:
void SetUp() override
{
int warp_size = 0;
hipError_t err = hipDeviceGetAttribute(&warp_size, hipDeviceAttributeWarpSize, 0);
if(err == hipSuccess && warp_size == 32)
{
GTEST_SKIP() << "CShuffleEpilogue tests not supported on wave32 devices";
}
}
};
int main(int argc, char** argv)
{
::testing::InitGoogleTest(&argc, argv);
::testing::AddGlobalTestEnvironment(new Wave32CheckEnvironment);
return RUN_ALL_TESTS();
}

View File

@@ -4,16 +4,18 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/device_memory.hpp"
#include "ck_tile/host/hip_check_error.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include "ck_tile/ops/epilogue/cshuffle_epilogue.hpp"
#include "ck_tile/ops/elementwise.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include <algorithm>
#include <array>
#include <iostream>
#include <memory>
#include <numeric>
#include <random>
#include <utility>
#include <vector>
#include <hip/hip_runtime.h>
@@ -28,11 +30,14 @@ enum class ScaleType
// Simple test kernel to invoke the CShuffleEpilogue
template <typename Problem, index_t M, index_t N, ScaleType Scale>
__global__ void test_cshuffle_epilogue_kernel(typename Problem::ODataType* __restrict__ output_data,
float* m_scale,
float* n_scale)
__global__ void
test_cshuffle_epilogue_kernel(const typename Problem::AccDataType* __restrict__ input_data,
typename Problem::ODataType* __restrict__ output_data,
float* m_scale,
float* n_scale)
{
using Epilogue = CShuffleEpilogue<Problem>;
using Epilogue = CShuffleEpilogue<Problem>;
using AccDataType = typename Epilogue::AccDataType;
static_assert(Problem::kMPerBlock <= M && Problem::kNPerBlock <= N,
"Block size must fit in tensor dimensions");
@@ -40,15 +45,55 @@ __global__ void test_cshuffle_epilogue_kernel(typename Problem::ODataType* __res
// Allocate shared memory for epilogue
__shared__ char smem[Epilogue::GetSmemSize()];
// Create accumulator tile
constexpr auto lds_distribution_encode =
make_static_tile_distribution(Epilogue::MakeLdsDistributionEncode());
auto acc_tile =
make_static_distributed_tensor<typename Epilogue::AccDataType>(lds_distribution_encode);
// Create accumulator tile with GEMM accumulator distribution (matches BlockGemm)
using WG = ck_tile::WarpGemmDispatcher<typename Epilogue::ADataType,
typename Epilogue::BDataType,
typename Problem::AccDataType,
Problem::MPerXdl,
Problem::NPerXdl,
Problem::KPerXdl,
Problem::isCTransposed>;
// Fill acc_tile with a simple pattern
auto& acc_buffer = acc_tile.get_thread_buffer();
acc_buffer[0] = 2.0F;
constexpr index_t MIterPerWarp = Problem::kMPerBlock / (Problem::MWave * Problem::MPerXdl);
constexpr index_t NIterPerWarp = Problem::kNPerBlock / (Problem::NWave * Problem::NPerXdl);
constexpr auto c_block_outer_dstr_encoding = ck_tile::tile_distribution_encoding<
ck_tile::sequence<>,
ck_tile::tuple<ck_tile::sequence<MIterPerWarp, Problem::MWave>,
ck_tile::sequence<NIterPerWarp, Problem::NWave>>,
ck_tile::tuple<ck_tile::sequence<1, 2>>,
ck_tile::tuple<ck_tile::sequence<1, 1>>,
ck_tile::sequence<1, 2>,
ck_tile::sequence<0, 0>>{};
constexpr auto acc_distribution_encode = ck_tile::detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
constexpr auto acc_distribution = make_static_tile_distribution(acc_distribution_encode);
auto acc_tile = make_static_distributed_tensor<AccDataType>(acc_distribution);
// Create input tensor view for loading from global memory
// Note: cast away const since buffer_view infrastructure doesn't support const pointers,
// but the input_data is only read, never written
// Use runtime values for dimensions to avoid issues with constant buffer size types
constexpr index_t kMPerBlock = Problem::kMPerBlock;
constexpr index_t kNPerBlock = Problem::kNPerBlock;
auto input_tensor_view = make_naive_tensor_view<address_space_enum::global>(
const_cast<AccDataType*>(input_data),
make_tuple(kMPerBlock, kNPerBlock),
make_tuple(kNPerBlock, 1), // row-major strides
number<1>{},
number<1>{});
// Create tile window using the correct accumulator distribution
auto input_tile_window =
make_tile_window(input_tensor_view,
make_tuple(number<Problem::kMPerBlock>{}, number<Problem::kNPerBlock>{}),
{0, 0},
acc_distribution); // Use GEMM acc distribution, not LDS distribution
// Load input data from global memory into acc_tile
load_tile(acc_tile, input_tile_window);
// Create output tensor view
auto output_tensor_view =
@@ -123,83 +168,216 @@ using SimpleCShuffleEpilogueProblem =
false // isCTransposed
>;
// Launch kernel with RowCol scaling
template <typename Problem, index_t M, index_t N>
void launch_kernel_with_rowcol_scale(const typename Problem::AccDataType* device_input,
typename Problem::ODataType* device_output,
dim3 gridSize,
dim3 blockSize)
{
HostTensor<float> h_m_scale({M});
HostTensor<float> h_n_scale({N});
for(index_t i = 0; i < M; ++i)
{
h_m_scale.mData[i] = 1.0F;
}
for(index_t i = 0; i < N; ++i)
{
h_n_scale.mData[i] = 1.0F;
}
h_n_scale.mData[1] = 2.0F;
DeviceMem m_scale_buf(h_m_scale.get_element_space_size_in_bytes());
DeviceMem n_scale_buf(h_n_scale.get_element_space_size_in_bytes());
m_scale_buf.ToDevice(h_m_scale.data());
n_scale_buf.ToDevice(h_n_scale.data());
test_cshuffle_epilogue_kernel<Problem, M, N, ScaleType::RowCol>
<<<gridSize, blockSize>>>(device_input,
device_output,
static_cast<float*>(m_scale_buf.GetDeviceBuffer()),
static_cast<float*>(n_scale_buf.GetDeviceBuffer()));
HIP_CHECK_ERROR(hipGetLastError());
HIP_CHECK_ERROR(hipDeviceSynchronize());
}
// Launch kernel with Tensor scaling
template <typename Problem, index_t M, index_t N>
void launch_kernel_with_tensor_scale(const typename Problem::AccDataType* device_input,
typename Problem::ODataType* device_output,
dim3 gridSize,
dim3 blockSize)
{
HostTensor<float> h_m_scale({1});
HostTensor<float> h_n_scale({1});
h_m_scale.mData[0] = 2.0F;
h_n_scale.mData[0] = 1.0F;
DeviceMem m_scale_buf(h_m_scale.get_element_space_size_in_bytes());
DeviceMem n_scale_buf(h_n_scale.get_element_space_size_in_bytes());
m_scale_buf.ToDevice(h_m_scale.data());
n_scale_buf.ToDevice(h_n_scale.data());
test_cshuffle_epilogue_kernel<Problem, M, N, ScaleType::Tensor>
<<<gridSize, blockSize>>>(device_input,
device_output,
static_cast<float*>(m_scale_buf.GetDeviceBuffer()),
static_cast<float*>(n_scale_buf.GetDeviceBuffer()));
HIP_CHECK_ERROR(hipGetLastError());
HIP_CHECK_ERROR(hipDeviceSynchronize());
}
// Launch kernel without scaling
template <typename Problem, index_t M, index_t N>
void launch_kernel_without_scale(const typename Problem::AccDataType* device_input,
typename Problem::ODataType* device_output,
dim3 gridSize,
dim3 blockSize)
{
test_cshuffle_epilogue_kernel<Problem, M, N, ScaleType::None>
<<<gridSize, blockSize>>>(device_input, device_output, nullptr, nullptr);
HIP_CHECK_ERROR(hipGetLastError());
HIP_CHECK_ERROR(hipDeviceSynchronize());
}
/// Generate N unique fp16 bit patterns from the normal range.
/// Uses positive normals (0x0400-0x7BFF) first, then negative normals (0x8400-0xFBFF).
/// Static asserts if N > 61440 (max unique normal fp16 values).
template <size_t N>
constexpr std::array<uint16_t, N> generate_fp16_bit_patterns()
{
static_assert(N <= 61440, "N exceeds available unique normal fp16 values");
std::array<uint16_t, N> result{};
constexpr uint16_t kPosStart = 0x0400;
constexpr uint16_t kNegStart = 0x8400;
constexpr size_t kMaxPositiveNormals = 30720;
for(size_t i = 0; i < N; ++i)
{
result[i] = (i < kMaxPositiveNormals)
? static_cast<uint16_t>(kPosStart + i)
: static_cast<uint16_t>(kNegStart + (i - kMaxPositiveNormals));
}
return result;
}
/// Convert fp16 bit patterns to float values.
/// Performs: uint16_t -> half_t (bit_cast) -> float
template <size_t N>
std::array<float, N> convert_fp16_bits(const std::array<uint16_t, N>& bits)
{
std::array<float, N> result;
for(size_t i = 0; i < N; ++i)
{
half_t h = bit_cast<half_t>(bits[i]);
result[i] = type_convert<float>(h);
}
return result;
}
/// Generate unique fp16 values as a HostTensor for permutation testing.
/// Uses layered architecture: bit patterns -> type conversion -> HostTensor.
template <typename AccDataType, index_t Rows, index_t Cols>
HostTensor<AccDataType> generate_unique_fp16_input()
{
constexpr size_t N = static_cast<size_t>(Rows * Cols);
constexpr auto bits = generate_fp16_bit_patterns<N>();
auto values = convert_fp16_bits(bits);
HostTensor<AccDataType> host_input({Rows, Cols});
for(index_t m = 0; m < Rows; ++m)
{
for(index_t n = 0; n < Cols; ++n)
{
host_input(m, n) = static_cast<AccDataType>(values[static_cast<size_t>(m * Cols + n)]);
}
}
return host_input;
}
template <typename Problem, index_t M, index_t N>
auto run_cshuffle_epilogue_test(ScaleType scale = ScaleType::None)
{
using ODataType = typename Problem::ODataType;
using AccDataType = typename Problem::AccDataType;
using ODataType = typename Problem::ODataType;
constexpr index_t kMPerBlock = Problem::kMPerBlock;
constexpr index_t kNPerBlock = Problem::kNPerBlock;
index_t kBlockSize = ck_tile::is_wave32() ? Problem::kBlockSize / 2 : Problem::kBlockSize;
const index_t kBlockSize = ck_tile::is_wave32() ? Problem::kBlockSize / 2 : Problem::kBlockSize;
std::cout << "Running CShuffleEpilogue test with M=" << M << ", N=" << N
<< ", MPerBlock=" << kMPerBlock << ", NPerBlock=" << kNPerBlock
<< ", BlockSize=" << kBlockSize << std::endl;
// Allocate host memory
const size_t output_size = M * N;
HostTensor<AccDataType> host_input =
generate_unique_fp16_input<AccDataType, kMPerBlock, kNPerBlock>();
std::vector<ODataType> host_output(output_size, static_cast<ODataType>(0));
// Allocate device input and copy from host
DeviceMem device_input_buf(host_input.get_element_space_size_in_bytes());
device_input_buf.ToDevice(host_input.data());
auto* device_input = static_cast<const AccDataType*>(device_input_buf.GetDeviceBuffer());
// Allocate device memory
ODataType* device_output;
// Allocate host output memory
HostTensor<ODataType> host_output({M, N});
host_output.SetZero();
HIP_CHECK_ERROR(hipMalloc(&device_output, output_size * sizeof(ODataType)));
// Allocate device output memory
DeviceMem device_output_buf(host_output.get_element_space_size_in_bytes());
device_output_buf.ToDevice(host_output.data());
ODataType* device_output = static_cast<ODataType*>(device_output_buf.GetDeviceBuffer());
HIP_CHECK_ERROR(hipMemcpy(
device_output, host_output.data(), output_size * sizeof(ODataType), hipMemcpyHostToDevice));
// Launch kernel
// Launch kernel with appropriate scale configuration
dim3 gridSize(1, 1, 1);
dim3 blockSize(kBlockSize, 1, 1);
if(scale == ScaleType::RowCol)
switch(scale)
{
float* m_scale;
float* n_scale;
std::vector<float> h_m_scale(M, 1.0F);
std::vector<float> h_n_scale(N, 1.0F);
h_n_scale[1] = 2.0F; // multiply one col only with 2
HIP_CHECK_ERROR(hipMalloc(&m_scale, M * sizeof(float)));
HIP_CHECK_ERROR(hipMalloc(&n_scale, N * sizeof(float)));
HIP_CHECK_ERROR(
hipMemcpy(m_scale, h_m_scale.data(), M * sizeof(float), hipMemcpyHostToDevice));
HIP_CHECK_ERROR(
hipMemcpy(n_scale, h_n_scale.data(), N * sizeof(float), hipMemcpyHostToDevice));
test_cshuffle_epilogue_kernel<Problem, M, N, ScaleType::RowCol>
<<<gridSize, blockSize>>>(device_output, m_scale, n_scale);
case ScaleType::RowCol:
launch_kernel_with_rowcol_scale<Problem, M, N>(
device_input, device_output, gridSize, blockSize);
break;
case ScaleType::Tensor:
launch_kernel_with_tensor_scale<Problem, M, N>(
device_input, device_output, gridSize, blockSize);
break;
case ScaleType::None:
launch_kernel_without_scale<Problem, M, N>(
device_input, device_output, gridSize, blockSize);
break;
}
else if(scale == ScaleType::Tensor)
{
float* m_scale;
float* n_scale;
std::vector<float> h_m_scale(1, 2.0F);
std::vector<float> h_n_scale(1, 1.0F);
HIP_CHECK_ERROR(hipMalloc(&m_scale, sizeof(float)));
HIP_CHECK_ERROR(hipMalloc(&n_scale, sizeof(float)));
HIP_CHECK_ERROR(hipMemcpy(m_scale, h_m_scale.data(), sizeof(float), hipMemcpyHostToDevice));
HIP_CHECK_ERROR(hipMemcpy(n_scale, h_n_scale.data(), sizeof(float), hipMemcpyHostToDevice));
test_cshuffle_epilogue_kernel<Problem, M, N, ScaleType::Tensor>
<<<gridSize, blockSize>>>(device_output, m_scale, n_scale);
}
else
{
test_cshuffle_epilogue_kernel<Problem, M, N, ScaleType::None>
<<<gridSize, blockSize>>>(device_output, nullptr, nullptr);
}
// Check for kernel launch errors
HIP_CHECK_ERROR(hipGetLastError());
HIP_CHECK_ERROR(hipDeviceSynchronize());
// Copy results back
HIP_CHECK_ERROR(hipMemcpy(
host_output.data(), device_output, output_size * sizeof(ODataType), hipMemcpyDeviceToHost));
// Cleanup
HIP_CHECK_ERROR(hipFree(device_output));
device_output_buf.FromDevice(host_output.data());
return host_output;
}
// Convert output values to sorted float vector for verification
// Uses float as intermediate to preserve precision for floating-point comparison
template <typename ODataType>
std::vector<float> convert_and_sort_output(const HostTensor<ODataType>& output)
{
std::vector<float> result;
result.reserve(output.get_element_size());
for(size_t i = 0; i < output.get_element_size(); ++i)
{
result.push_back(type_convert<float>(output.mData[i]));
}
std::sort(result.begin(), result.end());
return result;
}
// Run both unscaled and scaled tests for comparison
// Returns pair of (unscaled_output, scaled_output) host tensors
template <typename Problem, index_t M, index_t N, ScaleType ScaleMode>
auto run_scale_comparison_test()
{
auto unscaled_output = run_cshuffle_epilogue_test<Problem, M, N>(ScaleType::None);
auto scaled_output = run_cshuffle_epilogue_test<Problem, M, N>(ScaleMode);
return std::make_pair(std::move(unscaled_output), std::move(scaled_output));
}
} // namespace ck_tile