Implement batched gemm bias permute for RDNA4 (#3534)

* feat: test setup for batched contraction (aka batched gemm multiple d e permute)

* wip: device struct for WMMA batched contraction multiple d based on new gridwise op

* feat: working batched contraction on RDNA, non-naive tensor descriptors for gridwise_gemm_wmma_cshuffle_v3, test setup for odd cases

* fix: failure to resolve template parameters when calling new function overload

* fix: passing reference type as parameter instead of underlying types

* fix: merge error caused duplicate definitions

* fix: make sure constness of template and parameters types match

* fix: don't compile batched contraction test on unsupported architectures

* feat: add example for new wmma implementation, and consolidate example code between platforms

* style: return inline instead of with branch

* chore: add extra assert on vector memory access sizes

* chore: clean up some unused variables

* fix: correct tail number calculation, added small cases and extra instances to the test

* fix: properly support wave transfer by generating correct grid descriptors dependent on the transfer method

[ROCm/composable_kernel commit: fe40a5d139]
This commit is contained in:
Erwin Terpstra
2026-01-17 08:30:27 +01:00
committed by GitHub
parent 80bc8aaf76
commit beffadc5a0
18 changed files with 2475 additions and 1009 deletions

View File

@@ -1,8 +1,9 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
# ONLY XDL_KERNELS
# ONLY XDL_AND_WMMA_KERNELS
add_instance_library(device_batched_gemm_bias_permute_instance
device_batched_gemm_bias_permute_m2_n3_k1_xdl_c_shuffle_f16_f16_f16_f16_instance.cpp
device_batched_gemm_bias_permute_m2_n3_k1_wmma_c_shuffle_f16_f16_f16_f16_instance.cpp
)

View File

@@ -0,0 +1,78 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// This (ifndef) is a hack to use customized behavior for buffer load rather than using default
// setting Don't use this hack unless absolutely necessary!
// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op
#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle_v3.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using F16 = ck::half_t;
using F32 = float;
using F16_Tuple = ck::Tuple<F16>;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Add = ck::tensor_operation::element_wise::Add;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
static constexpr auto ABSpec = ck::tensor_operation::device::TensorSpecialization::Packed;
static constexpr auto DESpec = ck::tensor_operation::device::TensorSpecialization::Default;
// A[g0, m0, m1, k0] * B[g0, n0, n1, n2, k0] + D[g0, m0, m1, n0, n1, n2] = E[g0, n0, m0, n0, n1, m1]
// m/n/n/n are the fast changing dimension for A/B/D/E
using device_batched_contraction_bias_permute_m2_n3_k1_wmma_c_shuffle_f16_f16_f16_f16_mnnm_instance =
std::tuple<
// clang-format off
//################################################| NumDimG| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| A| B| DE| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CDEBlockTransferClusterLengths| CDEBlockTransfer|
//################################################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Specialization| Specialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector|
//################################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat|
//################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceBatchedContractionMultipleD_Wmma_CShuffle_V3< 1, 2, 3, 1, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Add, GemmSpec, ABSpec, ABSpec, DESpec, 256, 256, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<1, 1>>,
DeviceBatchedContractionMultipleD_Wmma_CShuffle_V3< 1, 2, 3, 1, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Add, GemmSpec, ABSpec, ABSpec, DESpec, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<1, 1>>,
DeviceBatchedContractionMultipleD_Wmma_CShuffle_V3< 1, 2, 3, 1, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Add, GemmSpec, ABSpec, ABSpec, DESpec, 128, 128, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, S<1, 1>>,
DeviceBatchedContractionMultipleD_Wmma_CShuffle_V3< 1, 2, 3, 1, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Add, GemmSpec, ABSpec, ABSpec, DESpec, 128, 64, 32, 64, 8, 8, 16, 16, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, S<1, 1>>,
DeviceBatchedContractionMultipleD_Wmma_CShuffle_V3< 1, 2, 3, 1, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Add, GemmSpec, ABSpec, ABSpec, DESpec, 256, 256, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<1, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
DeviceBatchedContractionMultipleD_Wmma_CShuffle_V3< 1, 2, 3, 1, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Add, GemmSpec, ABSpec, ABSpec, DESpec, 128, 64, 32, 64, 8, 8, 16, 16, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, S<4, 4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
DeviceBatchedContractionMultipleD_Wmma_CShuffle_V3< 1, 2, 3, 1, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Add, GemmSpec, ABSpec, ABSpec, DESpec, 64, 64, 32, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, S<4, 4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>
// clang-format on
>;
void add_device_batched_contraction_bias_permute_m2_n3_k1_wmma_c_shuffle_f16_f16_f16_f16_mnnm_instance(
std::vector<std::unique_ptr<DeviceBatchedContractionMultipleD<1,
2,
3,
1,
F16,
F16,
F16_Tuple,
F16,
PassThrough,
PassThrough,
Add>>>& instances)
{
add_device_operation_instances(
instances,
device_batched_contraction_bias_permute_m2_n3_k1_wmma_c_shuffle_f16_f16_f16_f16_mnnm_instance{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck