Files
composable_kernel/test/ck_tile/epilogue/test_cshuffle_epilogue_fp8.cpp
Max Podkorytov 0438ab1b79 [rocm-libraries] ROCm/rocm-libraries#4518 (commit dd161dc)
[CK_TILE] Fix CShuffleEpilogue test to use correct GEMM
 accumulator distribution (#4518)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

## Summary

The test was using LDS distribution to create the accumulator tile, but
CShuffleEpilogue expects the GEMM accumulator distribution that
BlockGemm produces. This mismatch caused incorrect data permutation.

## Changes

- Use WarpGemmDispatcher to get correct accumulator distribution
encoding
- Load test input from host-initialized global memory for deterministic
verification
- Shard tests by data type (FP16, FP8) with gfx950-specific FP8 tests
- Extract scale tests into separate target for better organization
- Implement exact permutation verification (all unique values appear
once)
- Reduce tile size from 256x256 to 128x128 to fit in unique fp16 range
- Add parameterized test configurations for various warp layouts and
MFMA types

## Test plan

- [x] Run new cshuffle epilogue tests

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2026-03-02 08:55:05 +00:00

49 lines
2.1 KiB
C++

// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "test_cshuffle_epilogue_common.hpp"
using namespace ck_tile;
// FP8 MFMA tile configurations with half_t output
// Using half_t output avoids FP8 range limitations while testing FP8-specific tile sizes
using FP8Config_128x128_2x2x1_16x16x16 = TileConfig<half_t, 128, 128, 2, 2, 16, 16, 16, fp8_t>;
using FP8Config_128x128_1x4x1_16x16x16 = TileConfig<half_t, 128, 128, 1, 4, 16, 16, 16, fp8_t>;
using FP8Config_128x128_4x1x1_16x16x16 = TileConfig<half_t, 128, 128, 4, 1, 16, 16, 16, fp8_t>;
using FP8Config_128x128_2x2x1_32x32x16 = TileConfig<half_t, 128, 128, 2, 2, 32, 32, 16, fp8_t>;
using FP8Config_128x128_2x2x1_16x16x32 = TileConfig<half_t, 128, 128, 2, 2, 16, 16, 32, fp8_t>;
using FP8Config_128x128_2x2x1_32x32x32 = TileConfig<half_t, 128, 128, 2, 2, 32, 32, 32, fp8_t>;
using FP8Config_128x128_2x2x1_16x16x64 = TileConfig<half_t, 128, 128, 2, 2, 16, 16, 64, fp8_t>;
using FP8TestTypes = ::testing::Types<FP8Config_128x128_2x2x1_16x16x16,
FP8Config_128x128_1x4x1_16x16x16,
FP8Config_128x128_4x1x1_16x16x16,
FP8Config_128x128_2x2x1_32x32x16,
FP8Config_128x128_2x2x1_16x16x32,
FP8Config_128x128_2x2x1_32x32x32,
FP8Config_128x128_2x2x1_16x16x64>;
CK_INSTANTIATE_TYPED_TEST_SUITE(FP8, CShuffleEpilogueTypedTest, FP8TestTypes)
// Global test environment to check for wave32 devices
class Wave32CheckEnvironment : public ::testing::Environment
{
public:
void SetUp() override
{
int warp_size = 0;
hipError_t err = hipDeviceGetAttribute(&warp_size, hipDeviceAttributeWarpSize, 0);
if(err == hipSuccess && warp_size == 32)
{
GTEST_SKIP() << "CShuffleEpilogue tests not supported on wave32 devices";
}
}
};
int main(int argc, char** argv)
{
::testing::InitGoogleTest(&argc, argv);
::testing::AddGlobalTestEnvironment(new Wave32CheckEnvironment);
return RUN_ALL_TESTS();
}