mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[CK_TILE] Fix CShuffleEpilogue test to use correct GEMM accumulator distribution (#4518) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary The test was using LDS distribution to create the accumulator tile, but CShuffleEpilogue expects the GEMM accumulator distribution that BlockGemm produces. This mismatch caused incorrect data permutation. ## Changes - Use WarpGemmDispatcher to get correct accumulator distribution encoding - Load test input from host-initialized global memory for deterministic verification - Shard tests by data type (FP16, FP8) with gfx950-specific FP8 tests - Extract scale tests into separate target for better organization - Implement exact permutation verification (all unique values appear once) - Reduce tile size from 256x256 to 128x128 to fit in unique fp16 range - Add parameterized test configurations for various warp layouts and MFMA types ## Test plan - [x] Run new cshuffle epilogue tests 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
143 lines
6.1 KiB
C++
143 lines
6.1 KiB
C++
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
|
// SPDX-License-Identifier: MIT
|
|
|
|
/**
|
|
* CShuffleEpilogue Test Infrastructure
|
|
*
|
|
* File organization:
|
|
* - test_cshuffle_epilogue_common.hpp: TileConfig template, verification helpers,
|
|
* typed test suite definition (this file)
|
|
* - test_cshuffle_epilogue_util.hpp: Kernel templates, launch helpers, test runners
|
|
* - test_cshuffle_epilogue_fp16.cpp: FP16 tile configurations
|
|
* - test_cshuffle_epilogue_fp8.cpp: FP8 tile configurations (standard)
|
|
* - test_cshuffle_epilogue_fp8_gfx950.cpp: FP8 configurations for gfx950
|
|
* - test_cshuffle_epilogue_scale.cpp: RowCol and Tensor scaling tests
|
|
*/
|
|
|
|
#pragma once
|
|
|
|
#include "test_cshuffle_epilogue_util.hpp"
|
|
#include <algorithm>
|
|
#include <cmath>
|
|
#include <gtest/gtest.h>
|
|
#include <vector>
|
|
|
|
// TileConfig defines a test configuration for CShuffleEpilogue.
|
|
// - ODataType_: The output data type written to global memory
|
|
// - MfmaDataType_: The data type used for MFMA instruction selection (determines valid KPerXdl)
|
|
// Defaults to ODataType_ but can differ (e.g., FP8 MFMA tiles with FP16 output
|
|
// to avoid FP8 range limitations in test verification)
|
|
template <typename ODataType_,
|
|
ck_tile::index_t MPerBlock_,
|
|
ck_tile::index_t NPerBlock_,
|
|
ck_tile::index_t MWave_,
|
|
ck_tile::index_t NWave_,
|
|
ck_tile::index_t MPerXdl_,
|
|
ck_tile::index_t NPerXdl_,
|
|
ck_tile::index_t KPerXdl_,
|
|
typename MfmaDataType_ = ODataType_>
|
|
struct TileConfig
|
|
{
|
|
using DataType = ODataType_;
|
|
using MfmaDataType = MfmaDataType_;
|
|
static constexpr ck_tile::index_t kMPerBlock = MPerBlock_;
|
|
static constexpr ck_tile::index_t kNPerBlock = NPerBlock_;
|
|
static constexpr ck_tile::index_t MWave = MWave_;
|
|
static constexpr ck_tile::index_t NWave = NWave_;
|
|
static constexpr ck_tile::index_t MPerXdl = MPerXdl_;
|
|
static constexpr ck_tile::index_t NPerXdl = NPerXdl_;
|
|
static constexpr ck_tile::index_t KPerXdl = KPerXdl_;
|
|
};
|
|
|
|
// Helper to construct SimpleCShuffleEpilogueProblem from TileConfig
|
|
// Uses MfmaDataType for MFMA input types (A/B) and DataType for output
|
|
template <typename Config, typename AccDataType = float>
|
|
using MakeProblem = ck_tile::SimpleCShuffleEpilogueProblem<typename Config::MfmaDataType,
|
|
typename Config::MfmaDataType,
|
|
AccDataType,
|
|
typename Config::DataType,
|
|
Config::kMPerBlock,
|
|
Config::kNPerBlock,
|
|
Config::MWave,
|
|
Config::NWave,
|
|
Config::MPerXdl,
|
|
Config::NPerXdl,
|
|
Config::KPerXdl>;
|
|
|
|
// Verification helper: check that output contains valid data from the epilogue shuffle.
|
|
// The C-shuffle epilogue loads thread-local values and writes them to output through LDS.
|
|
// We verify: correct output size, no NaN values, no unwritten zeros, and at least
|
|
// kBlockSize unique values (one per thread).
|
|
template <typename DataType,
|
|
ck_tile::index_t kMPerBlock,
|
|
ck_tile::index_t kNPerBlock,
|
|
ck_tile::index_t kBlockSize>
|
|
void verify_permutation_output(const std::vector<float>& sorted_vals)
|
|
{
|
|
constexpr size_t expected_size = static_cast<size_t>(kMPerBlock * kNPerBlock);
|
|
|
|
ASSERT_EQ(sorted_vals.size(), expected_size) << "Output size mismatch";
|
|
|
|
// Verify no NaN values
|
|
for(size_t i = 0; i < sorted_vals.size(); ++i)
|
|
{
|
|
ASSERT_FALSE(std::isnan(sorted_vals[i])) << "NaN at index " << i;
|
|
}
|
|
|
|
// Count unique values using bit-exact comparison (sorted fp32 values from fp16 should be
|
|
// distinct)
|
|
size_t num_unique = 1;
|
|
for(size_t i = 1; i < sorted_vals.size(); ++i)
|
|
{
|
|
if(ck_tile::bit_cast<uint32_t>(sorted_vals[i]) !=
|
|
ck_tile::bit_cast<uint32_t>(sorted_vals[i - 1]))
|
|
{
|
|
++num_unique;
|
|
}
|
|
}
|
|
|
|
// Verify exact permutation: all input values should appear exactly once in output
|
|
EXPECT_EQ(num_unique, expected_size) << "Expected exact permutation with " << expected_size
|
|
<< " unique values, got " << num_unique;
|
|
}
|
|
|
|
// Type-parameterized test fixture
|
|
template <typename Config>
|
|
class CShuffleEpilogueTypedTest : public ::testing::Test
|
|
{
|
|
};
|
|
|
|
TYPED_TEST_SUITE_P(CShuffleEpilogueTypedTest);
|
|
|
|
TYPED_TEST_P(CShuffleEpilogueTypedTest, BasicTest)
|
|
{
|
|
using Config = TypeParam;
|
|
using DataType = typename Config::DataType;
|
|
|
|
constexpr ck_tile::index_t kMPerBlock = Config::kMPerBlock;
|
|
constexpr ck_tile::index_t kNPerBlock = Config::kNPerBlock;
|
|
|
|
using TestProblem = MakeProblem<Config>;
|
|
constexpr ck_tile::index_t kBlockSize = TestProblem::kBlockSize;
|
|
|
|
auto host_output = ck_tile::run_cshuffle_epilogue_test<TestProblem, kMPerBlock, kNPerBlock>(
|
|
ck_tile::ScaleType::None);
|
|
|
|
// Convert output to sorted vector and verify using existing helper
|
|
auto output_vals = ck_tile::convert_and_sort_output(host_output);
|
|
verify_permutation_output<DataType, kMPerBlock, kNPerBlock, kBlockSize>(output_vals);
|
|
}
|
|
|
|
REGISTER_TYPED_TEST_SUITE_P(CShuffleEpilogueTypedTest, BasicTest);
|
|
|
|
// Allow this test suite to be included without instantiation (e.g., in scale tests)
|
|
GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(CShuffleEpilogueTypedTest);
|
|
|
|
// Macro to instantiate typed test suites with suppressed clang warnings
|
|
// NOLINTNEXTLINE(cppcoreguidelines-macro-usage)
|
|
#define CK_INSTANTIATE_TYPED_TEST_SUITE(Prefix, Suite, Types) \
|
|
_Pragma("clang diagnostic push") \
|
|
_Pragma("clang diagnostic ignored \"-Wused-but-marked-unused\"") \
|
|
INSTANTIATE_TYPED_TEST_SUITE_P(Prefix, Suite, Types); \
|
|
_Pragma("clang diagnostic pop")
|