[rocm-libraries] ROCm/rocm-libraries#6212 (commit ccee58d)

=?UTF-8?q?[CK=20TILE]=20Unification=20Work=20=E2=80=93=20?=
 =?UTF-8?q?More=20accurate=20tests=20for=20MmaPipelines=20(#6212)?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

## Motivation

This PR solves several issues:

#### More accurate tests for MmaPipelines

The current tests for the MmaPipelines (test_amdgcn_sparse_mma,
test_amdgcn_wavewise_mma) use explicit input fragment vectors filled
with 1s, and only check the output of a single lane. We should have
tests that actually use the MmaPipelines with non-trivial input matrices
and verify the complete output.
Some other aspects of the current MmaPipelines tests that I noticed and
deserve some attention:

1. There is sometimes iteration over K outside of the pipeline, which is
then included in WaveTileK or FragK, which is not correct. We should
remove it, move K iteration inside of the pipeline, or be more clear
about this outer-K loop size and how it propagates downwards.
2. There is very tight coupling between the kernel, gtest code, and
test_pipeline helper, requiring a lot of information and functions to be
passed back and forth.
3. The test_pipeline helper is doing a bunch of register-related logic
on the host (related to point 1)
4. Without this register logic the only thing it does is check the
device, call the kernel, and check the output, but with a lot of
boilerplate.

#### Test helper for detecting target arch at HOST runtime

There is a really apparent issue we faced while writing tests:

Scenario:
1. Compile a test that supports both gfx950 and gfx1201 for gfx950
2. Run the test on a server that only has gfx1201 GPU

Actual:
Segmentation fault

Expected:
The test can correctly detect from HOST runtime that the DEVICE
target_id was different and skips the test.

Notes:

The only way of detecting the COMPILER_TARGET_ID in the existing "arch"
framework is launching a kernel and calling `get_compiler_target()` (so,
from a DEVICE code). This will create a segmentation fault if the
current arch differs from the target arch. To cope with this issue, we
propose to export the compiler target(s) (note they can be many) through
`projects/composablekernel/test/ck_tile/core/arch/CMakeLists.txt` and
define a test helper to deal with such cases.

#### Add composition support to Transforms

We have a small number of Transforms which act on MmaOp input and output
data, before and after the MmaOp call respectively. These are currently
implemented to work on an MmaTile level, but in theory they are also
supposed to work at a WaveTile level, i.e. after composition of multiple
MmaTiles to create larger effective MNK dimensions. Currently the
composed MmaTiles look like 2D C-style arrays of the individual MmaTile
level register vectors (see WaveWiseMmaPipeline). The transforms should
be able to take these and perform the proper transforms to the whole
WaveTile at once. This might allow for better performing
transformations.

Note: This PR handles the SparseTransform case and if we don't end up
doing scale as a transformation, there isn't really much left to do. If
we end up having only the sparse transform as a non-trivial transform,
then we could also consider removing the Transform framework.
This commit is contained in:
chris-tsiaousis-hpc
2026-06-03 14:35:18 +00:00
committed by assistant-librarian[bot]
parent 88f8d24c34
commit db05d61136
20 changed files with 1646 additions and 589 deletions

View File

@@ -7,19 +7,109 @@ if(CK_USE_OCP_FP8)
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
endif()
if(GPU_TARGETS MATCHES "gfx9|gfx120")
add_gtest_executable(test_amdgcn_sparse_mma pipeline/test_amdgcn_sparse_mma.cpp)
# ---------------------------------------------------------------------------
# Map GPU target strings to hex amdgcn_target_id values (arch.hpp).
# Builds a -DCK_CMAKE_GPU_TARGET_IDS=0xHHHH,... definition that host-side
# test code can consume without launching a device kernel.
# ---------------------------------------------------------------------------
function(_ck_gpu_target_string_to_id TARGET_STR OUT_VAR)
string(TOLOWER "${TARGET_STR}" _tgt)
string(REGEX REPLACE ":.*" "" _tgt "${_tgt}")
# GFX9
if(_tgt STREQUAL "gfx908")
set(${OUT_VAR} "0x0908" PARENT_SCOPE)
elseif(_tgt STREQUAL "gfx90a")
set(${OUT_VAR} "0x090A" PARENT_SCOPE)
elseif(_tgt STREQUAL "gfx942")
set(${OUT_VAR} "0x0942" PARENT_SCOPE)
elseif(_tgt STREQUAL "gfx950")
set(${OUT_VAR} "0x0950" PARENT_SCOPE)
# GFX10.3
elseif(_tgt STREQUAL "gfx1030")
set(${OUT_VAR} "0x1030" PARENT_SCOPE)
elseif(_tgt STREQUAL "gfx1031")
set(${OUT_VAR} "0x1031" PARENT_SCOPE)
elseif(_tgt STREQUAL "gfx1032")
set(${OUT_VAR} "0x1032" PARENT_SCOPE)
elseif(_tgt STREQUAL "gfx1033")
set(${OUT_VAR} "0x1033" PARENT_SCOPE)
elseif(_tgt STREQUAL "gfx1034")
set(${OUT_VAR} "0x1034" PARENT_SCOPE)
elseif(_tgt STREQUAL "gfx1035")
set(${OUT_VAR} "0x1035" PARENT_SCOPE)
elseif(_tgt STREQUAL "gfx1036")
set(${OUT_VAR} "0x1036" PARENT_SCOPE)
elseif(_tgt MATCHES "^gfx10-3-generic$")
set(${OUT_VAR} "0x103F" PARENT_SCOPE)
# GFX11
elseif(_tgt STREQUAL "gfx1100")
set(${OUT_VAR} "0x1100" PARENT_SCOPE)
elseif(_tgt STREQUAL "gfx1101")
set(${OUT_VAR} "0x1101" PARENT_SCOPE)
elseif(_tgt STREQUAL "gfx1102")
set(${OUT_VAR} "0x1102" PARENT_SCOPE)
elseif(_tgt STREQUAL "gfx1103")
set(${OUT_VAR} "0x1103" PARENT_SCOPE)
elseif(_tgt STREQUAL "gfx1150")
set(${OUT_VAR} "0x1150" PARENT_SCOPE)
elseif(_tgt STREQUAL "gfx1151")
set(${OUT_VAR} "0x1151" PARENT_SCOPE)
elseif(_tgt STREQUAL "gfx1152")
set(${OUT_VAR} "0x1152" PARENT_SCOPE)
elseif(_tgt STREQUAL "gfx1153")
set(${OUT_VAR} "0x1153" PARENT_SCOPE)
elseif(_tgt MATCHES "^gfx11-generic$")
set(${OUT_VAR} "0x11FF" PARENT_SCOPE)
# GFX12
elseif(_tgt STREQUAL "gfx1200")
set(${OUT_VAR} "0x1200" PARENT_SCOPE)
elseif(_tgt STREQUAL "gfx1201")
set(${OUT_VAR} "0x1201" PARENT_SCOPE)
elseif(_tgt MATCHES "^gfx12-generic$")
set(${OUT_VAR} "0x12FF" PARENT_SCOPE)
elseif(_tgt STREQUAL "gfx1250")
set(${OUT_VAR} "0x1250" PARENT_SCOPE)
else()
message(WARNING "_ck_gpu_target_string_to_id: unknown GPU target '${TARGET_STR}', skipping")
set(${OUT_VAR} "" PARENT_SCOPE)
endif()
endfunction()
function(_ck_add_gpu_target_ids_define TARGET_NAME)
get_property(_archs TARGET ${TARGET_NAME} PROPERTY HIP_ARCHITECTURES)
string(REPLACE "," ";" _archs "${_archs}")
set(_hex_ids)
foreach(_tgt IN LISTS _archs)
_ck_gpu_target_string_to_id("${_tgt}" _hex)
if(_hex AND NOT _hex STREQUAL "0x0000")
list(APPEND _hex_ids "${_hex}")
endif()
endforeach()
list(JOIN _hex_ids "," _hex_str)
if(_hex_str)
target_compile_definitions(${TARGET_NAME} PRIVATE "CK_CMAKE_GPU_TARGET_IDS=${_hex_str}")
endif()
endfunction()
# Convenience: add_gtest_executable + inject CK_CMAKE_GPU_TARGET_IDS
macro(_add_mma_gtest TEST_NAME)
add_gtest_executable(${TEST_NAME} ${ARGN})
_ck_add_gpu_target_ids_define(${TEST_NAME})
endmacro()
# ---------------------------------------------------------------------------
if(GPU_TARGETS MATCHES "gfx9|gfx120")
_add_mma_gtest(test_amdgcn_sparse_mma pipeline/test_amdgcn_sparse_mma.cpp)
target_compile_options(test_amdgcn_sparse_mma PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
endif()
if(GPU_TARGETS MATCHES "gfx950")
add_gtest_executable(test_amdgcn_scale_mma pipeline/test_amdgcn_scale_mma.cpp)
_add_mma_gtest(test_amdgcn_scale_mma pipeline/test_amdgcn_scale_mma.cpp)
target_compile_options(test_amdgcn_scale_mma PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
endif()
if(GPU_TARGETS MATCHES "gfx9")
add_gtest_executable(test_amdgcn_mma test_amdgcn_mma.cpp)
_add_mma_gtest(test_amdgcn_mma test_amdgcn_mma.cpp)
target_compile_options(test_amdgcn_mma PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
add_gtest_executable(test_amdgcn_wavewise_mma pipeline/test_amdgcn_wavewise_mma.cpp)
_add_mma_gtest(test_amdgcn_wavewise_mma pipeline/test_amdgcn_wavewise_mma.cpp)
target_compile_options(test_amdgcn_wavewise_mma PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
else()
message(DEBUG "Skipping ck_tile_gemm tests for current target")
@@ -37,45 +127,45 @@ macro(set_mma_test_arch_define target_name)
endmacro()
if(GPU_TARGETS MATCHES "gfx9")
add_gtest_executable(test_amdgcn_mma_layout_gfx9 test_amdgcn_mma_layout_gfx9.cpp)
_add_mma_gtest(test_amdgcn_mma_layout_gfx9 test_amdgcn_mma_layout_gfx9.cpp)
target_compile_options(test_amdgcn_mma_layout_gfx9 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
set_mma_test_arch_define(test_amdgcn_mma_layout_gfx9)
endif()
if(GPU_TARGETS MATCHES "gfx908|gfx90a")
add_gtest_executable(test_amdgcn_mma_layout_gfx908_and_gfx90a test_amdgcn_mma_layout_gfx908_and_gfx90a.cpp)
_add_mma_gtest(test_amdgcn_mma_layout_gfx908_and_gfx90a test_amdgcn_mma_layout_gfx908_and_gfx90a.cpp)
target_compile_options(test_amdgcn_mma_layout_gfx908_and_gfx90a PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
set_mma_test_arch_define(test_amdgcn_mma_layout_gfx908_and_gfx90a)
endif()
if(GPU_TARGETS MATCHES "gfx90a|gfx942|gfx950")
add_gtest_executable(test_amdgcn_mma_layout_gfx90a_and_higher test_amdgcn_mma_layout_gfx90a_and_higher.cpp)
_add_mma_gtest(test_amdgcn_mma_layout_gfx90a_and_higher test_amdgcn_mma_layout_gfx90a_and_higher.cpp)
target_compile_options(test_amdgcn_mma_layout_gfx90a_and_higher PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
set_mma_test_arch_define(test_amdgcn_mma_layout_gfx90a_and_higher)
endif()
if(GPU_TARGETS MATCHES "gfx942|gfx950")
add_gtest_executable(test_amdgcn_mma_layout_gfx942_and_higher test_amdgcn_mma_layout_gfx942_and_higher.cpp)
_add_mma_gtest(test_amdgcn_mma_layout_gfx942_and_higher test_amdgcn_mma_layout_gfx942_and_higher.cpp)
target_compile_options(test_amdgcn_mma_layout_gfx942_and_higher PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS} -Wno-header-hygiene)
set_mma_test_arch_define(test_amdgcn_mma_layout_gfx942_and_higher)
endif()
if(GPU_TARGETS MATCHES "gfx950")
add_gtest_executable(test_amdgcn_mma_layout_gfx950 test_amdgcn_mma_layout_gfx950.cpp)
_add_mma_gtest(test_amdgcn_mma_layout_gfx950 test_amdgcn_mma_layout_gfx950.cpp)
target_compile_options(test_amdgcn_mma_layout_gfx950 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
set_mma_test_arch_define(test_amdgcn_mma_layout_gfx950)
endif()
if(GPU_TARGETS MATCHES "gfx11")
add_gtest_executable(test_amdgcn_mma_layout_gfx11 test_amdgcn_mma_layout_gfx11.cpp)
_add_mma_gtest(test_amdgcn_mma_layout_gfx11 test_amdgcn_mma_layout_gfx11.cpp)
target_compile_options(test_amdgcn_mma_layout_gfx11 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
endif()
if(GPU_TARGETS MATCHES "gfx120")
add_gtest_executable(test_amdgcn_mma_layout_gfx12 test_amdgcn_mma_layout_gfx12.cpp)
_add_mma_gtest(test_amdgcn_mma_layout_gfx12 test_amdgcn_mma_layout_gfx12.cpp)
target_compile_options(test_amdgcn_mma_layout_gfx12 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
endif()
add_gtest_executable(test_amdgcn_mma_pipeline pipeline/test_amdgcn_mma_pipeline.cpp)
_add_mma_gtest(test_amdgcn_mma_pipeline pipeline/test_amdgcn_mma_pipeline.cpp)
target_compile_options(test_amdgcn_mma_pipeline PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})

View File

@@ -0,0 +1,87 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core/arch/arch.hpp"
#include <unordered_set>
namespace ck_tile::core::arch::testing {
static CK_TILE_HOST auto getCMakeGpuTargetIds()
{
using ck_tile::core::arch::amdgcn_target_id;
#ifdef CK_CMAKE_GPU_TARGET_IDS
constexpr uint32_t ids[] = {CK_CMAKE_GPU_TARGET_IDS};
std::unordered_set<amdgcn_target_id> result;
for(auto id : ids)
result.insert(static_cast<amdgcn_target_id>(id));
return result;
#else
return std::unordered_set<amdgcn_target_id>{};
#endif
}
template <typename Func>
static CK_TILE_HOST bool dispatchCompilerTarget(ck_tile::core::arch::amdgcn_target_id id,
Func&& func)
{
using namespace ck_tile::core::arch;
// clang-format off
switch(id)
{
case amdgcn_target_id::GFX908: func(make_amdgcn_gfx9_target<amdgcn_target_id::GFX908>()); return true;
case amdgcn_target_id::GFX90A: func(make_amdgcn_gfx9_target<amdgcn_target_id::GFX90A>()); return true;
case amdgcn_target_id::GFX942: func(make_amdgcn_gfx9_target<amdgcn_target_id::GFX942>()); return true;
case amdgcn_target_id::GFX950: func(make_amdgcn_gfx9_target<amdgcn_target_id::GFX950>()); return true;
case amdgcn_target_id::GFX1030: func(make_amdgcn_gfx10_3_target<amdgcn_target_id::GFX1030>()); return true;
case amdgcn_target_id::GFX1031: func(make_amdgcn_gfx10_3_target<amdgcn_target_id::GFX1031>()); return true;
case amdgcn_target_id::GFX1032: func(make_amdgcn_gfx10_3_target<amdgcn_target_id::GFX1032>()); return true;
case amdgcn_target_id::GFX1033: func(make_amdgcn_gfx10_3_target<amdgcn_target_id::GFX1033>()); return true;
case amdgcn_target_id::GFX1034: func(make_amdgcn_gfx10_3_target<amdgcn_target_id::GFX1034>()); return true;
case amdgcn_target_id::GFX1035: func(make_amdgcn_gfx10_3_target<amdgcn_target_id::GFX1035>()); return true;
case amdgcn_target_id::GFX1036: func(make_amdgcn_gfx10_3_target<amdgcn_target_id::GFX1036>()); return true;
case amdgcn_target_id::GFX103_GENERIC: func(make_amdgcn_gfx10_3_target<amdgcn_target_id::GFX103_GENERIC>()); return true;
case amdgcn_target_id::GFX1100: func(make_amdgcn_gfx11_target<amdgcn_target_id::GFX1100>()); return true;
case amdgcn_target_id::GFX1101: func(make_amdgcn_gfx11_target<amdgcn_target_id::GFX1101>()); return true;
case amdgcn_target_id::GFX1102: func(make_amdgcn_gfx11_target<amdgcn_target_id::GFX1102>()); return true;
case amdgcn_target_id::GFX1103: func(make_amdgcn_gfx11_target<amdgcn_target_id::GFX1103>()); return true;
case amdgcn_target_id::GFX1150: func(make_amdgcn_gfx11_target<amdgcn_target_id::GFX1150>()); return true;
case amdgcn_target_id::GFX1151: func(make_amdgcn_gfx11_target<amdgcn_target_id::GFX1151>()); return true;
case amdgcn_target_id::GFX1152: func(make_amdgcn_gfx11_target<amdgcn_target_id::GFX1152>()); return true;
case amdgcn_target_id::GFX1153: func(make_amdgcn_gfx11_target<amdgcn_target_id::GFX1153>()); return true;
case amdgcn_target_id::GFX11_GENERIC: func(make_amdgcn_gfx11_target<amdgcn_target_id::GFX11_GENERIC>()); return true;
case amdgcn_target_id::GFX1200: func(make_amdgcn_gfx12_target<amdgcn_target_id::GFX1200>()); return true;
case amdgcn_target_id::GFX1201: func(make_amdgcn_gfx12_target<amdgcn_target_id::GFX1201>()); return true;
case amdgcn_target_id::GFX12_GENERIC: func(make_amdgcn_gfx12_target<amdgcn_target_id::GFX12_GENERIC>()); return true;
case amdgcn_target_id::GFX1250: func(make_amdgcn_gfx12_target<amdgcn_target_id::GFX1250>()); return true;
case amdgcn_target_id::HOST: return false;
}
// clang-format on
__builtin_unreachable();
}
static CK_TILE_HOST constexpr int32_t getCMakeWaveSize()
{
using ck_tile::core::arch::amdgcn_target_id;
#ifdef CK_CMAKE_GPU_TARGET_IDS
constexpr uint32_t ids[] = {CK_CMAKE_GPU_TARGET_IDS};
constexpr index_t targets_size = sizeof(ids) / sizeof(ids[0]);
static_assert(targets_size > 0);
constexpr auto first_target_id = static_cast<amdgcn_target_id>(ids[0]);
if constexpr(first_target_id >= amdgcn_target_id::GFX908 &&
first_target_id <= amdgcn_target_id::GFX950)
{
return 64;
}
else
{
return 32;
}
#else
static_assert(false, "Configure CK_CMAKE_GPU_TARGET_IDS before calling this function.");
return 0;
#endif
}
} // namespace ck_tile::core::arch::testing

View File

@@ -1,34 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <cstdint>
#include <cstdio>
#include "ck_tile/core/arch/arch.hpp"
#include <hip/hip_runtime.h>
#include "ck_tile/host/hip_check_error.hpp"
namespace {
__global__ void getWaveSizeForSelectedOp(uint32_t* waveSize)
{
using CompilerTarget = decltype(ck_tile::core::arch::get_compiler_target());
if(waveSize)
*waveSize = static_cast<uint32_t>(CompilerTarget::WAVE_SIZE_ID);
}
static __host__ uint32_t getDeviceWaveSize()
{
uint32_t* d_wave_size;
HIP_CHECK_ERROR(hipMalloc(&d_wave_size, sizeof(uint32_t)));
getWaveSizeForSelectedOp<<<1, 64>>>(d_wave_size);
HIP_CHECK_ERROR(hipDeviceSynchronize());
uint32_t wave_size;
HIP_CHECK_ERROR(hipMemcpy(&wave_size, d_wave_size, sizeof(uint32_t), hipMemcpyDeviceToHost));
return wave_size;
}
} // namespace

View File

@@ -10,223 +10,421 @@
#include <gtest/gtest.h>
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/arch/mma/utility/tile_distribution_encoding_calculator.hpp"
#include "ck_tile/core/arch/mma/utility/tile_distribution_encoding_register_mapper.hpp"
#include "ck_tile/core/numeric/type_convert.hpp"
#include "ck_tile/core/numeric/vector_type.hpp"
#include "ck_tile/host/hip_check_error.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include <hip/hip_runtime.h>
#include "../get_wave_size_helper.hpp"
#include "../get_cmake_targets_helper.hpp"
template <typename AType_ = ck_tile::fp16_t,
typename BType_ = ck_tile::fp16_t,
typename CType_ = ck_tile::fp32_t,
uint32_t WaveTileM_ = 16,
uint32_t WaveTileN_ = 16,
uint32_t WaveTileK_ = 32,
typename ScaleAType_ = int,
typename ScaleBType_ = int>
struct MmaPipelineTest
namespace mma_pipeline_test {
using namespace ck_tile;
using namespace ck_tile::core::arch;
using namespace ck_tile::core::arch::mma;
using namespace ck_tile::core::arch::testing;
inline bool hipTargetMatchesCmakeTargets(amdgcn_target_id arch)
{
using AType = AType_;
using BType = BType_;
using CType = CType_;
using ScaleAType = ScaleAType_;
using ScaleBType = ScaleBType_;
static constexpr auto WaveTileM = WaveTileM_;
static constexpr auto WaveTileN = WaveTileN_;
static constexpr auto WaveTileK = WaveTileK_;
void test_pipeline(std::function<bool(ck_tile::core::arch::amdgcn_target_id)> shouldSkip,
std::function<void(uint32_t, void*, void*, void*, void*)> kernel,
std::function<CType(uint32_t)> getExpected,
std::function<AType(size_t)> aInitializer = nullptr)
const auto cmake_targets = getCMakeGpuTargetIds();
if(cmake_targets.count(arch) == 0)
{
using namespace ck_tile;
using namespace ck_tile::core::arch;
int devCount;
hipDevice_t dev;
HIP_CHECK_ERROR(hipGetDevice(&dev));
HIP_CHECK_ERROR(hipGetDeviceCount(&devCount));
hipDeviceProp_t devProp;
HIP_CHECK_ERROR(hipGetDeviceProperties(&devProp, dev));
auto currentArchId = hip_device_prop_gcn_arch_name_to_amdgcn_target_id(devProp.gcnArchName);
bool hasDevice = static_cast<bool>(devCount > 0);
int deviceWarpSize = devProp.warpSize;
if(!hasDevice || shouldSkip(currentArchId))
// gfx12-generic and gfx11-generic make no difference with the specialized archs.
// Some CI pipelines make use of that and configure the project with the generic
// flags besides compiling for (f.e.) gfx1201.
if(arch >= amdgcn_target_id::GFX1200 && arch <= amdgcn_target_id::GFX12_GENERIC)
{
GTEST_SKIP() << "No HIP device found. Skipping test.";
return (cmake_targets.count(amdgcn_target_id::GFX12_GENERIC) > 0);
}
// WaveTile size, also the expected fragment size (MmaTile) from the selector.
// Note: Actual FragK might be slightly different due to hardware implementation, but the
// test_accum_over_k kernel will loop over the K dimension to ensure that the total K is
// correct.
static constexpr uint32_t FragM = WaveTileM;
static constexpr uint32_t FragN = WaveTileN;
static constexpr uint32_t FragK = WaveTileK;
// The number of elements per thread
uint32_t AElements = FragM * FragK / deviceWarpSize;
uint32_t BElements = FragN * FragK / deviceWarpSize;
uint32_t CElements = FragM * FragN / deviceWarpSize;
uint32_t ASize = AElements * sizeof(AType);
uint32_t BSize = BElements * sizeof(BType);
uint32_t CSize = CElements * sizeof(CType);
// Initialize A (use custom initializer or default all 1's), B to all 1's, C to all 0's
std::vector<AType> h_a(AElements);
if(aInitializer)
else if(arch >= amdgcn_target_id::GFX1100 && arch <= amdgcn_target_id::GFX11_GENERIC)
{
for(size_t i = 0; i < AElements; ++i)
h_a[i] = aInitializer(i);
return (cmake_targets.count(amdgcn_target_id::GFX11_GENERIC) > 0);
}
else
}
return true;
}
template <typename CType, typename AType, typename BType>
void reference_matmul(std::vector<CType>& C,
const std::vector<AType>& A,
const std::vector<BType>& B,
uint32_t M,
uint32_t N,
uint32_t K)
{
for(uint32_t m = 0; m < M; ++m)
{
for(uint32_t n = 0; n < N; ++n)
{
std::fill(h_a.begin(), h_a.end(), type_convert<AType>(1));
float acc = 0.0f;
for(uint32_t k = 0; k < K; ++k)
{
acc += type_convert<float>(A[m * K + k]) * type_convert<float>(B[k * N + n]);
}
C[m * N + n] = static_cast<CType>(acc);
}
std::vector<BType> h_b(BElements, type_convert<BType>(1));
std::vector<CType> h_c(CElements, type_convert<CType>(0));
std::vector<CType> h_out(CElements, type_convert<CType>(0));
}
}
AType* d_a;
BType* d_b;
CType* d_c;
CType* d_out;
template <typename T>
T deterministic_value(uint32_t row, uint32_t col, uint32_t minor_dim)
{
float v = static_cast<float>((row * minor_dim + col) % 7 + 1) * 0.25f;
return type_convert<T>(v);
}
HIP_CHECK_ERROR(hipMalloc(&d_a, ASize));
HIP_CHECK_ERROR(hipMalloc(&d_b, BSize));
HIP_CHECK_ERROR(hipMalloc(&d_c, CSize));
HIP_CHECK_ERROR(hipMalloc(&d_out, CSize));
// Copy inputs to device
HIP_CHECK_ERROR(hipMemcpy(d_a, h_a.data(), ASize, hipMemcpyHostToDevice));
HIP_CHECK_ERROR(hipMemcpy(d_b, h_b.data(), BSize, hipMemcpyHostToDevice));
HIP_CHECK_ERROR(hipMemcpy(d_c, h_c.data(), CSize, hipMemcpyHostToDevice));
const auto wave_size = getDeviceWaveSize();
kernel(wave_size, d_a, d_b, d_c, d_out);
HIP_CHECK_ERROR(hipDeviceSynchronize());
HIP_CHECK_ERROR(hipMemcpy(h_out.data(), d_out, CSize, hipMemcpyDeviceToHost));
// Verify output against expected value for all elements
for(size_t i = 0; i < CElements; ++i)
// Apply 2:4 sparsity pattern to A matrix in-place (for sparse pipeline tests).
// Every group of 4 consecutive K elements keeps slots 0 and 2, zeros slots 1 and 3.
template <typename T>
void apply_sparse_pattern(std::vector<T>& A, uint32_t M, uint32_t K)
{
for(uint32_t m = 0; m < M; ++m)
{
for(uint32_t k = 0; k < K; k += 4)
{
EXPECT_NEAR(h_out[i], getExpected(FragK), 1e-3);
// Keep slots 0, 2. Zero out slots 1, 3.
if(k + 1 < K)
A[m * K + k + 1] = static_cast<T>(0);
if(k + 3 < K)
A[m * K + k + 3] = static_cast<T>(0);
}
}
}
HIP_CHECK_ERROR(hipFree(d_a));
HIP_CHECK_ERROR(hipFree(d_b));
HIP_CHECK_ERROR(hipFree(d_c));
HIP_CHECK_ERROR(hipFree(d_out));
// Fill per-lane A fragments from logical A[M][K] matrix.
// For dense pipelines: AVecType = InternalAVecT[FragsM][FragsK]
// For sparse pipelines: AVecType = ExternalAFragVecT[FragsM][FragsK] (uncompressed)
template <typename Pipeline, typename AScalar>
void fill_a_fragments(typename Pipeline::AVecType* a_per_lane,
const std::vector<AScalar>& A_matrix,
uint32_t K,
uint32_t waveSize)
{
using MmaOp = typename Pipeline::MmaOp;
using ARegMap = TileDistrEncRegMap<typename TileDistrEncCalc<MmaOp>::AWarpDstrEncoding>;
using AFragScalar = typename vector_traits<typename MmaOp::AVecType>::scalar_type;
constexpr uint32_t FragM = Pipeline::FragM;
constexpr uint32_t FragK = Pipeline::FragK;
constexpr uint32_t FragsM = Pipeline::FragsM;
constexpr uint32_t FragsK = Pipeline::FragsK;
constexpr uint32_t kCompressionRatio = MmaOp::kCompressionRatio;
// The A register map maps (lane, vec_idx) -> (m_within_frag, k_within_frag)
// For sparse: k_within_frag is in the compressed K domain (K / kCompressionRatio)
constexpr index_t a_vec_size = ARegMap::num_vector_items;
constexpr index_t external_a_frag_vec_size = a_vec_size * kCompressionRatio;
for(uint32_t lane = 0; lane < waveSize; ++lane)
{
auto* lane_a = reinterpret_cast<AFragScalar*>(&a_per_lane[lane]);
for(uint32_t bm = 0; bm < FragsM; ++bm)
{
for(uint32_t bk = 0; bk < FragsK; ++bk)
{
uint32_t frag_offset = (bm * FragsK + bk) * external_a_frag_vec_size;
if constexpr(kCompressionRatio > 1)
{
// Sparse: fill external (uncompressed) vector
for(index_t ev = 0; ev < external_a_frag_vec_size; ++ev)
{
index_t compressed_v = ev / kCompressionRatio;
index_t sub_pos = ev % kCompressionRatio;
auto coords =
ARegMap::calc_matrix_indices_from_lane_vector(lane, compressed_v);
uint32_t m_local = coords[0];
uint32_t k_compressed = coords[1];
uint32_t k_local = k_compressed * kCompressionRatio + sub_pos;
uint32_t m_global = bm * FragM + m_local;
uint32_t k_global = bk * FragK + k_local;
lane_a[frag_offset + ev] =
static_cast<AFragScalar>(A_matrix[m_global * K + k_global]);
}
}
else
{
// Dense/Scale: direct mapping
for(index_t v = 0; v < a_vec_size; ++v)
{
auto coords = ARegMap::calc_matrix_indices_from_lane_vector(lane, v);
uint32_t m_local = coords[0];
uint32_t k_local = coords[1];
uint32_t m_global = bm * FragM + m_local;
uint32_t k_global = bk * FragK + k_local;
lane_a[frag_offset + v] =
static_cast<AFragScalar>(A_matrix[m_global * K + k_global]);
}
}
}
}
}
}
// Fill per-lane B fragments from logical B[K][N] matrix.
// BVecType = InternalBVecT[FragsN][FragsK]
template <typename Pipeline, typename BScalar>
void fill_b_fragments(typename Pipeline::BVecType* b_per_lane,
const std::vector<BScalar>& B_matrix,
uint32_t N,
uint32_t waveSize)
{
using MmaOp = typename Pipeline::MmaOp;
using BRegMap = TileDistrEncRegMap<typename TileDistrEncCalc<MmaOp>::BWarpDstrEncoding>;
using BFragScalar = typename vector_traits<typename MmaOp::BVecType>::scalar_type;
constexpr uint32_t FragN = Pipeline::FragN;
constexpr uint32_t FragK = Pipeline::FragK;
constexpr uint32_t FragsN = Pipeline::FragsN;
constexpr uint32_t FragsK = Pipeline::FragsK;
constexpr index_t b_vec_size = BRegMap::num_vector_items;
for(uint32_t lane = 0; lane < waveSize; ++lane)
{
auto* lane_b = reinterpret_cast<BFragScalar*>(&b_per_lane[lane]);
for(uint32_t bn = 0; bn < FragsN; ++bn)
{
for(uint32_t bk = 0; bk < FragsK; ++bk)
{
uint32_t frag_offset = (bn * FragsK + bk) * b_vec_size;
for(index_t v = 0; v < b_vec_size; ++v)
{
auto coords = BRegMap::calc_matrix_indices_from_lane_vector(lane, v);
uint32_t n_local = coords[0];
uint32_t k_local = coords[1];
uint32_t n_global = bn * FragN + n_local;
uint32_t k_global = bk * FragK + k_local;
// B matrix is stored as B[K][N]
lane_b[frag_offset + v] =
static_cast<BFragScalar>(B_matrix[k_global * N + n_global]);
}
}
}
}
}
// Extract C matrix from per-lane C fragments.
// CVecType = InternalCVecT[FragsM][FragsN]
template <typename Pipeline, typename CScalar>
void extract_c_matrix(const typename Pipeline::CVecType* c_per_lane,
std::vector<CScalar>& C_matrix,
uint32_t N,
uint32_t waveSize)
{
using MmaOp = typename Pipeline::MmaOp;
using CRegMap = TileDistrEncRegMap<typename TileDistrEncCalc<MmaOp>::CWarpDstrEncoding>;
using CFragScalar = typename vector_traits<typename MmaOp::CVecType>::scalar_type;
constexpr uint32_t FragM = Pipeline::FragM;
constexpr uint32_t FragN = Pipeline::FragN;
constexpr uint32_t FragsM = Pipeline::FragsM;
constexpr uint32_t FragsN = Pipeline::FragsN;
constexpr index_t c_vec_size = CRegMap::num_vector_items;
for(uint32_t lane = 0; lane < waveSize; ++lane)
{
auto* lane_c = reinterpret_cast<const CFragScalar*>(&c_per_lane[lane]);
for(uint32_t bm = 0; bm < FragsM; ++bm)
{
for(uint32_t bn = 0; bn < FragsN; ++bn)
{
uint32_t frag_offset = (bm * FragsN + bn) * c_vec_size;
for(index_t v = 0; v < c_vec_size; ++v)
{
auto coords = CRegMap::calc_matrix_indices_from_lane_vector(lane, v);
uint32_t m_local = coords[0];
uint32_t n_local = coords[1];
uint32_t m_global = bm * FragM + m_local;
uint32_t n_global = bn * FragN + n_local;
C_matrix[m_global * N + n_global] =
static_cast<CScalar>(lane_c[frag_offset + v]);
}
}
}
}
}
/// Internal: runs the test with a fully resolved Pipeline type.
/// Called from run_pipeline_matrix_test after dispatching on compiler target.
template <typename Pipeline,
typename KernelType,
typename AScalar = fp16_t,
typename BScalar = fp16_t,
typename CScalar = fp32_t>
void run_pipeline_matrix_test_impl(uint32_t M,
uint32_t N,
uint32_t K,
uint32_t waveSize,
KernelType kernel,
bool isSparse,
bool transposeExpected = false,
float referenceScale = 1.0f)
{
std::vector<AScalar> A_matrix(M * K);
std::vector<BScalar> B_matrix(K * N);
std::vector<CScalar> C_expected(M * N, static_cast<CScalar>(0));
std::vector<CScalar> C_actual(M * N, static_cast<CScalar>(0));
for(uint32_t m = 0; m < M; ++m)
for(uint32_t k = 0; k < K; ++k)
A_matrix[m * K + k] = deterministic_value<AScalar>(m, k, K);
for(uint32_t k = 0; k < K; ++k)
for(uint32_t n = 0; n < N; ++n)
B_matrix[k * N + n] = deterministic_value<BScalar>(k, n, N);
if(isSparse)
{
apply_sparse_pattern(A_matrix, M, K);
}
void
test_pipeline(std::function<bool(ck_tile::core::arch::amdgcn_target_id)> shouldSkip,
std::function<void(uint32_t, void*, void*, void*, void*, void*, void*)> kernel,
std::function<CType(uint32_t, ScaleAType, ScaleBType)> getExpected,
std::function<AType(size_t)> aInitializer = nullptr)
reference_matmul(C_expected, A_matrix, B_matrix, M, N, K);
using AVecType = typename Pipeline::AVecType;
using BVecType = typename Pipeline::BVecType;
using CVecType = typename Pipeline::CVecType;
const size_t a_buf_size = waveSize * sizeof(AVecType);
const size_t b_buf_size = waveSize * sizeof(BVecType);
const size_t c_buf_size = waveSize * sizeof(CVecType);
std::vector<uint8_t> h_a(a_buf_size, 0);
std::vector<uint8_t> h_b(b_buf_size, 0);
std::vector<uint8_t> h_c(c_buf_size, 0);
fill_a_fragments<Pipeline>(reinterpret_cast<AVecType*>(h_a.data()), A_matrix, K, waveSize);
fill_b_fragments<Pipeline>(reinterpret_cast<BVecType*>(h_b.data()), B_matrix, N, waveSize);
void *d_a, *d_b, *d_c;
HIP_CHECK_ERROR(hipMalloc(&d_a, a_buf_size));
HIP_CHECK_ERROR(hipMalloc(&d_b, b_buf_size));
HIP_CHECK_ERROR(hipMalloc(&d_c, c_buf_size));
HIP_CHECK_ERROR(hipMemcpy(d_a, h_a.data(), a_buf_size, hipMemcpyHostToDevice));
HIP_CHECK_ERROR(hipMemcpy(d_b, h_b.data(), b_buf_size, hipMemcpyHostToDevice));
HIP_CHECK_ERROR(hipMemset(d_c, 0, c_buf_size));
ck_tile::launch_kernel(ck_tile::stream_config{},
ck_tile::make_kernel(kernel, dim3(1), dim3(waveSize), 0, d_a, d_b, d_c));
HIP_CHECK_ERROR(hipDeviceSynchronize());
HIP_CHECK_ERROR(hipMemcpy(h_c.data(), d_c, c_buf_size, hipMemcpyDeviceToHost));
extract_c_matrix<Pipeline>(
reinterpret_cast<const CVecType*>(h_c.data()), C_actual, N, waveSize);
for(uint32_t m = 0; m < M; ++m)
{
using namespace ck_tile;
using namespace ck_tile::core::arch;
int devCount;
hipDevice_t dev;
HIP_CHECK_ERROR(hipGetDevice(&dev));
HIP_CHECK_ERROR(hipGetDeviceCount(&devCount));
hipDeviceProp_t devProp;
HIP_CHECK_ERROR(hipGetDeviceProperties(&devProp, dev));
auto currentArchId = hip_device_prop_gcn_arch_name_to_amdgcn_target_id(devProp.gcnArchName);
bool hasDevice = static_cast<bool>(devCount > 0);
int deviceWarpSize = devProp.warpSize;
if(!hasDevice || shouldSkip(currentArchId))
for(uint32_t n = 0; n < N; ++n)
{
GTEST_SKIP() << "No HIP device found. Skipping test.";
// When transposeExpected is true, the kernel computes C^T via SwapAB,
// so compare actual C[m][n] against reference C[n][m].
constexpr float relative_tolerance = 1e-2f;
constexpr float absolute_tolerance = 1e-3f;
float expected = transposeExpected ? static_cast<float>(C_expected[n * M + m])
: static_cast<float>(C_expected[m * N + n]);
expected *= referenceScale;
float actual = static_cast<float>(C_actual[m * N + n]);
EXPECT_NEAR(
actual, expected, std::abs(expected) * relative_tolerance + absolute_tolerance)
<< "Mismatch at C[" << m << "][" << n << "]";
}
// WaveTile size, also the expected fragment size (MmaTile) from the selector.
// Note: Actual FragK might be slightly different due to hardware implementation, but the
// test_accum_over_k kernel will loop over the K dimension to ensure that the total K is
// correct.
static constexpr uint32_t FragM = WaveTileM;
static constexpr uint32_t FragN = WaveTileN;
static constexpr uint32_t FragK = WaveTileK;
// The number of elements per thread
uint32_t AElements = FragM * FragK / deviceWarpSize / numeric_traits<AType>::PackedSize;
uint32_t BElements = FragN * FragK / deviceWarpSize / numeric_traits<BType>::PackedSize;
uint32_t CElements = FragM * FragN / deviceWarpSize;
uint32_t ASize = AElements * sizeof(AType);
uint32_t BSize = BElements * sizeof(BType);
uint32_t CSize = CElements * sizeof(CType);
uint32_t ScaleASize = 1 * sizeof(ScaleAType);
uint32_t ScaleBSize = 1 * sizeof(ScaleBType);
// Initialize A (use custom initializer or default all 1's), B to all 1's, C to all 0's
std::vector<AType> h_a(AElements);
if(aInitializer)
{
for(size_t i = 0; i < AElements; ++i)
h_a[i] = aInitializer(i);
}
else
{
std::fill(h_a.begin(), h_a.end(), type_convert<AType>(1.0f));
}
std::vector<BType> h_b(BElements, type_convert<BType>(1.0f));
std::vector<CType> h_c(CElements, type_convert<CType>(0.0f));
std::vector<CType> h_out(CElements, type_convert<CType>(0.0f));
// The actual scale is computed as pow(2, scale - 127), so:
// 126 -> 2^-1 and 129 -> 2^2.
ScaleAType h_scale_a = 126;
ScaleBType h_scale_b = 129;
AType* d_a;
BType* d_b;
CType* d_c;
CType* d_out;
ScaleAType* d_scale_a;
ScaleBType* d_scale_b;
HIP_CHECK_ERROR(hipMalloc(&d_a, ASize));
HIP_CHECK_ERROR(hipMalloc(&d_b, BSize));
HIP_CHECK_ERROR(hipMalloc(&d_c, CSize));
HIP_CHECK_ERROR(hipMalloc(&d_out, CSize));
HIP_CHECK_ERROR(hipMalloc(&d_scale_a, ScaleASize));
HIP_CHECK_ERROR(hipMalloc(&d_scale_b, ScaleBSize));
// Copy inputs to device
HIP_CHECK_ERROR(hipMemcpy(d_a, h_a.data(), ASize, hipMemcpyHostToDevice));
HIP_CHECK_ERROR(hipMemcpy(d_b, h_b.data(), BSize, hipMemcpyHostToDevice));
HIP_CHECK_ERROR(hipMemcpy(d_c, h_c.data(), CSize, hipMemcpyHostToDevice));
HIP_CHECK_ERROR(hipMemcpy(d_scale_a, &h_scale_a, ScaleASize, hipMemcpyHostToDevice));
HIP_CHECK_ERROR(hipMemcpy(d_scale_b, &h_scale_b, ScaleBSize, hipMemcpyHostToDevice));
const auto wave_size = getDeviceWaveSize();
kernel(wave_size, d_a, d_b, d_c, d_out, d_scale_a, d_scale_b);
HIP_CHECK_ERROR(hipDeviceSynchronize());
HIP_CHECK_ERROR(hipMemcpy(h_out.data(), d_out, CSize, hipMemcpyDeviceToHost));
// Verify output against expected value for all elements
for(size_t i = 0; i < CElements; ++i)
{
EXPECT_NEAR(h_out[i], getExpected(FragK, h_scale_a, h_scale_b), 1e-3);
}
HIP_CHECK_ERROR(hipFree(d_a));
HIP_CHECK_ERROR(hipFree(d_b));
HIP_CHECK_ERROR(hipFree(d_c));
HIP_CHECK_ERROR(hipFree(d_out));
HIP_CHECK_ERROR(hipFree(d_scale_a));
HIP_CHECK_ERROR(hipFree(d_scale_b));
}
};
HIP_CHECK_ERROR(hipFree(d_a));
HIP_CHECK_ERROR(hipFree(d_b));
HIP_CHECK_ERROR(hipFree(d_c));
}
/// @tparam PipelineFactory A template template that, given a CompilerTarget type, produces
/// the Pipeline type: PipelineFactory<Target>::type
/// @tparam KernelType Kernel functor struct with kBlockSize and __device__ operator()
/// @tparam AScalar Scalar type for A matrix (e.g., fp16_t)
/// @tparam BScalar Scalar type for B matrix (e.g., fp16_t)
/// @tparam CScalar Scalar type for C matrix (e.g., fp32_t)
/// @param M WaveTile M dimension
/// @param N WaveTile N dimension
/// @param K WaveTile K dimension
/// @param shouldSkip Predicate returning true if current device should skip
/// @param kernel Kernel functor instance to launch via make_kernel
/// @param isSparse Whether to apply 2:4 sparsity pattern to A
/// @param transposeExpected When true, compare against transposed reference (for
/// SwapAB/TransposeC)
/// @param referenceScale Scalar multiplier applied to the reference matmul result before
/// comparison (e.g., to account for scale-MMA scaling factors)
template <template <typename> class PipelineFactory,
typename KernelType,
typename AScalar = fp16_t,
typename BScalar = fp16_t,
typename CScalar = fp32_t>
void run_pipeline_matrix_test(uint32_t M,
uint32_t N,
uint32_t K,
std::function<bool(ck_tile::core::arch::amdgcn_target_id)> shouldSkip,
KernelType kernel,
bool isSparse = false,
bool transposeExpected = false,
float referenceScale = 1.0f)
{
int devCount;
hipDevice_t dev;
HIP_CHECK_ERROR(hipGetDevice(&dev));
HIP_CHECK_ERROR(hipGetDeviceCount(&devCount));
hipDeviceProp_t devProp;
HIP_CHECK_ERROR(hipGetDeviceProperties(&devProp, dev));
auto currentArchId = hip_device_prop_gcn_arch_name_to_amdgcn_target_id(devProp.gcnArchName);
if(devCount <= 0 || shouldSkip(currentArchId))
{
GTEST_SKIP() << "No HIP device found or arch (0x" << std::hex
<< static_cast<int>(currentArchId) << ") not supported. Skipping test.";
}
if(!hipTargetMatchesCmakeTargets(currentArchId))
{
std::cout << "The GPU targets exposed by CMake are: ";
for(const auto& target : getCMakeGpuTargetIds())
{
std::cout << "(0x" << std::hex << static_cast<int>(target) << ")\n";
}
FAIL() << "The HIP device (0x" << std::hex << static_cast<int>(currentArchId)
<< ") does not match the compiler target(s).";
}
const uint32_t waveSize = static_cast<uint32_t>(devProp.warpSize);
bool dispatched = dispatchCompilerTarget(currentArchId, [&](auto target) {
using CompilerTarget = decltype(target);
using Pipeline = typename PipelineFactory<CompilerTarget>::type;
run_pipeline_matrix_test_impl<Pipeline, KernelType, AScalar, BScalar, CScalar>(
M, N, K, waveSize, kernel, isSparse, transposeExpected, referenceScale);
});
if(!dispatched)
{
GTEST_SKIP() << "Cannot dispatch on HOST target.";
}
}
} // namespace mma_pipeline_test

View File

@@ -18,7 +18,7 @@ TEST(MmaPipelineOptionFlagsTests, ConversionTests)
MmaPipelineOptionFlags flags_0{};
MmaPipelineOptionFlags flags_1{MmaPipelineOptionFlag::ABSwap};
MmaPipelineOptionFlags flags_2{MmaPipelineOptionFlag::COMPRESS_A};
MmaPipelineOptionFlags flags_3{0b11};
MmaPipelineOptionFlags flags_3{0b11}; // TODO c++20 - remove this
EXPECT_TRUE(flags_0.testFlag(MmaPipelineOptionFlag::NONE));
EXPECT_FALSE(flags_0.testFlag(MmaPipelineOptionFlag::ABSwap));

View File

@@ -5,17 +5,16 @@
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/arch/mma/amdgcn_mma.hpp"
#include "ck_tile/core/arch/mma/mma.hpp"
#include "ck_tile/core/arch/mma/mma_op_family.hpp"
#include "ck_tile/core/arch/mma/mma_selector.hpp"
#include "ck_tile/core/arch/mma/mma_traits.hpp"
#include "ck_tile/core/arch/mma/scale/scale_mma_pipeline.hpp"
#include "ck_tile/core/numeric/float8.hpp"
#include "ck_tile/core/numeric/pk_fp4.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include <gtest/gtest.h>
#include <cmath>
#include <cstdint>
#include <iostream>
#include <type_traits>
@@ -44,21 +43,21 @@ void ScaleMfmaGfx950Specialization_impl()
CompilerTargetGfx950,
MmaOpFamily::SCALE>;
static_assert(std::is_same_v<typename TestScaleMma::OpType, MfmaOp> &&
TestScaleMma::OpFamily == MmaOpFamily::SCALE,
"GFX950 scale intrinsic should have ScaleMFMAOp type");
EXPECT_TRUE((std::is_same_v<typename TestScaleMma::OpType, MfmaOp> &&
TestScaleMma::OpFamily == MmaOpFamily::SCALE))
<< "GFX950 scale intrinsic should have ScaleMFMAOp type";
static_assert(is_mma_op_of_family_v<MmaOpFamily::SCALE, TestScaleMma>,
"GFX950 scale intrinsic should be detected as Scale");
EXPECT_TRUE((is_mma_op_of_family_v<MmaOpFamily::SCALE, TestScaleMma>))
<< "GFX950 scale intrinsic should be detected as Scale";
// Get its traits
using TestTraits = MmaOpTraits<TestScaleMma>;
// Verify trait detection
static_assert(TestTraits::IsScale, "Scale MMA should be detected as scale");
static_assert(TestTraits::IsSupported, "Scale MMA specialization should be supported");
static_assert(TestTraits::IsMfma, "Scale MFMA should be detected as MFMA");
static_assert(!TestTraits::IsWmma, "Scale MFMA should not be detected as WMMA");
EXPECT_TRUE(TestTraits::IsScale) << "Scale MMA should be detected as scale";
EXPECT_TRUE(TestTraits::IsSupported) << "Scale MMA specialization should be supported";
EXPECT_TRUE(TestTraits::IsMfma) << "Scale MFMA should be detected as MFMA";
EXPECT_FALSE(TestTraits::IsWmma) << "Scale MFMA should not be detected as WMMA";
}
TEST(ScaleMMATrait, ScaleMfmaGfx950Specialization)
@@ -67,14 +66,10 @@ TEST(ScaleMMATrait, ScaleMfmaGfx950Specialization)
ScaleMfmaGfx950Specialization_impl<fp8_t, fp8_t, fp32_t, 16u, 16u, 128u>();
// Test bf8 -> fp32 scale MFMA for GFX950 (16x16x128)
ScaleMfmaGfx950Specialization_impl<bf8_t, bf8_t, fp32_t, 16u, 16u, 128u>();
// Test fp4 -> fp32 scale MFMA for GFX950 (16x16x128)
ScaleMfmaGfx950Specialization_impl<pk_fp4_t, pk_fp4_t, fp32_t, 16u, 16u, 128u>();
// Test fp8 -> fp32 scale MFMA for GFX950 (32x32x64)
ScaleMfmaGfx950Specialization_impl<fp8_t, fp8_t, fp32_t, 32u, 32u, 64u>();
// Test bf8 -> fp32 scale MFMA for GFX950 (32x32x64)
ScaleMfmaGfx950Specialization_impl<bf8_t, bf8_t, fp32_t, 32u, 32u, 64u>();
// Test fp4 -> fp32 scale MFMA for GFX950 (32x32x64)
ScaleMfmaGfx950Specialization_impl<pk_fp4_t, pk_fp4_t, fp32_t, 32u, 32u, 64u>();
std::cout << "GFX950 scale MFMA specialization is correct" << std::endl;
}
@@ -97,7 +92,7 @@ void TestConceptRequirements_impl()
DefaultScaleMfmaCtrlFlags,
CompilerTargetGfx950,
MmaOpFamily::SCALE>;
static_assert(MmaOpI<TestScaleMma>);
EXPECT_TRUE(MmaOpI<TestScaleMma>);
}
#endif // CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER
@@ -106,10 +101,8 @@ TEST(ScaleMMATrait, TestConceptRequirements)
#if CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER
TestConceptRequirements_impl<fp8_t, fp8_t, fp32_t, 16u, 16u, 128u>();
TestConceptRequirements_impl<bf8_t, bf8_t, fp32_t, 16u, 16u, 128u>();
TestConceptRequirements_impl<pk_fp4_t, pk_fp4_t, fp32_t, 16u, 16u, 128u>();
TestConceptRequirements_impl<fp8_t, fp8_t, fp32_t, 32u, 32u, 64u>();
TestConceptRequirements_impl<bf8_t, bf8_t, fp32_t, 32u, 32u, 64u>();
TestConceptRequirements_impl<pk_fp4_t, pk_fp4_t, fp32_t, 32u, 32u, 64u>();
#else
GTEST_SKIP() << "Not compiled with concepts. Skipping test.";
#endif // CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER
@@ -132,15 +125,15 @@ void ScaleSelector_impl()
if constexpr(isValid)
{
// Selector should pick a scale MFMA implementation
static_assert(MmaOpTraits<Selected>::IsScale);
static_assert(MmaOpTraits<Selected>::IsMfma);
static_assert(MmaOpTraits<Selected>::IsSupported);
static_assert((std::is_same<typename Selected::OpType, MfmaOp>::value));
EXPECT_TRUE(MmaOpTraits<Selected>::IsScale);
EXPECT_TRUE(MmaOpTraits<Selected>::IsMfma);
EXPECT_TRUE(MmaOpTraits<Selected>::IsSupported);
EXPECT_TRUE((std::is_same<typename Selected::OpType, MfmaOp>::value));
}
else
{
// Selector should pick the unsupported pass through
static_assert(!MmaOpTraits<Selected>::IsSupported);
EXPECT_FALSE(MmaOpTraits<Selected>::IsSupported);
}
});
});
@@ -150,7 +143,6 @@ TEST(ScaleMMATrait, ScaleSelector)
{
ScaleSelector_impl<fp8_t, fp8_t, fp32_t>();
ScaleSelector_impl<bf8_t, bf8_t, fp32_t>();
ScaleSelector_impl<pk_fp4_t, pk_fp4_t, fp32_t>();
}
template <typename AType,
@@ -161,34 +153,72 @@ template <typename AType,
std::uint32_t WaveTileM,
std::uint32_t WaveTileN,
std::uint32_t WaveTileK>
__global__ void
test_scale_accum_over_k(void* a, void* b, void* c, void* out, void* scale_A, void* scale_B)
struct ScalePipelineKernel
{
using Pipeline = ScaleMmaPipeline<AType, BType, CType, WaveTileM, WaveTileN, WaveTileK>;
static constexpr int kBlockSize = mma_pipeline_test::getCMakeWaveSize();
using AVecType = typename Pipeline::AVecType;
using BVecType = typename Pipeline::BVecType;
using CVecType = typename Pipeline::CVecType;
// NOTE: WaveTileK is used as a Pipeline template parameter, but the K iteration is
// happening outside the Pipeline. This is a bit incorrect currently.
static constexpr std::uint32_t kIters = WaveTileK / Pipeline::MmaOp::kK;
// Initialize the accumulator
CVecType result = *reinterpret_cast<CVecType*>(c);
// Accumulate input AxB over WaveTileK/FragK iterations
for(std::uint32_t i = 0; i < kIters; ++i)
__device__ void
operator()(const void* a_per_lane, const void* b_per_lane, void* c_per_lane) const
{
result = Pipeline::exec(*reinterpret_cast<AVecType*>(a),
*reinterpret_cast<BVecType*>(b),
result,
*reinterpret_cast<ScaleAType*>(scale_A),
*reinterpret_cast<ScaleBType*>(scale_B));
}
using Pipeline = ScaleMmaPipeline<AType, BType, CType, WaveTileM, WaveTileN, WaveTileK>;
*reinterpret_cast<CVecType*>(out) = result;
}
using AVecType = typename Pipeline::AVecType;
using BVecType = typename Pipeline::BVecType;
using CVecType = typename Pipeline::CVecType;
const uint32_t lane = threadIdx.x;
AVecType a;
BVecType b;
CVecType c;
__builtin_memcpy(&a,
static_cast<const uint8_t*>(a_per_lane) + lane * sizeof(AVecType),
sizeof(AVecType));
__builtin_memcpy(&b,
static_cast<const uint8_t*>(b_per_lane) + lane * sizeof(BVecType),
sizeof(BVecType));
__builtin_memset(&c, 0, sizeof(CVecType));
if constexpr(MmaOpTraits<typename Pipeline::MmaOp>::IsSupported)
{
// Each lane has a single 8-bit E8M0 scale that applies to all
// 32 A/B elements in that lane. The byte's position within the
// VGPR is selected by opsel. Replicating the byte to all 4
// positions makes the value opsel-independent.
// scale_a byte = 126 -> 2^(126-127) = 2^-1 = 0.5
// scale_b byte = 129 -> 2^(129-127) = 2^2 = 4.0
// Combined scale factor = 0.5 * 4.0 = 2.0
constexpr int32_t replicate_byte = 0x01010101;
ScaleAType scale_a = 126u * replicate_byte;
ScaleBType scale_b = 129u * replicate_byte;
Pipeline::exec(a, b, c, scale_a, scale_b);
__builtin_memcpy(
static_cast<uint8_t*>(c_per_lane) + lane * sizeof(CVecType), &c, sizeof(CVecType));
}
}
};
template <typename AType,
typename BType,
typename CType,
std::uint32_t WaveTileM,
std::uint32_t WaveTileN,
std::uint32_t WaveTileK>
struct ScalePipelineFactory
{
template <typename Target>
struct Create
{
using type = ScaleMmaPipeline<AType,
BType,
CType,
WaveTileM,
WaveTileN,
WaveTileK,
MmaAccumPolicy::ROW_MAJOR,
Target>;
};
};
template <typename AType,
typename BType,
@@ -198,39 +228,37 @@ template <typename AType,
std::uint32_t WaveTileK>
void MmaSelector_Scale_Real_impl()
{
using TestType = MmaPipelineTest<AType, BType, CType, WaveTileM, WaveTileN, WaveTileK>;
TestType test;
using ScaleAType = std::int32_t;
using ScaleBType = std::int32_t;
const auto should_skip = [](amdgcn_target_id currentArchId) {
bool isSupportedWmma = false;
bool isSupportedMfma = (currentArchId == amdgcn_target_id::GFX950);
return ((currentArchId == amdgcn_target_id::HOST) || !(isSupportedWmma || isSupportedMfma));
return ((currentArchId == amdgcn_target_id::HOST) || !isSupportedMfma);
};
const std::function<fp32_t(
std::uint32_t, typename TestType::ScaleAType, typename TestType::ScaleBType)>
validator =
[](std::uint32_t fragK, TestType::ScaleAType scale_A, TestType::ScaleBType scale_B) {
fp32_t actual_scale_A = std::powf(2.0f, scale_A - 127.0f);
fp32_t actual_scale_B = std::powf(2.0f, scale_B - 127.0f);
return static_cast<fp32_t>(fragK) * actual_scale_A * actual_scale_B;
};
const auto kernel = [](std::uint32_t waveSize,
void* a,
void* b,
void* c,
void* out,
void* scale_A,
void* scale_B) {
test_scale_accum_over_k<typename TestType::AType,
typename TestType::BType,
typename TestType::CType,
typename TestType::ScaleAType,
typename TestType::ScaleBType,
TestType::WaveTileM,
TestType::WaveTileN,
TestType::WaveTileK>
<<<1, waveSize>>>(a, b, c, out, scale_A, scale_B);
};
test.test_pipeline(should_skip, kernel, validator);
using Factory = ScalePipelineFactory<AType, BType, CType, WaveTileM, WaveTileN, WaveTileK>;
using Kernel = ScalePipelineKernel<AType,
BType,
CType,
ScaleAType,
ScaleBType,
WaveTileM,
WaveTileN,
WaveTileK>;
// scale_a=126 -> 2^-1=0.5, scale_b=129 -> 2^2=4.0 -> combined = 2.0
constexpr float reference_scale = 2.0f;
mma_pipeline_test::
run_pipeline_matrix_test<Factory::template Create, Kernel, AType, BType, CType>(
WaveTileM,
WaveTileN,
WaveTileK,
should_skip,
Kernel{},
/*isSparse=*/false,
/*transposeExpected=*/false,
reference_scale);
}
// Live test on real hardware for scale selection and execution.
@@ -245,12 +273,6 @@ TEST(ScaleMMATrait, MmaSelector_Scale_BF8_BF8_F32_16x16x128_Real)
MmaSelector_Scale_Real_impl<bf8_t, bf8_t, fp32_t, 16u, 16u, 128u>();
}
// Live test on real hardware for scale selection and execution.
TEST(ScaleMMATrait, MmaSelector_Scale_F4_F4_F32_16x16x128_Real)
{
MmaSelector_Scale_Real_impl<pk_fp4_t, pk_fp4_t, fp32_t, 16u, 16u, 128u>();
}
// Live test on real hardware for scale selection and execution.
TEST(ScaleMMATrait, MmaSelector_Scale_F8_F8_F32_32x32x64_Real)
{
@@ -263,8 +285,215 @@ TEST(ScaleMMATrait, MmaSelector_Scale_BF8_BF8_F32_32x32x64_Real)
MmaSelector_Scale_Real_impl<bf8_t, bf8_t, fp32_t, 32u, 32u, 64u>();
}
// Live test on real hardware for scale selection and execution.
TEST(ScaleMMATrait, MmaSelector_Scale_F4_F4_F32_32x32x64_Real)
// ---------------------------------------------------------------------------
// Multi-fragment (WaveWise) scale pipeline tests
// ---------------------------------------------------------------------------
// Kernel functor with AccumPolicy support for multi-fragment scale pipeline tests.
template <typename AType,
typename BType,
typename CType,
typename ScaleAType,
typename ScaleBType,
std::uint32_t WaveTileM,
std::uint32_t WaveTileN,
std::uint32_t WaveTileK,
MmaAccumPolicy AccumPolicy>
struct ScaleWaveWisePipelineKernel
{
MmaSelector_Scale_Real_impl<pk_fp4_t, pk_fp4_t, fp32_t, 32u, 32u, 64u>();
static constexpr int kBlockSize = mma_pipeline_test::getCMakeWaveSize();
__device__ void
operator()(const void* a_per_lane, const void* b_per_lane, void* c_per_lane) const
{
using CompilerTarget = decltype(get_compiler_target());
using Pipeline = ScaleMmaPipeline<AType,
BType,
CType,
WaveTileM,
WaveTileN,
WaveTileK,
AccumPolicy,
CompilerTarget>;
using AVecType = typename Pipeline::AVecType;
using BVecType = typename Pipeline::BVecType;
using CVecType = typename Pipeline::CVecType;
const uint32_t lane = threadIdx.x;
AVecType a;
BVecType b;
CVecType c;
__builtin_memcpy(&a,
static_cast<const uint8_t*>(a_per_lane) + lane * sizeof(AVecType),
sizeof(AVecType));
__builtin_memcpy(&b,
static_cast<const uint8_t*>(b_per_lane) + lane * sizeof(BVecType),
sizeof(BVecType));
__builtin_memset(&c, 0, sizeof(CVecType));
if constexpr(MmaOpTraits<typename Pipeline::MmaOp>::IsSupported)
{
// Each lane has a single 8-bit E8M0 scale that applies to all
// 32 A/B elements in that lane. The byte's position within the
// VGPR is selected by opsel. Replicating the byte to all 4
// positions makes the value opsel-independent.
// scale_a byte = 126 -> 2^(126-127) = 2^-1 = 0.5
// scale_b byte = 129 -> 2^(129-127) = 2^2 = 4.0
// Combined scale factor = 0.5 * 4.0 = 2.0
constexpr int32_t replicate_byte = 0x01010101;
ScaleAType scale_a = 126u * replicate_byte;
ScaleBType scale_b = 129u * replicate_byte;
Pipeline::exec(a, b, c, scale_a, scale_b);
__builtin_memcpy(
static_cast<uint8_t*>(c_per_lane) + lane * sizeof(CVecType), &c, sizeof(CVecType));
}
}
};
template <typename AType,
typename BType,
typename CType,
std::uint32_t WaveTileM,
std::uint32_t WaveTileN,
std::uint32_t WaveTileK,
MmaAccumPolicy AccumPolicy>
struct ScaleWaveWisePipelineFactory
{
template <typename Target>
struct Create
{
using type = ScaleMmaPipeline<AType,
BType,
CType,
WaveTileM,
WaveTileN,
WaveTileK,
AccumPolicy,
Target>;
};
};
template <typename AType,
typename BType,
typename CType,
std::uint32_t WaveTileM,
std::uint32_t WaveTileN,
std::uint32_t WaveTileK,
MmaAccumPolicy AccumPolicy = MmaAccumPolicy::ROW_MAJOR>
void MmaSelector_Scale_WaveWise_Real_impl()
{
using ScaleAType = std::int32_t;
using ScaleBType = std::int32_t;
const auto should_skip = [](amdgcn_target_id currentArchId) {
bool isSupportedMfma = (currentArchId == amdgcn_target_id::GFX950);
return ((currentArchId == amdgcn_target_id::HOST) || !isSupportedMfma);
};
using Factory = ScaleWaveWisePipelineFactory<AType,
BType,
CType,
WaveTileM,
WaveTileN,
WaveTileK,
AccumPolicy>;
using Kernel = ScaleWaveWisePipelineKernel<AType,
BType,
CType,
ScaleAType,
ScaleBType,
WaveTileM,
WaveTileN,
WaveTileK,
AccumPolicy>;
// scale_a=126 -> 2^-1=0.5, scale_b=129 -> 2^2=4.0 -> combined = 2.0
constexpr float reference_scale = 2.0f;
mma_pipeline_test::
run_pipeline_matrix_test<Factory::template Create, Kernel, AType, BType, CType>(
WaveTileM,
WaveTileN,
WaveTileK,
should_skip,
Kernel{},
/*isSparse=*/false,
/*transposeExpected=*/false,
reference_scale);
}
// Multi-fragment tests: 64x64x64 uses 32x32x64 op -> FragsM=2, FragsN=2, FragsK=1
TEST(ScaleMMATrait, MmaSelector_Scale_F8_F8_F32_64x64x64_WaveWise_RowMajor_Real)
{
MmaSelector_Scale_WaveWise_Real_impl<fp8_t,
fp8_t,
fp32_t,
64u,
64u,
64u,
MmaAccumPolicy::ROW_MAJOR>();
}
TEST(ScaleMMATrait, MmaSelector_Scale_F8_F8_F32_64x64x64_WaveWise_ColMajor_Real)
{
MmaSelector_Scale_WaveWise_Real_impl<fp8_t,
fp8_t,
fp32_t,
64u,
64u,
64u,
MmaAccumPolicy::COL_MAJOR>();
}
TEST(ScaleMMATrait, MmaSelector_Scale_BF8_BF8_F32_64x64x64_WaveWise_RowMajor_Real)
{
MmaSelector_Scale_WaveWise_Real_impl<bf8_t,
bf8_t,
fp32_t,
64u,
64u,
64u,
MmaAccumPolicy::ROW_MAJOR>();
}
// Multi-fragment tests: 32x32x128 uses 32x32x64 op -> FragsM=1, FragsN=1, FragsK=2
TEST(ScaleMMATrait, MmaSelector_Scale_F8_F8_F32_32x32x128_WaveWise_RowMajor_Real)
{
MmaSelector_Scale_WaveWise_Real_impl<fp8_t, fp8_t, fp32_t, 32u, 32u, 128u>();
}
TEST(ScaleMMATrait, MmaSelector_Scale_BF8_BF8_F32_32x32x128_WaveWise_RowMajor_Real)
{
MmaSelector_Scale_WaveWise_Real_impl<bf8_t, bf8_t, fp32_t, 32u, 32u, 128u>();
}
// Multi-fragment tests: 64x64x128 uses 32x32x64 op -> FragsM=2, FragsN=2, FragsK=2
TEST(ScaleMMATrait, MmaSelector_Scale_F8_F8_F32_64x64x128_WaveWise_RowMajor_Real)
{
MmaSelector_Scale_WaveWise_Real_impl<fp8_t, fp8_t, fp32_t, 64u, 64u, 128u>();
}
TEST(ScaleMMATrait, MmaSelector_Scale_F8_F8_F32_64x64x128_WaveWise_ColMajor_Real)
{
MmaSelector_Scale_WaveWise_Real_impl<fp8_t,
fp8_t,
fp32_t,
64u,
64u,
128u,
MmaAccumPolicy::COL_MAJOR>();
}
// Multi-fragment tests with 16x16x128 op: 32x16x128 -> FragsM=2, FragsN=1, FragsK=1
TEST(ScaleMMATrait, MmaSelector_Scale_F8_F8_F32_32x16x128_WaveWise_RowMajor_Real)
{
MmaSelector_Scale_WaveWise_Real_impl<fp8_t, fp8_t, fp32_t, 32u, 16u, 128u>();
}
// Multi-fragment tests with 16x16x128 op: 16x32x128 -> FragsM=1, FragsN=2, FragsK=1
TEST(ScaleMMATrait, MmaSelector_Scale_F8_F8_F32_16x32x128_WaveWise_RowMajor_Real)
{
MmaSelector_Scale_WaveWise_Real_impl<fp8_t, fp8_t, fp32_t, 16u, 32u, 128u>();
}

