[rocm-libraries] ROCm/rocm-libraries#4837 (commit 6316035)

[CK TILE] Unification of sparse MFMA/WMMA policy structs
 (#4837)

## Motivation

The existing unification work supports DENSE intrinsics. In this PR we
enable support for SPARSE as well as SCALE intrinsics and add an example
SPARSE implementation.

## Technical Details

Mostly trivial changes. One framework change is that the desired
`MmaOpFamily` is passed to the `MmaDefaultSelector`. As my relevant
commit explains, we do not support a fallback family at the moment, but
it is something we can consider.

## Test Plan

Added a new test for the relevant sparse specializations.

## Test Result

Test should pass.

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
chris-tsiaousis-hpc
2026-03-05 19:53:16 +00:00
committed by assistant-librarian[bot]
parent 6e558658ea
commit 03ce21ddcb
23 changed files with 1173 additions and 89 deletions

View File

@@ -7,6 +7,10 @@ if(CK_USE_OCP_FP8)
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
endif()
if(GPU_TARGETS MATCHES "gfx9|gfx12")
add_gtest_executable(test_amdgcn_sparse_mma test_amdgcn_sparse_mma.cpp)
target_compile_options(test_amdgcn_sparse_mma PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
endif()
if(GPU_TARGETS MATCHES "gfx9")
add_gtest_executable(test_amdgcn_mma test_amdgcn_mma.cpp)
target_compile_options(test_amdgcn_mma PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})

View File

@@ -0,0 +1,34 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <cstdint>
#include <cstdio>
#include "ck_tile/core/arch/arch.hpp"
#include <hip/hip_runtime.h>
#include "ck_tile/host/hip_check_error.hpp"
namespace {
__global__ void getWaveSizeForSelectedOp(uint32_t* waveSize)
{
using CompilerTarget = decltype(ck_tile::core::arch::get_compiler_target());
if(waveSize)
*waveSize = static_cast<uint32_t>(CompilerTarget::WAVE_SIZE_ID);
}
static __host__ uint32_t getDeviceWaveSize()
{
uint32_t* d_wave_size;
HIP_CHECK_ERROR(hipMalloc(&d_wave_size, sizeof(uint32_t)));
getWaveSizeForSelectedOp<<<1, 64>>>(d_wave_size);
HIP_CHECK_ERROR(hipDeviceSynchronize());
uint32_t wave_size;
HIP_CHECK_ERROR(hipMemcpy(&wave_size, d_wave_size, sizeof(uint32_t), hipMemcpyDeviceToHost));
return wave_size;
}
} // namespace

View File

@@ -11,6 +11,8 @@
#include "ck_tile/core/utility/type_traits.hpp"
#include "ck_tile/host/hip_check_error.hpp"
#include "get_wave_size_helper.hpp"
using namespace ck_tile;
using namespace ck_tile::core::arch;
using namespace ck_tile::core::arch::mma;
@@ -47,10 +49,12 @@ struct amdgcn_mma<fp32_t,
16u,
DummyCtrlFlags,
CompilerTarget,
MmaOpFamily::DENSE,
enable_if_target_id_dummy_t<CompilerTarget>>
{
// Mfma operation type
using OpType = DummyOpType;
using OpType = DummyOpType;
static constexpr MmaOpFamily OpFamily = MmaOpFamily::DENSE;
// Register types
using AVecType = ext_vector_t<fp32_t, 4>;
@@ -81,8 +85,15 @@ struct amdgcn_mma<fp32_t,
// Have an alias so we can test supported arch vs unsupported arch
// TODO: c++20 template <amdgcn_target_arch_id CompilerTarget>
template <typename CompilerTarget>
using DummyAmdgcnMma =
amdgcn_mma<fp32_t, fp32_t, fp32_t, 16u, 16u, 16u, DummyCtrlFlags, CompilerTarget>;
using DummyAmdgcnMma = amdgcn_mma<fp32_t,
fp32_t,
fp32_t,
16u,
16u,
16u,
DummyCtrlFlags,
CompilerTarget,
MmaOpFamily::DENSE>;
/*! @struct MmaDefaultSelector
* @brief For dummy Id only, instantiate tests for both MFMA and WMMA selectors so we can them both
@@ -93,6 +104,7 @@ using DummyAmdgcnMma =
* @tparam FragN Size of the N dimension of the fragment to decompose
* @tparam FragK Size of the K dimension of the fragment to decompose
* @tparam CompilerTarget The compiler target
* @tparam OpFamily The MMA operation family
*/
template <typename ADataType,
typename BDataType,
@@ -100,7 +112,8 @@ template <typename ADataType,
uint32_t FragM,
uint32_t FragN,
uint32_t FragK,
typename CompilerTarget>
typename CompilerTarget,
MmaOpFamily OpFamily>
// TODO: c++20 amdgcn_target_arch_id CompilerTarget>
// TODO: requires
struct MmaDefaultSelector<ADataType,
@@ -110,7 +123,9 @@ struct MmaDefaultSelector<ADataType,
FragN,
FragK,
CompilerTarget,
enable_if_target_id_dummy_t<CompilerTarget>>
OpFamily,
enable_if_all<enable_if_target_id_dummy_t<CompilerTarget>,
std::enable_if_t<OpFamily == MmaOpFamily::DENSE>>>
{
using SelectedOp = DummyAmdgcnMma<CompilerTarget>;
};
@@ -128,6 +143,8 @@ TEST(TestAmdgcnMma, ArchSupported)
// Check OpType
EXPECT_TRUE(
(std::is_same<typename MmaOp::OpType, DummyOpType>::value)); // OpType is DummyOpType
// Check OpFamily
EXPECT_TRUE((is_mma_op_of_family_v<MmaOpFamily::DENSE, MmaOp>));
// Check AVecType, BVecType, CVecType
EXPECT_TRUE((std::is_same<typename MmaOp::AVecType, ext_vector_t<fp32_t, 4>>::value));
@@ -157,6 +174,8 @@ TEST(TestAmdgcnMma, ArchUnsupported)
// OpType should be Unsupported
EXPECT_TRUE((std::is_same<typename MmaOp::OpType, Unsupported>::value));
// OpFamily should be Undefined
EXPECT_TRUE((is_mma_op_of_family_v<MmaOpFamily::UNDEFINED, MmaOp>));
// AVecType, BVecType, CVecType should match default
EXPECT_TRUE((std::is_same<typename MmaOp::AVecType, ext_vector_t<fp32_t, 1>>::value));
@@ -367,6 +386,7 @@ TEST(TestAmdgcnMma, MmaOpTraitsUnsupportedMembers)
EXPECT_TRUE((std::is_same<typename Traits::AVecType, ext_vector_t<fp32_t, 1>>::value));
EXPECT_TRUE((std::is_same<typename Traits::BVecType, ext_vector_t<fp32_t, 1>>::value));
EXPECT_TRUE((std::is_same<typename Traits::CVecType, ext_vector_t<fp32_t, 1>>::value));
EXPECT_EQ(Traits::OpFamily, MmaOpFamily::UNDEFINED);
EXPECT_EQ(Traits::kAMBlock, 0);
EXPECT_EQ(Traits::kBNBlock, 0);
EXPECT_EQ(Traits::kAMLane, 0);
@@ -386,9 +406,14 @@ TEST(TestAmdgcnMma, MmaOpTraitsUnsupportedMembers)
TEST(TestAmdgcnMma, MmaDefaultSelectorSupported)
{
// Direct selection of the supported dummy instruction
using SelectedMma =
typename MmaDefaultSelector<fp32_t, fp32_t, fp32_t, 16u, 16u, 16u, DummyCompilerTarget>::
SelectedOp;
using SelectedMma = typename MmaDefaultSelector<fp32_t,
fp32_t,
fp32_t,
16u,
16u,
16u,
DummyCompilerTarget,
MmaOpFamily::DENSE>::SelectedOp;
// Should select DummyAmdgcnMma specialization
EXPECT_TRUE((std::is_same<SelectedMma, DummyAmdgcnMma<DummyCompilerTarget>>::value));
// OpType should be DummyOpType
@@ -401,8 +426,14 @@ TEST(TestAmdgcnMma, MmaDefaultSelectorSupported)
TEST(TestAmdgcnMma, MmaDefaultSelectorUnsupported)
{
// Direct selection of the unsupported dummy instruction
using SelectedMma =
MmaDefaultSelector<fp32_t, fp32_t, fp32_t, 16u, 16u, 16u, amdgcn_target<>>::SelectedOp;
using SelectedMma = MmaDefaultSelector<fp32_t,
fp32_t,
fp32_t,
16u,
16u,
16u,
amdgcn_target<>,
MmaOpFamily::UNDEFINED>::SelectedOp;
// OpType should be Unsupported
EXPECT_TRUE((std::is_same<typename SelectedMma::OpType, Unsupported>::value));
// IsSupported should be false
@@ -414,9 +445,14 @@ TEST(TestAmdgcnMma, MmaDefaultSelectorUnsupported)
TEST(TestAmdgcnMma, MmaDefaultSelectorSupportedFragment)
{
// Select indirectly with a fragment size of 256x128x64
using SelectedMma =
MmaDefaultSelector<fp32_t, fp32_t, fp32_t, 256u, 128u, 64u, DummyCompilerTarget>::
SelectedOp;
using SelectedMma = MmaDefaultSelector<fp32_t,
fp32_t,
fp32_t,
256u,
128u,
64u,
DummyCompilerTarget,
MmaOpFamily::DENSE>::SelectedOp;
// Should select DummyAmdgcnMma specialization
EXPECT_TRUE((std::is_same<SelectedMma, DummyAmdgcnMma<DummyCompilerTarget>>::value));
// OpType should be DummyOpType
@@ -429,8 +465,14 @@ TEST(TestAmdgcnMma, MmaDefaultSelectorSupportedFragment)
TEST(TestAmdgcnMma, MmaDefaultSelectorUnsupportedFragment)
{
// This should fall back to unsupported since DummyAmdgcnMma only supports 16x16x16
using SelectedMma =
MmaDefaultSelector<fp32_t, fp32_t, fp32_t, 8u, 8u, 8u, DummyCompilerTarget>::SelectedOp;
using SelectedMma = MmaDefaultSelector<fp32_t,
fp32_t,
fp32_t,
8u,
8u,
8u,
DummyCompilerTarget,
MmaOpFamily::DENSE>::SelectedOp;
EXPECT_FALSE((std::is_same<typename SelectedMma::OpType, Unsupported>::value));
EXPECT_TRUE(MmaOpTraits<SelectedMma>::IsSupported);
}
@@ -438,8 +480,14 @@ TEST(TestAmdgcnMma, MmaDefaultSelectorUnsupportedFragment)
// Test MmaDefaultSelector for a different data type (fp16_t) and unsupported arch
TEST(TestAmdgcnMma, MmaDefaultSelectorFp16Unsupported)
{
using SelectedMma =
MmaDefaultSelector<fp16_t, fp16_t, fp16_t, 16u, 16u, 16u, amdgcn_target<>>::SelectedOp;
using SelectedMma = MmaDefaultSelector<fp16_t,
fp16_t,
fp16_t,
16u,
16u,
16u,
amdgcn_target<>,
MmaOpFamily::UNDEFINED>::SelectedOp;
// Should select default amdgcn_mma (Unsupported)
EXPECT_TRUE((std::is_same<typename SelectedMma::OpType, Unsupported>::value));
EXPECT_FALSE(MmaOpTraits<SelectedMma>::IsSupported);
@@ -464,7 +512,8 @@ __global__ void test_accum_over_k(void* a, void* b, void* c, void* out)
FragM,
FragN,
FragK,
decltype(get_compiler_target())>;
decltype(get_compiler_target()),
MmaOpFamily::DENSE>;
using MmaOp = typename Selector::SelectedOp;
using MmaTraits = MmaOpTraits<MmaOp>;
@@ -561,8 +610,9 @@ TEST(TestAmdgcnMma, MmaSelector_F16_F16_F32_16x16x32_Real)
HIP_CHECK_ERROR(hipMemcpy(d_b, h_b.data(), BSize, hipMemcpyHostToDevice));
HIP_CHECK_ERROR(hipMemcpy(d_c, h_c.data(), CSize, hipMemcpyHostToDevice));
// Need at least 1 WG with 64 threads to get defined MFMA/WMMA behaviour
test_accum_over_k<AType, BType, CType, FragM, FragN, FragK><<<1, 64>>>(d_a, d_b, d_c, d_out);
const auto wave_size = getDeviceWaveSize();
test_accum_over_k<AType, BType, CType, FragM, FragN, FragK>
<<<1, wave_size>>>(d_a, d_b, d_c, d_out);
HIP_CHECK_ERROR(hipDeviceSynchronize());
HIP_CHECK_ERROR(hipMemcpy(h_out.data(), d_out, CSize, hipMemcpyDeviceToHost));
@@ -661,8 +711,9 @@ TEST(TestAmdgcnMma, MmaSelector_F16_F16_F32_112x112x128_Real)
HIP_CHECK_ERROR(hipMemcpy(d_b, h_b.data(), BSize, hipMemcpyHostToDevice));
HIP_CHECK_ERROR(hipMemcpy(d_c, h_c.data(), CSize, hipMemcpyHostToDevice));
// Need at least 1 WG with 64 threads to get defined MFMA/WMMA behaviour
test_accum_over_k<AType, BType, CType, FragM, FragN, FragK><<<1, 64>>>(d_a, d_b, d_c, d_out);
const auto wave_size = getDeviceWaveSize();
test_accum_over_k<AType, BType, CType, FragM, FragN, FragK>
<<<1, wave_size>>>(d_a, d_b, d_c, d_out);
HIP_CHECK_ERROR(hipDeviceSynchronize());
HIP_CHECK_ERROR(hipMemcpy(h_out.data(), d_out, CSize, hipMemcpyDeviceToHost));

View File

@@ -0,0 +1,274 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <gtest/gtest.h>
#include <iostream>
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/arch/mma/amdgcn_mma.hpp"
#include "ck_tile/core/arch/mma/mma_op_family.hpp"
#include "ck_tile/core/arch/mma/mma_selector.hpp"
#include <hip/hip_runtime.h>
#include "ck_tile/host/hip_check_error.hpp"
#include "ck_tile/core/arch/mma/mma_traits.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include "get_wave_size_helper.hpp"
using namespace ck_tile;
using namespace ck_tile::core::arch;
using namespace ck_tile::core::arch::mma;
using CompilerTargetGfx950 = decltype(make_amdgcn_gfx9_target<amdgcn_target_id::GFX950>());
TEST(SparseMMATrait, SparseMfmaGfx950Specialization)
{
// Test fp16 → fp32 sparse MFMA for GFX950 (16x16x32)
using TestSparseMfma16x16 = amdgcn_mma<fp16_t,
fp16_t,
fp32_t,
16u,
16u,
32u,
DefaultSparseMfmaCtrlFlags,
CompilerTargetGfx950,
MmaOpFamily::SPARSE>;
static_assert(std::is_same_v<typename TestSparseMfma16x16::OpType, MfmaOp> &&
TestSparseMfma16x16::OpFamily == MmaOpFamily::SPARSE,
"GFX950 sparse 16x16x32 should have SparseMFMAOp type");
static_assert(is_mma_op_of_family_v<MmaOpFamily::SPARSE, TestSparseMfma16x16>,
"GFX950 sparse 16x16x32 should be detected as Sparse");
std::cout << "GFX950 sparse MFMA specialization is correct" << std::endl;
}
TEST(SparseMMATrait, MmaOpTraitsIntegration)
{
// Create a sparse MMA op (16x16x32 fp16 specialization)
using TestSparseMmma = amdgcn_mma<fp16_t,
fp16_t,
fp32_t,
16u,
16u,
32u,
DefaultSparseMfmaCtrlFlags,
CompilerTargetGfx950,
MmaOpFamily::SPARSE>;
// Get its traits
using TestTraits = MmaOpTraits<TestSparseMmma>;
// Verify trait detection
static_assert(TestTraits::IsSparse, "Sparse MMA should be detected as sparse");
static_assert(TestTraits::IsSupported, "Sparse MMA specialization should be supported");
static_assert(TestTraits::IsMfma, "Sparse MFMA should be detected as MFMA");
static_assert(!TestTraits::IsWmma, "Sparse MFMA should not be detected as WMMA");
std::cout << "MmaOpTraits correctly integrates sparse operations" << std::endl;
}
TEST(SparseMMATrait, DenseVsSparseDistinction)
{
// Dense MFMA from mfma/mfma_gfx9.hpp
using DenseMfma = amdgcn_mma<fp16_t,
fp16_t,
fp32_t,
16u,
16u,
32u,
DefaultMfmaCtrlFlags,
CompilerTargetGfx950,
MmaOpFamily::DENSE>;
// Sparse MFMA on GFX950
using SparseMfma = amdgcn_mma<fp16_t,
fp16_t,
fp32_t,
16u,
16u,
32u,
DefaultSparseMfmaCtrlFlags,
CompilerTargetGfx950,
MmaOpFamily::SPARSE>;
// Verify they have different operation types
static_assert(std::is_same_v<typename DenseMfma::OpType, typename SparseMfma::OpType> &&
DenseMfma::OpFamily != SparseMfma::OpFamily,
"Dense and Sparse MFMA should have the same OpType tags and different OpFamily");
// Verify traits correctly identify them
static_assert(MmaOpTraits<DenseMfma>::IsMfma && MmaOpTraits<DenseMfma>::IsDense &&
!MmaOpTraits<DenseMfma>::IsSparse && !MmaOpTraits<DenseMfma>::IsScale &&
MmaOpTraits<DenseMfma>::IsSupported,
"Dense MFMA should be identified correctly");
static_assert(MmaOpTraits<SparseMfma>::IsSparse && MmaOpTraits<SparseMfma>::IsMfma &&
!MmaOpTraits<SparseMfma>::IsDense && !MmaOpTraits<SparseMfma>::IsScale &&
MmaOpTraits<SparseMfma>::IsSupported,
"Sparse MFMA should be identified correctly");
std::cout << "Dense and sparse MMA operations are correctly distinguished" << std::endl;
}
TEST(SparseMMATrait, SparseSelector)
{
static_for<1, 33, 1>{}([](auto i) {
using Selected = typename MmaDefaultSelector<fp16_t,
fp16_t,
fp32_t,
static_cast<uint32_t>(i),
static_cast<uint32_t>(i),
static_cast<uint32_t>(2 * i),
CompilerTargetGfx950,
MmaOpFamily::SPARSE>::SelectedOp;
static constexpr bool isValid = (i == 16) || (i == 32);
if constexpr(isValid)
{
// Selector should pick a sparse MFMA implementation
static_assert(MmaOpTraits<Selected>::IsSparse);
static_assert(MmaOpTraits<Selected>::IsMfma);
static_assert(MmaOpTraits<Selected>::IsSupported);
static_assert((std::is_same<typename Selected::OpType, MfmaOp>::value));
}
else
{
// Selector should pick the unsupported pass through
static_assert(!MmaOpTraits<Selected>::IsSupported);
}
});
}
template <typename AType,
typename BType,
typename CType,
uint32_t FragM,
uint32_t FragN,
uint32_t FragK>
__global__ void test_sparse_accum_over_k(void* a, void* b, void* c, void* out)
{
using CompilerTarget = decltype(get_compiler_target());
using Selector = MmaDefaultSelector<AType,
BType,
CType,
FragM,
FragN,
FragK,
CompilerTarget,
MmaOpFamily::SPARSE>;
using MmaOp = typename Selector::SelectedOp;
using MmaTraits = MmaOpTraits<MmaOp>;
using CVecType = typename MmaOp::CVecType;
static constexpr uint32_t kIters = FragK / MmaTraits::BlockK;
// Initialize the accumulator
CVecType result = *reinterpret_cast<typename MmaOp::CVecType*>(c);
// Accumulate input AxB over FragK/BlockK iterations
for(uint32_t i = 0; i < kIters; ++i)
{
result = MmaOp::exec(*reinterpret_cast<typename MmaOp::AVecType*>(a),
*reinterpret_cast<typename MmaOp::BVecType*>(b),
result);
}
*reinterpret_cast<typename MmaOp::CVecType*>(out) = result;
}
// Live test on real hardware for sparse selection and execution.
TEST(SparseMMATrait, MmaSelector_Sparse_F16_F16_F32_16x16x32_Real)
{
int devCount;
hipDevice_t dev;
HIP_CHECK_ERROR(hipGetDevice(&dev));
HIP_CHECK_ERROR(hipGetDeviceCount(&devCount));
hipDeviceProp_t devProp;
HIP_CHECK_ERROR(hipGetDeviceProperties(&devProp, dev));
auto currentArchId = hip_device_prop_gcn_arch_name_to_amdgcn_target_id(devProp.gcnArchName);
bool hasDevice = static_cast<bool>(devCount > 0);
int deviceWarpSize = devProp.warpSize;
bool isSupportedWmma = (currentArchId >= amdgcn_target_id::GFX1200) &&
(currentArchId <= amdgcn_target_id::GFX12_GENERIC);
bool isSupportedMfma =
(currentArchId >= amdgcn_target_id::GFX942) && (currentArchId <= amdgcn_target_id::GFX950);
// TODO: c++20 add check for arch id
if(!hasDevice || (currentArchId == amdgcn_target_id::HOST) ||
!(isSupportedWmma || isSupportedMfma))
{
GTEST_SKIP() << "No HIP device found. Skipping test.";
}
using AType = fp16_t;
using BType = fp16_t;
using CType = fp32_t;
// Fragment size, also the expected block size from the selector.
// Note: Actual blockK might be slightly different due to hardware implementation, but the
// test_accum_over_k kernel will loop over the K dimension to ensure that the total K is
// correct.
static constexpr uint32_t FragM = 16;
static constexpr uint32_t FragN = 16;
static constexpr uint32_t FragK = 32;
static constexpr uint32_t BlockM = FragM;
static constexpr uint32_t BlockN = FragN;
static constexpr uint32_t BlockK = FragK;
// The number of elements per thread
uint32_t AElements = BlockM * BlockK / deviceWarpSize;
uint32_t BElements = BlockN * BlockK / deviceWarpSize;
uint32_t CElements = BlockM * BlockN / deviceWarpSize;
uint32_t ASize = AElements * sizeof(AType);
uint32_t BSize = BElements * sizeof(BType);
uint32_t CSize = CElements * sizeof(CType);
// Initialize A and B to all 1's, C to all 0's
std::vector<AType> h_a(AElements, static_cast<AType>(1));
std::vector<BType> h_b(BElements, static_cast<BType>(1));
std::vector<CType> h_c(CElements, static_cast<CType>(0));
std::vector<CType> h_out(CElements, static_cast<CType>(0));
AType* d_a;
BType* d_b;
CType* d_c;
CType* d_out;
HIP_CHECK_ERROR(hipMalloc(&d_a, ASize));
HIP_CHECK_ERROR(hipMalloc(&d_b, BSize));
HIP_CHECK_ERROR(hipMalloc(&d_c, CSize));
HIP_CHECK_ERROR(hipMalloc(&d_out, CSize));
// Copy inputs to device
HIP_CHECK_ERROR(hipMemcpy(d_a, h_a.data(), ASize, hipMemcpyHostToDevice));
HIP_CHECK_ERROR(hipMemcpy(d_b, h_b.data(), BSize, hipMemcpyHostToDevice));
HIP_CHECK_ERROR(hipMemcpy(d_c, h_c.data(), CSize, hipMemcpyHostToDevice));
const auto wave_size = getDeviceWaveSize();
test_sparse_accum_over_k<AType, BType, CType, FragM, FragN, FragK>
<<<1, wave_size>>>(d_a, d_b, d_c, d_out);
HIP_CHECK_ERROR(hipDeviceSynchronize());
HIP_CHECK_ERROR(hipMemcpy(h_out.data(), d_out, CSize, hipMemcpyDeviceToHost));
// Output should be FragK for all elements, because the inputs are all 1's
for(size_t i = 0; i < CElements; ++i)
{
// In sparse only half of the A values are non-zero, thus the /2.
CType expected = static_cast<CType>(FragK) / 2;
EXPECT_NEAR(h_out[i], expected, 1e-3);
}
HIP_CHECK_ERROR(hipFree(d_a));
HIP_CHECK_ERROR(hipFree(d_b));
HIP_CHECK_ERROR(hipFree(d_c));
HIP_CHECK_ERROR(hipFree(d_out));
}