mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-16 19:09:59 +00:00
Implement batched gemm bias permute for RDNA4 (#3534)
* feat: test setup for batched contraction (aka batched gemm multiple d e permute)
* wip: device struct for WMMA batched contraction multiple d based on new gridwise op
* feat: working batched contraction on RDNA, non-naive tensor descriptors for gridwise_gemm_wmma_cshuffle_v3, test setup for odd cases
* fix: failure to resolve template parameters when calling new function overload
* fix: passing reference type as parameter instead of underlying types
* fix: merge error caused duplicate definitions
* fix: make sure constness of template and parameters types match
* fix: don't compile batched contraction test on unsupported architectures
* feat: add example for new wmma implementation, and consolidate example code between platforms
* style: return inline instead of with branch
* chore: add extra assert on vector memory access sizes
* chore: clean up some unused variables
* fix: correct tail number calculation, added small cases and extra instances to the test
* fix: properly support wave transfer by generating correct grid descriptors dependent on the transfer method
[ROCm/composable_kernel commit: fe40a5d139]
This commit is contained in:
@@ -1,8 +1,9 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
# ONLY XDL_KERNELS
|
||||
# ONLY XDL_AND_WMMA_KERNELS
|
||||
add_instance_library(device_batched_gemm_bias_permute_instance
|
||||
device_batched_gemm_bias_permute_m2_n3_k1_xdl_c_shuffle_f16_f16_f16_f16_instance.cpp
|
||||
device_batched_gemm_bias_permute_m2_n3_k1_wmma_c_shuffle_f16_f16_f16_f16_instance.cpp
|
||||
)
|
||||
|
||||
|
||||
@@ -0,0 +1,78 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
// This (ifndef) is a hack to use customized behavior for buffer load rather than using default
|
||||
// setting Don't use this hack unless absolutely necessary!
|
||||
// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op
|
||||
#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
using F16_Tuple = ck::Tuple<F16>;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using Add = ck::tensor_operation::element_wise::Add;
|
||||
|
||||
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
static constexpr auto ABSpec = ck::tensor_operation::device::TensorSpecialization::Packed;
|
||||
static constexpr auto DESpec = ck::tensor_operation::device::TensorSpecialization::Default;
|
||||
|
||||
// A[g0, m0, m1, k0] * B[g0, n0, n1, n2, k0] + D[g0, m0, m1, n0, n1, n2] = E[g0, n0, m0, n0, n1, m1]
|
||||
// m/n/n/n are the fast changing dimension for A/B/D/E
|
||||
using device_batched_contraction_bias_permute_m2_n3_k1_wmma_c_shuffle_f16_f16_f16_f16_mnnm_instance =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//################################################| NumDimG| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| A| B| DE| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CDEBlockTransferClusterLengths| CDEBlockTransfer|
|
||||
//################################################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Specialization| Specialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector|
|
||||
//################################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat|
|
||||
//################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceBatchedContractionMultipleD_Wmma_CShuffle_V3< 1, 2, 3, 1, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Add, GemmSpec, ABSpec, ABSpec, DESpec, 256, 256, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<1, 1>>,
|
||||
DeviceBatchedContractionMultipleD_Wmma_CShuffle_V3< 1, 2, 3, 1, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Add, GemmSpec, ABSpec, ABSpec, DESpec, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<1, 1>>,
|
||||
DeviceBatchedContractionMultipleD_Wmma_CShuffle_V3< 1, 2, 3, 1, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Add, GemmSpec, ABSpec, ABSpec, DESpec, 128, 128, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, S<1, 1>>,
|
||||
DeviceBatchedContractionMultipleD_Wmma_CShuffle_V3< 1, 2, 3, 1, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Add, GemmSpec, ABSpec, ABSpec, DESpec, 128, 64, 32, 64, 8, 8, 16, 16, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, S<1, 1>>,
|
||||
DeviceBatchedContractionMultipleD_Wmma_CShuffle_V3< 1, 2, 3, 1, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Add, GemmSpec, ABSpec, ABSpec, DESpec, 256, 256, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<1, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
|
||||
DeviceBatchedContractionMultipleD_Wmma_CShuffle_V3< 1, 2, 3, 1, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Add, GemmSpec, ABSpec, ABSpec, DESpec, 128, 64, 32, 64, 8, 8, 16, 16, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, S<4, 4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
|
||||
DeviceBatchedContractionMultipleD_Wmma_CShuffle_V3< 1, 2, 3, 1, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Add, GemmSpec, ABSpec, ABSpec, DESpec, 64, 64, 32, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, S<4, 4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_batched_contraction_bias_permute_m2_n3_k1_wmma_c_shuffle_f16_f16_f16_f16_mnnm_instance(
|
||||
std::vector<std::unique_ptr<DeviceBatchedContractionMultipleD<1,
|
||||
2,
|
||||
3,
|
||||
1,
|
||||
F16,
|
||||
F16,
|
||||
F16_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Add>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_batched_contraction_bias_permute_m2_n3_k1_wmma_c_shuffle_f16_f16_f16_f16_mnnm_instance{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
Reference in New Issue
Block a user