View File

@@ -8,6 +8,7 @@
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/arch/mma/amdgcn_mma.hpp"
#include "ck_tile/core/arch/mma/mma.hpp"
#include "ck_tile/core/arch/mma/mma_op_family.hpp"
#include "ck_tile/core/arch/mma/mma_selector.hpp"
#include "ck_tile/core/arch/mma/sparse/sparse_mma_pipeline.hpp"
@@ -41,14 +42,12 @@ TEST(SparseMMATrait, SparseMfmaGfx950Specialization)
CompilerTargetGfx950,
MmaOpFamily::SPARSE>;
static_assert(std::is_same_v<typename TestSparseMfma16x16::OpType, MfmaOp> &&
TestSparseMfma16x16::OpFamily == MmaOpFamily::SPARSE,
"GFX950 sparse 16x16x32 should have SparseMFMAOp type");
EXPECT_TRUE((std::is_same_v<typename TestSparseMfma16x16::OpType, MfmaOp> &&
TestSparseMfma16x16::OpFamily == MmaOpFamily::SPARSE))
<< "GFX950 sparse 16x16x32 should have SparseMFMAOp type";
static_assert(is_mma_op_of_family_v<MmaOpFamily::SPARSE, TestSparseMfma16x16>,
"GFX950 sparse 16x16x32 should be detected as Sparse");
std::cout << "GFX950 sparse MFMA specialization is correct" << std::endl;
EXPECT_TRUE((is_mma_op_of_family_v<MmaOpFamily::SPARSE, TestSparseMfma16x16>))
<< "GFX950 sparse 16x16x32 should be detected as Sparse";
}
TEST(SparseMMATrait, MmaOpTraitsIntegration)
@@ -68,12 +67,10 @@ TEST(SparseMMATrait, MmaOpTraitsIntegration)
using TestTraits = MmaOpTraits<TestSparseMmma>;
// Verify trait detection
static_assert(TestTraits::IsSparse, "Sparse MMA should be detected as sparse");
static_assert(TestTraits::IsSupported, "Sparse MMA specialization should be supported");
static_assert(TestTraits::IsMfma, "Sparse MFMA should be detected as MFMA");
static_assert(!TestTraits::IsWmma, "Sparse MFMA should not be detected as WMMA");
std::cout << "MmaOpTraits correctly integrates sparse operations" << std::endl;
EXPECT_TRUE(TestTraits::IsSparse) << "Sparse MMA should be detected as sparse";
EXPECT_TRUE(TestTraits::IsSupported) << "Sparse MMA specialization should be supported";
EXPECT_TRUE(TestTraits::IsMfma) << "Sparse MFMA should be detected as MFMA";
EXPECT_FALSE(TestTraits::IsWmma) << "Sparse MFMA should not be detected as WMMA";
}
TEST(SparseMMATrait, TestConceptRequirements)
@@ -88,7 +85,7 @@ TEST(SparseMMATrait, TestConceptRequirements)
DefaultSparseMfmaCtrlFlags,
CompilerTargetGfx950,
MmaOpFamily::SPARSE>;
static_assert(MmaOpI<TestSparseMmma>);
EXPECT_TRUE(MmaOpI<TestSparseMmma>);
#else
GTEST_SKIP() << "Not compiled with concepts. Skipping test.";
#endif // CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER
@@ -119,22 +116,20 @@ TEST(SparseMMATrait, DenseVsSparseDistinction)
MmaOpFamily::SPARSE>;
// Verify they have different operation types
static_assert(std::is_same_v<typename DenseMfma::OpType, typename SparseMfma::OpType> &&
DenseMfma::OpFamily != SparseMfma::OpFamily,
"Dense and Sparse MFMA should have the same OpType tags and different OpFamily");
EXPECT_TRUE((std::is_same_v<typename DenseMfma::OpType, typename SparseMfma::OpType> &&
DenseMfma::OpFamily != SparseMfma::OpFamily))
<< "Dense and Sparse MFMA should have the same OpType tags and different OpFamily";
// Verify traits correctly identify them
static_assert(MmaOpTraits<DenseMfma>::IsMfma && MmaOpTraits<DenseMfma>::IsDense &&
!MmaOpTraits<DenseMfma>::IsSparse && !MmaOpTraits<DenseMfma>::IsScale &&
MmaOpTraits<DenseMfma>::IsSupported,
"Dense MFMA should be identified correctly");
EXPECT_TRUE((MmaOpTraits<DenseMfma>::IsMfma && MmaOpTraits<DenseMfma>::IsDense &&
!MmaOpTraits<DenseMfma>::IsSparse && !MmaOpTraits<DenseMfma>::IsScale &&
MmaOpTraits<DenseMfma>::IsSupported))
<< "Dense MFMA should be identified correctly";
static_assert(MmaOpTraits<SparseMfma>::IsSparse && MmaOpTraits<SparseMfma>::IsMfma &&
!MmaOpTraits<SparseMfma>::IsDense && !MmaOpTraits<SparseMfma>::IsScale &&
MmaOpTraits<SparseMfma>::IsSupported,
"Sparse MFMA should be identified correctly");
std::cout << "Dense and sparse MMA operations are correctly distinguished" << std::endl;
EXPECT_TRUE((MmaOpTraits<SparseMfma>::IsSparse && MmaOpTraits<SparseMfma>::IsMfma &&
!MmaOpTraits<SparseMfma>::IsDense && !MmaOpTraits<SparseMfma>::IsScale &&
MmaOpTraits<SparseMfma>::IsSupported))
<< "Sparse MFMA should be identified correctly";
}
TEST(SparseMMATrait, SparseSelector)
@@ -153,100 +148,53 @@ TEST(SparseMMATrait, SparseSelector)
if constexpr(isValid)
{
// Selector should pick a sparse MFMA implementation
static_assert(MmaOpTraits<Selected>::IsSparse);
static_assert(MmaOpTraits<Selected>::IsMfma);
static_assert(MmaOpTraits<Selected>::IsSupported);
static_assert((std::is_same<typename Selected::OpType, MfmaOp>::value));
EXPECT_TRUE(MmaOpTraits<Selected>::IsSparse);
EXPECT_TRUE(MmaOpTraits<Selected>::IsMfma);
EXPECT_TRUE(MmaOpTraits<Selected>::IsSupported);
EXPECT_TRUE((std::is_same<typename Selected::OpType, MfmaOp>::value));
}
else
{
// Selector should pick the unsupported pass through
static_assert(!MmaOpTraits<Selected>::IsSupported);
EXPECT_FALSE(MmaOpTraits<Selected>::IsSupported);
}
});
}
template <typename AType,
typename BType,
typename CType,
uint32_t WaveTileM,
uint32_t WaveTileN,
uint32_t WaveTileK>
__global__ void test_sparse_accum_over_k(void* a, void* b, void* c, void* out)
{
using Pipeline = SparseMmaPipeline<AType, BType, CType, WaveTileM, WaveTileN, WaveTileK>;
using AVecType = typename Pipeline::AVecType;
using BVecType = typename Pipeline::BVecType;
using CVecType = typename Pipeline::CVecType;
static constexpr uint32_t kIters = WaveTileK / Pipeline::MmaOp::kK;
// Initialize the accumulator
CVecType result = *reinterpret_cast<CVecType*>(c);
// Accumulate input AxB over WaveTileK/FragK iterations
for(uint32_t i = 0; i < kIters; ++i)
{
result = Pipeline::exec(
*reinterpret_cast<AVecType*>(a), *reinterpret_cast<BVecType*>(b), result);
}
*reinterpret_cast<CVecType*>(out) = result;
}
// Live test on real hardware for sparse selection and execution.
TEST(SparseMMATrait, MmaSelector_Sparse_F16_F16_F32_16x16x32_Real)
{
MmaPipelineTest<> test;
const auto should_skip = [](amdgcn_target_id currentArchId) {
bool isSupportedWmma = (currentArchId >= amdgcn_target_id::GFX1200) &&
(currentArchId <= amdgcn_target_id::GFX12_GENERIC);
bool isSupportedMfma = (currentArchId >= amdgcn_target_id::GFX942) &&
(currentArchId <= amdgcn_target_id::GFX950);
return ((currentArchId == amdgcn_target_id::HOST) || !(isSupportedWmma || isSupportedMfma));
};
const std::function<fp32_t(uint32_t)> validator = [](uint32_t waveTileK) {
return static_cast<fp32_t>(waveTileK) / 2;
};
const auto kernel = [](uint32_t waveSize, void* a, void* b, void* c, void* out) {
test_sparse_accum_over_k<MmaPipelineTest<>::AType,
MmaPipelineTest<>::BType,
MmaPipelineTest<>::CType,
MmaPipelineTest<>::WaveTileM,
MmaPipelineTest<>::WaveTileN,
MmaPipelineTest<>::WaveTileK><<<1, waveSize>>>(a, b, c, out);
};
// Initialize A with 2:4 structured sparsity pattern: {1, 0, 1, 0, ...}
// This ensures the sparse compression transform is actually exercised -
// a no-op or broken compression would pass zeros through, causing incorrect results.
const std::function<fp16_t(size_t)> sparseAInit = [](size_t i) -> fp16_t {
return (i % 2 == 0) ? type_convert<fp16_t>(1) : type_convert<fp16_t>(0);
};
test.test_pipeline(should_skip, kernel, validator, sparseAInit);
}
template <uint32_t CompressionRatio, typename Vec>
__global__ void test_sparse_transform(void* a, void* idx)
struct SparseTransformKernel
{
using ResultT =
decltype(SparseCompressTransform<CompressionRatio>::exec(*static_cast<Vec*>(a)));
using FirstT = std::tuple_element_t<0, ResultT>;
const auto& [vec, i] = SparseCompressTransform<CompressionRatio>::exec(*static_cast<Vec*>(a));
*reinterpret_cast<remove_cvref_t<FirstT>*>(a) = vec;
*reinterpret_cast<int32_t*>(idx) = i;
}
static constexpr int kBlockSize = mma_pipeline_test::getCMakeWaveSize();
__device__ void operator()(void* a, void* idx) const
{
using ResultT =
decltype(SparseCompressTransform<CompressionRatio>::exec(*static_cast<Vec*>(a)));
using FirstT = std::tuple_element_t<0, ResultT>;
using IdxT = std::tuple_element_t<1, ResultT>;
const auto& [vec, i] =
SparseCompressTransform<CompressionRatio>::exec(*static_cast<Vec*>(a));
*reinterpret_cast<remove_cvref_t<FirstT>*>(a) = vec;
__builtin_memcpy(idx, &i, sizeof(IdxT));
}
};
// Generalized helper: runs the sparse transform kernel and verifies compressed output and index.
template <int NUM, int RATIO, typename Type>
void sparse_transform_verify(const std::vector<Type>& input,
const std::vector<Type>& expected_output,
int32_t expected_idx)
void sparse_transform_verify(
const std::vector<Type>& input,
const std::vector<Type>& expected_output,
const sparse::detail::SparseIdxPack<sparse::detail::idx_words_needed<NUM / RATIO>>&
expected_idx)
{
static_assert(RATIO == 2, "Extend functionality if other ratio is used.");
ASSERT_EQ(static_cast<int>(input.size()), NUM);
ASSERT_EQ(static_cast<int>(expected_output.size()), NUM / RATIO);
constexpr int CompressedSize = NUM / RATIO;
constexpr int IdxNumWords = sparse::detail::idx_words_needed<CompressedSize>;
using IdxType = sparse::detail::SparseIdxPack<IdxNumWords>;
int devCount;
hipDevice_t dev;
HIP_CHECK_ERROR(hipGetDevice(&dev));
@@ -265,24 +213,31 @@ void sparse_transform_verify(const std::vector<Type>& input,
}
float* d_v;
int32_t* d_idx;
void* d_idx;
static constexpr auto Size = sizeof(Type) * NUM;
HIP_CHECK_ERROR(hipMalloc(&d_v, Size));
HIP_CHECK_ERROR(hipMalloc(&d_idx, sizeof(int32_t)));
HIP_CHECK_ERROR(hipMalloc(&d_idx, sizeof(IdxType)));
// Copy inputs to device
HIP_CHECK_ERROR(hipMemcpy(d_v, input.data(), Size, hipMemcpyHostToDevice));
test_sparse_transform<RATIO, ext_vector_t<Type, NUM>><<<1, 32>>>(d_v, d_idx);
using Kernel = SparseTransformKernel<RATIO, ext_vector_t<Type, NUM>>;
ck_tile::launch_kernel(ck_tile::stream_config{},
ck_tile::make_kernel(Kernel{}, dim3(1), dim3(32), 0, d_v, d_idx));
HIP_CHECK_ERROR(hipDeviceSynchronize());
std::vector<Type> h_out(NUM / RATIO, static_cast<Type>(0));
HIP_CHECK_ERROR(hipMemcpy(h_out.data(), d_v, Size / RATIO, hipMemcpyDeviceToHost));
int32_t h_idx;
HIP_CHECK_ERROR(hipMemcpy(&h_idx, d_idx, sizeof(int32_t), hipMemcpyDeviceToHost));
IdxType h_idx{};
HIP_CHECK_ERROR(hipMemcpy(&h_idx, d_idx, sizeof(IdxType), hipMemcpyDeviceToHost));
EXPECT_EQ(h_idx, expected_idx) << "Index mask mismatch";
EXPECT_EQ(h_idx.words[0], expected_idx.words[0]) << "Index mask mismatch (word 0)";
for(int w = 1; w < IdxNumWords; ++w)
{
EXPECT_EQ(h_idx.words[w], expected_idx.words[w])
<< "Index mask mismatch (word " << w << ")";
}
for(int i = 0; i < NUM / RATIO; ++i)
{
EXPECT_EQ(h_out[i], expected_output[i]) << "Output mismatch at position " << i;
@@ -296,10 +251,11 @@ void sparse_transform_verify(const std::vector<Type>& input,
// initialization values (from nonzero_elems init) that don't correspond to the
// default index (slot 2). We only validate entries where the index was explicitly
// set, i.e. where input[slot] is non-zero.
constexpr int CompressedSize = NUM / RATIO;
for(int i = 0; i < CompressedSize; ++i)
{
int slot = (h_idx >> (2 * i)) & 0b11;
const int word = (2 * i) / 32;
const int shift = (2 * i) % 32;
int slot = (h_idx.words[word] >> shift) & 0b11;
int group = i / 2;
Type input_at_slot = input[group * 4 + slot];
// Only check when input at the indexed slot is non-zero (explicitly assigned)
@@ -319,20 +275,36 @@ void sparse_transform_verify(const std::vector<Type>& input,
// Helper: build expected index from a per-group 4-bit pattern, repeated for all groups.
// Each group of 4 input elements contributes 2 compressed elements -> 2 x 2-bit index fields = 4
// bits.
static int32_t build_repeated_group_idx(int num_groups, int32_t group_bits_4)
template <int NumGroups>
static auto build_repeated_group_idx(int32_t group_bits_4)
{
int32_t idx = 0;
for(int g = 0; g < num_groups; ++g)
idx |= (group_bits_4 << (4 * g));
constexpr int CompressedSize = NumGroups * 2;
constexpr int NumWords = sparse::detail::idx_words_needed<CompressedSize>;
sparse::detail::SparseIdxPack<NumWords> idx{};
for(int g = 0; g < NumGroups; ++g)
{
const int bit_pos = g * 4;
const int word = bit_pos / 32;
const int shift = bit_pos % 32;
idx.words[word] |= (group_bits_4 << shift);
}
return idx;
}
// Helper: build expected index from alternating even/odd 4-bit group patterns.
static int32_t build_alternating_group_idx(int num_groups, int32_t even_bits_4, int32_t odd_bits_4)
template <int NumGroups>
static auto build_alternating_group_idx(int32_t even_bits_4, int32_t odd_bits_4)
{
int32_t idx = 0;
for(int g = 0; g < num_groups; ++g)
idx |= ((g % 2 == 0 ? even_bits_4 : odd_bits_4) << (4 * g));
constexpr int CompressedSize = NumGroups * 2;
constexpr int NumWords = sparse::detail::idx_words_needed<CompressedSize>;
sparse::detail::SparseIdxPack<NumWords> idx{};
for(int g = 0; g < NumGroups; ++g)
{
const int bit_pos = g * 4;
const int word = bit_pos / 32;
const int shift = bit_pos % 32;
idx.words[word] |= ((g % 2 == 0 ? even_bits_4 : odd_bits_4) << shift);
}
return idx;
}
@@ -354,7 +326,7 @@ void sparse_transform_test_case()
expected_out[i] = v[i * 2];
}
int32_t expected_idx = build_repeated_group_idx(NUM / 4, 0b1000);
auto expected_idx = build_repeated_group_idx<NUM / 4>(0b1000);
sparse_transform_verify<NUM, RATIO, Type>(v, expected_out, expected_idx);
}
@@ -365,6 +337,7 @@ TEST(SparseTransformsTest, ValidCompressionRatio)
sparse_transform_test_case<8, 2, fp16_t>();
sparse_transform_test_case<16, 2, fp16_t>();
sparse_transform_test_case<32, 2, fp16_t>();
sparse_transform_test_case<64, 2, fp16_t>(); // multi-word SparseIdxPack
}
// All-zero input: no non-zeros in any group of 4.
@@ -377,7 +350,7 @@ void sparse_transform_all_zero()
using T = fp16_t;
std::vector<T> input(NUM, static_cast<T>(0));
std::vector<T> expected_output(NUM / 2, static_cast<T>(0));
int32_t expected_idx = build_repeated_group_idx(NUM / 4, 0b1010);
auto expected_idx = build_repeated_group_idx<NUM / 4>(0b1010);
sparse_transform_verify<NUM, 2, T>(input, expected_output, expected_idx);
}
@@ -386,6 +359,7 @@ TEST(SparseTransformsTest, AllZeroInput)
sparse_transform_all_zero<8>();
sparse_transform_all_zero<16>();
sparse_transform_all_zero<32>();
sparse_transform_all_zero<64>(); // multi-word SparseIdxPack
}
// Single non-zero per group of 4 (at slot 3).
@@ -408,7 +382,7 @@ void sparse_transform_single_nonzero()
expected_output[g * 2 + 1] = val;
}
int32_t expected_idx = build_repeated_group_idx(NUM / 4, 0b1011);
auto expected_idx = build_repeated_group_idx<NUM / 4>(0b1011);
sparse_transform_verify<NUM, 2, T>(input, expected_output, expected_idx);
}
@@ -417,6 +391,7 @@ TEST(SparseTransformsTest, SingleNonZeroPerGroup)
sparse_transform_single_nonzero<8>();
sparse_transform_single_nonzero<16>();
sparse_transform_single_nonzero<32>();
sparse_transform_single_nonzero<64>(); // multi-word SparseIdxPack
}
// Non-zeros at slots 1 and 3 in each group.
@@ -439,7 +414,7 @@ void sparse_transform_slots_1_and_3()
expected_output[g * 2 + 1] = b;
}
int32_t expected_idx = build_repeated_group_idx(NUM / 4, 0b1101);
auto expected_idx = build_repeated_group_idx<NUM / 4>(0b1101);
sparse_transform_verify<NUM, 2, T>(input, expected_output, expected_idx);
}
@@ -448,6 +423,7 @@ TEST(SparseTransformsTest, NonZerosAtSlots1And3)
sparse_transform_slots_1_and_3<8>();
sparse_transform_slots_1_and_3<16>();
sparse_transform_slots_1_and_3<32>();
sparse_transform_slots_1_and_3<64>(); // multi-word SparseIdxPack
}
// Non-zeros at slots 0 and 3 in each group (non-adjacent).
@@ -470,7 +446,7 @@ void sparse_transform_slots_0_and_3()
expected_output[g * 2 + 1] = b;
}
int32_t expected_idx = build_repeated_group_idx(NUM / 4, 0b1100);
auto expected_idx = build_repeated_group_idx<NUM / 4>(0b1100);
sparse_transform_verify<NUM, 2, T>(input, expected_output, expected_idx);
}
@@ -479,6 +455,7 @@ TEST(SparseTransformsTest, NonZerosAtSlots0And3)
sparse_transform_slots_0_and_3<8>();
sparse_transform_slots_0_and_3<16>();
sparse_transform_slots_0_and_3<32>();
sparse_transform_slots_0_and_3<64>(); // multi-word SparseIdxPack
}
// Mixed sparsity pattern: even groups have non-zeros at slots 0,2; odd groups at slots 1,3.
@@ -511,7 +488,7 @@ void sparse_transform_mixed()
expected_output[g * 2 + 1] = b;
}
int32_t expected_idx = build_alternating_group_idx(NUM / 4, 0b1000, 0b1101);
auto expected_idx = build_alternating_group_idx<NUM / 4>(0b1000, 0b1101);
sparse_transform_verify<NUM, 2, T>(input, expected_output, expected_idx);
}
@@ -520,4 +497,156 @@ TEST(SparseTransformsTest, MixedSparsityPattern)
sparse_transform_mixed<8>();
sparse_transform_mixed<16>();
sparse_transform_mixed<32>();
sparse_transform_mixed<64>(); // multi-word SparseIdxPack
}
template <typename AType,
typename BType,
typename CType,
uint32_t WaveTileM,
uint32_t WaveTileN,
uint32_t WaveTileK,
MmaAccumPolicy AccumPolicy>
struct SparsePipelineKernel
{
static constexpr int kBlockSize = mma_pipeline_test::getCMakeWaveSize();
__device__ void
operator()(const void* a_per_lane, const void* b_per_lane, void* c_per_lane) const
{
using CompilerTarget = decltype(get_compiler_target());
using Pipeline = SparseMmaPipeline<AType,
BType,
CType,
WaveTileM,
WaveTileN,
WaveTileK,
AccumPolicy,
CompilerTarget>;
using AVecType = typename Pipeline::AVecType;
using BVecType = typename Pipeline::BVecType;
using CVecType = typename Pipeline::CVecType;
const uint32_t lane = threadIdx.x;
AVecType a;
BVecType b;
CVecType c;
__builtin_memcpy(&a,
static_cast<const uint8_t*>(a_per_lane) + lane * sizeof(AVecType),
sizeof(AVecType));
__builtin_memcpy(&b,
static_cast<const uint8_t*>(b_per_lane) + lane * sizeof(BVecType),
sizeof(BVecType));
__builtin_memset(&c, 0, sizeof(CVecType));
if constexpr(MmaOpTraits<typename Pipeline::MmaOp>::IsSupported)
{
Pipeline::exec(a, b, c);
__builtin_memcpy(
static_cast<uint8_t*>(c_per_lane) + lane * sizeof(CVecType), &c, sizeof(CVecType));
}
}
};
namespace {
const auto should_skip = [](amdgcn_target_id currentArchId) {
bool isSupportedWmma = (currentArchId >= amdgcn_target_id::GFX1200) &&
(currentArchId <= amdgcn_target_id::GFX12_GENERIC);
bool isSupportedMfma =
(currentArchId >= amdgcn_target_id::GFX942) && (currentArchId <= amdgcn_target_id::GFX950);
return ((currentArchId == amdgcn_target_id::HOST) || !(isSupportedWmma || isSupportedMfma));
};
} // namespace
template <typename AType,
typename BType,
typename CType,
uint32_t WaveTileM,
uint32_t WaveTileN,
uint32_t WaveTileK,
MmaAccumPolicy AccumPolicy>
struct SparsePipelineFactory
{
template <typename Target>
struct Create
{
using type = SparseMmaPipeline<AType,
BType,
CType,
WaveTileM,
WaveTileN,
WaveTileK,
AccumPolicy,
Target>;
};
};
template <typename AType,
typename BType,
typename CType,
uint32_t WaveTileM,
uint32_t WaveTileN,
uint32_t WaveTileK,
MmaAccumPolicy AccumPolicy = MmaAccumPolicy::ROW_MAJOR>
void SparsePipeline_Real_impl()
{
using Factory =
SparsePipelineFactory<AType, BType, CType, WaveTileM, WaveTileN, WaveTileK, AccumPolicy>;
using Kernel =
SparsePipelineKernel<AType, BType, CType, WaveTileM, WaveTileN, WaveTileK, AccumPolicy>;
mma_pipeline_test::
run_pipeline_matrix_test<Factory::template Create, Kernel, AType, BType, CType>(
WaveTileM, WaveTileN, WaveTileK, should_skip, Kernel{}, /*isSparse=*/true);
}
// Full matrix verification: 16x16x32 single-fragment sparse pipeline (ROW_MAJOR)
TEST(SparseMmaPipeline, FullMatrixVerify_16x16x32)
{
SparsePipeline_Real_impl<fp16_t, fp16_t, fp32_t, 16u, 16u, 32u>();
}
// Multi-fragment K: 16x16x64 -> 2 K fragments, tests internal K iteration (ROW_MAJOR)
TEST(SparseMmaPipeline, FullMatrixVerify_16x16x64)
{
SparsePipeline_Real_impl<fp16_t, fp16_t, fp32_t, 16u, 16u, 64u>();
}
// Full matrix verification: 16x16x32 single-fragment sparse pipeline (COL_MAJOR)
TEST(SparseMmaPipeline, FullMatrixVerify_16x16x32_ColMajor)
{
SparsePipeline_Real_impl<fp16_t, fp16_t, fp32_t, 16u, 16u, 32u, MmaAccumPolicy::COL_MAJOR>();
}
// Multi-fragment K: 16x16x64 -> 2 K fragments, tests internal K iteration (COL_MAJOR)
TEST(SparseMmaPipeline, FullMatrixVerify_16x16x64_ColMajor)
{
SparsePipeline_Real_impl<fp16_t, fp16_t, fp32_t, 16u, 16u, 64u, MmaAccumPolicy::COL_MAJOR>();
}
// Multi-fragment K: 16x16x128 -> 4 K fragments, exercises multi-word SparseIdxPack (ROW_MAJOR)
TEST(SparseMmaPipeline, FullMatrixVerify_16x16x128)
{
SparsePipeline_Real_impl<fp16_t, fp16_t, fp32_t, 16u, 16u, 128u>();
}
// Multi-fragment K: 16x16x256 -> 8 K fragments, exercises larger multi-word SparseIdxPack
// (ROW_MAJOR)
TEST(SparseMmaPipeline, FullMatrixVerify_16x16x256)
{
SparsePipeline_Real_impl<fp16_t, fp16_t, fp32_t, 16u, 16u, 256u>();
}
// Multi-fragment K: 16x16x128 -> 4 K fragments (COL_MAJOR)
TEST(SparseMmaPipeline, FullMatrixVerify_16x16x128_ColMajor)
{
SparsePipeline_Real_impl<fp16_t, fp16_t, fp32_t, 16u, 16u, 128u, MmaAccumPolicy::COL_MAJOR>();
}
// Multi-fragment K: 16x16x256 -> 8 K fragments (COL_MAJOR)
TEST(SparseMmaPipeline, FullMatrixVerify_16x16x256_ColMajor)
{
SparsePipeline_Real_impl<fp16_t, fp16_t, fp32_t, 16u, 16u, 256u, MmaAccumPolicy::COL_MAJOR>();
}

