mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[rocm-libraries] ROCm/rocm-libraries#4837 (commit 6316035)
[CK TILE] Unification of sparse MFMA/WMMA policy structs (#4837) ## Motivation The existing unification work supports DENSE intrinsics. In this PR we enable support for SPARSE as well as SCALE intrinsics and add an example SPARSE implementation. ## Technical Details Mostly trivial changes. One framework change is that the desired `MmaOpFamily` is passed to the `MmaDefaultSelector`. As my relevant commit explains, we do not support a fallback family at the moment, but it is something we can consider. ## Test Plan Added a new test for the relevant sparse specializations. ## Test Result Test should pass. ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
6e558658ea
commit
03ce21ddcb
@@ -7,6 +7,10 @@ if(CK_USE_OCP_FP8)
|
||||
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
|
||||
endif()
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx9|gfx12")
|
||||
add_gtest_executable(test_amdgcn_sparse_mma test_amdgcn_sparse_mma.cpp)
|
||||
target_compile_options(test_amdgcn_sparse_mma PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
endif()
|
||||
if(GPU_TARGETS MATCHES "gfx9")
|
||||
add_gtest_executable(test_amdgcn_mma test_amdgcn_mma.cpp)
|
||||
target_compile_options(test_amdgcn_mma PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
34
test/ck_tile/core/arch/mma/get_wave_size_helper.hpp
Normal file
34
test/ck_tile/core/arch/mma/get_wave_size_helper.hpp
Normal file
@@ -0,0 +1,34 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdint>
|
||||
#include <cstdio>
|
||||
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include <hip/hip_runtime.h>
|
||||
#include "ck_tile/host/hip_check_error.hpp"
|
||||
|
||||
namespace {
|
||||
|
||||
__global__ void getWaveSizeForSelectedOp(uint32_t* waveSize)
|
||||
{
|
||||
using CompilerTarget = decltype(ck_tile::core::arch::get_compiler_target());
|
||||
|
||||
if(waveSize)
|
||||
*waveSize = static_cast<uint32_t>(CompilerTarget::WAVE_SIZE_ID);
|
||||
}
|
||||
|
||||
static __host__ uint32_t getDeviceWaveSize()
|
||||
{
|
||||
uint32_t* d_wave_size;
|
||||
HIP_CHECK_ERROR(hipMalloc(&d_wave_size, sizeof(uint32_t)));
|
||||
getWaveSizeForSelectedOp<<<1, 64>>>(d_wave_size);
|
||||
HIP_CHECK_ERROR(hipDeviceSynchronize());
|
||||
uint32_t wave_size;
|
||||
HIP_CHECK_ERROR(hipMemcpy(&wave_size, d_wave_size, sizeof(uint32_t), hipMemcpyDeviceToHost));
|
||||
return wave_size;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
@@ -11,6 +11,8 @@
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
#include "ck_tile/host/hip_check_error.hpp"
|
||||
|
||||
#include "get_wave_size_helper.hpp"
|
||||
|
||||
using namespace ck_tile;
|
||||
using namespace ck_tile::core::arch;
|
||||
using namespace ck_tile::core::arch::mma;
|
||||
@@ -47,10 +49,12 @@ struct amdgcn_mma<fp32_t,
|
||||
16u,
|
||||
DummyCtrlFlags,
|
||||
CompilerTarget,
|
||||
MmaOpFamily::DENSE,
|
||||
enable_if_target_id_dummy_t<CompilerTarget>>
|
||||
{
|
||||
// Mfma operation type
|
||||
using OpType = DummyOpType;
|
||||
using OpType = DummyOpType;
|
||||
static constexpr MmaOpFamily OpFamily = MmaOpFamily::DENSE;
|
||||
|
||||
// Register types
|
||||
using AVecType = ext_vector_t<fp32_t, 4>;
|
||||
@@ -81,8 +85,15 @@ struct amdgcn_mma<fp32_t,
|
||||
// Have an alias so we can test supported arch vs unsupported arch
|
||||
// TODO: c++20 template <amdgcn_target_arch_id CompilerTarget>
|
||||
template <typename CompilerTarget>
|
||||
using DummyAmdgcnMma =
|
||||
amdgcn_mma<fp32_t, fp32_t, fp32_t, 16u, 16u, 16u, DummyCtrlFlags, CompilerTarget>;
|
||||
using DummyAmdgcnMma = amdgcn_mma<fp32_t,
|
||||
fp32_t,
|
||||
fp32_t,
|
||||
16u,
|
||||
16u,
|
||||
16u,
|
||||
DummyCtrlFlags,
|
||||
CompilerTarget,
|
||||
MmaOpFamily::DENSE>;
|
||||
|
||||
/*! @struct MmaDefaultSelector
|
||||
* @brief For dummy Id only, instantiate tests for both MFMA and WMMA selectors so we can them both
|
||||
@@ -93,6 +104,7 @@ using DummyAmdgcnMma =
|
||||
* @tparam FragN Size of the N dimension of the fragment to decompose
|
||||
* @tparam FragK Size of the K dimension of the fragment to decompose
|
||||
* @tparam CompilerTarget The compiler target
|
||||
* @tparam OpFamily The MMA operation family
|
||||
*/
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
@@ -100,7 +112,8 @@ template <typename ADataType,
|
||||
uint32_t FragM,
|
||||
uint32_t FragN,
|
||||
uint32_t FragK,
|
||||
typename CompilerTarget>
|
||||
typename CompilerTarget,
|
||||
MmaOpFamily OpFamily>
|
||||
// TODO: c++20 amdgcn_target_arch_id CompilerTarget>
|
||||
// TODO: requires
|
||||
struct MmaDefaultSelector<ADataType,
|
||||
@@ -110,7 +123,9 @@ struct MmaDefaultSelector<ADataType,
|
||||
FragN,
|
||||
FragK,
|
||||
CompilerTarget,
|
||||
enable_if_target_id_dummy_t<CompilerTarget>>
|
||||
OpFamily,
|
||||
enable_if_all<enable_if_target_id_dummy_t<CompilerTarget>,
|
||||
std::enable_if_t<OpFamily == MmaOpFamily::DENSE>>>
|
||||
{
|
||||
using SelectedOp = DummyAmdgcnMma<CompilerTarget>;
|
||||
};
|
||||
@@ -128,6 +143,8 @@ TEST(TestAmdgcnMma, ArchSupported)
|
||||
// Check OpType
|
||||
EXPECT_TRUE(
|
||||
(std::is_same<typename MmaOp::OpType, DummyOpType>::value)); // OpType is DummyOpType
|
||||
// Check OpFamily
|
||||
EXPECT_TRUE((is_mma_op_of_family_v<MmaOpFamily::DENSE, MmaOp>));
|
||||
|
||||
// Check AVecType, BVecType, CVecType
|
||||
EXPECT_TRUE((std::is_same<typename MmaOp::AVecType, ext_vector_t<fp32_t, 4>>::value));
|
||||
@@ -157,6 +174,8 @@ TEST(TestAmdgcnMma, ArchUnsupported)
|
||||
|
||||
// OpType should be Unsupported
|
||||
EXPECT_TRUE((std::is_same<typename MmaOp::OpType, Unsupported>::value));
|
||||
// OpFamily should be Undefined
|
||||
EXPECT_TRUE((is_mma_op_of_family_v<MmaOpFamily::UNDEFINED, MmaOp>));
|
||||
|
||||
// AVecType, BVecType, CVecType should match default
|
||||
EXPECT_TRUE((std::is_same<typename MmaOp::AVecType, ext_vector_t<fp32_t, 1>>::value));
|
||||
@@ -367,6 +386,7 @@ TEST(TestAmdgcnMma, MmaOpTraitsUnsupportedMembers)
|
||||
EXPECT_TRUE((std::is_same<typename Traits::AVecType, ext_vector_t<fp32_t, 1>>::value));
|
||||
EXPECT_TRUE((std::is_same<typename Traits::BVecType, ext_vector_t<fp32_t, 1>>::value));
|
||||
EXPECT_TRUE((std::is_same<typename Traits::CVecType, ext_vector_t<fp32_t, 1>>::value));
|
||||
EXPECT_EQ(Traits::OpFamily, MmaOpFamily::UNDEFINED);
|
||||
EXPECT_EQ(Traits::kAMBlock, 0);
|
||||
EXPECT_EQ(Traits::kBNBlock, 0);
|
||||
EXPECT_EQ(Traits::kAMLane, 0);
|
||||
@@ -386,9 +406,14 @@ TEST(TestAmdgcnMma, MmaOpTraitsUnsupportedMembers)
|
||||
TEST(TestAmdgcnMma, MmaDefaultSelectorSupported)
|
||||
{
|
||||
// Direct selection of the supported dummy instruction
|
||||
using SelectedMma =
|
||||
typename MmaDefaultSelector<fp32_t, fp32_t, fp32_t, 16u, 16u, 16u, DummyCompilerTarget>::
|
||||
SelectedOp;
|
||||
using SelectedMma = typename MmaDefaultSelector<fp32_t,
|
||||
fp32_t,
|
||||
fp32_t,
|
||||
16u,
|
||||
16u,
|
||||
16u,
|
||||
DummyCompilerTarget,
|
||||
MmaOpFamily::DENSE>::SelectedOp;
|
||||
// Should select DummyAmdgcnMma specialization
|
||||
EXPECT_TRUE((std::is_same<SelectedMma, DummyAmdgcnMma<DummyCompilerTarget>>::value));
|
||||
// OpType should be DummyOpType
|
||||
@@ -401,8 +426,14 @@ TEST(TestAmdgcnMma, MmaDefaultSelectorSupported)
|
||||
TEST(TestAmdgcnMma, MmaDefaultSelectorUnsupported)
|
||||
{
|
||||
// Direct selection of the unsupported dummy instruction
|
||||
using SelectedMma =
|
||||
MmaDefaultSelector<fp32_t, fp32_t, fp32_t, 16u, 16u, 16u, amdgcn_target<>>::SelectedOp;
|
||||
using SelectedMma = MmaDefaultSelector<fp32_t,
|
||||
fp32_t,
|
||||
fp32_t,
|
||||
16u,
|
||||
16u,
|
||||
16u,
|
||||
amdgcn_target<>,
|
||||
MmaOpFamily::UNDEFINED>::SelectedOp;
|
||||
// OpType should be Unsupported
|
||||
EXPECT_TRUE((std::is_same<typename SelectedMma::OpType, Unsupported>::value));
|
||||
// IsSupported should be false
|
||||
@@ -414,9 +445,14 @@ TEST(TestAmdgcnMma, MmaDefaultSelectorUnsupported)
|
||||
TEST(TestAmdgcnMma, MmaDefaultSelectorSupportedFragment)
|
||||
{
|
||||
// Select indirectly with a fragment size of 256x128x64
|
||||
using SelectedMma =
|
||||
MmaDefaultSelector<fp32_t, fp32_t, fp32_t, 256u, 128u, 64u, DummyCompilerTarget>::
|
||||
SelectedOp;
|
||||
using SelectedMma = MmaDefaultSelector<fp32_t,
|
||||
fp32_t,
|
||||
fp32_t,
|
||||
256u,
|
||||
128u,
|
||||
64u,
|
||||
DummyCompilerTarget,
|
||||
MmaOpFamily::DENSE>::SelectedOp;
|
||||
// Should select DummyAmdgcnMma specialization
|
||||
EXPECT_TRUE((std::is_same<SelectedMma, DummyAmdgcnMma<DummyCompilerTarget>>::value));
|
||||
// OpType should be DummyOpType
|
||||
@@ -429,8 +465,14 @@ TEST(TestAmdgcnMma, MmaDefaultSelectorSupportedFragment)
|
||||
TEST(TestAmdgcnMma, MmaDefaultSelectorUnsupportedFragment)
|
||||
{
|
||||
// This should fall back to unsupported since DummyAmdgcnMma only supports 16x16x16
|
||||
using SelectedMma =
|
||||
MmaDefaultSelector<fp32_t, fp32_t, fp32_t, 8u, 8u, 8u, DummyCompilerTarget>::SelectedOp;
|
||||
using SelectedMma = MmaDefaultSelector<fp32_t,
|
||||
fp32_t,
|
||||
fp32_t,
|
||||
8u,
|
||||
8u,
|
||||
8u,
|
||||
DummyCompilerTarget,
|
||||
MmaOpFamily::DENSE>::SelectedOp;
|
||||
EXPECT_FALSE((std::is_same<typename SelectedMma::OpType, Unsupported>::value));
|
||||
EXPECT_TRUE(MmaOpTraits<SelectedMma>::IsSupported);
|
||||
}
|
||||
@@ -438,8 +480,14 @@ TEST(TestAmdgcnMma, MmaDefaultSelectorUnsupportedFragment)
|
||||
// Test MmaDefaultSelector for a different data type (fp16_t) and unsupported arch
|
||||
TEST(TestAmdgcnMma, MmaDefaultSelectorFp16Unsupported)
|
||||
{
|
||||
using SelectedMma =
|
||||
MmaDefaultSelector<fp16_t, fp16_t, fp16_t, 16u, 16u, 16u, amdgcn_target<>>::SelectedOp;
|
||||
using SelectedMma = MmaDefaultSelector<fp16_t,
|
||||
fp16_t,
|
||||
fp16_t,
|
||||
16u,
|
||||
16u,
|
||||
16u,
|
||||
amdgcn_target<>,
|
||||
MmaOpFamily::UNDEFINED>::SelectedOp;
|
||||
// Should select default amdgcn_mma (Unsupported)
|
||||
EXPECT_TRUE((std::is_same<typename SelectedMma::OpType, Unsupported>::value));
|
||||
EXPECT_FALSE(MmaOpTraits<SelectedMma>::IsSupported);
|
||||
@@ -464,7 +512,8 @@ __global__ void test_accum_over_k(void* a, void* b, void* c, void* out)
|
||||
FragM,
|
||||
FragN,
|
||||
FragK,
|
||||
decltype(get_compiler_target())>;
|
||||
decltype(get_compiler_target()),
|
||||
MmaOpFamily::DENSE>;
|
||||
|
||||
using MmaOp = typename Selector::SelectedOp;
|
||||
using MmaTraits = MmaOpTraits<MmaOp>;
|
||||
@@ -561,8 +610,9 @@ TEST(TestAmdgcnMma, MmaSelector_F16_F16_F32_16x16x32_Real)
|
||||
HIP_CHECK_ERROR(hipMemcpy(d_b, h_b.data(), BSize, hipMemcpyHostToDevice));
|
||||
HIP_CHECK_ERROR(hipMemcpy(d_c, h_c.data(), CSize, hipMemcpyHostToDevice));
|
||||
|
||||
// Need at least 1 WG with 64 threads to get defined MFMA/WMMA behaviour
|
||||
test_accum_over_k<AType, BType, CType, FragM, FragN, FragK><<<1, 64>>>(d_a, d_b, d_c, d_out);
|
||||
const auto wave_size = getDeviceWaveSize();
|
||||
test_accum_over_k<AType, BType, CType, FragM, FragN, FragK>
|
||||
<<<1, wave_size>>>(d_a, d_b, d_c, d_out);
|
||||
HIP_CHECK_ERROR(hipDeviceSynchronize());
|
||||
|
||||
HIP_CHECK_ERROR(hipMemcpy(h_out.data(), d_out, CSize, hipMemcpyDeviceToHost));
|
||||
@@ -661,8 +711,9 @@ TEST(TestAmdgcnMma, MmaSelector_F16_F16_F32_112x112x128_Real)
|
||||
HIP_CHECK_ERROR(hipMemcpy(d_b, h_b.data(), BSize, hipMemcpyHostToDevice));
|
||||
HIP_CHECK_ERROR(hipMemcpy(d_c, h_c.data(), CSize, hipMemcpyHostToDevice));
|
||||
|
||||
// Need at least 1 WG with 64 threads to get defined MFMA/WMMA behaviour
|
||||
test_accum_over_k<AType, BType, CType, FragM, FragN, FragK><<<1, 64>>>(d_a, d_b, d_c, d_out);
|
||||
const auto wave_size = getDeviceWaveSize();
|
||||
test_accum_over_k<AType, BType, CType, FragM, FragN, FragK>
|
||||
<<<1, wave_size>>>(d_a, d_b, d_c, d_out);
|
||||
HIP_CHECK_ERROR(hipDeviceSynchronize());
|
||||
|
||||
HIP_CHECK_ERROR(hipMemcpy(h_out.data(), d_out, CSize, hipMemcpyDeviceToHost));
|
||||
|
||||
274
test/ck_tile/core/arch/mma/test_amdgcn_sparse_mma.cpp
Normal file
274
test/ck_tile/core/arch/mma/test_amdgcn_sparse_mma.cpp
Normal file
@@ -0,0 +1,274 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <iostream>
|
||||
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/core/arch/mma/amdgcn_mma.hpp"
|
||||
#include "ck_tile/core/arch/mma/mma_op_family.hpp"
|
||||
#include "ck_tile/core/arch/mma/mma_selector.hpp"
|
||||
#include <hip/hip_runtime.h>
|
||||
#include "ck_tile/host/hip_check_error.hpp"
|
||||
#include "ck_tile/core/arch/mma/mma_traits.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
#include "get_wave_size_helper.hpp"
|
||||
|
||||
using namespace ck_tile;
|
||||
using namespace ck_tile::core::arch;
|
||||
using namespace ck_tile::core::arch::mma;
|
||||
|
||||
using CompilerTargetGfx950 = decltype(make_amdgcn_gfx9_target<amdgcn_target_id::GFX950>());
|
||||
|
||||
TEST(SparseMMATrait, SparseMfmaGfx950Specialization)
|
||||
{
|
||||
// Test fp16 → fp32 sparse MFMA for GFX950 (16x16x32)
|
||||
using TestSparseMfma16x16 = amdgcn_mma<fp16_t,
|
||||
fp16_t,
|
||||
fp32_t,
|
||||
16u,
|
||||
16u,
|
||||
32u,
|
||||
DefaultSparseMfmaCtrlFlags,
|
||||
CompilerTargetGfx950,
|
||||
MmaOpFamily::SPARSE>;
|
||||
|
||||
static_assert(std::is_same_v<typename TestSparseMfma16x16::OpType, MfmaOp> &&
|
||||
TestSparseMfma16x16::OpFamily == MmaOpFamily::SPARSE,
|
||||
"GFX950 sparse 16x16x32 should have SparseMFMAOp type");
|
||||
|
||||
static_assert(is_mma_op_of_family_v<MmaOpFamily::SPARSE, TestSparseMfma16x16>,
|
||||
"GFX950 sparse 16x16x32 should be detected as Sparse");
|
||||
|
||||
std::cout << "GFX950 sparse MFMA specialization is correct" << std::endl;
|
||||
}
|
||||
|
||||
TEST(SparseMMATrait, MmaOpTraitsIntegration)
|
||||
{
|
||||
// Create a sparse MMA op (16x16x32 fp16 specialization)
|
||||
using TestSparseMmma = amdgcn_mma<fp16_t,
|
||||
fp16_t,
|
||||
fp32_t,
|
||||
16u,
|
||||
16u,
|
||||
32u,
|
||||
DefaultSparseMfmaCtrlFlags,
|
||||
CompilerTargetGfx950,
|
||||
MmaOpFamily::SPARSE>;
|
||||
|
||||
// Get its traits
|
||||
using TestTraits = MmaOpTraits<TestSparseMmma>;
|
||||
|
||||
// Verify trait detection
|
||||
static_assert(TestTraits::IsSparse, "Sparse MMA should be detected as sparse");
|
||||
static_assert(TestTraits::IsSupported, "Sparse MMA specialization should be supported");
|
||||
static_assert(TestTraits::IsMfma, "Sparse MFMA should be detected as MFMA");
|
||||
static_assert(!TestTraits::IsWmma, "Sparse MFMA should not be detected as WMMA");
|
||||
|
||||
std::cout << "MmaOpTraits correctly integrates sparse operations" << std::endl;
|
||||
}
|
||||
|
||||
TEST(SparseMMATrait, DenseVsSparseDistinction)
|
||||
{
|
||||
// Dense MFMA from mfma/mfma_gfx9.hpp
|
||||
using DenseMfma = amdgcn_mma<fp16_t,
|
||||
fp16_t,
|
||||
fp32_t,
|
||||
16u,
|
||||
16u,
|
||||
32u,
|
||||
DefaultMfmaCtrlFlags,
|
||||
CompilerTargetGfx950,
|
||||
MmaOpFamily::DENSE>;
|
||||
|
||||
// Sparse MFMA on GFX950
|
||||
using SparseMfma = amdgcn_mma<fp16_t,
|
||||
fp16_t,
|
||||
fp32_t,
|
||||
16u,
|
||||
16u,
|
||||
32u,
|
||||
DefaultSparseMfmaCtrlFlags,
|
||||
CompilerTargetGfx950,
|
||||
MmaOpFamily::SPARSE>;
|
||||
|
||||
// Verify they have different operation types
|
||||
static_assert(std::is_same_v<typename DenseMfma::OpType, typename SparseMfma::OpType> &&
|
||||
DenseMfma::OpFamily != SparseMfma::OpFamily,
|
||||
"Dense and Sparse MFMA should have the same OpType tags and different OpFamily");
|
||||
|
||||
// Verify traits correctly identify them
|
||||
static_assert(MmaOpTraits<DenseMfma>::IsMfma && MmaOpTraits<DenseMfma>::IsDense &&
|
||||
!MmaOpTraits<DenseMfma>::IsSparse && !MmaOpTraits<DenseMfma>::IsScale &&
|
||||
MmaOpTraits<DenseMfma>::IsSupported,
|
||||
"Dense MFMA should be identified correctly");
|
||||
|
||||
static_assert(MmaOpTraits<SparseMfma>::IsSparse && MmaOpTraits<SparseMfma>::IsMfma &&
|
||||
!MmaOpTraits<SparseMfma>::IsDense && !MmaOpTraits<SparseMfma>::IsScale &&
|
||||
MmaOpTraits<SparseMfma>::IsSupported,
|
||||
"Sparse MFMA should be identified correctly");
|
||||
|
||||
std::cout << "Dense and sparse MMA operations are correctly distinguished" << std::endl;
|
||||
}
|
||||
|
||||
TEST(SparseMMATrait, SparseSelector)
|
||||
{
|
||||
static_for<1, 33, 1>{}([](auto i) {
|
||||
using Selected = typename MmaDefaultSelector<fp16_t,
|
||||
fp16_t,
|
||||
fp32_t,
|
||||
static_cast<uint32_t>(i),
|
||||
static_cast<uint32_t>(i),
|
||||
static_cast<uint32_t>(2 * i),
|
||||
CompilerTargetGfx950,
|
||||
MmaOpFamily::SPARSE>::SelectedOp;
|
||||
|
||||
static constexpr bool isValid = (i == 16) || (i == 32);
|
||||
if constexpr(isValid)
|
||||
{
|
||||
// Selector should pick a sparse MFMA implementation
|
||||
static_assert(MmaOpTraits<Selected>::IsSparse);
|
||||
static_assert(MmaOpTraits<Selected>::IsMfma);
|
||||
static_assert(MmaOpTraits<Selected>::IsSupported);
|
||||
static_assert((std::is_same<typename Selected::OpType, MfmaOp>::value));
|
||||
}
|
||||
else
|
||||
{
|
||||
// Selector should pick the unsupported pass through
|
||||
static_assert(!MmaOpTraits<Selected>::IsSupported);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template <typename AType,
|
||||
typename BType,
|
||||
typename CType,
|
||||
uint32_t FragM,
|
||||
uint32_t FragN,
|
||||
uint32_t FragK>
|
||||
__global__ void test_sparse_accum_over_k(void* a, void* b, void* c, void* out)
|
||||
{
|
||||
using CompilerTarget = decltype(get_compiler_target());
|
||||
using Selector = MmaDefaultSelector<AType,
|
||||
BType,
|
||||
CType,
|
||||
FragM,
|
||||
FragN,
|
||||
FragK,
|
||||
CompilerTarget,
|
||||
MmaOpFamily::SPARSE>;
|
||||
|
||||
using MmaOp = typename Selector::SelectedOp;
|
||||
using MmaTraits = MmaOpTraits<MmaOp>;
|
||||
|
||||
using CVecType = typename MmaOp::CVecType;
|
||||
|
||||
static constexpr uint32_t kIters = FragK / MmaTraits::BlockK;
|
||||
|
||||
// Initialize the accumulator
|
||||
CVecType result = *reinterpret_cast<typename MmaOp::CVecType*>(c);
|
||||
|
||||
// Accumulate input AxB over FragK/BlockK iterations
|
||||
for(uint32_t i = 0; i < kIters; ++i)
|
||||
{
|
||||
result = MmaOp::exec(*reinterpret_cast<typename MmaOp::AVecType*>(a),
|
||||
*reinterpret_cast<typename MmaOp::BVecType*>(b),
|
||||
result);
|
||||
}
|
||||
|
||||
*reinterpret_cast<typename MmaOp::CVecType*>(out) = result;
|
||||
}
|
||||
|
||||
// Live test on real hardware for sparse selection and execution.
|
||||
TEST(SparseMMATrait, MmaSelector_Sparse_F16_F16_F32_16x16x32_Real)
|
||||
{
|
||||
int devCount;
|
||||
hipDevice_t dev;
|
||||
HIP_CHECK_ERROR(hipGetDevice(&dev));
|
||||
HIP_CHECK_ERROR(hipGetDeviceCount(&devCount));
|
||||
|
||||
hipDeviceProp_t devProp;
|
||||
HIP_CHECK_ERROR(hipGetDeviceProperties(&devProp, dev));
|
||||
|
||||
auto currentArchId = hip_device_prop_gcn_arch_name_to_amdgcn_target_id(devProp.gcnArchName);
|
||||
bool hasDevice = static_cast<bool>(devCount > 0);
|
||||
int deviceWarpSize = devProp.warpSize;
|
||||
|
||||
bool isSupportedWmma = (currentArchId >= amdgcn_target_id::GFX1200) &&
|
||||
(currentArchId <= amdgcn_target_id::GFX12_GENERIC);
|
||||
bool isSupportedMfma =
|
||||
(currentArchId >= amdgcn_target_id::GFX942) && (currentArchId <= amdgcn_target_id::GFX950);
|
||||
// TODO: c++20 add check for arch id
|
||||
if(!hasDevice || (currentArchId == amdgcn_target_id::HOST) ||
|
||||
!(isSupportedWmma || isSupportedMfma))
|
||||
{
|
||||
GTEST_SKIP() << "No HIP device found. Skipping test.";
|
||||
}
|
||||
|
||||
using AType = fp16_t;
|
||||
using BType = fp16_t;
|
||||
using CType = fp32_t;
|
||||
|
||||
// Fragment size, also the expected block size from the selector.
|
||||
// Note: Actual blockK might be slightly different due to hardware implementation, but the
|
||||
// test_accum_over_k kernel will loop over the K dimension to ensure that the total K is
|
||||
// correct.
|
||||
static constexpr uint32_t FragM = 16;
|
||||
static constexpr uint32_t FragN = 16;
|
||||
static constexpr uint32_t FragK = 32;
|
||||
static constexpr uint32_t BlockM = FragM;
|
||||
static constexpr uint32_t BlockN = FragN;
|
||||
static constexpr uint32_t BlockK = FragK;
|
||||
|
||||
// The number of elements per thread
|
||||
uint32_t AElements = BlockM * BlockK / deviceWarpSize;
|
||||
uint32_t BElements = BlockN * BlockK / deviceWarpSize;
|
||||
uint32_t CElements = BlockM * BlockN / deviceWarpSize;
|
||||
|
||||
uint32_t ASize = AElements * sizeof(AType);
|
||||
uint32_t BSize = BElements * sizeof(BType);
|
||||
uint32_t CSize = CElements * sizeof(CType);
|
||||
|
||||
// Initialize A and B to all 1's, C to all 0's
|
||||
std::vector<AType> h_a(AElements, static_cast<AType>(1));
|
||||
std::vector<BType> h_b(BElements, static_cast<BType>(1));
|
||||
std::vector<CType> h_c(CElements, static_cast<CType>(0));
|
||||
std::vector<CType> h_out(CElements, static_cast<CType>(0));
|
||||
|
||||
AType* d_a;
|
||||
BType* d_b;
|
||||
CType* d_c;
|
||||
CType* d_out;
|
||||
|
||||
HIP_CHECK_ERROR(hipMalloc(&d_a, ASize));
|
||||
HIP_CHECK_ERROR(hipMalloc(&d_b, BSize));
|
||||
HIP_CHECK_ERROR(hipMalloc(&d_c, CSize));
|
||||
HIP_CHECK_ERROR(hipMalloc(&d_out, CSize));
|
||||
|
||||
// Copy inputs to device
|
||||
HIP_CHECK_ERROR(hipMemcpy(d_a, h_a.data(), ASize, hipMemcpyHostToDevice));
|
||||
HIP_CHECK_ERROR(hipMemcpy(d_b, h_b.data(), BSize, hipMemcpyHostToDevice));
|
||||
HIP_CHECK_ERROR(hipMemcpy(d_c, h_c.data(), CSize, hipMemcpyHostToDevice));
|
||||
|
||||
const auto wave_size = getDeviceWaveSize();
|
||||
test_sparse_accum_over_k<AType, BType, CType, FragM, FragN, FragK>
|
||||
<<<1, wave_size>>>(d_a, d_b, d_c, d_out);
|
||||
HIP_CHECK_ERROR(hipDeviceSynchronize());
|
||||
|
||||
HIP_CHECK_ERROR(hipMemcpy(h_out.data(), d_out, CSize, hipMemcpyDeviceToHost));
|
||||
|
||||
// Output should be FragK for all elements, because the inputs are all 1's
|
||||
for(size_t i = 0; i < CElements; ++i)
|
||||
{
|
||||
// In sparse only half of the A values are non-zero, thus the /2.
|
||||
CType expected = static_cast<CType>(FragK) / 2;
|
||||
|
||||
EXPECT_NEAR(h_out[i], expected, 1e-3);
|
||||
}
|
||||
|
||||
HIP_CHECK_ERROR(hipFree(d_a));
|
||||
HIP_CHECK_ERROR(hipFree(d_b));
|
||||
HIP_CHECK_ERROR(hipFree(d_c));
|
||||
HIP_CHECK_ERROR(hipFree(d_out));
|
||||
}
|
||||
Reference in New Issue
Block a user