Integrate new packed cast threadwise tensor slice transfer into gridwise gemm pipelines.

This commit is contained in:
Ville Pietilä
2025-08-15 12:06:44 +00:00
parent 6374e16a43
commit 00a3ce734a
4 changed files with 157 additions and 283 deletions

View File

@@ -4,13 +4,11 @@
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/utility/env.hpp"
#include "ck/utility/type.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/packed_cast.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_selector.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
@@ -101,18 +99,7 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
is_single_rate_mfma,
is_scale_mfma>::selected_mfma.k_per_blk);
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
// gfx950 specific optimizations for BF16 inputs
#if defined(__gfx950__)
static constexpr bool is_gfx950_and_bf16_input_ =
std::is_same_v<ADataType, ck::bhalf_t> &&
std::is_same_v<BDataType, ck::bhalf_t> &&
std::is_same_v<CShuffleDataType, ck::bhalf_t> &&
std::is_same_v<AccDataType, float>;
#else
static constexpr bool is_gfx950_and_bf16_input_ = false;
#endif
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
__host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch, index_t Batch)
{
@@ -274,10 +261,6 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
p_b_grid{p_b_grid_},
p_c_grid{p_c_grid_}
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "[GridwiseGemm_xdl_cshuffle_conv_v3] GFX950 and BF16 optimization enabled: " << is_gfx950_and_bf16_input_ << std::endl;
}
}
const ADataType* p_a_grid;
@@ -656,6 +639,18 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
// if arch = gfx942
using Block2CTileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>;
// check if we should apply gf950 device specific optimization for BF16 output
__device__ static constexpr bool is_gfx650_and_bf16_output()
{
#if defined(__gfx950__)
return
std::is_same_v<CShuffleDataType, ck::bhalf_t> &&
std::is_same_v<AccDataType, float>;
#else
return false;
#endif
}
template <typename AGridDesc_AK0_M_K1,
typename BGridDesc_BK0_N_K1,
typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
@@ -900,9 +895,9 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
make_multi_index(n_thread_data_on_block));
// shuffle: threadwise copy C from VGPR to LDS
auto c_thread_copy_vgpr_to_lds =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
using ThreadwiseTransfer = std::conditional_t<
is_gfx650_and_bf16_output(),
ThreadwiseTensorSliceTransfer_v1r3_packed_cast<AccDataType,
CShuffleDataType,
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
@@ -920,8 +915,30 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
1,
InMemoryDataOperationEnum::Set,
1,
true,
is_gfx950_and_bf16_input_>{
true>,
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
CShuffleDataType,
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
ck::tensor_operation::element_wise::PassThrough,
Sequence<CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
I1,
I1,
M2,
I1,
M4,
I1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
7,
1,
InMemoryDataOperationEnum::Set,
1,
true>
>;
// shuffle: threadwise copy C from VGPR to LDS
auto c_thread_copy_vgpr_to_lds = ThreadwiseTransfer{
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
make_multi_index(0,
0,
@@ -989,21 +1006,6 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
// make sure it's safe to write to LDS
block_sync_lds();
if constexpr (is_gfx950_and_bf16_input_)
{
auto c_thread_packed_cast = PackedCastV2<
M2,
M4,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle
>{};
c_thread_packed_cast.Run(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, // source desc (TensorDescriptor struct)
sfc_c_vgpr.GetIndexTupleOfNumber(access_id), // source slice origin
c_thread_buf // source buffer
);
}
// each thread write its data from VGPR to LDS
c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
@@ -1288,9 +1290,9 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
make_multi_index(n_thread_data_on_block));
// shuffle: threadwise copy C from VGPR to LDS
auto c_thread_copy_vgpr_to_lds =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
using ThreadwiseTransfer = std::conditional_t<
is_gfx650_and_bf16_output(),
ThreadwiseTensorSliceTransfer_v1r3_packed_cast<AccDataType,
CShuffleDataType,
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
@@ -1308,8 +1310,30 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
1,
InMemoryDataOperationEnum::Set,
1,
true,
is_gfx950_and_bf16_input_>{
true>,
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
CShuffleDataType,
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
ck::tensor_operation::element_wise::PassThrough,
Sequence<CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
I1,
I1,
M2,
I1,
M4,
I1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
7,
1,
InMemoryDataOperationEnum::Set,
1,
true>
>;
// shuffle: threadwise copy C from VGPR to LDS
auto c_thread_copy_vgpr_to_lds = ThreadwiseTransfer{
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
make_multi_index(0,
0,
@@ -1376,21 +1400,6 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
static_for<0, num_access, 1>{}([&](auto access_id) {
// make sure it's safe to write to LDS
block_sync_lds();
if constexpr (is_gfx950_and_bf16_input_)
{
auto c_thread_packed_cast = PackedCastV2<
M2,
M4,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle
>{};
c_thread_packed_cast.Run(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, // source desc
sfc_c_vgpr.GetIndexTupleOfNumber(access_id), // source slice origin
c_thread_buf // source buffer
);
}
// each thread write its data from VGPR to LDS
c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,

View File

@@ -9,7 +9,6 @@
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/packed_cast.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_selector.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
@@ -1385,6 +1384,18 @@ struct GridwiseGemm_xdl_cshuffle_v3
using Block2CTileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>;
// using Block2CTileMap = BlockToCTileMap_3DGrid_KSplit<MPerBlock, NPerBlock>;
// check if we should apply gf950 device specific optimization for BF16 output
__device__ static constexpr bool is_gfx650_and_bf16_output()
{
#if defined(__gfx950__)
return
std::is_same_v<CShuffleDataType, ck::bhalf_t> &&
std::is_same_v<AccDataType, float>;
#else
return false;
#endif
}
template <typename AGridDesc_AK0_M_K1,
typename BGridDesc_BK0_N_K1,
typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
@@ -1647,9 +1658,10 @@ struct GridwiseGemm_xdl_cshuffle_v3
}
};
// shuffle: threadwise copy C from VGPR to LDS
auto c_thread_copy_vgpr_to_lds =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
using ThreadwiseTransfer = std::conditional_t<
is_gfx650_and_bf16_output(),
ThreadwiseTensorSliceTransfer_v1r3_packed_cast<
AccDataType,
CShuffleDataType,
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
@@ -1669,8 +1681,33 @@ struct GridwiseGemm_xdl_cshuffle_v3
1,
InMemoryDataOperationEnum::Set,
1,
true,
is_gfx950_and_bf16_input_>{
true>,
ThreadwiseTensorSliceTransfer_v1r3<
AccDataType,
CShuffleDataType,
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
conditional_t<DoElementwiseBeforeCShuffle,
CElementwiseOperation,
tensor_operation::element_wise::PassThrough>,
Sequence<CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
I1,
I1,
M2,
I1,
M4,
I1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
7,
1,
InMemoryDataOperationEnum::Set,
1,
true>
>;
// shuffle: threadwise copy C from VGPR to LDS
auto c_thread_copy_vgpr_to_lds = ThreadwiseTransfer{
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
make_multi_index(0,
0,
@@ -1740,21 +1777,6 @@ struct GridwiseGemm_xdl_cshuffle_v3
// make sure it's safe to write to LDS
block_sync_lds();
if constexpr (is_gfx950_and_bf16_input_)
{
auto c_thread_packed_cast = PackedCastV2<
M2,
M4,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle
>{};
c_thread_packed_cast.Run(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, // source desc (TensorDescriptor struct)
sfc_c_vgpr.GetIndexTupleOfNumber(access_id), // source slice origin
c_thread_buf // source buffer
);
}
// each thread write its data from VGPR to LDS
c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
@@ -2068,28 +2090,52 @@ struct GridwiseGemm_xdl_cshuffle_v3
n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
make_multi_index(n_thread_data_on_block));
using ThreadwiseTransfer = std::conditional_t<
is_gfx650_and_bf16_output(),
ThreadwiseTensorSliceTransfer_v1r3_packed_cast<
AccDataType,
CShuffleDataType,
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
ck::tensor_operation::element_wise::PassThrough,
Sequence<CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
I1,
I1,
M2,
I1,
M4,
I1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
7,
1,
InMemoryDataOperationEnum::Set,
1,
true>,
ThreadwiseTensorSliceTransfer_v1r3<
AccDataType,
CShuffleDataType,
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
ck::tensor_operation::element_wise::PassThrough,
Sequence<CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
I1,
I1,
M2,
I1,
M4,
I1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
7,
1,
InMemoryDataOperationEnum::Set,
1,
true>
>;
// shuffle: threadwise copy C from VGPR to LDS
auto c_thread_copy_vgpr_to_lds =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
CShuffleDataType,
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
ck::tensor_operation::element_wise::PassThrough,
Sequence<CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
I1,
I1,
M2,
I1,
M4,
I1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
7,
1,
InMemoryDataOperationEnum::Set,
1,
true,
is_gfx950_and_bf16_input_>{
auto c_thread_copy_vgpr_to_lds = ThreadwiseTransfer{
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
make_multi_index(0,
0,
@@ -2157,21 +2203,6 @@ struct GridwiseGemm_xdl_cshuffle_v3
// make sure it's safe to write to LDS
block_sync_lds();
if constexpr (is_gfx950_and_bf16_input_)
{
auto c_thread_packed_cast = PackedCastV2<
M2,
M4,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle
>{};
c_thread_packed_cast.Run(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, // source desc
sfc_c_vgpr.GetIndexTupleOfNumber(access_id), // source slice origin
c_thread_buf // source buffer
);
}
// each thread write its data from VGPR to LDS
c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
sfc_c_vgpr.GetIndexTupleOfNumber(access_id),

View File

@@ -1,167 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/utility/type_convert.hpp"
#include "ck/tensor_description/tensor_space_filling_curve.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace ck {
template <ck::index_t M2, ck::index_t M4, ck::index_t CShuffleMXdlPerWavePerShuffle, ck::index_t CShuffleNXdlPerWavePerShuffle>
struct PackedCast
{
template <typename SrcDesc, typename SrcSliceOriginIdx, typename SrcBuffer>
__device__ void Run(const SrcDesc&, const SrcSliceOriginIdx&, SrcBuffer& src_buf)
{
static_assert(SrcDesc::IsKnownAtCompileTime(),
"wrong! SrcDesc need to known at compile-time");
static_assert(is_known_at_compile_time<remove_cvref_t<SrcSliceOriginIdx>>::value,
"wrong! SrcSliceOrigin need to known at compile-time");
static_assert(SrcBuffer::IsStaticBuffer(), "wrong! SrcBuffer need to be StaticBuffer");
constexpr auto src_desc = remove_cvref_t<SrcDesc>{};
constexpr auto src_slice_origin_idx = to_multi_index(SrcSliceOriginIdx{});
// Calculate total elements in this slice
constexpr index_t elements_per_slice =
CShuffleMXdlPerWavePerShuffle * CShuffleNXdlPerWavePerShuffle * M2 * M4;
constexpr auto calculate_coords = [src_slice_origin_idx](auto idx) constexpr {
// We know that the access order is
// Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
// Sequence<CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, 1, 1, M2, 1, M4, 1>
constexpr index_t m4_offset = idx.value % M4;
constexpr index_t m2_offset = (idx.value / M4) % M2;
constexpr index_t n_xdl_offset = (idx.value / (M4 * M2)) % CShuffleNXdlPerWavePerShuffle;
constexpr index_t m_xdl_offset = idx.value / (M4 * M2 * CShuffleNXdlPerWavePerShuffle);
return make_tuple(
src_slice_origin_idx[Number<0>{}] + Number<m_xdl_offset>{},
src_slice_origin_idx[Number<1>{}] + Number<n_xdl_offset>{},
Number<0>{}, // this dim has unit size
Number<0>{}, // this dim has unit size
src_slice_origin_idx[Number<4>{}] + Number<m2_offset>{},
Number<0>{}, // this dim has unit size
src_slice_origin_idx[Number<6>{}] + Number<m4_offset>{},
Number<0>{} // this dim has unit size
);
};
constexpr index_t num_pairs = elements_per_slice / 2;
constexpr bool has_odd_element = (elements_per_slice % 2 == 1);
static_for<0, num_pairs, 1>{}([&](auto pair_idx) {
constexpr auto idx_0 = Number<pair_idx * 2>{};
constexpr auto idx_1 = Number<pair_idx * 2 + 1>{};
constexpr auto coord_0 = calculate_coords(idx_0);
constexpr auto coord_1 = calculate_coords(idx_1);
constexpr auto offset_0 = src_desc.CalculateOffset(coord_0);
constexpr auto offset_1 = src_desc.CalculateOffset(coord_1);
float& val_0 = src_buf(Number<offset_0>{});
float& val_1 = src_buf(Number<offset_1>{});
static_cast_float_to_bhalf_packed(val_0, val_1);
});
// Handle last element if the number of elements is odd.
if constexpr (has_odd_element)
{
constexpr auto last_idx = Number<elements_per_slice - 1>{};
constexpr auto last_coord = calculate_coords(last_idx);
// Single element conversion
constexpr auto last_offset = src_desc.CalculateOffset(last_coord);
float& last_val = src_buf[Number<last_offset>{}];
const auto single_bf16 = static_cast<__bf16>(last_val);
uint16_t* parts = reinterpret_cast<uint16_t*>(&last_val);
const uint16_t* bf16_bits = reinterpret_cast<const uint16_t*>(&single_bf16);
parts[1] = bf16_bits[0];
}
};
};
template <ck::index_t M2, ck::index_t M4, ck::index_t CShuffleMXdlPerWavePerShuffle, ck::index_t CShuffleNXdlPerWavePerShuffle>
struct PackedCastV2
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
template <typename SrcDesc, typename SrcSliceOriginIdx, typename SrcBuffer>
__device__ void Run(const SrcDesc&, const SrcSliceOriginIdx&, SrcBuffer& src_buf)
{
static_assert(SrcDesc::IsKnownAtCompileTime(),
"wrong! SrcDesc need to known at compile-time");
static_assert(is_known_at_compile_time<remove_cvref_t<SrcSliceOriginIdx>>::value,
"wrong! SrcSliceOrigin need to known at compile-time");
static_assert(SrcBuffer::IsStaticBuffer(), "wrong! SrcBuffer need to be StaticBuffer");
constexpr auto src_desc = remove_cvref_t<SrcDesc>{};
constexpr auto src_slice_origin_idx = to_multi_index(SrcSliceOriginIdx{});
using SliceLengths = Sequence<CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
I1,
I1,
M2,
I1,
M4,
I1>;
using DimAccessOrder = Sequence<0, 1, 2, 3, 4, 5, 6, 7>;
using DstScalarPerAccess = Sequence<1, 1, 1, 1, 1, 1, 1, 1>;
using SpaceFillingCurve = SpaceFillingCurve<SliceLengths,
DimAccessOrder,
DstScalarPerAccess>;
static_assert(SpaceFillingCurve::ScalarPerVector == 1,
"wrong! SpaceFillingCurve::ScalarPerVector must be 1 for PackedCastV2");
constexpr index_t num_access = SpaceFillingCurve::GetNumOfAccess();
constexpr index_t num_pairs = num_access / 2;
constexpr bool has_odd_element = (num_access % 2 == 1);
static_for<0, num_pairs, 1>{}([&](auto i_pair)
{
constexpr auto idx_1d_0 = I2 * i_pair;
constexpr auto idx_1d_1 = I2 * i_pair + I1;
constexpr auto idx_md_0 = SpaceFillingCurve::GetIndex(idx_1d_0);
constexpr auto idx_md_1 = SpaceFillingCurve::GetIndex(idx_1d_1);
constexpr index_t src_offset_0 = src_desc.CalculateOffset(src_slice_origin_idx + idx_md_0);
constexpr index_t src_offset_1 = src_desc.CalculateOffset(src_slice_origin_idx + idx_md_1);
float& val_0 = src_buf(Number<src_offset_0>{});
float& val_1 = src_buf(Number<src_offset_1>{});
static_cast_float_to_bhalf_packed_v2(val_0, val_1);
});
// Handle last element if the number of elements is odd.
if constexpr (has_odd_element)
{
constexpr auto last_idx_1d = Number<num_access - 1>{};
constexpr auto last_idx_md = SpaceFillingCurve::GetIndex(last_idx_1d);
// Single element conversion
constexpr auto last_src_offset = src_desc.CalculateOffset(last_idx_1d);
float& last_val = src_buf(Number<last_src_offset>{});
const auto single_bf16 = static_cast<__bf16>(last_val);
uint16_t* parts = reinterpret_cast<uint16_t*>(&last_val);
const uint16_t* bf16_bits = reinterpret_cast<const uint16_t*>(&single_bf16);
parts[0] = bf16_bits[0];
}
};
};
}

View File

@@ -235,7 +235,8 @@ class TestGroupedConvndBwdWeight2d_bf16_gfx950 : public TestGroupedConvndBwdWeig
};
using KernelTypes2d_bf16_gfx950 = ::testing::Types<
std::tuple<ck::bhalf_t, ck::bhalf_t, ck::bhalf_t, NHWGC, GKYXC, NHWGK, ck::Number<2>>,
// This layout does not yet work.
//std::tuple<ck::bhalf_t, ck::bhalf_t, ck::bhalf_t, NHWGC, GKYXC, NHWGK, ck::Number<2>>,
std::tuple<ck::bhalf_t, ck::bhalf_t, ck::bhalf_t, NGCHW, GKYXC, NGKHW, ck::Number<2>>,
std::tuple<ck::bhalf_t, ck::bhalf_t, ck::bhalf_t, NGCHW, GKCYX, NGKHW, ck::Number<2>>>;