mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
Implement batched gemm gemm for RDNA (3 and 4) (#2612)
* Create new copies of existing device struct and gridwise struct for batched_gemm_softmax_gemm and disable the softmax part. Still based on old wmma pipelines. Also copy the example and remove the softmax part from the reference calculation. Works and results match reference except for tiny float errors in problem 2. * Turn DeviceBatchedGemmGemm_Wmma_CShuffleV3 into a proper DeviceBatchedGemmGemm derived class, with the right argument and invoker functions. Update example to use new definitions. * Remove unused cross-attention and self-attention kernels, arguments, and invokers. Also remove other unused Argument types. * Remove masking related code, test unusual sizes in example. * Remove remaining softmax related code from GridwiseBatchedGemmGemm_wmma_cshuffle_v3 and example. * Remove code related to numDims, bias, and TensorSpec from Device struct and example. * Add layout template parameters to device struct * Move (NPerBlock, LTilePerBlock) device struct template arguments up by two places to match XDL template argument ordering. * Merge accumulation data types into one type to match XDL device struct. * Remove NPerWmma template parameter from device struct and just set it equal to LPerWmma. Now device struct template params exactly match those for XDL batched gemm gemm. * Add support for RCCR layout and test this in example * Add batched_gemm_gemm_wmma to instance library + profiler, and add gtest just like for xdl. * Add RCCR instance and additional RCRR instance to library. * Remove unused permute and alpha related code. Time all tests. Fix B1 strides in argument verification. * Remove references to G0, G1 in favor of batch, reduce dimensionality of length and stride arrays. * Managed to replace old wmma gridwise pipeline and blockwise struct with new wmma blockwise pipeline. Some cleanup required but all tests pass. * Make TransposeC a proper template parameter that gets passed all the way from BlockGemmPipeline_Selector to WmmaGemm so we can use the correct settings for bacthed gemm gemm as well as regular gemm. Gemm universal tests now pass again. * Replace old LoopSched and PipelineVer params with BlockwiseGemm pipeline equivalents, and use these in instance factory. The v3 pipeline does not work yet, but v1 works for intrawave and interwave. * Adapt the A wave descriptor to deal with RDNA4 wmma. This fixes batched gemm gemm functionality on RDNA4. * Fixed two aspects of the v3 pipeline that were incorrect: First of all the blockwise copy operator was invoked once too many in all cases (RunRead and move window), which broke batched gemm gemm when the blockwise pipeline was used multiple times. Furthermore we should be using the mainloop (hotloop) for num_k_loop >=2 instead of num_k_loop >=3. Now we can use support any K dimension. * Remove num prefetch parameter from gridwise struct since we don't use it and it doesn't do anything, * Remove unused non-lds paths. * Test and update the IsSupportedArgument() and CheckValidity() functions for all layouts + padding modes and various problem sizes. * Add a lot of instances to the profiler with various blocksizes and pipelines, all verified. * Add support for BF16: instance library, tests, and examples. * Add examples for int8 and fp8, had to add type_convert_sp template specializations for the latter. * Template the library instance lists and add default padding instances. * Move memory calculations from the kernel to the Argument contructor. Also actually parse and use the user-provided batch strides. * Actually parse and use user-provided regular strides. * More refactor: remove references to multiple dims per dims, and g0 / g1. Also move xdl specific test utils out of generic test util header. * Small post-rebase-on-develop fix due to bscale-related pipeline changes. All tests rerun + tested bscale and regular gemm. * Introduce the correct GetCThreadDescriptor function in the blockwise gemm pipelines for the TransposeC=true case. It turns out to be identical for our batched gemm gemm (gemm0) usecases, but could theoretically be different for wmma_gemm instances with smaller-than-4-byte output data size. * Remove unused NumPrefetch template parameter, we don't need to match the XDL template params one-to-one. * Implement proper TailNum and HasMainLoop template parameters for the v3 pipeline. Now the Run() function knows at compile time whether there are 1, 2, or more loops in total, and adds or removes sections accordingly. It still uses the blockwise copy operators the correct amount of times. * Add print lambda with env check and file and func to device and gridwise level compatibility error messages. Also respect compatibility in example script. * RDNA3 does not support fp8
This commit is contained in:
committed by
GitHub
parent
ef6c28e989
commit
7330ec37ee
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -16,6 +16,70 @@ namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
#ifdef CK_USE_WMMA
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_batched_gemm_gemm_wmma_cshuffle_v3_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance(
|
||||
std::vector<std::unique_ptr<DeviceBatchedGemmGemm<Row,
|
||||
Col,
|
||||
Row,
|
||||
Row,
|
||||
BF16,
|
||||
BF16,
|
||||
BF16,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
void add_device_batched_gemm_gemm_wmma_cshuffle_v3_bf16_bf16_bf16_bf16_gmk_gnk_gon_gmo_instance(
|
||||
std::vector<std::unique_ptr<DeviceBatchedGemmGemm<Row,
|
||||
Col,
|
||||
Col,
|
||||
Row,
|
||||
BF16,
|
||||
BF16,
|
||||
BF16,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif // CK_ENABLE_BF16
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_batched_gemm_gemm_wmma_cshuffle_v3_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance(
|
||||
std::vector<std::unique_ptr<DeviceBatchedGemmGemm<Row,
|
||||
Col,
|
||||
Row,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
void add_device_batched_gemm_gemm_wmma_cshuffle_v3_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance(
|
||||
std::vector<std::unique_ptr<DeviceBatchedGemmGemm<Row,
|
||||
Col,
|
||||
Col,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif // CK_ENABLE_FP16
|
||||
#endif // CK_USE_WMMA
|
||||
#ifdef CK_USE_XDL
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance(
|
||||
std::vector<std::unique_ptr<DeviceBatchedGemmGemm<Row,
|
||||
@@ -46,6 +110,8 @@ void add_device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_i
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif // CK_ENABLE_FP16
|
||||
#endif // CK_USE_XDL
|
||||
template <typename ALayout,
|
||||
typename B0Layout,
|
||||
typename B1Layout,
|
||||
@@ -86,7 +152,46 @@ struct DeviceOperationInstanceFactory<
|
||||
static auto GetInstances()
|
||||
{
|
||||
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
|
||||
|
||||
#ifdef CK_USE_WMMA
|
||||
#ifdef CK_ENABLE_BF16
|
||||
if constexpr(is_same_v<ADataType, bhalf_t> && is_same_v<B0DataType, bhalf_t> &&
|
||||
is_same_v<B1DataType, bhalf_t> && is_same_v<CDataType, bhalf_t>)
|
||||
{
|
||||
if constexpr(is_same_v<ALayout, Row> && is_same_v<B0Layout, Col> &&
|
||||
is_same_v<B1Layout, Row> && is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_batched_gemm_gemm_wmma_cshuffle_v3_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance(
|
||||
op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Row> && is_same_v<B0Layout, Col> &&
|
||||
is_same_v<B1Layout, Col> && is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_batched_gemm_gemm_wmma_cshuffle_v3_bf16_bf16_bf16_bf16_gmk_gnk_gon_gmo_instance(
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif // CK_ENABLE_BF16
|
||||
#ifdef CK_ENABLE_FP16
|
||||
if constexpr(is_same_v<ADataType, half_t> && is_same_v<B0DataType, half_t> &&
|
||||
is_same_v<B1DataType, half_t> && is_same_v<CDataType, half_t>)
|
||||
{
|
||||
if constexpr(is_same_v<ALayout, Row> && is_same_v<B0Layout, Col> &&
|
||||
is_same_v<B1Layout, Row> && is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_batched_gemm_gemm_wmma_cshuffle_v3_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance(
|
||||
op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Row> && is_same_v<B0Layout, Col> &&
|
||||
is_same_v<B1Layout, Col> && is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_batched_gemm_gemm_wmma_cshuffle_v3_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance(
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif // CK_ENABLE_FP16
|
||||
#endif // CK_USE_WMMA
|
||||
#ifdef CK_USE_XDL
|
||||
#ifdef CK_ENABLE_FP16
|
||||
if constexpr(is_same_v<ADataType, half_t> && is_same_v<B0DataType, half_t> &&
|
||||
is_same_v<B1DataType, half_t> && is_same_v<CDataType, half_t>)
|
||||
{
|
||||
@@ -103,10 +208,11 @@ struct DeviceOperationInstanceFactory<
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif // CK_ENABLE_FP16
|
||||
#endif // CK_USE_XDL
|
||||
return op_ptrs;
|
||||
}
|
||||
};
|
||||
#endif
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
@@ -1,5 +1,9 @@
|
||||
# ONLY XDL_KERNELS
|
||||
# ONLY XDL_AND_WMMA_KERNELS
|
||||
add_instance_library(device_batched_gemm_gemm_instance
|
||||
device_batched_gemm_gemm_wmma_cshuffle_v3_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp
|
||||
device_batched_gemm_gemm_wmma_cshuffle_v3_bf16_bf16_bf16_bf16_gmk_gnk_gon_gmo_instance.cpp
|
||||
device_batched_gemm_gemm_wmma_cshuffle_v3_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
|
||||
device_batched_gemm_gemm_wmma_cshuffle_v3_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance.cpp
|
||||
device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
|
||||
device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance.cpp
|
||||
)
|
||||
|
||||
@@ -0,0 +1,92 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
static constexpr auto GemmDefault = GemmSpecialization::Default;
|
||||
static constexpr auto GemmPadded = GemmSpecialization::MNKOPadding;
|
||||
|
||||
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
|
||||
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
|
||||
|
||||
static constexpr auto PipeVerV1 = BlockGemmPipelineVersion::v1;
|
||||
static constexpr auto PipeVerV3 = BlockGemmPipelineVersion::v3;
|
||||
|
||||
// gemm0: Acc[g, m, n] = A[g, m, k] * B0[g, k, n]
|
||||
// gemm1: C[g, m, o] = Acc[g, m, n] * B1[g, n, o]
|
||||
// Note that in some cases the "m, o, n" dimensions are referred to as the "gemm1 m, n, k"
|
||||
// dimensions instead!
|
||||
template <GemmSpecialization GemmSpec,
|
||||
BlockGemmPipelineScheduler PipeSched,
|
||||
BlockGemmPipelineVersion PipeVer>
|
||||
using device_batched_gemm_gemm_wmma_cshuffle_v3_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances =
|
||||
std::
|
||||
tuple<
|
||||
// clang-format off
|
||||
//################################| ALayout| B0Layout| B1Layout| CLayout| AData| B0Data| B1Data| CData| AccData| CShuffle| A| B0| Acc0| B1| C| GEMM| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| BlkGemm| BlkGemm|
|
||||
//################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| PipeSched| PipelineVer|
|
||||
//################################| | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | |
|
||||
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Row, Row, BF16, BF16, BF16, BF16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 16, 64, 64, 64, 64, 8, 8, 8, 16, 16, 1, 4, 4, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<2, 2, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, false, 1, 1, S<1, 16, 1, 2>, 8, PipeSched, PipeVer>,
|
||||
DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Row, Row, BF16, BF16, BF16, BF16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 16, 128, 64, 64, 64, 8, 8, 8, 16, 16, 1, 8, 4, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<2, 2, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, false, 1, 1, S<1, 16, 1, 2>, 8, PipeSched, PipeVer>,
|
||||
DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Row, Row, BF16, BF16, BF16, BF16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 128, 64, 64, 64, 8, 8, 8, 16, 16, 1, 8, 4, S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<2, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, false, 1, 1, S<1, 32, 1, 2>, 8, PipeSched, PipeVer>,
|
||||
DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Row, Row, BF16, BF16, BF16, BF16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 64, 64, 8, 8, 8, 16, 16, 1, 4, 4, S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<2, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, false, 1, 1, S<1, 32, 1, 2>, 8, PipeSched, PipeVer>,
|
||||
DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Row, Row, BF16, BF16, BF16, BF16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 128, 64, 64, 64, 8, 8, 8, 16, 16, 1, 8, 4, S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false, 1, 1, S<1, 64, 1, 2>, 8, PipeSched, PipeVer>,
|
||||
DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Row, Row, BF16, BF16, BF16, BF16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 64, 64, 64, 64, 8, 8, 8, 16, 16, 1, 4, 4, S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false, 1, 1, S<1, 64, 1, 2>, 8, PipeSched, PipeVer>,
|
||||
DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Row, Row, BF16, BF16, BF16, BF16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 64, 64, 8, 8, 8, 16, 16, 1, 8, 4, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<2, 16, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, false, 1, 1, S<1, 128, 1, 2>, 8, PipeSched, PipeVer>,
|
||||
DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Row, Row, BF16, BF16, BF16, BF16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 64, 64, 8, 8, 8, 16, 16, 1, 8, 4, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<2, 16, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, false, 1, 1, S<1, 128, 1, 2>, 8, PipeSched, PipeVer>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_batched_gemm_gemm_wmma_cshuffle_v3_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance(
|
||||
std::vector<std::unique_ptr<DeviceBatchedGemmGemm<Row,
|
||||
Col,
|
||||
Row,
|
||||
Row,
|
||||
BF16,
|
||||
BF16,
|
||||
BF16,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
// clang-format off
|
||||
add_device_operation_instances(instances, device_batched_gemm_gemm_wmma_cshuffle_v3_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances<GemmDefault, Intrawave, PipeVerV1>{});
|
||||
add_device_operation_instances(instances, device_batched_gemm_gemm_wmma_cshuffle_v3_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances<GemmDefault, Interwave, PipeVerV1>{});
|
||||
add_device_operation_instances(instances, device_batched_gemm_gemm_wmma_cshuffle_v3_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances<GemmDefault, Intrawave, PipeVerV3>{});
|
||||
add_device_operation_instances(instances, device_batched_gemm_gemm_wmma_cshuffle_v3_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances<GemmPadded, Intrawave, PipeVerV1>{});
|
||||
add_device_operation_instances(instances, device_batched_gemm_gemm_wmma_cshuffle_v3_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances<GemmPadded, Interwave, PipeVerV1>{});
|
||||
add_device_operation_instances(instances, device_batched_gemm_gemm_wmma_cshuffle_v3_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances<GemmPadded, Intrawave, PipeVerV3>{});
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,92 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
static constexpr auto GemmDefault = GemmSpecialization::Default;
|
||||
static constexpr auto GemmPadded = GemmSpecialization::MNKOPadding;
|
||||
|
||||
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
|
||||
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
|
||||
|
||||
static constexpr auto PipeVerV1 = BlockGemmPipelineVersion::v1;
|
||||
static constexpr auto PipeVerV3 = BlockGemmPipelineVersion::v3;
|
||||
|
||||
// gemm0: Acc[g, m, n] = A[g, m, k] * B0[g, k, n]
|
||||
// gemm1: C[g, m, o] = Acc[g, m, n] * B1[g, n, o]
|
||||
// Note that in some cases the "m, o, n" dimensions are referred to as the "gemm1 m, n, k"
|
||||
// dimensions instead!
|
||||
template <GemmSpecialization GemmSpec,
|
||||
BlockGemmPipelineScheduler PipeSched,
|
||||
BlockGemmPipelineVersion PipeVer>
|
||||
using device_batched_gemm_gemm_wmma_cshuffle_v3_bf16_bf16_bf16_bf16_gmk_gnk_gon_gmo_instances =
|
||||
std::
|
||||
tuple<
|
||||
// clang-format off
|
||||
//################################| ALayout| B0Layout| B1Layout| CLayout| AData| B0Data| B1Data| CData| AccData| CShuffle| A| B0| Acc0| B1| C| GEMM| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| BlkGemm| BlkGemm|
|
||||
//################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| PipeSched| PipelineVer|
|
||||
//################################| | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | |
|
||||
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Col, Row, BF16, BF16, BF16, BF16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 16, 64, 64, 64, 64, 8, 8, 8, 16, 16, 1, 4, 4, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, 1, 1, S<1, 16, 1, 2>, 8, PipeSched, PipeVer>,
|
||||
DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Col, Row, BF16, BF16, BF16, BF16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 16, 128, 64, 64, 64, 8, 8, 8, 16, 16, 1, 8, 4, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, 1, 1, S<1, 16, 1, 2>, 8, PipeSched, PipeVer>,
|
||||
DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Col, Row, BF16, BF16, BF16, BF16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 128, 64, 64, 64, 8, 8, 8, 16, 16, 1, 8, 4, S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, 1, 1, S<1, 32, 1, 2>, 8, PipeSched, PipeVer>,
|
||||
DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Col, Row, BF16, BF16, BF16, BF16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 64, 64, 8, 8, 8, 16, 16, 1, 4, 4, S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, 1, 1, S<1, 32, 1, 2>, 8, PipeSched, PipeVer>,
|
||||
DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Col, Row, BF16, BF16, BF16, BF16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 128, 64, 64, 64, 8, 8, 8, 16, 16, 1, 8, 4, S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, 1, 1, S<1, 64, 1, 2>, 8, PipeSched, PipeVer>,
|
||||
DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Col, Row, BF16, BF16, BF16, BF16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 64, 64, 64, 64, 8, 8, 8, 16, 16, 1, 4, 4, S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, 1, 1, S<1, 64, 1, 2>, 8, PipeSched, PipeVer>,
|
||||
DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Col, Row, BF16, BF16, BF16, BF16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 64, 64, 8, 8, 8, 16, 16, 1, 8, 4, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, 1, 1, S<1, 128, 1, 2>, 8, PipeSched, PipeVer>,
|
||||
DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Col, Row, BF16, BF16, BF16, BF16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 64, 64, 8, 8, 8, 16, 16, 1, 8, 4, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, 1, 1, S<1, 128, 1, 2>, 8, PipeSched, PipeVer>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_batched_gemm_gemm_wmma_cshuffle_v3_bf16_bf16_bf16_bf16_gmk_gnk_gon_gmo_instance(
|
||||
std::vector<std::unique_ptr<DeviceBatchedGemmGemm<Row,
|
||||
Col,
|
||||
Col,
|
||||
Row,
|
||||
BF16,
|
||||
BF16,
|
||||
BF16,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
// clang-format off
|
||||
add_device_operation_instances(instances, device_batched_gemm_gemm_wmma_cshuffle_v3_bf16_bf16_bf16_bf16_gmk_gnk_gon_gmo_instances<GemmDefault, Intrawave, PipeVerV1>{});
|
||||
add_device_operation_instances(instances, device_batched_gemm_gemm_wmma_cshuffle_v3_bf16_bf16_bf16_bf16_gmk_gnk_gon_gmo_instances<GemmDefault, Interwave, PipeVerV1>{});
|
||||
add_device_operation_instances(instances, device_batched_gemm_gemm_wmma_cshuffle_v3_bf16_bf16_bf16_bf16_gmk_gnk_gon_gmo_instances<GemmDefault, Intrawave, PipeVerV3>{});
|
||||
add_device_operation_instances(instances, device_batched_gemm_gemm_wmma_cshuffle_v3_bf16_bf16_bf16_bf16_gmk_gnk_gon_gmo_instances<GemmPadded, Intrawave, PipeVerV1>{});
|
||||
add_device_operation_instances(instances, device_batched_gemm_gemm_wmma_cshuffle_v3_bf16_bf16_bf16_bf16_gmk_gnk_gon_gmo_instances<GemmPadded, Interwave, PipeVerV1>{});
|
||||
add_device_operation_instances(instances, device_batched_gemm_gemm_wmma_cshuffle_v3_bf16_bf16_bf16_bf16_gmk_gnk_gon_gmo_instances<GemmPadded, Intrawave, PipeVerV3>{});
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,92 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
static constexpr auto GemmDefault = GemmSpecialization::Default;
|
||||
static constexpr auto GemmPadded = GemmSpecialization::MNKOPadding;
|
||||
|
||||
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
|
||||
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
|
||||
|
||||
static constexpr auto PipeVerV1 = BlockGemmPipelineVersion::v1;
|
||||
static constexpr auto PipeVerV3 = BlockGemmPipelineVersion::v3;
|
||||
|
||||
// gemm0: Acc[g, m, n] = A[g, m, k] * B0[g, k, n]
|
||||
// gemm1: C[g, m, o] = Acc[g, m, n] * B1[g, n, o]
|
||||
// Note that in some cases the "m, o, n" dimensions are referred to as the "gemm1 m, n, k"
|
||||
// dimensions instead!
|
||||
template <GemmSpecialization GemmSpec,
|
||||
BlockGemmPipelineScheduler PipeSched,
|
||||
BlockGemmPipelineVersion PipeVer>
|
||||
using device_batched_gemm_gemm_wmma_cshuffle_v3_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances =
|
||||
std::
|
||||
tuple<
|
||||
// clang-format off
|
||||
//################################| ALayout| B0Layout| B1Layout| CLayout| AData| B0Data| B1Data| CData| AccData| CShuffle| A| B0| Acc0| B1| C| GEMM| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| BlkGemm| BlkGemm|
|
||||
//################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| PipeSched| PipelineVer|
|
||||
//################################| | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | |
|
||||
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Row, Row, F16, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmPadded, 32, 16, 64, 64, 64, 64, 8, 8, 8, 16, 16, 1, 4, 4, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<2, 2, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, false, 1, 1, S<1, 16, 1, 2>, 8, PipeSched, PipeVer>,
|
||||
DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Row, Row, F16, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmPadded, 32, 16, 128, 64, 64, 64, 8, 8, 8, 16, 16, 1, 8, 4, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<2, 2, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, false, 1, 1, S<1, 16, 1, 2>, 8, PipeSched, PipeVer>,
|
||||
DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Row, Row, F16, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmPadded, 64, 32, 128, 64, 64, 64, 8, 8, 8, 16, 16, 1, 8, 4, S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<2, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, false, 1, 1, S<1, 32, 1, 2>, 8, PipeSched, PipeVer>,
|
||||
DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Row, Row, F16, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmPadded, 64, 32, 64, 64, 64, 64, 8, 8, 8, 16, 16, 1, 4, 4, S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<2, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, false, 1, 1, S<1, 32, 1, 2>, 8, PipeSched, PipeVer>,
|
||||
DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Row, Row, F16, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmPadded, 128, 64, 128, 64, 64, 64, 8, 8, 8, 16, 16, 1, 8, 4, S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false, 1, 1, S<1, 64, 1, 2>, 8, PipeSched, PipeVer>,
|
||||
DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Row, Row, F16, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmPadded, 128, 64, 64, 64, 64, 64, 8, 8, 8, 16, 16, 1, 4, 4, S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false, 1, 1, S<1, 64, 1, 2>, 8, PipeSched, PipeVer>,
|
||||
DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Row, Row, F16, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmPadded, 256, 128, 128, 64, 64, 64, 8, 8, 8, 16, 16, 1, 8, 4, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<2, 16, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, false, 1, 1, S<1, 128, 1, 2>, 8, PipeSched, PipeVer>,
|
||||
DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Row, Row, F16, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmPadded, 256, 128, 128, 64, 64, 64, 8, 8, 8, 16, 16, 1, 8, 4, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<2, 16, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, false, 1, 1, S<1, 128, 1, 2>, 8, PipeSched, PipeVer>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_batched_gemm_gemm_wmma_cshuffle_v3_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance(
|
||||
std::vector<std::unique_ptr<DeviceBatchedGemmGemm<Row,
|
||||
Col,
|
||||
Row,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
// clang-format off
|
||||
add_device_operation_instances(instances, device_batched_gemm_gemm_wmma_cshuffle_v3_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances<GemmDefault, Intrawave, PipeVerV1>{});
|
||||
add_device_operation_instances(instances, device_batched_gemm_gemm_wmma_cshuffle_v3_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances<GemmDefault, Interwave, PipeVerV1>{});
|
||||
add_device_operation_instances(instances, device_batched_gemm_gemm_wmma_cshuffle_v3_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances<GemmDefault, Intrawave, PipeVerV3>{});
|
||||
add_device_operation_instances(instances, device_batched_gemm_gemm_wmma_cshuffle_v3_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances<GemmPadded, Intrawave, PipeVerV1>{});
|
||||
add_device_operation_instances(instances, device_batched_gemm_gemm_wmma_cshuffle_v3_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances<GemmPadded, Interwave, PipeVerV1>{});
|
||||
add_device_operation_instances(instances, device_batched_gemm_gemm_wmma_cshuffle_v3_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances<GemmPadded, Intrawave, PipeVerV3>{});
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,92 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
static constexpr auto GemmDefault = GemmSpecialization::Default;
|
||||
static constexpr auto GemmPadded = GemmSpecialization::MNKOPadding;
|
||||
|
||||
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
|
||||
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
|
||||
|
||||
static constexpr auto PipeVerV1 = BlockGemmPipelineVersion::v1;
|
||||
static constexpr auto PipeVerV3 = BlockGemmPipelineVersion::v3;
|
||||
|
||||
// gemm0: Acc[g, m, n] = A[g, m, k] * B0[g, k, n]
|
||||
// gemm1: C[g, m, o] = Acc[g, m, n] * B1[g, n, o]
|
||||
// Note that in some cases the "m, o, n" dimensions are referred to as the "gemm1 m, n, k"
|
||||
// dimensions instead!
|
||||
template <GemmSpecialization GemmSpec,
|
||||
BlockGemmPipelineScheduler PipeSched,
|
||||
BlockGemmPipelineVersion PipeVer>
|
||||
using device_batched_gemm_gemm_wmma_cshuffle_v3_f16_f16_f16_f16_gmk_gnk_gon_gmo_instances =
|
||||
std::
|
||||
tuple<
|
||||
// clang-format off
|
||||
//################################| ALayout| B0Layout| B1Layout| CLayout| AData| B0Data| B1Data| CData| AccData| CShuffle| A| B0| Acc0| B1| C| GEMM| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| BlkGemm| BlkGemm|
|
||||
//################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| PipeSched| PipelineVer|
|
||||
//################################| | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | |
|
||||
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Col, Row, F16, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 16, 64, 64, 64, 64, 8, 8, 8, 16, 16, 1, 4, 4, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, 1, 1, S<1, 16, 1, 2>, 8, PipeSched, PipeVer>,
|
||||
DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Col, Row, F16, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 16, 128, 64, 64, 64, 8, 8, 8, 16, 16, 1, 8, 4, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, 1, 1, S<1, 16, 1, 2>, 8, PipeSched, PipeVer>,
|
||||
DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Col, Row, F16, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 128, 64, 64, 64, 8, 8, 8, 16, 16, 1, 8, 4, S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, 1, 1, S<1, 32, 1, 2>, 8, PipeSched, PipeVer>,
|
||||
DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Col, Row, F16, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 64, 64, 8, 8, 8, 16, 16, 1, 4, 4, S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, 1, 1, S<1, 32, 1, 2>, 8, PipeSched, PipeVer>,
|
||||
DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Col, Row, F16, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 128, 64, 64, 64, 8, 8, 8, 16, 16, 1, 8, 4, S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, 1, 1, S<1, 64, 1, 2>, 8, PipeSched, PipeVer>,
|
||||
DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Col, Row, F16, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 64, 64, 64, 64, 8, 8, 8, 16, 16, 1, 4, 4, S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, 1, 1, S<1, 64, 1, 2>, 8, PipeSched, PipeVer>,
|
||||
DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Col, Row, F16, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 64, 64, 8, 8, 8, 16, 16, 1, 8, 4, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, 1, 1, S<1, 128, 1, 2>, 8, PipeSched, PipeVer>,
|
||||
DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Col, Row, F16, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 64, 64, 8, 8, 8, 16, 16, 1, 8, 4, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, 1, 1, S<1, 128, 1, 2>, 8, PipeSched, PipeVer>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_batched_gemm_gemm_wmma_cshuffle_v3_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance(
|
||||
std::vector<std::unique_ptr<DeviceBatchedGemmGemm<Row,
|
||||
Col,
|
||||
Col,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
// clang-format off
|
||||
add_device_operation_instances(instances, device_batched_gemm_gemm_wmma_cshuffle_v3_f16_f16_f16_f16_gmk_gnk_gon_gmo_instances<GemmDefault, Intrawave, PipeVerV1>{});
|
||||
add_device_operation_instances(instances, device_batched_gemm_gemm_wmma_cshuffle_v3_f16_f16_f16_f16_gmk_gnk_gon_gmo_instances<GemmDefault, Interwave, PipeVerV1>{});
|
||||
add_device_operation_instances(instances, device_batched_gemm_gemm_wmma_cshuffle_v3_f16_f16_f16_f16_gmk_gnk_gon_gmo_instances<GemmDefault, Intrawave, PipeVerV3>{});
|
||||
add_device_operation_instances(instances, device_batched_gemm_gemm_wmma_cshuffle_v3_f16_f16_f16_f16_gmk_gnk_gon_gmo_instances<GemmPadded, Intrawave, PipeVerV1>{});
|
||||
add_device_operation_instances(instances, device_batched_gemm_gemm_wmma_cshuffle_v3_f16_f16_f16_f16_gmk_gnk_gon_gmo_instances<GemmPadded, Interwave, PipeVerV1>{});
|
||||
add_device_operation_instances(instances, device_batched_gemm_gemm_wmma_cshuffle_v3_f16_f16_f16_f16_gmk_gnk_gon_gmo_instances<GemmPadded, Intrawave, PipeVerV3>{});
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
Reference in New Issue
Block a user