View File

@@ -3,52 +3,70 @@
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/arch/mma/mma_op_family.hpp"
#include "ck_tile/core/arch/mma/mma_traits.hpp"
#include "ck_tile/core/arch/mma/mma_wavewise.hpp"
#include "ck_tile/core/arch/mma/mma.hpp"
#include "pipeline_tests_helper.hpp"
#include <memory>
using namespace ck_tile;
using namespace ck_tile::core::arch;
using namespace ck_tile::core::arch::mma;
// Kernel functor: constructs Pipeline internally using device-side get_compiler_target().
// Uses void* for data to avoid host/device symbol mismatches.
template <typename AType,
typename BType,
typename CType,
uint32_t WaveTileM,
uint32_t WaveTileN,
uint32_t WaveTileK,
bool CTranspose>
__global__ void test_wavewise_pipeline(void* a, void* b, void* c, void* out)
MmaAccumPolicy AccumPolicy,
bool TransposeC>
struct WaveWisePipelineKernel
{
using CompilerTarget = decltype(get_compiler_target());
static constexpr int kBlockSize = mma_pipeline_test::getCMakeWaveSize();
using Pipeline = WaveWiseMmaPipeline<AType,
BType,
CType,
WaveTileM,
WaveTileN,
WaveTileK,
MmaOpFamily::DENSE,
MmaAccumPolicy::ROW_MAJOR,
CTranspose,
CompilerTarget>;
using AVecType = typename Pipeline::AVecType;
using BVecType = typename Pipeline::BVecType;
using CVecType = typename Pipeline::CVecType;
auto result = Pipeline::exec(*reinterpret_cast<AVecType*>(a),
*reinterpret_cast<BVecType*>(b),
*reinterpret_cast<CVecType*>(c));
if constexpr(MmaOpTraits<typename Pipeline::MmaOp>::IsSupported)
__device__ void
operator()(const void* a_per_lane, const void* b_per_lane, void* c_per_lane) const
{
// When the MmaOp is Unsupported (default) it returns the CVecType by value
// so this cast is impossible...
__builtin_memcpy(out, static_cast<const void*>(result), sizeof(CVecType));
using CompilerTarget = decltype(get_compiler_target());
using Pipeline = WaveWiseMmaPipeline<AType,
BType,
CType,
WaveTileM,
WaveTileN,
WaveTileK,
MmaOpFamily::DENSE,
AccumPolicy,
TransposeC,
CompilerTarget>;
using AVecType = typename Pipeline::AVecType;
using BVecType = typename Pipeline::BVecType;
using CVecType = typename Pipeline::CVecType;
const uint32_t lane = threadIdx.x;
AVecType a;
BVecType b;
CVecType c;
__builtin_memcpy(&a,
static_cast<const uint8_t*>(a_per_lane) + lane * sizeof(AVecType),
sizeof(AVecType));
__builtin_memcpy(&b,
static_cast<const uint8_t*>(b_per_lane) + lane * sizeof(BVecType),
sizeof(BVecType));
__builtin_memset(&c, 0, sizeof(CVecType));
if constexpr(MmaOpTraits<typename Pipeline::MmaOp>::IsSupported)
{
Pipeline::exec(a, b, c);
__builtin_memcpy(
static_cast<uint8_t*>(c_per_lane) + lane * sizeof(CVecType), &c, sizeof(CVecType));
}
}
}
};
namespace {
const auto should_skip = [](amdgcn_target_id currentArchId) {
@@ -57,37 +75,95 @@ const auto should_skip = [](amdgcn_target_id currentArchId) {
(currentArchId >= amdgcn_target_id::GFX942) && (currentArchId <= amdgcn_target_id::GFX950);
return ((currentArchId == amdgcn_target_id::HOST) || !(isSupportedWmma || isSupportedMfma));
};
const std::function<fp32_t(uint32_t)> validator = [](uint32_t waveTileK) {
return static_cast<fp32_t>(waveTileK);
};
} // namespace
TEST(WaveWiseMmaPipeline, testKIter)
template <typename AType,
typename BType,
typename CType,
uint32_t WaveTileM,
uint32_t WaveTileN,
uint32_t WaveTileK,
MmaAccumPolicy AccumPolicy>
struct WaveWisePipelineFactory
{
MmaPipelineTest<> test;
const auto kernel = [](uint32_t waveSize, void* a, void* b, void* c, void* out) {
test_wavewise_pipeline<MmaPipelineTest<>::AType,
MmaPipelineTest<>::BType,
MmaPipelineTest<>::CType,
MmaPipelineTest<>::WaveTileM,
MmaPipelineTest<>::WaveTileN,
MmaPipelineTest<>::WaveTileK,
false><<<1, waveSize>>>(a, b, c, out);
template <typename Target>
struct Create
{
using type = WaveWiseMmaPipeline<AType,
BType,
CType,
WaveTileM,
WaveTileN,
WaveTileK,
MmaOpFamily::DENSE,
AccumPolicy,
false,
Target>;
};
test.test_pipeline(should_skip, kernel, validator);
};
template <typename AType,
typename BType,
typename CType,
uint32_t WaveTileM,
uint32_t WaveTileN,
uint32_t WaveTileK,
MmaAccumPolicy AccumPolicy = MmaAccumPolicy::ROW_MAJOR,
bool TransposeC = false>
void WaveWisePipeline_Real_impl()
{
using Factory =
WaveWisePipelineFactory<AType, BType, CType, WaveTileM, WaveTileN, WaveTileK, AccumPolicy>;
using Kernel = WaveWisePipelineKernel<AType,
BType,
CType,
WaveTileM,
WaveTileN,
WaveTileK,
AccumPolicy,
TransposeC>;
mma_pipeline_test::
run_pipeline_matrix_test<Factory::template Create, Kernel, AType, BType, CType>(
WaveTileM,
WaveTileN,
WaveTileK,
should_skip,
Kernel{},
/*isSparse=*/false,
/*transposeExpected=*/TransposeC);
}
TEST(WaveWiseMmaPipeline, testKIterSwapAB)
TEST(WaveWiseMmaPipeline, FullMatrixVerify_16x16x32)
{
MmaPipelineTest<> test;
const auto kernel = [](uint32_t waveSize, void* a, void* b, void* c, void* out) {
test_wavewise_pipeline<MmaPipelineTest<>::AType,
MmaPipelineTest<>::BType,
MmaPipelineTest<>::CType,
MmaPipelineTest<>::WaveTileM,
MmaPipelineTest<>::WaveTileN,
MmaPipelineTest<>::WaveTileK,
true><<<1, waveSize>>>(a, b, c, out);
};
test.test_pipeline(should_skip, kernel, validator);
WaveWisePipeline_Real_impl<fp16_t, fp16_t, fp32_t, 16u, 16u, 32u>();
}
TEST(WaveWiseMmaPipeline, FullMatrixVerify_16x16x32_SwapAB)
{
WaveWisePipeline_Real_impl<fp16_t,
fp16_t,
fp32_t,
16u,
16u,
32u,
MmaAccumPolicy::ROW_MAJOR,
true>();
}
TEST(WaveWiseMmaPipeline, FullMatrixVerify_16x16x32_ColMajor)
{
WaveWisePipeline_Real_impl<fp16_t, fp16_t, fp32_t, 16u, 16u, 32u, MmaAccumPolicy::COL_MAJOR>();
}
TEST(WaveWiseMmaPipeline, FullMatrixVerify_16x16x32_ColMajor_TransposeC)
{
WaveWisePipeline_Real_impl<fp16_t,
fp16_t,
fp32_t,
16u,
16u,
32u,
MmaAccumPolicy::COL_MAJOR,
true>();
}

View File

@@ -1,8 +1,6 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "get_wave_size_helper.hpp"
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/arch/mma/amdgcn_mma.hpp"
#include "ck_tile/core/arch/mma/mma_op_family.hpp"
@@ -12,6 +10,8 @@
#include "ck_tile/core/numeric/vector_type.hpp"
#include "ck_tile/host/hip_check_error.hpp"
#include "get_cmake_targets_helper.hpp"
#include <gtest/gtest.h>
#include <hip/hip_runtime.h>
@@ -21,6 +21,7 @@
using namespace ck_tile;
using namespace ck_tile::core::arch;
using namespace ck_tile::core::arch::mma;
using namespace ck_tile::core::arch::testing;
// Dummy values for testing
constexpr uint32_t DummyTargetIdVal = 55555u;
@@ -484,7 +485,7 @@ TEST(TestAmdgcnMma, MmaSelector_F16_F16_F32_16x16x32_Real)
HIP_CHECK_ERROR(hipMemcpy(d_b, h_b.data(), BSize, hipMemcpyHostToDevice));
HIP_CHECK_ERROR(hipMemcpy(d_c, h_c.data(), CSize, hipMemcpyHostToDevice));
const auto wave_size = getDeviceWaveSize();
const auto wave_size = getCMakeWaveSize();
test_accum_over_k<AType, BType, CType, WaveTileM, WaveTileN, WaveTileK>
<<<1, wave_size>>>(d_a, d_b, d_c, d_out);
HIP_CHECK_ERROR(hipDeviceSynchronize());
@@ -585,7 +586,7 @@ TEST(TestAmdgcnMma, MmaSelector_F16_F16_F32_112x112x128_Real)
HIP_CHECK_ERROR(hipMemcpy(d_b, h_b.data(), BSize, hipMemcpyHostToDevice));
HIP_CHECK_ERROR(hipMemcpy(d_c, h_c.data(), CSize, hipMemcpyHostToDevice));
const auto wave_size = getDeviceWaveSize();
const auto wave_size = getCMakeWaveSize();
test_accum_over_k<AType, BType, CType, WaveTileM, WaveTileN, WaveTileK>
<<<1, wave_size>>>(d_a, d_b, d_c, d_out);
HIP_CHECK_ERROR(hipDeviceSynchronize());