mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-15 10:37:44 +00:00
* wavelet gemm programming model support for CK
* GEMM pipeline update for wavelet progrmmaing model
* Updated wavelet programming pipeline
* fixes for global-write for math-wave
* fixed bug in global writes
* Updated comments for better readability
* fixed clang format errors
* added block_lds without barrier sync
* clean
* clean
* clean
* clean
* refactor
* prototype
4 layouts
fix default stride
all problem sizes
tidy
move file
update build script
restore old file
fix build
* refactor standalone test to use gemm test harness
* simplify gemm test
* update build script
* remove redundant
* early return when cmd arg doesn't match
* tidy
* report failure when result not validated
* tidy
* Add comment depicting B2C mapping pattern.
* Formatting & comments.
* Comparison with custom B2C mapping pattern.
* Example for wavelet gemm.
* Add wavelet to Gemm standalone test.
* Remove debug code.
* Remove dangling #endif directive.
Co-authored-by: root <Raman Jana>
Co-authored-by: Chao Liu <chao.liu2@amd.com>
Co-authored-by: Adam Osewski <aosewski@amd.com>
Co-authored-by: Anthony Chang <ac.chang@outlook.com>
Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com>
[ROCm/composable_kernel commit: 1cfa87608a]
53 lines
6.0 KiB
C++
53 lines
6.0 KiB
C++
// SPDX-License-Identifier: MIT
|
|
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
|
|
|
#include "common.hpp"
|
|
|
|
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl.hpp"
|
|
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp"
|
|
|
|
using ADataType = ck::half_t;
|
|
using BDataType = ck::half_t;
|
|
using AccDataType = float;
|
|
using CShuffleDataType = float;
|
|
using CDataType = ck::half_t;
|
|
|
|
using F16 = ck::half_t;
|
|
|
|
using ALayout = Row;
|
|
using BLayout = Col;
|
|
using CLayout = Row;
|
|
|
|
using AElementOp = PassThrough;
|
|
using BElementOp = PassThrough;
|
|
using CElementOp = PassThrough;
|
|
|
|
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
|
|
|
|
// clang-format off
|
|
using DeviceGemmInstance0 = ck::tensor_operation::device::DeviceGemmXdl
|
|
// ######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
|
|
// ######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
|
|
// ######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
|
|
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
|
< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>;
|
|
// // clang-format on
|
|
|
|
// clang-format off
|
|
using DeviceGemmInstance1 = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle
|
|
// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
|
// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
|
// ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
|
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
|
< ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
|
|
// clang-format on
|
|
|
|
using DeviceGemmInstance = DeviceGemmInstance1;
|
|
|
|
using ReferenceGemmInstance = ck::tensor_operation::host::
|
|
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
|
|
|
|
#include "run_gemm_example.inc"
|
|
|
|
int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); }
|