mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-13 01:36:06 +00:00
CK: Extract shared boilerplate from 47 gemm_quant test files (#6323) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Depends on #6303 ## Summary Extract shared test boilerplate (includes, type aliases, test fixture macros) from 47 `test_gemm_quant_*` files into a single `test_gemm_quant_common.hpp` header. Each test file is reduced from ~50 lines of boilerplate to ~5 lines. | Metric | Value | |--------|-------| | Files changed | 48 | | Insertions | +413 | | Deletions | −1,106 | | **Net lines removed** | **−693** | ### What changed | Before | After | |--------|-------| | 47 test files, each with ~50 lines of identical includes, type aliases, and fixture macros | 1 shared header (`test_gemm_quant_common.hpp`) + 47 thin files (~5 lines each: include + params) | ### Readability assessment A code realist review confirmed this change **improves readability**: the 47 test files had identical boilerplate obscuring the only meaningful content — the `GemmConfig` type alias and test dimensions. After the refactoring, each file's unique configuration is immediately visible, and adding a new test variant requires specifying only the varying parameters instead of copying 50 lines. ### Cumulative cleanup series stats | PR | Description | Net lines | |----|-------------|-----------| | #6300 | Remove 61 dead `#if 0` blocks | −2,648 | | #6302 | Remove 41 commented-out dead code blocks | −2,861 | | #6303 | Remove 4 orphaned files | −3,886 | | This PR | Extract gemm_quant test boilerplate | −693 | | **Total** | | **−10,088** |
229 lines
30 KiB
C++
229 lines
30 KiB
C++
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
|
// SPDX-License-Identifier: MIT
|
|
|
|
#pragma once
|
|
|
|
#include "ck/ck.hpp"
|
|
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
|
#include "ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp"
|
|
|
|
using F16 = ck::half_t;
|
|
using BF16 = ck::bhalf_t;
|
|
using F32 = float;
|
|
using F64 = double;
|
|
|
|
template <ck::index_t... Is>
|
|
using S = ck::Sequence<Is...>;
|
|
|
|
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
|
|
|
// Generic instances for fp32, fp16 and bf16 data types.
|
|
template <ck::index_t NumDimM,
|
|
ck::index_t NumDimN,
|
|
ck::index_t NumDimK,
|
|
typename ADataType,
|
|
typename BDataType,
|
|
typename AccDataType,
|
|
typename CShuffleDataType,
|
|
typename DsDataType,
|
|
typename EDataType,
|
|
typename ComputeDataType,
|
|
typename AElementOp,
|
|
typename BElementOp,
|
|
typename CDEElementOp>
|
|
// clang-format off
|
|
using DeviceOpInstanceKK_Generic = ck::tensor_operation::device::
|
|
//#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute|
|
|
//#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Data|
|
|
//#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Type|
|
|
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
|
DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>;
|
|
// clang-format on
|
|
|
|
template <ck::index_t NumDimM,
|
|
ck::index_t NumDimN,
|
|
ck::index_t NumDimK,
|
|
typename ADataType,
|
|
typename BDataType,
|
|
typename AccDataType,
|
|
typename CShuffleDataType,
|
|
typename DsDataType,
|
|
typename EDataType,
|
|
typename ComputeDataType,
|
|
typename AElementOp,
|
|
typename BElementOp,
|
|
typename CDEElementOp>
|
|
// clang-format off
|
|
using DeviceOpInstanceKN_Generic = ck::tensor_operation::device::
|
|
//#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute|
|
|
//#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Data|
|
|
//#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Type|
|
|
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
|
DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 4, 1, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>;
|
|
// clang-format on
|
|
|
|
template <ck::index_t NumDimM,
|
|
ck::index_t NumDimN,
|
|
ck::index_t NumDimK,
|
|
typename ADataType,
|
|
typename BDataType,
|
|
typename AccDataType,
|
|
typename CShuffleDataType,
|
|
typename DsDataType,
|
|
typename EDataType,
|
|
typename ComputeDataType,
|
|
typename AElementOp,
|
|
typename BElementOp,
|
|
typename CDEElementOp>
|
|
// clang-format off
|
|
using DeviceOpInstanceMK_Generic = ck::tensor_operation::device::
|
|
//#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute|
|
|
//#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Data|
|
|
//#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Type|
|
|
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
|
DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 1, 4, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>;
|
|
// clang-format on
|
|
|
|
template <ck::index_t NumDimM,
|
|
ck::index_t NumDimN,
|
|
ck::index_t NumDimK,
|
|
typename ADataType,
|
|
typename BDataType,
|
|
typename AccDataType,
|
|
typename CShuffleDataType,
|
|
typename DsDataType,
|
|
typename EDataType,
|
|
typename ComputeDataType,
|
|
typename AElementOp,
|
|
typename BElementOp,
|
|
typename CDEElementOp>
|
|
// clang-format off
|
|
using DeviceOpInstanceMN_Generic = ck::tensor_operation::device::
|
|
//#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute|
|
|
//#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Data|
|
|
//#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Type|
|
|
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
|
DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 1, 1, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>;
|
|
// clang-format on
|
|
|
|
// Fp64 instances.
|
|
template <ck::index_t NumDimM,
|
|
ck::index_t NumDimN,
|
|
ck::index_t NumDimK,
|
|
typename ADataType,
|
|
typename BDataType,
|
|
typename AccDataType,
|
|
typename CShuffleDataType,
|
|
typename DsDataType,
|
|
typename EDataType,
|
|
typename ComputeDataType,
|
|
typename AElementOp,
|
|
typename BElementOp,
|
|
typename CDEElementOp>
|
|
// clang-format off
|
|
using DeviceOpInstanceKK_FP64 = ck::tensor_operation::device::
|
|
//#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute|
|
|
//#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Data|
|
|
//#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Type|
|
|
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
|
DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 16, 1, 16>, 1, ComputeDataType>;
|
|
// clang-format on
|
|
|
|
template <ck::index_t NumDimM,
|
|
ck::index_t NumDimN,
|
|
ck::index_t NumDimK,
|
|
typename ADataType,
|
|
typename BDataType,
|
|
typename AccDataType,
|
|
typename CShuffleDataType,
|
|
typename DsDataType,
|
|
typename EDataType,
|
|
typename ComputeDataType,
|
|
typename AElementOp,
|
|
typename BElementOp,
|
|
typename CDEElementOp>
|
|
// clang-format off
|
|
using DeviceOpInstanceKN_FP64 = ck::tensor_operation::device::
|
|
//#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute|
|
|
//#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Data|
|
|
//#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Type|
|
|
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
|
DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 16, 2, 1, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 0, 1, 1, S<1, 16, 1, 16>, 1, ComputeDataType>;
|
|
// clang-format on
|
|
|
|
template <ck::index_t NumDimM,
|
|
ck::index_t NumDimN,
|
|
ck::index_t NumDimK,
|
|
typename ADataType,
|
|
typename BDataType,
|
|
typename AccDataType,
|
|
typename CShuffleDataType,
|
|
typename DsDataType,
|
|
typename EDataType,
|
|
typename ComputeDataType,
|
|
typename AElementOp,
|
|
typename BElementOp,
|
|
typename CDEElementOp>
|
|
// clang-format off
|
|
using DeviceOpInstanceMK_FP64 = ck::tensor_operation::device::
|
|
//#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute|
|
|
//#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Data|
|
|
//#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Type|
|
|
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
|
DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 16, 1, 2, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 16, 1, 16>, 1, ComputeDataType>;
|
|
// clang-format on
|
|
|
|
template <ck::index_t NumDimM,
|
|
ck::index_t NumDimN,
|
|
ck::index_t NumDimK,
|
|
typename ADataType,
|
|
typename BDataType,
|
|
typename AccDataType,
|
|
typename CShuffleDataType,
|
|
typename DsDataType,
|
|
typename EDataType,
|
|
typename ComputeDataType,
|
|
typename AElementOp,
|
|
typename BElementOp,
|
|
typename CDEElementOp>
|
|
// clang-format off
|
|
using DeviceOpInstanceMN_FP64 = ck::tensor_operation::device::
|
|
//#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute|
|
|
//#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Data|
|
|
//#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Type|
|
|
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
|
DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 16, 1, 1, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 0, 1, 1, S<1, 16, 1, 16>, 1, ComputeDataType>;
|
|
// clang-format on
|
|
|
|
// Macro to instantiate all four layout variants of DeviceOpInstance.
|
|
//
|
|
// BASE: Generic (for fp16/bf16/fp32) or FP64 (for fp64 — different tile sizes)
|
|
// SUFFIX: NN for bilinear (DsDataType = Tuple<DDataType>),
|
|
// N for scale (DsDataType = Tuple<>)
|
|
//
|
|
// Requires these names to be defined in the calling TU before invocation:
|
|
// NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType,
|
|
// CShuffleDataType, DsDataType, EDataType, ComputeDataType,
|
|
// AElementOp, BElementOp, CDEElementOp
|
|
//
|
|
// Example: CK_CONTRACTION_DEVICE_OP_INSTANCES(Generic, NN);
|
|
// expands to DeviceOpInstanceKKNN, DeviceOpInstanceKNNN,
|
|
// DeviceOpInstanceMKNN, DeviceOpInstanceMNNN,
|
|
// and sets DeviceOpInstance = DeviceOpInstanceKKNN.
|
|
// clang-format off
|
|
#define CK_CONTRACTION_DEVICE_OP_INSTANCES(BASE, SUFFIX) \
|
|
using DeviceOpInstanceKK##SUFFIX = DeviceOpInstanceKK_##BASE<NumDimM, NumDimN, NumDimK, \
|
|
ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, \
|
|
ComputeDataType, AElementOp, BElementOp, CDEElementOp>; \
|
|
using DeviceOpInstanceKN##SUFFIX = DeviceOpInstanceKN_##BASE<NumDimM, NumDimN, NumDimK, \
|
|
ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, \
|
|
ComputeDataType, AElementOp, BElementOp, CDEElementOp>; \
|
|
using DeviceOpInstanceMK##SUFFIX = DeviceOpInstanceMK_##BASE<NumDimM, NumDimN, NumDimK, \
|
|
ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, \
|
|
ComputeDataType, AElementOp, BElementOp, CDEElementOp>; \
|
|
using DeviceOpInstanceMN##SUFFIX = DeviceOpInstanceMN_##BASE<NumDimM, NumDimN, NumDimK, \
|
|
ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, \
|
|
ComputeDataType, AElementOp, BElementOp, CDEElementOp>; \
|
|
using DeviceOpInstance = DeviceOpInstanceKK##SUFFIX
|
|
// clang-format on
|