diff --git a/example/01_gemm/gemm_xdl_bf16.cpp b/example/01_gemm/gemm_xdl_bf16.cpp index 8f0631c1ce..a4567dcd6e 100644 --- a/example/01_gemm/gemm_xdl_bf16.cpp +++ b/example/01_gemm/gemm_xdl_bf16.cpp @@ -11,8 +11,7 @@ #include "host_tensor.hpp" #include "host_tensor_generator.hpp" #include "device_tensor.hpp" -#include "device_gemm_xdl.hpp" -#include "device_gemm_xdl_c_shuffle.hpp" +#include "device_gemm_xdl_cshuffle.hpp" #include "element_wise_operation.hpp" #include "reference_gemm.hpp" #include "gemm_specialization.hpp" @@ -37,47 +36,51 @@ using ALayout = ck::tensor_layout::gemm::RowMajor; using BLayout = ck::tensor_layout::gemm::ColumnMajor; using CLayout = ck::tensor_layout::gemm::RowMajor; +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + // clang-format off -using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl_C_Shuffle< - ADataType, // ADataType - BDataType, // BDataType - CDataType, // CDataType - AccDataType, // AccDataType - CDataType, // CShuffleDataType - ALayout, // ALayout - BLayout, // BLayout - CLayout, // CLayout - PassThrough, // AElementwiseOperation - PassThrough, // BElementwiseOperation - PassThrough, // CElementwiseOperation - 256, // BlockSize - 256, // MPerBlock - 128, // NPerBlock - 32, // KPerBlock - 8, // AK1 - 8, // BK1 - 32, // MPerXDL - 32, // NPerXDL - 4, // MXdlPerWave - 2, // NXdlPerWave - S<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 - S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder - S<1, 0, 2>, // ABlockTransferSrcAccessOrder - 2, // ABlockTransferSrcVectorDim - 8, // ABlockTransferSrcScalarPerVector - 8, // ABlockTransferDstScalarPerVector_K1 - true, // ABlockLdsAddExtraM - S<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 - S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder - S<1, 0, 2>, // BBlockTransferSrcAccessOrder - 2, // BBlockTransferSrcVectorDim - 8, // BBlockTransferSrcScalarPerVector - 8, // BBlockTransferDstScalarPerVector_K1 - true, // BBlockLdsAddExtraN - 1, // CShuffleMXdlPerWavePerShuffle - 1, // CShuffleNXdlPerWavePerShuffle - S<1, 1, 32, 1, 1, 8>, // CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl - 8>; // CBlockTransferScalarPerVector_NWaveNPerXdl +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle + , // typename ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // typename ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // typename ABlockTransferSrcAccessOrder + 2, // index_t ABlockTransferSrcVectorDim + 8, // index_t ABlockTransferSrcScalarPerVector + 8, // index_t ABlockTransferDstScalarPerVector_AK1 + 1, // index_t ABlockLdsExtraM + S<4, 64, 1>, // typename BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // typename BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // typename BBlockTransferSrcAccessOrder + 2, // index_t BBlockTransferSrcVectorDim + 8, // index_t BBlockTransferSrcScalarPerVector + 8, // index_t BBlockTransferDstScalarPerVector_BK1 + 1, // index_t BBlockLdsExtraN + 1, // index_t CShuffleMXdlPerWavePerShuffle + 1, // index_t CShuffleNXdlPerWavePerShuffle + S<1, 32, 1, 8>, // typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 8>; // index_t CShuffleBlockTransferScalarPerVector_NPerBlock // clang-format on using ReferenceGemmInstance = ck::tensor_operation::host:: diff --git a/example/01_gemm/gemm_xdl_fp16.cpp b/example/01_gemm/gemm_xdl_fp16.cpp index 2d5a95e400..fc04a13ca5 100644 --- a/example/01_gemm/gemm_xdl_fp16.cpp +++ b/example/01_gemm/gemm_xdl_fp16.cpp @@ -4,7 +4,6 @@ #include #include #include - #include "check_err.hpp" #include "config.hpp" #include "device.hpp" @@ -12,7 +11,6 @@ #include "host_tensor_generator.hpp" #include "device_tensor.hpp" #include "device_gemm_xdl.hpp" -#include "device_gemm_xdl_c_shuffle.hpp" #include "device_gemm_xdl_cshuffle.hpp" #include "element_wise_operation.hpp" #include "reference_gemm.hpp" @@ -46,11 +44,11 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa // clang-format off using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle -//######| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| -//######| | | | Type| Type| Type| DataType| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| -//######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| -//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < Row, Col, Row, F16, F16, F16, F32, F32, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>; +//######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +//######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +//######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < Row, Col, Row, F16, F16, F16, F32, F32, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>; // clang-format on using ReferenceGemmInstance = ck::tensor_operation::host:: diff --git a/example/01_gemm/gemm_xdl_int8.cpp b/example/01_gemm/gemm_xdl_int8.cpp index 724757565e..ab5869db61 100644 --- a/example/01_gemm/gemm_xdl_int8.cpp +++ b/example/01_gemm/gemm_xdl_int8.cpp @@ -11,8 +11,7 @@ #include "host_tensor.hpp" #include "host_tensor_generator.hpp" #include "device_tensor.hpp" -#include "device_gemm_xdl.hpp" -#include "device_gemm_xdl_c_shuffle.hpp" +#include "device_gemm_xdl_cshuffle.hpp" #include "element_wise_operation.hpp" #include "reference_gemm.hpp" #include "gemm_specialization.hpp" @@ -20,64 +19,63 @@ template using S = ck::Sequence; -using F32 = float; - -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; - using PassThrough = ck::tensor_operation::element_wise::PassThrough; using ADataType = int8_t; using BDataType = int8_t; -using CDataType = int32_t; +using CDataType = int8_t; using AccDataType = int32_t; -using CShuffleDataType = int32_t; +using CShuffleDataType = int8_t; using ALayout = ck::tensor_layout::gemm::RowMajor; using BLayout = ck::tensor_layout::gemm::ColumnMajor; using CLayout = ck::tensor_layout::gemm::RowMajor; +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + // clang-format off -using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl_C_Shuffle< - ADataType, // ADataType - BDataType, // BDataType - CDataType, // CDataType - AccDataType, // AccDataType - CShuffleDataType, // CShuffleDataType - ALayout, // ALayout - BLayout, // BLayout - CLayout, // CLayout - PassThrough, // AElementwiseOperation - PassThrough, // BElementwiseOperation - PassThrough, // CElementwiseOperation - 256, // BlockSize - 256, // MPerBlock - 128, // NPerBlock - 64, // KPerBlock - 16, // AK1 - 16, // BK1 - 32, // MPerXDL - 32, // NPerXDL - 4, // MXdlPerWave - 2, // NXdlPerWave - S<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 - S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder - S<1, 0, 2>, // ABlockTransferSrcAccessOrder - 2, // ABlockTransferSrcVectorDim - 16, // ABlockTransferSrcScalarPerVector - 16, // ABlockTransferDstScalarPerVector_K1 - true, // ABlockLdsAddExtraM - S<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 - S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder - S<1, 0, 2>, // BBlockTransferSrcAccessOrder - 2, // BBlockTransferSrcVectorDim - 16, // BBlockTransferSrcScalarPerVector - 16, // BBlockTransferDstScalarPerVector_K1 - true, // BBlockLdsAddExtraN - 1, // CShuffleMXdlPerWavePerShuffle - 1, // CShuffleNXdlPerWavePerShuffle - S<1, 1, 32, 1, 1, 8>, // CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl - 4>; // CBlockTransferScalarPerVector_NWaveNPerXdl +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle< + ALayout, // typename ALayout + BLayout, // typename BLayout + CLayout, // typename CLayout + ADataType, // typename ADataType + BDataType, // typename BDataType + CDataType, // typename CDataType + AccDataType, // typename GemmAccDataType + CShuffleDataType, // typename CShuffleDataType + PassThrough, // typename AElementwiseOperation + PassThrough, // typename BElementwiseOperation + PassThrough, // typename CElementwiseOperation + GemmDefault, // GemmSpecialization GemmSpec + 1, // index_t NumGemmKPrefetchStage + 256, // index_t BlockSize + 256, // index_t MPerBlock + 128, // index_t NPerBlock + 64, // index_t KPerBlock + 16, // index_t AK1 + 16, // index_t BK1 + 32, // index_t MPerXDL + 32, // index_t NPerXDL + 4, // index_t MXdlPerWave + 2, // index_t NXdlPerWave + S<4, 64, 1>, // typename ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // typename ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // typename ABlockTransferSrcAccessOrder + 2, // index_t ABlockTransferSrcVectorDim + 16, // index_t ABlockTransferSrcScalarPerVector + 16, // index_t ABlockTransferDstScalarPerVector_AK1 + 1, // index_t ABlockLdsExtraM + S<4, 64, 1>, // typename BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // typename BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // typename BBlockTransferSrcAccessOrder + 2, // index_t BBlockTransferSrcVectorDim + 8, // index_t BBlockTransferSrcScalarPerVector + 8, // index_t BBlockTransferDstScalarPerVector_BK1 + 1, // index_t BBlockLdsExtraN + 1, // index_t CShuffleMXdlPerWavePerShuffle + 1, // index_t CShuffleNXdlPerWavePerShuffle + S<1, 64, 1, 4>, // typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 16>; // index_t CShuffleBlockTransferScalarPerVector_NPerBlock // clang-format on using ReferenceGemmInstance = ck::tensor_operation::host:: diff --git a/example/14_gemm_xdl_requant_relu_requant/gemm_xdl_requant_relu_requant_int8.cpp b/example/14_gemm_xdl_requant_relu_requant/gemm_xdl_requant_relu_requant_int8.cpp index ca3b58bd00..324dc35d3f 100644 --- a/example/14_gemm_xdl_requant_relu_requant/gemm_xdl_requant_relu_requant_int8.cpp +++ b/example/14_gemm_xdl_requant_relu_requant/gemm_xdl_requant_relu_requant_int8.cpp @@ -13,74 +13,91 @@ #include "host_tensor_generator.hpp" #include "host_gemm.hpp" #include "device_tensor.hpp" -#include "device_gemm_xdl.hpp" -#include "device_gemm_xdl_c_shuffle.hpp" +#include "device_gemm_xdl_cshuffle.hpp" #include "element_wise_operation.hpp" #include "reference_gemm.hpp" #include "gemm_specialization.hpp" +struct RequantReluRequant +{ + // FIXME: We just need one scale for Relu / Leaky Relu / PRelu + RequantReluRequant(float scaleGemm, float scaleRelu) + : scaleGemm_(scaleGemm), scaleRelu_(scaleRelu) + { + } + + __host__ __device__ constexpr void operator()(float& y, const float& x) const + { + float gemm_requant = scaleGemm_ * x; + float relu = gemm_requant > 0 ? gemm_requant : 0; + float relu_requant = scaleRelu_ * relu; + y = relu_requant > 127 ? 127 : relu_requant < -128 ? -128 : relu_requant; + } + + float scaleGemm_; + float scaleRelu_; +}; + template using S = ck::Sequence; -using F32 = float; - -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; -using RequantReluRequant = ck::tensor_operation::element_wise::RequantReluRequant; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; using ADataType = int8_t; using BDataType = int8_t; using CDataType = int8_t; using AccDataType = int32_t; -using CShuffleDataType = int32_t; +using CShuffleDataType = float; using ALayout = ck::tensor_layout::gemm::RowMajor; using BLayout = ck::tensor_layout::gemm::ColumnMajor; using CLayout = ck::tensor_layout::gemm::RowMajor; +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + // clang-format off -using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl_C_Shuffle< - ADataType, // ADataType - BDataType, // BDataType - CDataType, // CDataType - AccDataType, // AccDataType - CShuffleDataType, // CShuffleDataType - ALayout, // ALayout - BLayout, // BLayout - CLayout, // CLayout - PassThrough, // AElementwiseOperation - PassThrough, // BElementwiseOperation - RequantReluRequant, // CElementwiseOperation - 256, // BlockSize - 256, // MPerBlock - 128, // NPerBlock - 64, // KPerBlock - 16, // AK1 - 16, // BK1 - 32, // MPerXDL - 32, // NPerXDL - 4, // MXdlPerWave - 2, // NXdlPerWave - S<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 - S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder - S<1, 0, 2>, // ABlockTransferSrcAccessOrder - 2, // ABlockTransferSrcVectorDim - 16, // ABlockTransferSrcScalarPerVector - 16, // ABlockTransferDstScalarPerVector_K1 - true, // ABlockLdsAddExtraM - S<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 - S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder - S<1, 0, 2>, // BBlockTransferSrcAccessOrder - 2, // BBlockTransferSrcVectorDim - 16, // BBlockTransferSrcScalarPerVector - 16, // BBlockTransferDstScalarPerVector_K1 - true, // BBlockLdsAddExtraN - 1, // CShuffleMXdlPerWavePerShuffle - 1, // CShuffleNXdlPerWavePerShuffle - S<1, 1, 64, 1, 1, 4>, // CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl - 16>; // CBlockTransferScalarPerVector_NWaveNPerXdl +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle< + ALayout, // typename ALayout, + BLayout, // typename BLayout, + CLayout, // typename CLayout, + ADataType, // typename ADataType, + BDataType, // typename BDataType, + CDataType, // typename CDataType, + AccDataType, // typename GemmAccDataType, + CShuffleDataType, // typename CShuffleDataType, + PassThrough, // typename AElementwiseOperation, + PassThrough, // typename BElementwiseOperation, + RequantReluRequant, // typename CElementwiseOperation, + GemmDefault, // GemmSpecialization GemmSpec, + 1, // index_t NumGemmKPrefetchStage, + 256, // index_t BlockSize, + 256, // index_t MPerBlock, + 128, // index_t NPerBlock, + 64, // index_t KPerBlock, + 16, // index_t AK1, + 16, // index_t BK1, + 32, // index_t MPerXDL, + 32, // index_t NPerXDL, + 4, // index_t MXdlPerWave, + 2, // index_t NXdlPerWave, + S<4, 64, 1>, // typename ABlockTransferThreadClusterLengths_AK0_M_AK1, + S<1, 0, 2>, // typename ABlockTransferThreadClusterArrangeOrder, + S<1, 0, 2>, // typename ABlockTransferSrcAccessOrder, + 2, // index_t ABlockTransferSrcVectorDim, + 16, // index_t ABlockTransferSrcScalarPerVector, + 16, // index_t ABlockTransferDstScalarPerVector_AK1, + 1, // bool ABlockLdsExtraM, + S<4, 64, 1>, // typename BBlockTransferThreadClusterLengths_BK0_N_BK1, + S<1, 0, 2>, // typename BBlockTransferThreadClusterArrangeOrder, + S<1, 0, 2>, // typename BBlockTransferSrcAccessOrder, + 2, // index_t BBlockTransferSrcVectorDim, + 8, // index_t BBlockTransferSrcScalarPerVector, + 8, // index_t BBlockTransferDstScalarPerVector_BK1, + 1, // bool BBlockLdsExtraN, + 1, // index_t CShuffleMXdlPerWavePerShuffle, + 1, // index_t CShuffleNXdlPerWavePerShuffle, + S<1, 64, 1, 4>, // typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + 16>; // index_t CShuffleBlockTransferScalarPerVector_NPerBlock> // clang-format on using ReferenceGemmInstance = ck::tensor_operation::host:: diff --git a/include/ck/config.hpp b/include/ck/config.hpp index eedeb7e136..e6deefcbe3 100644 --- a/include/ck/config.hpp +++ b/include/ck/config.hpp @@ -26,17 +26,14 @@ #endif #endif -// buffer resourse, wave size +// buffer resource #ifndef __HIP_DEVICE_COMPILE__ // for host code #define CK_BUFFER_RESOURCE_3RD_DWORD -1 -#define CK_GPU_WAVE_SIZE -1 #elif defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__) || defined(__gfx908__) || \ defined(__gfx90a__) // for GPU code #define CK_BUFFER_RESOURCE_3RD_DWORD 0x00020000 -#define CK_GPU_WAVE_SIZE 64 #elif defined(__gfx1030__) // for GPU code #define CK_BUFFER_RESOURCE_3RD_DWORD 0x31014000 -#define CK_GPU_WAVE_SIZE 32 #endif // FMA instruction diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp index 8fe4beecba..f1670d9c89 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp @@ -1,10 +1,9 @@ -#ifndef CK_BLOCKWISE_GEMM_XDLOPS_HPP -#define CK_BLOCKWISE_GEMM_XDLOPS_HPP - +#pragma once #include "common_header.hpp" #include "threadwise_tensor_slice_transfer.hpp" #include "xdlops_gemm.hpp" #include "tensor_adaptor.hpp" +#include "thread_group.hpp" namespace ck { @@ -25,7 +24,9 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 static constexpr auto I2 = Number<2>{}; static constexpr auto I3 = Number<3>{}; - static constexpr index_t WaveSize = 64; + using ThisThreadBlock = ThisThreadBlock; + + static constexpr index_t WaveSize = get_warp_size(); static constexpr index_t MPerBlock = AK0MK1BlockDesc{}.GetLength(I1); static constexpr index_t NPerBlock = BK0NK1BlockDesc{}.GetLength(I1); @@ -55,7 +56,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 __device__ static auto GetWaveIdx() { - const index_t thread_id = get_thread_local_1d_id(); + const index_t thread_id = ThisThreadBlock::GetThreadId(); constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor( make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))), @@ -122,8 +123,8 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 BK0NK1BlockDesc::IsKnownAtCompileTime(), "wrong! Desc should be known at compile-time"); - static_assert(BlockSize == MWaves * NWaves * WaveSize, - "BlockSize != MWaves * NWaves * WaveSize\n"); + static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize, + "ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n"); static_assert(MPerBlock % (MPerXDL * MRepeat) == 0 && NPerBlock % (NPerXDL * NRepeat) == 0, "wrong!"); @@ -339,4 +340,3 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 }; } // namespace ck -#endif diff --git a/include/ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v5r1.hpp b/include/ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v5r1.hpp index acd99132cc..93fe5da723 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v5r1.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v5r1.hpp @@ -45,8 +45,8 @@ struct BlockwiseTensorSliceTransfer_v5r1 src_desc, make_zero_multi_index(), dst_desc, make_zero_multi_index()) { - static_assert(nDim == remove_reference_t>::GetNumOfDimension() && - nDim == remove_reference_t>::GetNumOfDimension() && + static_assert(nDim == remove_cvref_t::GetNumOfDimension() && + nDim == remove_cvref_t::GetNumOfDimension() && nDim == BlockSliceLengths::Size() && nDim == ThreadSliceLengths::Size() && nDim == ThreadClusterLengths::Size() && nDim == ThreadClusterArrangeOrder::Size() && diff --git a/include/ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v4r1.hpp b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp similarity index 82% rename from include/ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v4r1.hpp rename to include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp index 5aa6600848..cbabbaf47d 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v4r1.hpp +++ b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp @@ -1,6 +1,4 @@ -#ifndef CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V4R1_HPP -#define CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V4R1_HPP - +#pragma once #include "common_header.hpp" #include "tensor_descriptor.hpp" #include "tensor_descriptor_helper.hpp" @@ -13,7 +11,7 @@ namespace ck { // 1. Use StaticallyIndexedArray instead of C array for thread buffer // 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor // 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate -template -struct BlockwiseTensorSliceTransfer_v4r1 +struct ThreadGroupTensorSliceTransfer_v4r1 { static constexpr index_t nDim = remove_reference_t::GetNumOfDimension(); @@ -43,7 +41,7 @@ struct BlockwiseTensorSliceTransfer_v4r1 using Index = MultiIndex; - __device__ constexpr BlockwiseTensorSliceTransfer_v4r1( + __device__ constexpr ThreadGroupTensorSliceTransfer_v4r1( const SrcDesc& src_desc, const Index& src_block_slice_origin, const SrcElementwiseOperation& src_element_op, @@ -58,8 +56,8 @@ struct BlockwiseTensorSliceTransfer_v4r1 dst_element_op) { - static_assert(nDim == remove_reference_t>::GetNumOfDimension() && - nDim == remove_reference_t>::GetNumOfDimension() && + static_assert(nDim == remove_cvref_t::GetNumOfDimension() && + nDim == remove_cvref_t::GetNumOfDimension() && nDim == ThreadClusterLengths::Size() && nDim == ThreadClusterArrangeOrder::Size() && nDim == SrcDimAccessOrder::Size() && nDim == DstDimAccessOrder::Size(), @@ -69,14 +67,14 @@ struct BlockwiseTensorSliceTransfer_v4r1 is_same{}, "wrong! threads should be mapped to cover entire slicing window"); - static_assert(BlockSize >= thread_cluster_desc_.GetElementSize(), - "wrong! BlockSize too small"); + static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(), + "wrong! ThreadGroup::GetNumOfThread() too small"); - if(BlockSize == thread_cluster_desc_.GetElementSize() or - get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) { const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex( - make_multi_index(get_thread_local_1d_id())); + make_multi_index(ThreadGroup::GetThreadId())); const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths; @@ -92,8 +90,8 @@ struct BlockwiseTensorSliceTransfer_v4r1 const SrcBuffer& src_buf, Number thread_scratch_id = Number{}) { - if(BlockSize == thread_cluster_desc_.GetElementSize() or - get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) { threadwise_transfer_.RunRead(src_desc, src_buf, thread_scratch_id); } @@ -104,8 +102,8 @@ struct BlockwiseTensorSliceTransfer_v4r1 DstBuffer& dst_buf, Number thread_scratch_id = Number{}) { - if(BlockSize == thread_cluster_desc_.GetElementSize() or - get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) { threadwise_transfer_.RunWrite(dst_desc, dst_buf, thread_scratch_id); } @@ -124,8 +122,8 @@ struct BlockwiseTensorSliceTransfer_v4r1 __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step) { - if(BlockSize == thread_cluster_desc_.GetElementSize() or - get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) { threadwise_transfer_.MoveSrcSliceWindow(src_desc, step); } @@ -133,8 +131,8 @@ struct BlockwiseTensorSliceTransfer_v4r1 __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step) { - if(BlockSize == thread_cluster_desc_.GetElementSize() or - get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) { threadwise_transfer_.MoveDstSliceWindow(dst_desc, step); } @@ -169,4 +167,3 @@ struct BlockwiseTensorSliceTransfer_v4r1 }; } // namespace ck -#endif diff --git a/include/ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v6r1.hpp b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp similarity index 68% rename from include/ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v6r1.hpp rename to include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp index 957c8f522c..1f0ad3e35a 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v6r1.hpp +++ b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp @@ -1,6 +1,4 @@ -#ifndef CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V6R1_HPP -#define CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V6R1_HPP - +#pragma once #include "common_header.hpp" #include "tensor_descriptor.hpp" #include "tensor_descriptor_helper.hpp" @@ -13,10 +11,10 @@ namespace ck { // 1. Use StaticallyIndexedArray instead of C array for thread buffer // 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor // 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate -template -struct BlockwiseTensorSliceTransfer_v6r1 +struct ThreadGroupTensorSliceTransfer_v6r1 { static constexpr index_t nDim = remove_reference_t::GetNumOfDimension(); - static constexpr auto thread_slice_lengths = BlockSliceLengths{} / ThreadClusterLengths{}; + static constexpr auto thread_slice_lengths = SliceLengths{} / ThreadClusterLengths{}; using Index = MultiIndex; - __device__ constexpr BlockwiseTensorSliceTransfer_v6r1(const SrcDesc& src_desc, - const Index& src_block_slice_origin, - const DstDesc& dst_desc, - const Index& dst_block_slice_origin, - const ElementwiseOperation& element_op) + __device__ constexpr ThreadGroupTensorSliceTransfer_v6r1(const SrcDesc& src_desc, + const Index& src_block_slice_origin, + const DstDesc& dst_desc, + const Index& dst_block_slice_origin, + const ElementwiseOperation& element_op) : threadwise_transfer_(src_desc, make_zero_multi_index(), dst_desc, @@ -48,25 +46,25 @@ struct BlockwiseTensorSliceTransfer_v6r1 element_op) { - static_assert(nDim == remove_reference_t>::GetNumOfDimension() && - nDim == remove_reference_t>::GetNumOfDimension() && + static_assert(nDim == remove_cvref_t::GetNumOfDimension() && + nDim == remove_cvref_t::GetNumOfDimension() && nDim == ThreadClusterLengths::Size() && nDim == ThreadClusterArrangeOrder::Size() && nDim == DimAccessOrder::Size(), "wrong! nDim not consistent"); static_assert( - is_same{}, + is_same{}, "wrong! threads should be mapped to cover entire slicing window"); - static_assert(BlockSize >= thread_cluster_desc_.GetElementSize(), - "wrong! BlockSize too small"); + static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(), + "wrong! ThreadGroup::GetNumOfThread() too small"); - if(BlockSize == thread_cluster_desc_.GetElementSize() or - get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) { const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex( - make_multi_index(get_thread_local_1d_id())); + make_multi_index(ThreadGroup::GetThreadId())); const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths; @@ -83,8 +81,8 @@ struct BlockwiseTensorSliceTransfer_v6r1 const DstDesc& dst_desc, DstBuffer& dst_buf) { - if(BlockSize == thread_cluster_desc_.GetElementSize() or - get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) { threadwise_transfer_.Run(src_desc, src_buf, dst_desc, dst_buf); } @@ -92,8 +90,8 @@ struct BlockwiseTensorSliceTransfer_v6r1 __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step) { - if(BlockSize == thread_cluster_desc_.GetElementSize() or - get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) { threadwise_transfer_.MoveSrcSliceWindow(src_desc, step); } @@ -101,8 +99,8 @@ struct BlockwiseTensorSliceTransfer_v6r1 __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step) { - if(BlockSize == thread_cluster_desc_.GetElementSize() or - get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) { threadwise_transfer_.MoveDstSliceWindow(dst_desc, step); } @@ -130,4 +128,3 @@ struct BlockwiseTensorSliceTransfer_v6r1 }; } // namespace ck -#endif diff --git a/include/ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v6r2.hpp b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r2.hpp similarity index 68% rename from include/ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v6r2.hpp rename to include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r2.hpp index 2e06214b8c..121ddf12ad 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v6r2.hpp +++ b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r2.hpp @@ -1,6 +1,4 @@ -#ifndef CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V6R2_HPP -#define CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V6R2_HPP - +#pragma once #include "common_header.hpp" #include "tensor_descriptor.hpp" #include "tensor_descriptor_helper.hpp" @@ -13,10 +11,10 @@ namespace ck { // 1. Use StaticallyIndexedArray instead of C array for thread buffer // 2. It does not keep reference to tensor descriptor // 3. Run() does not construct new tensor coordinate -template -struct BlockwiseTensorSliceTransfer_v6r2 +struct ThreadGroupTensorSliceTransfer_v6r2 { static constexpr index_t nDim = remove_reference_t::GetNumOfDimension(); - static constexpr auto thread_slice_lengths = BlockSliceLengths{} / ThreadClusterLengths{}; + static constexpr auto thread_slice_lengths = SliceLengths{} / ThreadClusterLengths{}; using Index = MultiIndex; - __device__ constexpr BlockwiseTensorSliceTransfer_v6r2(const Src0Desc& src0_desc, - const Index& src0_block_slice_origin, - const Src1Desc& src1_desc, - const Index& src1_block_slice_origin, - const DstDesc& dst_desc, - const Index& dst_block_slice_origin, - const ElementwiseOperation& element_op) + __device__ constexpr ThreadGroupTensorSliceTransfer_v6r2(const Src0Desc& src0_desc, + const Index& src0_block_slice_origin, + const Src1Desc& src1_desc, + const Index& src1_block_slice_origin, + const DstDesc& dst_desc, + const Index& dst_block_slice_origin, + const ElementwiseOperation& element_op) : threadwise_transfer_(src0_desc, make_zero_multi_index(), src1_desc, @@ -55,26 +53,26 @@ struct BlockwiseTensorSliceTransfer_v6r2 element_op) { - static_assert(nDim == remove_reference_t>::GetNumOfDimension() && - nDim == remove_reference_t>::GetNumOfDimension() && - nDim == remove_reference_t>::GetNumOfDimension() && + static_assert(nDim == remove_cvref_t::GetNumOfDimension() && + nDim == remove_cvref_t::GetNumOfDimension() && + nDim == remove_cvref_t::GetNumOfDimension() && nDim == ThreadClusterLengths::Size() && nDim == ThreadClusterArrangeOrder::Size() && nDim == DimAccessOrder::Size(), "wrong! nDim not consistent"); static_assert( - is_same{}, + is_same{}, "wrong! threads should be mapped to cover entire slicing window"); - static_assert(BlockSize >= thread_cluster_desc_.GetElementSize(), - "wrong! BlockSize too small"); + static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(), + "wrong! ThreadGroup::GetNumOfThread() too small"); - if(BlockSize == thread_cluster_desc_.GetElementSize() or - get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) { const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex( - make_multi_index(get_thread_local_1d_id())); + make_multi_index(ThreadGroup::GetThreadId())); const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths; @@ -95,8 +93,8 @@ struct BlockwiseTensorSliceTransfer_v6r2 const DstDesc& dst_desc, DstBuffer& dst_buf) { - if(BlockSize == thread_cluster_desc_.GetElementSize() or - get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) { threadwise_transfer_.Run(src0_desc, src0_buf, src1_desc, src1_buf, dst_desc, dst_buf); } @@ -104,8 +102,8 @@ struct BlockwiseTensorSliceTransfer_v6r2 __device__ void MoveSrc0SliceWindow(const Src0Desc& src0_desc, const Index& step) { - if(BlockSize == thread_cluster_desc_.GetElementSize() or - get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) { threadwise_transfer_.MoveSrc0SliceWindow(src0_desc, step); } @@ -113,8 +111,8 @@ struct BlockwiseTensorSliceTransfer_v6r2 __device__ void MoveSrc1SliceWindow(const Src1Desc& src1_desc, const Index& step) { - if(BlockSize == thread_cluster_desc_.GetElementSize() or - get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) { threadwise_transfer_.MoveSrc1SliceWindow(src1_desc, step); } @@ -122,8 +120,8 @@ struct BlockwiseTensorSliceTransfer_v6r2 __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step) { - if(BlockSize == thread_cluster_desc_.GetElementSize() or - get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) { threadwise_transfer_.MoveDstSliceWindow(dst_desc, step); } @@ -154,4 +152,3 @@ struct BlockwiseTensorSliceTransfer_v6r2 }; } // namespace ck -#endif diff --git a/include/ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v6r3.hpp b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r3.hpp similarity index 68% rename from include/ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v6r3.hpp rename to include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r3.hpp index 085981736b..ca5db90f30 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v6r3.hpp +++ b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r3.hpp @@ -1,6 +1,4 @@ -#ifndef CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V6R3_HPP -#define CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V6R3_HPP - +#pragma once #include "common_header.hpp" #include "tensor_descriptor.hpp" #include "tensor_descriptor_helper.hpp" @@ -13,10 +11,10 @@ namespace ck { // 1. Use StaticallyIndexedArray instead of C array for thread buffer // 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor // 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate -template -struct BlockwiseTensorSliceTransfer_v6r3 +struct ThreadGroupTensorSliceTransfer_v6r3 { static constexpr index_t nDim = remove_reference_t::GetNumOfDimension(); - static constexpr auto thread_slice_lengths = BlockSliceLengths{} / ThreadClusterLengths{}; + static constexpr auto thread_slice_lengths = SliceLengths{} / ThreadClusterLengths{}; using Index = MultiIndex; - __device__ constexpr BlockwiseTensorSliceTransfer_v6r3(const Src0Desc& src0_desc, - const Index& src0_block_slice_origin, - const Src1Desc& src1_desc, - const Index& src1_block_slice_origin, - const Src2Desc& src2_desc, - const Index& src2_block_slice_origin, - const DstDesc& dst_desc, - const Index& dst_block_slice_origin, - const ElementwiseOperation& element_op) + __device__ constexpr ThreadGroupTensorSliceTransfer_v6r3(const Src0Desc& src0_desc, + const Index& src0_block_slice_origin, + const Src1Desc& src1_desc, + const Index& src1_block_slice_origin, + const Src2Desc& src2_desc, + const Index& src2_block_slice_origin, + const DstDesc& dst_desc, + const Index& dst_block_slice_origin, + const ElementwiseOperation& element_op) : threadwise_transfer_(src0_desc, make_zero_multi_index(), src1_desc, @@ -62,24 +60,24 @@ struct BlockwiseTensorSliceTransfer_v6r3 element_op) { - static_assert(nDim == remove_reference_t>::GetNumOfDimension() && - nDim == remove_reference_t>::GetNumOfDimension() && - nDim == remove_reference_t>::GetNumOfDimension() && - nDim == remove_reference_t>::GetNumOfDimension() && + static_assert(nDim == remove_cvref_t::GetNumOfDimension() && + nDim == remove_cvref_t::GetNumOfDimension() && + nDim == remove_cvref_t::GetNumOfDimension() && + nDim == remove_cvref_t::GetNumOfDimension() && nDim == ThreadClusterLengths::Size() && nDim == ThreadClusterArrangeOrder::Size() && nDim == DimAccessOrder::Size(), "wrong! nDim not consistent"); static_assert( - is_same{}, + is_same{}, "wrong! threads should be mapped to cover entire slicing window"); - static_assert(BlockSize >= thread_cluster_desc_.GetElementSize(), - "wrong! BlockSize too small"); + static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(), + "wrong! ThreadGroup::GetNumOfThread() too small"); - if(BlockSize == thread_cluster_desc_.GetElementSize() or - get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) { const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex( make_multi_index(get_thread_local_1d_id())); @@ -107,8 +105,8 @@ struct BlockwiseTensorSliceTransfer_v6r3 const DstDesc& dst_desc, DstBuffer& dst_buf) { - if(BlockSize == thread_cluster_desc_.GetElementSize() or - get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) { threadwise_transfer_.Run( src0_desc, src0_buf, src1_desc, src1_buf, src2_desc, src2_buf, dst_desc, dst_buf); @@ -117,8 +115,8 @@ struct BlockwiseTensorSliceTransfer_v6r3 __device__ void MoveSrc0SliceWindow(const Src0Desc& src0_desc, const Index& step) { - if(BlockSize == thread_cluster_desc_.GetElementSize() or - get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) { threadwise_transfer_.MoveSrc0SliceWindow(src0_desc, step); } @@ -126,8 +124,8 @@ struct BlockwiseTensorSliceTransfer_v6r3 __device__ void MoveSrc1SliceWindow(const Src1Desc& src1_desc, const Index& step) { - if(BlockSize == thread_cluster_desc_.GetElementSize() or - get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) { threadwise_transfer_.MoveSrc1SliceWindow(src1_desc, step); } @@ -135,8 +133,8 @@ struct BlockwiseTensorSliceTransfer_v6r3 __device__ void MoveSrc2SliceWindow(const Src2Desc& src2_desc, const Index& step) { - if(BlockSize == thread_cluster_desc_.GetElementSize() or - get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) { threadwise_transfer_.MoveSrc2SliceWindow(src2_desc, step); } @@ -144,8 +142,8 @@ struct BlockwiseTensorSliceTransfer_v6r3 __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step) { - if(BlockSize == thread_cluster_desc_.GetElementSize() or - get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) { threadwise_transfer_.MoveDstSliceWindow(dst_desc, step); } @@ -179,4 +177,3 @@ struct BlockwiseTensorSliceTransfer_v6r3 }; } // namespace ck -#endif diff --git a/include/ck/tensor_operation/gpu/device/device_batched_gemm_reduce_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/device_batched_gemm_reduce_xdl_cshuffle.hpp index 46b3939142..a90bc44fdf 100644 --- a/include/ck/tensor_operation/gpu/device/device_batched_gemm_reduce_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/device_batched_gemm_reduce_xdl_cshuffle.hpp @@ -720,11 +720,10 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce #include #include "device.hpp" @@ -660,13 +658,12 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_); - const auto K0 = arg.a_grid_desc_k0_m_k1_.GetLength(I0); - - const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0); + const auto K = + arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2); float ave_time = 0; - if(has_main_k0_block_loop) + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) { const auto kernel = kernel_gemm_xdlops_v3r2< GridwiseGemm, @@ -919,4 +916,3 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X } // namespace device } // namespace tensor_operation } // namespace ck -#endif diff --git a/include/ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp index 7f666b32ea..b508606a75 100644 --- a/include/ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp @@ -640,13 +640,12 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_); - const auto K0 = arg.a_grid_desc_k0_m_k1_.GetLength(I0); - - const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0); + const auto K = + arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2); float ave_time = 0; - if(has_main_k0_block_loop) + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) { const auto kernel = kernel_gemm_xdlops_v3r1< GridwiseGemm, diff --git a/include/ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp index f334cb9c8d..3574f7667e 100644 --- a/include/ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp @@ -478,13 +478,12 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_); - const auto K0 = arg.a_grid_desc_k0_m_k1_.GetLength(I0); - - const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0); + const auto K = + arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2); float ave_time = 0; - if(has_main_k0_block_loop) + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) { const auto kernel = kernel_gemm_xdlops_v2r3< GridwiseGemm, diff --git a/include/ck/tensor_operation/gpu/device/device_convnd_bwd_data_xdl_ndhwc_kzyxc_ndhwk.hpp b/include/ck/tensor_operation/gpu/device/device_convnd_bwd_data_xdl_ndhwc_kzyxc_ndhwk.hpp index 9182b0ef1f..ff267c6cdf 100644 --- a/include/ck/tensor_operation/gpu/device/device_convnd_bwd_data_xdl_ndhwc_kzyxc_ndhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/device_convnd_bwd_data_xdl_ndhwc_kzyxc_ndhwk.hpp @@ -1296,11 +1296,10 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_container_[i]); - const auto K0 = arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0); + const auto K = arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0) * + arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I2); - const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0); - - if(has_main_k0_block_loop) + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) { const auto kernel = kernel_gemm_xdlops_v2r3< GridwiseGemm, diff --git a/include/ck/tensor_operation/gpu/device/device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp index b13466274f..ac62448386 100644 --- a/include/ck/tensor_operation/gpu/device/device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp @@ -775,13 +775,12 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_); - const auto K0 = arg.a_grid_desc_k0_m_k1_.GetLength(I0); - - const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0); + const auto K = + arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2); float ave_time = 0; - if(has_main_k0_block_loop) + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) { const auto kernel = kernel_gemm_xdlops_v2r3< GridwiseGemm, diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp index f6856c65c4..1a3fbdf956 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp @@ -528,11 +528,10 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce -#include -#include "device.hpp" -#include "device_gemm.hpp" -#include "common_header.hpp" -#include "tensor_layout.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" -#include "gridwise_gemm_xdlops_v3r1.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { - -template < - typename ADataType, - typename BDataType, - typename CDataType, - typename AccDataType, - typename CShuffleDataType, - typename ALayout, - typename BLayout, - typename CLayout, - typename AElementwiseOperation, - typename BElementwiseOperation, - typename CElementwiseOperation, - ck::index_t BlockSize, - ck::index_t MPerBlock, - ck::index_t NPerBlock, - ck::index_t KPerBlock, - ck::index_t AK1, - ck::index_t BK1, - ck::index_t MPerXDL, - ck::index_t NPerXDL, - ck::index_t MXdlPerWave, - ck::index_t NXdlPerWave, - typename ABlockTransferThreadClusterLengths_K0_M_K1, - typename ABlockTransferThreadClusterArrangeOrder, - typename ABlockTransferSrcAccessOrder, - ck::index_t ABlockTransferSrcVectorDim, - ck::index_t ABlockTransferSrcScalarPerVector, - ck::index_t ABlockTransferDstScalarPerVector_K1, - bool ABlockLdsAddExtraM, - typename BBlockTransferThreadClusterLengths_K0_N_K1, - typename BBlockTransferThreadClusterArrangeOrder, - typename BBlockTransferSrcAccessOrder, - ck::index_t BBlockTransferSrcVectorDim, - ck::index_t BBlockTransferSrcScalarPerVector, - ck::index_t BBlockTransferDstScalarPerVector_K1, - bool BBlockLdsAddExtraN, - index_t CShuffleMXdlPerWavePerShuffle, - index_t CShuffleNXdlPerWavePerShuffle, - typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, - index_t CBlockTransferScalarPerVector_NWaveNPerXdl, - index_t NumPrefetch = 1> -struct DeviceGemmXdl_C_Shuffle - : public DeviceGemm -{ - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; - - static auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t K, index_t StrideA) - { - assert(K % AK1 == 0); - - const index_t K0 = K / AK1; - - const auto a_grid_desc_m_k = [&]() { - if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); - } - else if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); - } - }(); - - const auto a_grid_desc_k0_m_k1 = transform_tensor_descriptor( - a_grid_desc_m_k, - make_tuple(make_unmerge_transform(make_tuple(K0, AK1)), make_pass_through_transform(M)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - return a_grid_desc_k0_m_k1; - } - - static auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t StrideB) - { - assert(K % BK1 == 0); - - const index_t K0 = K / BK1; - - const auto b_grid_desc_k_n = [&]() { - if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1)); - } - else if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB)); - } - }(); - - const auto b_grid_desc_k0_n_k1 = transform_tensor_descriptor( - b_grid_desc_k_n, - make_tuple(make_unmerge_transform(make_tuple(K0, BK1)), make_pass_through_transform(N)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - return b_grid_desc_k0_n_k1; - } - - static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC) - { - if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); - } - else if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); - } - } - - using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_K0_M_K1(1, 1, 1)); - using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1)); - using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); - - // GridwiseGemm - using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1< - BlockSize, - ADataType, // TODO: distinguish A/B datatype - AccDataType, - CShuffleDataType, - CDataType, - InMemoryDataOperationEnum::Set, - AGridDesc_K0_M_K1, - BGridDesc_K0_N_K1, - CGridDesc_M_N, - AElementwiseOperation, - BElementwiseOperation, - CElementwiseOperation, - MPerBlock, - NPerBlock, - KPerBlock, - AK1, - BK1, - MPerXDL, - NPerXDL, - MXdlPerWave, - NXdlPerWave, - ABlockTransferThreadClusterLengths_K0_M_K1, - ABlockTransferThreadClusterArrangeOrder, - ABlockTransferSrcAccessOrder, - ABlockTransferSrcVectorDim, - ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_K1, - false, - ABlockLdsAddExtraM, - BBlockTransferThreadClusterLengths_K0_N_K1, - BBlockTransferThreadClusterArrangeOrder, - BBlockTransferSrcAccessOrder, - BBlockTransferSrcVectorDim, - BBlockTransferSrcScalarPerVector, - BBlockTransferDstScalarPerVector_K1, - false, - BBlockLdsAddExtraN, - CShuffleMXdlPerWavePerShuffle, - CShuffleNXdlPerWavePerShuffle, - CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, - CBlockTransferScalarPerVector_NWaveNPerXdl, - NumPrefetch>; - - // Argument - struct Argument : public BaseArgument - { - Argument(const ADataType* p_a_grid, - const BDataType* p_b_grid, - CDataType* p_c_grid, - index_t M, - index_t N, - index_t K, - index_t StrideA, - index_t StrideB, - index_t StrideC, - index_t M01, - index_t N01, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CElementwiseOperation c_element_op) - : p_a_grid_{p_a_grid}, - p_b_grid_{p_b_grid}, - p_c_grid_{p_c_grid}, - a_grid_desc_k0_m_k1_{}, - b_grid_desc_k0_n_k1_{}, - c_grid_desc_m_n_{}, - c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_{}, - block_2_ctile_map_{}, - M01_{M01}, - N01_{N01}, - a_element_op_{a_element_op}, - b_element_op_{b_element_op}, - c_element_op_{c_element_op} - { - a_grid_desc_k0_m_k1_ = - DeviceGemmXdl_C_Shuffle::MakeAGridDescriptor_K0_M_K1(M, K, StrideA); - b_grid_desc_k0_n_k1_ = - DeviceGemmXdl_C_Shuffle::MakeBGridDescriptor_K0_N_K1(K, N, StrideB); - c_grid_desc_m_n_ = DeviceGemmXdl_C_Shuffle::MakeCGridDescriptor_M_N(M, N, StrideC); - - if(GridwiseGemm::CheckValidity( - a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_, M01_, N01_)) - { - c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ = - GridwiseGemm:: - MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( - c_grid_desc_m_n_); - - block_2_ctile_map_ = - GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01); - } - } - - // private: - const ADataType* p_a_grid_; - const BDataType* p_b_grid_; - CDataType* p_c_grid_; - AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; - BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; - CGridDesc_M_N c_grid_desc_m_n_; - typename GridwiseGemm:: - CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl - c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_; - typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; - index_t M01_; - index_t N01_; - AElementwiseOperation a_element_op_; - BElementwiseOperation b_element_op_; - CElementwiseOperation c_element_op_; - }; - - // Invoker - struct Invoker : public BaseInvoker - { - using Argument = DeviceGemmXdl_C_Shuffle::Argument; - - float Run(const Argument& arg, int nrepeat = 1) - { - { - std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) - << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " - << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl; - - std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0) - << ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", " - << arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl; - - std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " - << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; - } - - if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, - arg.b_grid_desc_k0_n_k1_, - arg.c_grid_desc_m_n_, - arg.M01_, - arg.N01_)) - { - throw std::runtime_error( - "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting"); - } - - const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_); - - const auto K0 = arg.a_grid_desc_k0_m_k1_.GetLength(I0); - - const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0); - - float ave_time = 0; - - if(has_main_k0_block_loop) - { - const auto kernel = kernel_gemm_xdlops_v3r1< - GridwiseGemm, - ADataType, // TODO: distiguish A/B datatype - CDataType, - remove_reference_t, - remove_reference_t, - remove_reference_t< - typename GridwiseGemm:: - CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>, - AElementwiseOperation, - BElementwiseOperation, - CElementwiseOperation, - remove_reference_t, - true>; - - ave_time = launch_and_time_kernel( - kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_, - arg.p_b_grid_, - arg.p_c_grid_, - arg.a_grid_desc_k0_m_k1_, - arg.b_grid_desc_k0_n_k1_, - arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_, - arg.block_2_ctile_map_); - } - else - { - const auto kernel = kernel_gemm_xdlops_v3r1< - GridwiseGemm, - ADataType, // TODO: distiguish A/B datatype - CDataType, - remove_reference_t, - remove_reference_t, - remove_reference_t< - typename GridwiseGemm:: - CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>, - AElementwiseOperation, - BElementwiseOperation, - CElementwiseOperation, - remove_reference_t, - false>; - - ave_time = launch_and_time_kernel( - kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_, - arg.p_b_grid_, - arg.p_c_grid_, - arg.a_grid_desc_k0_m_k1_, - arg.b_grid_desc_k0_n_k1_, - arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_, - arg.block_2_ctile_map_); - } - - return ave_time; - } - - // polymorphic - float Run(const BaseArgument* p_arg, int nrepeat = 1) override - { - return Run(*dynamic_cast(p_arg), nrepeat); - } - }; - - static constexpr bool IsValidCompilationParameter() - { - // TODO: properly implement this check - return true; - } - - static bool IsSupportedArgument(const Argument& arg) - { - return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, - arg.b_grid_desc_k0_n_k1_, - arg.c_grid_desc_m_n_, - arg.M01_, - arg.N01_); - } - - // polymorphic - bool IsSupportedArgument(const BaseArgument* p_arg) override - { - return IsSupportedArgument(*dynamic_cast(p_arg)); - } - - static auto MakeArgument(const ADataType* p_a, - const BDataType* p_b, - CDataType* p_c, - index_t M, - index_t N, - index_t K, - index_t StrideA, - index_t StrideB, - index_t StrideC, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CElementwiseOperation c_element_op) - { - return Argument{p_a, - p_b, - p_c, - M, - N, - K, - StrideA, - StrideB, - StrideC, - 1, - 1, - a_element_op, - b_element_op, - c_element_op}; - } - - static auto MakeInvoker() { return Invoker{}; } - - // polymorphic - std::unique_ptr MakeArgumentPointer(const void* p_a, - const void* p_b, - void* p_c, - index_t M, - index_t N, - index_t K, - index_t StrideA, - index_t StrideB, - index_t StrideC, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CElementwiseOperation c_element_op, - index_t /* KBatch */ = 1) override - { - return std::make_unique(static_cast(p_a), - static_cast(p_b), - static_cast(p_c), - M, - N, - K, - StrideA, - StrideB, - StrideC, - 1, - 1, - a_element_op, - b_element_op, - c_element_op); - } - - // polymorphic - std::unique_ptr MakeInvokerPointer() override - { - return std::make_unique(Invoker{}); - } - - // polymorphic - std::string GetTypeString() const override - { - auto str = std::stringstream(); - - // clang-format off - str << "DeviceGemmXdl_C_Shuffle" - << "<" - << BlockSize << ", " - << MPerBlock << ", " - << NPerBlock << ", " - << KPerBlock << ", " - << AK1 << ", " - << BK1 - << ">"; - // clang-format on - - return str.str(); - } -}; - -} // namespace device -} // namespace tensor_operation -} // namespace ck -#endif diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle_bias_2d.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle_bias_2d.hpp index 9cdb8009fb..4010965312 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle_bias_2d.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle_bias_2d.hpp @@ -1,6 +1,4 @@ -#ifndef DEVICE_GEMM_XDL_C_SHUFFLE_BIAS_2D_HPP -#define DEVICE_GEMM_XDL_C_SHUFFLE_BIAS_2D_HPP - +#pragma once #include #include #include "device.hpp" @@ -291,18 +289,17 @@ struct DeviceGemmXdl_C_Shuffle_Bias_2d arg.N01_)) { throw std::runtime_error( - "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting"); + "wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2 has invalid setting"); } const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_); - const auto K0 = arg.a_grid_desc_k0_m_k1_.GetLength(I0); - - const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0); + const auto K = + arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2); float ave_time = 0; - if(has_main_k0_block_loop) + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) { const auto kernel = kernel_gemm_xdlops_v3r2< GridwiseGemm, @@ -505,4 +502,3 @@ struct DeviceGemmXdl_C_Shuffle_Bias_2d } // namespace device } // namespace tensor_operation } // namespace ck -#endif diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle_bias_activation.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle_bias_activation.hpp index cf9804ad4b..c65ff6022a 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle_bias_activation.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle_bias_activation.hpp @@ -303,13 +303,12 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_); - const auto K0 = arg.a_grid_desc_k0_m_k1_.GetLength(I0); - - const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0); + const auto K = + arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2); float ave_time = 0; - if(has_main_k0_block_loop) + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) { const auto kernel = kernel_gemm_xdlops_v3r2< GridwiseGemm, diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle_bias_activation_add.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle_bias_activation_add.hpp index 12257859c7..4a478c995d 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle_bias_activation_add.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle_bias_activation_add.hpp @@ -345,13 +345,12 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation_Add const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_); - const auto K0 = arg.a_grid_desc_k0_m_k1_.GetLength(I0); - - const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0); + const auto K = + arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2); float ave_time = 0; - if(has_main_k0_block_loop) + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) { const auto kernel = kernel_gemm_xdlops_v3r3< GridwiseGemm, diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle.hpp index 324b33ffb2..440519537e 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle.hpp @@ -462,13 +462,12 @@ struct DeviceGemm_Xdl_CShuffle const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_); - const auto K0 = arg.a_grid_desc_ak0_m_ak1_.GetLength(I0); - - const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0); + const auto K = + arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); float ave_time = 0; - if(has_main_k0_block_loop) + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) { const auto kernel = kernel_gemm_xdl_cshuffle_v1< GridwiseGemm, diff --git a/include/ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp b/include/ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp index bebe2fd61e..b9ad39578d 100644 --- a/include/ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp @@ -17,6 +17,88 @@ namespace ck { namespace tensor_operation { namespace device { +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_grouped_gemm_xdlops_v2r3( + const StaticallyIndexedArray gemm_descs, + const index_t group_count, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + const index_t block_id = get_block_1d_id(); + +#if 1 + static_for<0, MaxGroupCount, 1>{}([&](auto i) { + if(block_id >= gemm_descs[i].BlockStart_ && block_id < gemm_descs[i].BlockEnd_ && + i < group_count) + { + auto group_id = i; + + GridwiseGemm::template Run( + gemm_descs[group_id].a_ptr, + gemm_descs[group_id].b_ptr, + gemm_descs[group_id].c_ptr, + p_shared, + gemm_descs[group_id].a_grid_desc_k0_m_k1_, + gemm_descs[group_id].b_grid_desc_k0_n_k1_, + gemm_descs[group_id].c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, + a_element_op, + b_element_op, + c_element_op, + gemm_descs[group_id].grouped_gemm_block_2_ctile_map_); + } + }); +#else + const auto gemm_desc_ptr = reinterpret_cast(&gemm_descs); + + index_t group_id = 0; + static_for<0, MaxGroupCount, 1>{}([&](auto i) { + group_id = (block_id >= gemm_descs[i].BlockStart && block_id < gemm_descs[i].BlockEnd && + i < group_count) + ? i + : group_id; + }); + + const index_t block_id_grp = block_id - gemm_desc_ptr[group_id].BlockStart; + + GridwiseGemm::template Run( + gemm_desc_ptr[group_id].a_ptr, + gemm_desc_ptr[group_id].b_ptr, + gemm_desc_ptr[group_id].c_ptr, + p_shared, + gemm_desc_ptr[group_id].a_grid_desc_k0_m_k1_, + gemm_desc_ptr[group_id].b_grid_desc_k0_n_k1_, + gemm_desc_ptr[group_id].c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, + a_element_op, + b_element_op, + c_element_op, + gemm_desc_ptr[group_id].block_2_ctile_map_, + block_id_grp); +#endif +#else + ignore = gemm_descs; + ignore = group_count; + ignore = a_element_op; + ignore = b_element_op; + ignore = c_element_op; +#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +} + template gemm_desc_kernel_arg_arg; + StaticallyIndexedArray gemm_desc_kernel_args; - bool has_main_k0_block_loop = true; + bool has_main_k_block_loop = true; static_for<0, MaxGroupCount, 1>{}([&](auto i) { if(i < arg.gemm_desc_kernel_arg_.size()) { - gemm_desc_kernel_arg_arg(i) = arg.gemm_desc_kernel_arg_[i]; + gemm_desc_kernel_args(i) = arg.gemm_desc_kernel_arg_[i]; std::cout << "group: " << i << " arg.a_grid_desc_k0_m_k1_{" - << gemm_desc_kernel_arg_arg[i].a_grid_desc_k0_m_k1_.GetLength(I0) - << ", " - << gemm_desc_kernel_arg_arg[i].a_grid_desc_k0_m_k1_.GetLength(I1) - << ", " - << gemm_desc_kernel_arg_arg[i].a_grid_desc_k0_m_k1_.GetLength(I2) - << "}"; + << gemm_desc_kernel_args[i].a_grid_desc_k0_m_k1_.GetLength(I0) << ", " + << gemm_desc_kernel_args[i].a_grid_desc_k0_m_k1_.GetLength(I1) << ", " + << gemm_desc_kernel_args[i].a_grid_desc_k0_m_k1_.GetLength(I2) << "}"; std::cout << ", arg.b_grid_desc_k0_n_k1_{" - << gemm_desc_kernel_arg_arg[i].b_grid_desc_k0_n_k1_.GetLength(I0) - << ", " - << gemm_desc_kernel_arg_arg[i].b_grid_desc_k0_n_k1_.GetLength(I1) - << ", " - << gemm_desc_kernel_arg_arg[i].b_grid_desc_k0_n_k1_.GetLength(I2) - << "}"; + << gemm_desc_kernel_args[i].b_grid_desc_k0_n_k1_.GetLength(I0) << ", " + << gemm_desc_kernel_args[i].b_grid_desc_k0_n_k1_.GetLength(I1) << ", " + << gemm_desc_kernel_args[i].b_grid_desc_k0_n_k1_.GetLength(I2) << "}"; std::cout << ", arg.c_grid_desc_m_n_{ " - << gemm_desc_kernel_arg_arg[i].c_grid_desc_m_n_.GetLength(I0) << ", " - << gemm_desc_kernel_arg_arg[i].c_grid_desc_m_n_.GetLength(I1) << "}" + << gemm_desc_kernel_args[i].c_grid_desc_m_n_.GetLength(I0) << ", " + << gemm_desc_kernel_args[i].c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; - if(!GridwiseGemm::CheckValidity( - gemm_desc_kernel_arg_arg[i].a_grid_desc_k0_m_k1_, - gemm_desc_kernel_arg_arg[i].b_grid_desc_k0_n_k1_, - gemm_desc_kernel_arg_arg[i].c_grid_desc_m_n_, - arg.M01_, - arg.N01_)) + if(!GridwiseGemm::CheckValidity(gemm_desc_kernel_args[i].a_grid_desc_k0_m_k1_, + gemm_desc_kernel_args[i].b_grid_desc_k0_n_k1_, + gemm_desc_kernel_args[i].c_grid_desc_m_n_, + arg.M01_, + arg.N01_)) { throw std::runtime_error( "wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting"); } - const auto K0 = gemm_desc_kernel_arg_arg[i].a_grid_desc_k0_m_k1_.GetLength(I0); + const auto K = gemm_desc_kernel_args[i].a_grid_desc_k0_m_k1_.GetLength(I0) * + gemm_desc_kernel_args[i].a_grid_desc_k0_m_k1_.GetLength(I2); - if(GridwiseGemm::CalculateHasMainK0BlockLoop(K0) != has_main_k0_block_loop) + if(GridwiseGemm::CalculateHasMainKBlockLoop(K) != has_main_k_block_loop) { - throw std::runtime_error("wrong! not all gemm has_main_k0_block_loop"); + throw std::runtime_error("wrong! not all gemm has_main_k_block_loop"); } } }); float ave_time = 0; - if(has_main_k0_block_loop) + if(has_main_k_block_loop) { const auto kernel = kernel_grouped_gemm_xdlops_v2r3(x); - float relu = gemm_requant > 0 ? gemm_requant : 0; - float relu_requant = scaleRelu_ * relu; - y = static_cast(relu_requant > 127 ? 127 - : relu_requant < -128 ? -128 : relu_requant); - } - - // for reference_gemm - __host__ __device__ constexpr void operator()(float& y, const float& x) const - { - float gemm_requant = scaleGemm_ * x; - float relu = gemm_requant > 0 ? gemm_requant : 0; - float relu_requant = scaleRelu_ * relu; - y = static_cast(relu_requant > 127 ? 127 - : relu_requant < -128 ? -128 : relu_requant); - } - - float scaleGemm_; - float scaleRelu_; -}; - // Unary operators are usually called element-wisely before/after the reduction is executed on the // elements. They are needed for easy implementation of reduction types of AVG, NRM1, NRM2 diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp index dcacd99ae1..6a1b6eef31 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp @@ -1,65 +1,41 @@ -#ifndef CK_GRIDWISE_GEMM_PIPELINE_V1_HPP -#define CK_GRIDWISE_GEMM_PIPELINE_V1_HPP - +#pragma once #include "common_header.hpp" namespace ck { -template +template struct GridwiseGemmPipeline_v1; // 1-stage prefetch -template -struct GridwiseGemmPipeline_v1 +template <> +struct GridwiseGemmPipeline_v1<1> { static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; - static __device__ void Run(const AGridDesc& a_grid_desc, + __host__ __device__ static constexpr bool IsSupported(index_t /* num_loop */) { return true; } + + __host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop) + { + return num_loop > 1; + } + + template + __device__ static void Run(const AGridDesc& a_grid_desc, const ABlockDesc& a_block_desc, ABlockTransfer& a_blockwise_copy, const AGridBuffer& a_grid_buf, @@ -75,51 +51,6 @@ struct GridwiseGemmPipeline_v1 -struct GridwiseGemmPipeline_v1 +template <> +struct GridwiseGemmPipeline_v1<2> { static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; + __host__ __device__ static constexpr bool IsSupported(index_t num_loop) + { + // TODO: improve applicability + return num_loop % 2 == 0; + } + + __host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop) + { + return (num_loop / 2) > 1; + } + + template static __device__ void Run(const AGridDesc& a_grid_desc, const ABlockDesc& a_block_desc, ABlockTransfer& a_blockwise_copy, @@ -322,4 +249,3 @@ struct GridwiseGemmPipeline_v1 + bool HasMainKBlockLoop> __global__ void #if CK_USE_LAUNCH_BOUNDS __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) @@ -50,21 +50,21 @@ __global__ void #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - GridwiseGemm::template Run(p_a_grid, - p_b_grid, - p_c_grid, - p_d0_grid, - p_d1_grid, - p_shared, - a_element_op, - b_element_op, - c_element_op, - d1_element_op, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - c_grid_desc_mblock_mperblock_nblock_nperblock, - d_grid_desc_mblock_mperblock, - block_2_ctile_map); + GridwiseGemm::template Run(p_a_grid, + p_b_grid, + p_c_grid, + p_d0_grid, + p_d1_grid, + p_shared, + a_element_op, + b_element_op, + c_element_op, + d1_element_op, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + d_grid_desc_mblock_mperblock, + block_2_ctile_map); #else ignore = p_a_grid; ignore = p_b_grid; @@ -152,6 +152,10 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 static constexpr auto AK1 = Number{}; static constexpr auto BK1 = Number{}; + using ThisThreadBlock = ThisThreadBlock; + + using GridwiseGemmPipe = GridwiseGemmPipeline_v1; + __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() { // A matrix in LDS memory, dst of blockwise copy @@ -235,21 +239,10 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0)) return false; - // check NumGemmKPrefetchStage - if constexpr(NumGemmKPrefetchStage == 1) - { - // 1-stage prefetch always supported - } - else if constexpr(NumGemmKPrefetchStage == 2) - { - // 2-stage prefetch currently only support even number of K0 loop - // TODO: add support for odd number of K0 loop - if(!((K / KPerBlock) % 2 == 0)) - { - return false; - } - } - else + // check gridwise gemm pipeline + const auto num_k_loop = K / KPerBlock; + + if(!GridwiseGemmPipe::IsSupported(num_k_loop)) { return false; } @@ -269,12 +262,11 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 return grid_size; } - // TODO move this function into GEMM-pipeline class - __host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0) + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) { - const bool has_main_k0_block_loop = ((K0 * AK1) / (NumGemmKPrefetchStage * KPerBlock)) > 1; + const index_t num_loop = K / KPerBlock; - return has_main_k0_block_loop; + return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); } __host__ __device__ static constexpr auto @@ -360,7 +352,7 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 using DefaultBlock2CTileMap = remove_cvref_t; - template + template __device__ static void Run(const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_b_grid, FloatC* __restrict__ p_c_grid, @@ -411,28 +403,28 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 // A matrix blockwise copy auto a_blockwise_copy = - BlockwiseTensorSliceTransfer_v4r1, - ABlockTransferThreadClusterLengths_AK0_M_AK1, - ABlockTransferThreadClusterArrangeOrder, - FloatAB, - FloatAB, - decltype(a_grid_desc_ak0_m_ak1), - decltype(a_block_desc_ak0_m_ak1), - ABlockTransferSrcAccessOrder, - Sequence<1, 0, 2>, - ABlockTransferSrcVectorDim, - 2, - ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_AK1, - 1, - 1, - AThreadTransferSrcResetCoordinateAfterRun, - true, - NumGemmKPrefetchStage>( + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( a_grid_desc_ak0_m_ak1, make_multi_index(0, m_block_data_idx_on_grid, 0), a_element_op, @@ -442,28 +434,28 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 // B matrix blockwise copy auto b_blockwise_copy = - BlockwiseTensorSliceTransfer_v4r1, - BBlockTransferThreadClusterLengths_BK0_N_BK1, - BBlockTransferThreadClusterArrangeOrder, - FloatAB, - FloatAB, - decltype(b_grid_desc_bk0_n_bk1), - decltype(b_block_desc_bk0_n_bk1), - BBlockTransferSrcAccessOrder, - Sequence<1, 0, 2>, - BBlockTransferSrcVectorDim, - 2, - BBlockTransferSrcScalarPerVector, - BBlockTransferDstScalarPerVector_BK1, - 1, - 1, - BThreadTransferSrcResetCoordinateAfterRun, - true, - NumGemmKPrefetchStage>( + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( b_grid_desc_bk0_n_bk1, make_multi_index(0, n_block_data_idx_on_grid, 0), b_element_op, @@ -510,43 +502,25 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0); // gridwise GEMM pipeline - const auto gridwise_gemm_pipeline = - GridwiseGemmPipeline_v1, - remove_cvref_t, - remove_cvref_t, - remove_cvref_t, - remove_cvref_t, - remove_cvref_t, - remove_cvref_t, - remove_cvref_t, - remove_cvref_t, - remove_cvref_t, - remove_cvref_t, - remove_cvref_t, - remove_cvref_t, - remove_cvref_t, - NumGemmKPrefetchStage, - HasMainK0BlockLoop>{}; - const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / KPerBlock); - gridwise_gemm_pipeline.Run(a_grid_desc_ak0_m_ak1, - a_block_desc_ak0_m_ak1, - a_blockwise_copy, - a_grid_buf, - a_block_buf, - a_block_slice_copy_step, - b_grid_desc_bk0_n_bk1, - b_block_desc_bk0_n_bk1, - b_blockwise_copy, - b_grid_buf, - b_block_buf, - b_block_slice_copy_step, - blockwise_gemm, - c_thread_buf, - num_k_block_main_loop); + GridwiseGemmPipe::template Run(a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + blockwise_gemm, + c_thread_buf, + num_k_block_main_loop); // shuffle C and write out { @@ -662,8 +636,8 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ck::tensor_operation::element_wise::PassThrough{}}; // shuffle: blockwise copy C from LDS to global - auto c_shuffle_block_copy_lds_to_global = BlockwiseTensorSliceTransfer_v6r1< - BlockSize, // index_t BlockSize, + auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< + ThisThreadBlock, // ThreadGroup CElementwiseOperation, // ElementwiseOperation, CGlobalMemoryDataOperation, // DstInMemOp, Sequence<1, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp index 3354831e35..b28907b43e 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp @@ -4,8 +4,8 @@ #include "tensor_descriptor.hpp" #include "tensor_descriptor_helper.hpp" #include "blockwise_gemm_xdlops.hpp" -#include "blockwise_tensor_slice_transfer_v4r1.hpp" -#include "blockwise_tensor_slice_transfer_v6r1.hpp" +#include "thread_group_tensor_slice_transfer_v4r1.hpp" +#include "thread_group_tensor_slice_transfer_v6r1.hpp" #include "threadwise_tensor_slice_transfer.hpp" #include "gridwise_gemm_pipeline_v1.hpp" @@ -21,7 +21,7 @@ template + bool HasMainKBlockLoop> __global__ void #if CK_USE_LAUNCH_BOUNDS __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) @@ -41,17 +41,17 @@ __global__ void #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - GridwiseGemm::template Run(p_a_grid, - p_b_grid, - p_c_grid, - p_shared, - a_element_op, - b_element_op, - c_element_op, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - c_grid_desc_mblock_mperblock_nblock_nperblock, - block_2_ctile_map); + GridwiseGemm::template Run(p_a_grid, + p_b_grid, + p_c_grid, + p_shared, + a_element_op, + b_element_op, + c_element_op, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_ctile_map); #else ignore = p_a_grid; ignore = p_b_grid; @@ -125,6 +125,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 static constexpr auto AK1 = Number{}; static constexpr auto BK1 = Number{}; + using ThisThreadBlock = ThisThreadBlock; + + using GridwiseGemmPipe = GridwiseGemmPipeline_v1; + __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() { // A matrix in LDS memory, dst of blockwise copy @@ -190,10 +194,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, const CGridDesc_M_N& c_grid_desc_m_n) { - // static_assert(is_known_at_compile_time>::value && - // is_known_at_compile_time>::value, - // "wrong! K1 need to be known at compile-time"); - static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, "Invalid tuning param!"); @@ -208,21 +208,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0)) return false; - // check NumGemmKPrefetchStage - if constexpr(NumGemmKPrefetchStage == 1) - { - // 1-stage prefetch always supported - } - else if constexpr(NumGemmKPrefetchStage == 2) - { - // 2-stage prefetch currently only support even number of K0 loop - // TODO: add support for odd number of K0 loop - if(!((K / KPerBlock) % 2 == 0)) - { - return false; - } - } - else + // check gridwise gemm pipeline + const auto num_k_loop = K / KPerBlock; + + if(!GridwiseGemmPipe::IsSupported(num_k_loop)) { return false; } @@ -242,12 +231,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 return grid_size; } - // TODO move this function into GEMM-pipeline class - __host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0) + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) { - const bool has_main_k0_block_loop = ((K0 * AK1) / (NumGemmKPrefetchStage * KPerBlock)) > 1; + const index_t num_loop = K / KPerBlock; - return has_main_k0_block_loop; + return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); } __host__ __device__ static constexpr auto @@ -315,7 +303,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 using DefaultBlock2CTileMap = remove_cvref_t; - template + template __device__ static void Run(const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_b_grid, FloatC* __restrict__ p_c_grid, @@ -358,28 +346,28 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 // A matrix blockwise copy auto a_blockwise_copy = - BlockwiseTensorSliceTransfer_v4r1, - ABlockTransferThreadClusterLengths_AK0_M_AK1, - ABlockTransferThreadClusterArrangeOrder, - FloatAB, - FloatAB, - decltype(a_grid_desc_ak0_m_ak1), - decltype(a_block_desc_ak0_m_ak1), - ABlockTransferSrcAccessOrder, - Sequence<1, 0, 2>, - ABlockTransferSrcVectorDim, - 2, - ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_AK1, - 1, - 1, - AThreadTransferSrcResetCoordinateAfterRun, - true, - NumGemmKPrefetchStage>( + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( a_grid_desc_ak0_m_ak1, make_multi_index(0, m_block_data_idx_on_grid, 0), a_element_op, @@ -389,28 +377,28 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 // B matrix blockwise copy auto b_blockwise_copy = - BlockwiseTensorSliceTransfer_v4r1, - BBlockTransferThreadClusterLengths_BK0_N_BK1, - BBlockTransferThreadClusterArrangeOrder, - FloatAB, - FloatAB, - decltype(b_grid_desc_bk0_n_bk1), - decltype(b_block_desc_bk0_n_bk1), - BBlockTransferSrcAccessOrder, - Sequence<1, 0, 2>, - BBlockTransferSrcVectorDim, - 2, - BBlockTransferSrcScalarPerVector, - BBlockTransferDstScalarPerVector_BK1, - 1, - 1, - BThreadTransferSrcResetCoordinateAfterRun, - true, - NumGemmKPrefetchStage>( + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( b_grid_desc_bk0_n_bk1, make_multi_index(0, n_block_data_idx_on_grid, 0), b_element_op, @@ -457,43 +445,25 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0); // gridwise GEMM pipeline - const auto gridwise_gemm_pipeline = - GridwiseGemmPipeline_v1, - remove_cvref_t, - remove_cvref_t, - remove_cvref_t, - remove_cvref_t, - remove_cvref_t, - remove_cvref_t, - remove_cvref_t, - remove_cvref_t, - remove_cvref_t, - remove_cvref_t, - remove_cvref_t, - remove_cvref_t, - remove_cvref_t, - NumGemmKPrefetchStage, - HasMainK0BlockLoop>{}; - const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / KPerBlock); - gridwise_gemm_pipeline.Run(a_grid_desc_ak0_m_ak1, - a_block_desc_ak0_m_ak1, - a_blockwise_copy, - a_grid_buf, - a_block_buf, - a_block_slice_copy_step, - b_grid_desc_bk0_n_bk1, - b_block_desc_bk0_n_bk1, - b_blockwise_copy, - b_grid_buf, - b_block_buf, - b_block_slice_copy_step, - blockwise_gemm, - c_thread_buf, - num_k_block_main_loop); + GridwiseGemmPipe::template Run(a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + blockwise_gemm, + c_thread_buf, + num_k_block_main_loop); // shuffle C and write out { @@ -609,8 +579,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ck::tensor_operation::element_wise::PassThrough{}}; // shuffle: blockwise copy C from LDS to global - auto c_shuffle_block_copy_lds_to_global = BlockwiseTensorSliceTransfer_v6r1< - BlockSize, // index_t BlockSize, + auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< + ThisThreadBlock, // ThreadGroup CElementwiseOperation, // ElementwiseOperation, CGlobalMemoryDataOperation, // DstInMemOp, Sequence<1, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp index ae935593fe..19a37d4878 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp @@ -1,12 +1,10 @@ -#ifndef CK_GRIDWISE_GEMM_XDLOPS_V2R3_HPP -#define CK_GRIDWISE_GEMM_XDLOPS_V2R3_HPP - +#pragma once #include "common_header.hpp" #include "multi_index_transform_helper.hpp" #include "tensor_descriptor.hpp" #include "tensor_descriptor_helper.hpp" #include "blockwise_gemm_xdlops.hpp" -#include "blockwise_tensor_slice_transfer_v4r1.hpp" +#include "thread_group_tensor_slice_transfer_v4r1.hpp" #include "threadwise_tensor_slice_transfer.hpp" #include "gridwise_gemm_pipeline_v1.hpp" @@ -22,7 +20,7 @@ template + bool HasMainKBlockLoop> __global__ void #if CK_USE_LAUNCH_BOUNDS __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) @@ -42,17 +40,17 @@ __global__ void #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - GridwiseGemm::template Run(p_a_grid, - p_b_grid, - p_c_grid, - p_shared, - a_grid_desc_k0_m_k1, - b_grid_desc_k0_n_k1, - c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, - a_element_op, - b_element_op, - c_element_op, - block_2_ctile_map); + GridwiseGemm::template Run(p_a_grid, + p_b_grid, + p_c_grid, + p_shared, + a_grid_desc_k0_m_k1, + b_grid_desc_k0_n_k1, + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, + a_element_op, + b_element_op, + c_element_op, + block_2_ctile_map); #else ignore = p_a_grid; ignore = p_b_grid; @@ -67,88 +65,6 @@ __global__ void #endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) } -template -__global__ void -#if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) -#endif - kernel_grouped_gemm_xdlops_v2r3( - const StaticallyIndexedArray gemm_desc_, - const index_t group_count, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CElementwiseOperation c_element_op) -{ -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - - const index_t block_id = get_block_1d_id(); - -#if 1 - static_for<0, MaxGroupCount, 1>{}([&](auto i) { - if(block_id >= gemm_desc_[i].BlockStart_ && block_id < gemm_desc_[i].BlockEnd_ && - i < group_count) - { - auto group_id = i; - - GridwiseGemm::template Run( - gemm_desc_[group_id].a_ptr, - gemm_desc_[group_id].b_ptr, - gemm_desc_[group_id].c_ptr, - p_shared, - gemm_desc_[group_id].a_grid_desc_k0_m_k1_, - gemm_desc_[group_id].b_grid_desc_k0_n_k1_, - gemm_desc_[group_id].c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, - a_element_op, - b_element_op, - c_element_op, - gemm_desc_[group_id].grouped_gemm_block_2_ctile_map_); - } - }); -#else - const auto gemm_desc_ptr = reinterpret_cast(&gemm_desc_); - - index_t group_id = 0; - static_for<0, MaxGroupCount, 1>{}([&](auto i) { - group_id = (block_id >= gemm_desc_[i].BlockStart && block_id < gemm_desc_[i].BlockEnd && - i < group_count) - ? i - : group_id; - }); - - const index_t block_id_grp = block_id - gemm_desc_ptr[group_id].BlockStart; - - GridwiseGemm::template Run( - gemm_desc_ptr[group_id].a_ptr, - gemm_desc_ptr[group_id].b_ptr, - gemm_desc_ptr[group_id].c_ptr, - p_shared, - gemm_desc_ptr[group_id].a_grid_desc_k0_m_k1_, - gemm_desc_ptr[group_id].b_grid_desc_k0_n_k1_, - gemm_desc_ptr[group_id].c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, - a_element_op, - b_element_op, - c_element_op, - gemm_desc_ptr[group_id].block_2_ctile_map_, - block_id_grp); -#endif -#else - ignore = gemm_desc_; - ignore = group_count; - ignore = a_element_op; - ignore = b_element_op; - ignore = c_element_op; -#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) -} - template + index_t NumGemmKPrefetchStage = 1> struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 { static constexpr auto I0 = Number<0>{}; @@ -202,6 +118,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 // K1 should be Number<...> static constexpr auto K1 = Number{}; + using ThisThreadBlock = ThisThreadBlock; + + using GridwiseGemmPipe = GridwiseGemmPipeline_v1; + __host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1() { constexpr auto max_lds_align = K1; @@ -291,21 +211,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0)) return false; - // check NumPrefetch - if constexpr(NumPrefetch == 1) - { - // 1-stage prefetch always supported - } - else if constexpr(NumPrefetch == 2) - { - // 2-stage prefetch currently only support even number of K0 loop - // TODO: add support for odd number of K0 loop - if(!((K0 / K0PerBlock) % 2 == 0)) - { - return false; - } - } - else + // check gridwise gemm pipeline + const auto num_k_loop = K0 / K0PerBlock; + + if(!GridwiseGemmPipe::IsSupported(num_k_loop)) { return false; } @@ -335,12 +244,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 return grid_size; } - // TODO move this function into GEMM-pipeline class - __host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0) + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) { - const bool has_main_k0_block_loop = (K0 / (NumPrefetch * K0PerBlock)) > 1; + const index_t num_loop = K / (K0PerBlock * K1); - return has_main_k0_block_loop; + return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); } __host__ __device__ static constexpr auto @@ -433,7 +341,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 decltype(MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_M_N{})); using DefaultBlock2CTileMap = decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1)); - template + template __device__ static void Run(const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_b_grid, @@ -478,28 +386,28 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 // A matrix blockwise copy auto a_blockwise_copy = - BlockwiseTensorSliceTransfer_v4r1, - ABlockTransferThreadClusterLengths_K0_M_K1, - ABlockTransferThreadClusterArrangeOrder, - FloatAB, - FloatAB, - decltype(a_grid_desc_k0_m_k1), - decltype(a_block_desc_k0_m_k1), - ABlockTransferSrcAccessOrder, - Sequence<1, 0, 2>, - ABlockTransferSrcVectorDim, - 2, - ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_K1, - 1, - 1, - AThreadTransferSrcResetCoordinateAfterRun, - true, - NumPrefetch>( + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(a_grid_desc_k0_m_k1), + decltype(a_block_desc_k0_m_k1), + ABlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( a_grid_desc_k0_m_k1, make_multi_index(0, m_block_data_idx_on_grid, 0), a_element_op, @@ -509,28 +417,28 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 // B matrix blockwise copy auto b_blockwise_copy = - BlockwiseTensorSliceTransfer_v4r1, - BBlockTransferThreadClusterLengths_K0_N_K1, - BBlockTransferThreadClusterArrangeOrder, - FloatAB, - FloatAB, - decltype(b_grid_desc_k0_n_k1), - decltype(b_block_desc_k0_n_k1), - BBlockTransferSrcAccessOrder, - Sequence<1, 0, 2>, - BBlockTransferSrcVectorDim, - 2, - BBlockTransferSrcScalarPerVector, - BBlockTransferDstScalarPerVector_K1, - 1, - 1, - BThreadTransferSrcResetCoordinateAfterRun, - true, - NumPrefetch>( + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(b_grid_desc_k0_n_k1), + decltype(b_block_desc_k0_n_k1), + BBlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( b_grid_desc_k0_n_k1, make_multi_index(0, n_block_data_idx_on_grid, 0), b_element_op, @@ -575,41 +483,23 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); // gridwise GEMM pipeline - const auto gridwise_gemm_pipeline = - GridwiseGemmPipeline_v1, - remove_cvref_t, - remove_cvref_t, - remove_cvref_t, - remove_cvref_t, - remove_cvref_t, - remove_cvref_t, - remove_cvref_t, - remove_cvref_t, - remove_cvref_t, - remove_cvref_t, - remove_cvref_t, - remove_cvref_t, - remove_cvref_t, - NumPrefetch, - HasMainK0BlockLoop>{}; + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(K0 / K0PerBlock); - const index_t K0BlockMainLoop = __builtin_amdgcn_readfirstlane(K0 / K0PerBlock); - - gridwise_gemm_pipeline.Run(a_grid_desc_k0_m_k1, - a_block_desc_k0_m_k1, - a_blockwise_copy, - a_grid_buf, - a_block_buf, - a_block_slice_copy_step, - b_grid_desc_k0_n_k1, - b_block_desc_k0_n_k1, - b_blockwise_copy, - b_grid_buf, - b_block_buf, - b_block_slice_copy_step, - blockwise_gemm, - c_thread_buf, - K0BlockMainLoop); + GridwiseGemmPipe::template Run(a_grid_desc_k0_m_k1, + a_block_desc_k0_m_k1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_k0_n_k1, + b_block_desc_k0_n_k1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + blockwise_gemm, + c_thread_buf, + num_k_block_main_loop); // output: register to global memory { @@ -692,4 +582,3 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 }; } // namespace ck -#endif diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp index e9162f6e8a..4cc9345308 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp @@ -6,7 +6,7 @@ #include "tensor_descriptor.hpp" #include "tensor_descriptor_helper.hpp" #include "blockwise_gemm_xdlops.hpp" -#include "blockwise_tensor_slice_transfer_v4r1.hpp" +#include "thread_group_tensor_slice_transfer_v4r1.hpp" #include "threadwise_tensor_slice_transfer.hpp" namespace ck { @@ -120,6 +120,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 // K1 should be Number<...> static constexpr auto K1 = Number{}; + using ThisThreadBlock = ThisThreadBlock; + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() { constexpr auto max_lds_align = K1; @@ -420,27 +422,27 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 }(); // A matrix blockwise copy auto a_blockwise_copy = - BlockwiseTensorSliceTransfer_v4r1, - ABlockTransferThreadClusterLengths_K0_M_K1, - ABlockTransferThreadClusterArrangeOrder, - FloatAB, - FloatAB, - decltype(a_b_k0_m_k1_grid_desc), - decltype(a_b_k0_m_k1_block_desc), - ABlockTransferSrcAccessOrder, - Sequence<0, 2, 1, 3>, - ABlockTransferSrcVectorDim, - 3, - ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_K1, - 1, - 1, - AThreadTransferSrcResetCoordinateAfterRun, - true>( + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(a_b_k0_m_k1_grid_desc), + decltype(a_b_k0_m_k1_block_desc), + ABlockTransferSrcAccessOrder, + Sequence<0, 2, 1, 3>, + ABlockTransferSrcVectorDim, + 3, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true>( a_b_k0_m_k1_grid_desc, make_multi_index(k_batch_id, 0, m_block_data_idx_on_grid, 0), a_element_op, @@ -450,27 +452,27 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 // B matrix blockwise copy auto b_blockwise_copy = - BlockwiseTensorSliceTransfer_v4r1, - BBlockTransferThreadClusterLengths_K0_N_K1, - BBlockTransferThreadClusterArrangeOrder, - FloatAB, - FloatAB, - decltype(b_b_k0_n_k1_grid_desc), - decltype(b_b_k0_n_k1_block_desc), - BBlockTransferSrcAccessOrder, - Sequence<0, 2, 1, 3>, - BBlockTransferSrcVectorDim, - 3, - BBlockTransferSrcScalarPerVector, - BBlockTransferDstScalarPerVector_K1, - 1, - 1, - BThreadTransferSrcResetCoordinateAfterRun, - true>( + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(b_b_k0_n_k1_grid_desc), + decltype(b_b_k0_n_k1_block_desc), + BBlockTransferSrcAccessOrder, + Sequence<0, 2, 1, 3>, + BBlockTransferSrcVectorDim, + 3, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true>( b_b_k0_n_k1_grid_desc, make_multi_index(k_batch_id, 0, n_block_data_idx_on_grid, 0), b_element_op, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp index d1ea675e59..bcb7cd104c 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp @@ -6,8 +6,8 @@ #include "tensor_descriptor.hpp" #include "tensor_descriptor_helper.hpp" #include "blockwise_gemm_xdlops.hpp" -#include "blockwise_tensor_slice_transfer_v4r1.hpp" -#include "blockwise_tensor_slice_transfer_v6r1.hpp" +#include "thread_group_tensor_slice_transfer_v4r1.hpp" +#include "thread_group_tensor_slice_transfer_v6r1.hpp" #include "threadwise_tensor_slice_transfer.hpp" namespace ck { @@ -123,6 +123,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 // K1 should be Number<...> static constexpr auto K1 = Number{}; + using ThisThreadBlock = ThisThreadBlock; + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() { constexpr auto max_lds_align = K1; @@ -409,27 +411,27 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 }(); // A matrix blockwise copy auto a_blockwise_copy = - BlockwiseTensorSliceTransfer_v4r1, - ABlockTransferThreadClusterLengths_K0_M_K1, - ABlockTransferThreadClusterArrangeOrder, - FloatAB, - FloatAB, - decltype(a_b_k0_m_k1_grid_desc), - decltype(a_b_k0_m_k1_block_desc), - ABlockTransferSrcAccessOrder, - Sequence<0, 2, 1, 3>, - ABlockTransferSrcVectorDim, - 3, - ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_K1, - 1, - 1, - AThreadTransferSrcResetCoordinateAfterRun, - true>( + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(a_b_k0_m_k1_grid_desc), + decltype(a_b_k0_m_k1_block_desc), + ABlockTransferSrcAccessOrder, + Sequence<0, 2, 1, 3>, + ABlockTransferSrcVectorDim, + 3, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true>( a_b_k0_m_k1_grid_desc, make_multi_index(k_batch_id, 0, m_block_data_idx_on_grid, 0), a_element_op, @@ -439,27 +441,27 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 // B matrix blockwise copy auto b_blockwise_copy = - BlockwiseTensorSliceTransfer_v4r1, - BBlockTransferThreadClusterLengths_K0_N_K1, - BBlockTransferThreadClusterArrangeOrder, - FloatAB, - FloatAB, - decltype(b_b_k0_n_k1_grid_desc), - decltype(b_b_k0_n_k1_block_desc), - BBlockTransferSrcAccessOrder, - Sequence<0, 2, 1, 3>, - BBlockTransferSrcVectorDim, - 3, - BBlockTransferSrcScalarPerVector, - BBlockTransferDstScalarPerVector_K1, - 1, - 1, - BThreadTransferSrcResetCoordinateAfterRun, - true>( + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(b_b_k0_n_k1_grid_desc), + decltype(b_b_k0_n_k1_block_desc), + BBlockTransferSrcAccessOrder, + Sequence<0, 2, 1, 3>, + BBlockTransferSrcVectorDim, + 3, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true>( b_b_k0_n_k1_grid_desc, make_multi_index(k_batch_id, 0, n_block_data_idx_on_grid, 0), b_element_op, @@ -660,8 +662,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ck::tensor_operation::element_wise::PassThrough{}}; // LDS to global - auto c_block_copy_lds_to_global = BlockwiseTensorSliceTransfer_v6r1< - BlockSize, // index_t BlockSize, + auto c_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< + ThisThreadBlock, // index_t BlockSize, CElementwiseOperation, // ElementwiseOperation, CGlobalMemoryDataOperation, // DstInMemOp, Sequence<1, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp index fc9cd51c4f..eca71d9f77 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp @@ -1,13 +1,11 @@ -#ifndef CK_GRIDWISE_GEMM_XDLOPS_V3R1_HPP -#define CK_GRIDWISE_GEMM_XDLOPS_V3R1_HPP - +#pragma once #include "common_header.hpp" #include "multi_index_transform_helper.hpp" #include "tensor_descriptor.hpp" #include "tensor_descriptor_helper.hpp" #include "blockwise_gemm_xdlops.hpp" -#include "blockwise_tensor_slice_transfer_v4r1.hpp" -#include "blockwise_tensor_slice_transfer_v6r1.hpp" +#include "thread_group_tensor_slice_transfer_v4r1.hpp" +#include "thread_group_tensor_slice_transfer_v6r1.hpp" #include "threadwise_tensor_slice_transfer.hpp" #include "gridwise_gemm_pipeline_v1.hpp" #include "tensor_space_filling_curve.hpp" @@ -113,7 +111,7 @@ template < index_t CShuffleNXdlPerWavePerShuffle, typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, index_t CBlockTransferScalarPerVector_NWaveNPerXdl, - index_t NumPrefetch = 1> + index_t NumGemmKPrefetchStage = 1> struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 { static constexpr auto I0 = Number<0>{}; @@ -131,6 +129,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 static constexpr auto AK1 = Number{}; static constexpr auto BK1 = Number{}; + using ThisThreadBlock = ThisThreadBlock; + + using GridwiseGemmPipe = GridwiseGemmPipeline_v1; + __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() { constexpr auto max_lds_align = AK1; @@ -246,21 +248,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0)) return false; - // check NumPrefetch - if constexpr(NumPrefetch == 1) - { - // 1-stage prefetch always supported - } - else if constexpr(NumPrefetch == 2) - { - // 2-stage prefetch currently only support even number of K0 loop - // TODO: add support for odd number of K0 loop - if(!((K / KPerBlock) % 2 == 0)) - { - return false; - } - } - else + // check gridwise gemm pipeline + const auto num_k_loop = K / KPerBlock; + + if(!GridwiseGemmPipe::IsSupported(num_k_loop)) { return false; } @@ -290,12 +281,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 return grid_size; } - // TODO move this function into GEMM-pipeline class - __host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0) + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) { - const bool has_main_k0_block_loop = ((K0 * AK1) / (NumPrefetch * KPerBlock)) > 1; + const index_t num_loop = K / KPerBlock; - return has_main_k0_block_loop; + return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); } __host__ __device__ static constexpr auto @@ -413,28 +403,28 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 // A matrix blockwise copy auto a_blockwise_copy = - BlockwiseTensorSliceTransfer_v4r1, - ABlockTransferThreadClusterLengths_AK0_M_AK1, - ABlockTransferThreadClusterArrangeOrder, - FloatAB, - FloatAB, - decltype(a_grid_desc_ak0_m_ak1), - decltype(a_block_desc_ak0_m_ak1), - ABlockTransferSrcAccessOrder, - Sequence<1, 0, 2>, - ABlockTransferSrcVectorDim, - 2, - ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_K1, - 1, - 1, - AThreadTransferSrcResetCoordinateAfterRun, - true, - NumPrefetch>( + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( a_grid_desc_ak0_m_ak1, make_multi_index(0, m_block_data_idx_on_grid, 0), a_element_op, @@ -444,28 +434,28 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 // B matrix blockwise copy auto b_blockwise_copy = - BlockwiseTensorSliceTransfer_v4r1, - BBlockTransferThreadClusterLengths_BK0_N_BK1, - BBlockTransferThreadClusterArrangeOrder, - FloatAB, - FloatAB, - decltype(b_grid_desc_bk0_n_bk1), - decltype(b_block_desc_bk0_n_bk1), - BBlockTransferSrcAccessOrder, - Sequence<1, 0, 2>, - BBlockTransferSrcVectorDim, - 2, - BBlockTransferSrcScalarPerVector, - BBlockTransferDstScalarPerVector_K1, - 1, - 1, - BThreadTransferSrcResetCoordinateAfterRun, - true, - NumPrefetch>( + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( b_grid_desc_bk0_n_bk1, make_multi_index(0, n_block_data_idx_on_grid, 0), b_element_op, @@ -512,43 +502,25 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0); // gridwise GEMM pipeline - const auto gridwise_gemm_pipeline = - GridwiseGemmPipeline_v1, - remove_cvref_t, - remove_cvref_t, - remove_cvref_t, - remove_cvref_t, - remove_cvref_t, - remove_cvref_t, - remove_cvref_t, - remove_cvref_t, - remove_cvref_t, - remove_cvref_t, - remove_cvref_t, - remove_cvref_t, - remove_cvref_t, - NumPrefetch, - HasMainK0BlockLoop>{}; - const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / KPerBlock); - gridwise_gemm_pipeline.Run(a_grid_desc_ak0_m_ak1, - a_block_desc_ak0_m_ak1, - a_blockwise_copy, - a_grid_buf, - a_block_buf, - a_block_slice_copy_step, - b_grid_desc_bk0_n_bk1, - b_block_desc_bk0_n_bk1, - b_blockwise_copy, - b_grid_buf, - b_block_buf, - b_block_slice_copy_step, - blockwise_gemm, - c_thread_buf, - num_k_block_main_loop); + GridwiseGemmPipe::template Run(a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + blockwise_gemm, + c_thread_buf, + num_k_block_main_loop); // shuffle C and write out { @@ -672,8 +644,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 ck::tensor_operation::element_wise::PassThrough{}}; // LDS to global - auto c_block_copy_lds_to_global = BlockwiseTensorSliceTransfer_v6r1< - BlockSize, // index_t BlockSize, + auto c_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< + ThisThreadBlock, // ThreadGroup CElementwiseOperation, // ElementwiseOperation, CGlobalMemoryDataOperation, // DstInMemOp, Sequence<1, @@ -774,4 +746,3 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 }; } // namespace ck -#endif diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp index 51477cdb40..28624e08f9 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp @@ -6,8 +6,8 @@ #include "tensor_descriptor.hpp" #include "tensor_descriptor_helper.hpp" #include "blockwise_gemm_xdlops.hpp" -#include "blockwise_tensor_slice_transfer_v4r1.hpp" -#include "blockwise_tensor_slice_transfer_v6r2.hpp" +#include "thread_group_tensor_slice_transfer_v4r1.hpp" +#include "thread_group_tensor_slice_transfer_v6r2.hpp" #include "threadwise_tensor_slice_transfer.hpp" #include "gridwise_gemm_pipeline_v1.hpp" @@ -24,7 +24,7 @@ template + bool HasMainKBlockLoop> __global__ void #if CK_USE_LAUNCH_BOUNDS __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) @@ -48,7 +48,7 @@ __global__ void #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - GridwiseGemm::template Run( + GridwiseGemm::template Run( p_a_grid, p_b_grid, p_c_grid, @@ -119,7 +119,7 @@ template < index_t CShuffleNXdlPerWavePerShuffle, typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, index_t CBlockTransferScalarPerVector_NWaveNPerXdl, - index_t NumPrefetch = 1> + index_t NumGemmKPrefetchStage = 1> struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2 { static constexpr auto I0 = Number<0>{}; @@ -134,6 +134,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2 // K1 should be Number<...> static constexpr auto K1 = Number{}; + using ThisThreadBlock = ThisThreadBlock; + + using GridwiseGemmPipe = GridwiseGemmPipeline_v1; + __host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1() { constexpr auto max_lds_align = K1; @@ -252,21 +256,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2 if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0)) return false; - // check NumPrefetch - if constexpr(NumPrefetch == 1) - { - // 1-stage prefetch always supported - } - else if constexpr(NumPrefetch == 2) - { - // 2-stage prefetch currently only support even number of K0 loop - // TODO: add support for odd number of K0 loop - if(!((K0 / K0PerBlock) % 2 == 0)) - { - return false; - } - } - else + // check gridwise gemm pipeline + const auto num_k_loop = K0 / K0PerBlock; + + if(!GridwiseGemmPipe::IsSupported(num_k_loop)) { return false; } @@ -296,12 +289,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2 return grid_size; } - // TODO move this function into GEMM-pipeline class - __host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0) + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) { - const bool has_main_k0_block_loop = (K0 / (NumPrefetch * K0PerBlock)) > 1; + const index_t num_loop = K / (K0PerBlock * K1); - return has_main_k0_block_loop; + return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); } template @@ -379,7 +371,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2 using DefaultBlock2CTileMap = remove_cvref_t; - template + template __device__ static void Run(const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_b_grid, @@ -434,28 +426,28 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2 // A matrix blockwise copy auto a_blockwise_copy = - BlockwiseTensorSliceTransfer_v4r1, - ABlockTransferThreadClusterLengths_K0_M_K1, - ABlockTransferThreadClusterArrangeOrder, - FloatAB, - FloatAB, - decltype(a_grid_desc_k0_m_k1), - decltype(a_block_desc_k0_m_k1), - ABlockTransferSrcAccessOrder, - Sequence<1, 0, 2>, - ABlockTransferSrcVectorDim, - 2, - ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_K1, - 1, - 1, - AThreadTransferSrcResetCoordinateAfterRun, - true, - NumPrefetch>( + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(a_grid_desc_k0_m_k1), + decltype(a_block_desc_k0_m_k1), + ABlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( a_grid_desc_k0_m_k1, make_multi_index(0, m_block_data_idx_on_grid, 0), a_element_op, @@ -465,28 +457,28 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2 // B matrix blockwise copy auto b_blockwise_copy = - BlockwiseTensorSliceTransfer_v4r1, - BBlockTransferThreadClusterLengths_K0_N_K1, - BBlockTransferThreadClusterArrangeOrder, - FloatAB, - FloatAB, - decltype(b_grid_desc_k0_n_k1), - decltype(b_block_desc_k0_n_k1), - BBlockTransferSrcAccessOrder, - Sequence<1, 0, 2>, - BBlockTransferSrcVectorDim, - 2, - BBlockTransferSrcScalarPerVector, - BBlockTransferDstScalarPerVector_K1, - 1, - 1, - BThreadTransferSrcResetCoordinateAfterRun, - true, - NumPrefetch>( + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(b_grid_desc_k0_n_k1), + decltype(b_block_desc_k0_n_k1), + BBlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( b_grid_desc_k0_n_k1, make_multi_index(0, n_block_data_idx_on_grid, 0), b_element_op, @@ -531,41 +523,23 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2 constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); // gridwise GEMM pipeline - const auto gridwise_gemm_pipeline = - GridwiseGemmPipeline_v1, - remove_cvref_t, - remove_cvref_t, - remove_cvref_t, - remove_cvref_t, - remove_cvref_t, - remove_cvref_t, - remove_cvref_t, - remove_cvref_t, - remove_cvref_t, - remove_cvref_t, - remove_cvref_t, - remove_cvref_t, - remove_cvref_t, - NumPrefetch, - HasMainK0BlockLoop>{}; - const index_t K0BlockMainLoop = __builtin_amdgcn_readfirstlane(K0 / K0PerBlock); - gridwise_gemm_pipeline.Run(a_grid_desc_k0_m_k1, - a_block_desc_k0_m_k1, - a_blockwise_copy, - a_grid_buf, - a_block_buf, - a_block_slice_copy_step, - b_grid_desc_k0_n_k1, - b_block_desc_k0_n_k1, - b_blockwise_copy, - b_grid_buf, - b_block_buf, - b_block_slice_copy_step, - blockwise_gemm, - c_thread_buf, - K0BlockMainLoop); + GridwiseGemmPipe::template Run(a_grid_desc_k0_m_k1, + a_block_desc_k0_m_k1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_k0_n_k1, + b_block_desc_k0_n_k1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + blockwise_gemm, + c_thread_buf, + K0BlockMainLoop); // shuffle C and write out { @@ -690,8 +664,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2 n_thread_data_on_block_idx[I2]), ck::tensor_operation::element_wise::PassThrough{}}; - auto c_block_copy_lds_to_global = BlockwiseTensorSliceTransfer_v6r2< - BlockSize, // index_t BlockSize, + auto c_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r2< + ThisThreadBlock, // index_t BlockSize, CElementwiseOperation, // ElementwiseOperation, CGlobalMemoryDataOperation, // DstInMemOp, Sequence<1, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp index fa6f1d1f6b..46d00c7e1e 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp @@ -1,13 +1,11 @@ -#ifndef CK_GRIDWISE_GEMM_XDLOPS_V3R3_HPP -#define CK_GRIDWISE_GEMM_XDLOPS_V3R3_HPP - +#pragma once #include "common_header.hpp" #include "multi_index_transform_helper.hpp" #include "tensor_descriptor.hpp" #include "tensor_descriptor_helper.hpp" #include "blockwise_gemm_xdlops.hpp" -#include "blockwise_tensor_slice_transfer_v4r1.hpp" -#include "blockwise_tensor_slice_transfer_v6r3.hpp" +#include "thread_group_tensor_slice_transfer_v4r1.hpp" +#include "thread_group_tensor_slice_transfer_v6r3.hpp" #include "threadwise_tensor_slice_transfer.hpp" #include "gridwise_gemm_pipeline_v1.hpp" @@ -25,7 +23,7 @@ template + bool HasMainKBlockLoop> __global__ void #if CK_USE_LAUNCH_BOUNDS __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) @@ -52,7 +50,7 @@ __global__ void #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - GridwiseGemm::template Run( + GridwiseGemm::template Run( p_a_grid, p_b_grid, p_c_grid, @@ -128,7 +126,7 @@ template < index_t CShuffleNXdlPerWavePerShuffle, typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, index_t CBlockTransferScalarPerVector_NWaveNPerXdl, - index_t NumPrefetch = 1> + index_t NumGemmKPrefetchStage = 1> struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3 { static constexpr auto I0 = Number<0>{}; @@ -143,6 +141,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3 // K1 should be Number<...> static constexpr auto K1 = Number{}; + using ThisThreadBlock = ThisThreadBlock; + + using GridwiseGemmPipe = GridwiseGemmPipeline_v1; + __host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1() { constexpr auto max_lds_align = K1; @@ -261,21 +263,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3 if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0)) return false; - // check NumPrefetch - if constexpr(NumPrefetch == 1) - { - // 1-stage prefetch always supported - } - else if constexpr(NumPrefetch == 2) - { - // 2-stage prefetch currently only support even number of K0 loop - // TODO: add support for odd number of K0 loop - if(!((K0 / K0PerBlock) % 2 == 0)) - { - return false; - } - } - else + // check gridwise gemm pipeline + const auto num_k_loop = K0 / K0PerBlock; + + if(!GridwiseGemmPipe::IsSupported(num_k_loop)) { return false; } @@ -305,12 +296,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3 return grid_size; } - // TODO move this function into GEMM-pipeline class - __host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0) + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) { - const bool has_main_k0_block_loop = (K0 / (NumPrefetch * K0PerBlock)) > 1; + const index_t num_loop = K / (K0PerBlock * K1); - return has_main_k0_block_loop; + return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); } template @@ -393,7 +383,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3 using DefaultBlock2CTileMap = remove_cvref_t; - template + template __device__ static void Run(const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_b_grid, @@ -455,27 +445,27 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3 // A matrix blockwise copy auto a_blockwise_copy = - BlockwiseTensorSliceTransfer_v4r1, - ABlockTransferThreadClusterLengths_K0_M_K1, - ABlockTransferThreadClusterArrangeOrder, - FloatAB, - FloatAB, - decltype(a_grid_desc_k0_m_k1), - decltype(a_block_desc_k0_m_k1), - ABlockTransferSrcAccessOrder, - Sequence<1, 0, 2>, - ABlockTransferSrcVectorDim, - 2, - ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_K1, - 1, - 1, - AThreadTransferSrcResetCoordinateAfterRun, - true>( + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(a_grid_desc_k0_m_k1), + decltype(a_block_desc_k0_m_k1), + ABlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true>( a_grid_desc_k0_m_k1, make_multi_index(0, m_block_data_idx_on_grid, 0), a_element_op, @@ -485,27 +475,27 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3 // B matrix blockwise copy auto b_blockwise_copy = - BlockwiseTensorSliceTransfer_v4r1, - BBlockTransferThreadClusterLengths_K0_N_K1, - BBlockTransferThreadClusterArrangeOrder, - FloatAB, - FloatAB, - decltype(b_grid_desc_k0_n_k1), - decltype(b_block_desc_k0_n_k1), - BBlockTransferSrcAccessOrder, - Sequence<1, 0, 2>, - BBlockTransferSrcVectorDim, - 2, - BBlockTransferSrcScalarPerVector, - BBlockTransferDstScalarPerVector_K1, - 1, - 1, - BThreadTransferSrcResetCoordinateAfterRun, - true>( + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(b_grid_desc_k0_n_k1), + decltype(b_block_desc_k0_n_k1), + BBlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true>( b_grid_desc_k0_n_k1, make_multi_index(0, n_block_data_idx_on_grid, 0), b_element_op, @@ -550,41 +540,23 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3 constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); // gridwise GEMM pipeline - const auto gridwise_gemm_pipeline = - GridwiseGemmPipeline_v1, - remove_cvref_t, - remove_cvref_t, - remove_cvref_t, - remove_cvref_t, - remove_cvref_t, - remove_cvref_t, - remove_cvref_t, - remove_cvref_t, - remove_cvref_t, - remove_cvref_t, - remove_cvref_t, - remove_cvref_t, - remove_cvref_t, - NumPrefetch, - HasMainK0BlockLoop>{}; - const index_t K0BlockMainLoop = __builtin_amdgcn_readfirstlane(K0 / K0PerBlock); - gridwise_gemm_pipeline.Run(a_grid_desc_k0_m_k1, - a_block_desc_k0_m_k1, - a_blockwise_copy, - a_grid_buf, - a_block_buf, - a_block_slice_copy_step, - b_grid_desc_k0_n_k1, - b_block_desc_k0_n_k1, - b_blockwise_copy, - b_grid_buf, - b_block_buf, - b_block_slice_copy_step, - blockwise_gemm, - c_thread_buf, - K0BlockMainLoop); + GridwiseGemmPipe::template Run(a_grid_desc_k0_m_k1, + a_block_desc_k0_m_k1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_k0_n_k1, + b_block_desc_k0_n_k1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + blockwise_gemm, + c_thread_buf, + K0BlockMainLoop); // shuffle C and write out { @@ -623,17 +595,18 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3 constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, - make_tuple( - make_freeze_transform(I0), // freeze mblock - make_pass_through_transform( - Number{}), // M0 (MXdlPerWave) per shuffle - make_unmerge_transform( - make_tuple(M1, M2, M3, M4)), // M1 = MWave, M2 * M3 * M4 = MPerXdl - make_freeze_transform(I0), // freeze nblock - make_pass_through_transform( - Number{}), // N0 (NXdlPerWave) per shuffle - make_unmerge_transform( - make_tuple(N1, N2))), // M1 = MWave, M2 * M3 * M4 = MPerXdl + make_tuple(make_freeze_transform(I0), // freeze mblock + make_pass_through_transform( + Number{}), // M0 (MXdlPerWave) per + // shuffle + make_unmerge_transform( + make_tuple(M1, M2, M3, M4)), // M1 = MWave, M2 * M3 * M4 = MPerXdl + make_freeze_transform(I0), // freeze nblock + make_pass_through_transform( + Number{}), // N0 (NXdlPerWave) per + // shuffle + make_unmerge_transform( + make_tuple(N1, N2))), // M1 = MWave, M2 * M3 * M4 = MPerXdl make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, @@ -709,8 +682,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3 n_thread_data_on_block_idx[I2]), ck::tensor_operation::element_wise::PassThrough{}}; - auto c_block_copy_lds_to_global = BlockwiseTensorSliceTransfer_v6r3< - BlockSize, // index_t BlockSize, + auto c_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r3< + ThisThreadBlock, // ThreadGroup CElementwiseOperation, // ElementwiseOperation, CGlobalMemoryDataOperation, // DstInMemOp, Sequence<1, @@ -851,4 +824,3 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3 }; } // namespace ck -#endif diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp index 6521913541..7a75ca5380 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp @@ -51,7 +51,7 @@ template {}]); + element_op_(v, src_buf[Number{}]); // apply type convert - dst_vector.template AsType()(i) = type_convert(dst_v); + dst_vector.template AsType()(i) = type_convert(v); }); const bool is_dst_valid = @@ -213,7 +212,7 @@ struct ThreadwiseTensorSliceTransfer_v1r3 private: DstCoord dst_coord_; - const DstElementwiseOperation dst_element_op_; + const ElementwiseOperation element_op_; }; // namespace ThreadwiseTensorSliceTransfer_v1r3 // Assume: diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r1.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r1.hpp index c6360d3b29..042bc95f55 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r1.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r1.hpp @@ -102,8 +102,13 @@ struct ThreadwiseTensorSliceTransfer_v6r1 // apply pointwise operation static_for<0, ScalarPerVector, 1>{}([&](auto i) { - element_op_(dst_vector_container.template AsType()(i), - src_vector_container.template AsType()[i]); + SrcData v; + + // apply element-wise operation + element_op_(v, src_vector_container.template AsType()[i]); + + // apply type convert + dst_vector_container.template AsType()(i) = type_convert(v); }); const bool is_dst_valid = diff --git a/include/ck/utility/amd_xdlops.hpp b/include/ck/utility/amd_xdlops.hpp index 91d109bae1..94693f510e 100644 --- a/include/ck/utility/amd_xdlops.hpp +++ b/include/ck/utility/amd_xdlops.hpp @@ -266,8 +266,8 @@ struct intrin_mfma_i32_32x32x8i8<32, 32> __device__ static void Run(const int8x4_t& reg_a, const int8x4_t& reg_b, FloatC& reg_c) { reg_c.template AsType()(Number<0>{}) = - __builtin_amdgcn_mfma_i32_32x32x8i8(bit_cast(reg_a), - bit_cast(reg_b), + __builtin_amdgcn_mfma_i32_32x32x8i8(bit_cast(reg_a), + bit_cast(reg_b), reg_c.template AsType()[Number<0>{}], 0, 0, @@ -285,8 +285,8 @@ struct intrin_mfma_i32_16x16x16i8<16, 16> __device__ static void Run(const int8x4_t& reg_a, const int8x4_t& reg_b, FloatC& reg_c) { reg_c.template AsType()(Number<0>{}) = - __builtin_amdgcn_mfma_i32_16x16x16i8(bit_cast(reg_a), - bit_cast(reg_b), + __builtin_amdgcn_mfma_i32_16x16x16i8(bit_cast(reg_a), + bit_cast(reg_b), reg_c.template AsType()[Number<0>{}], 0, 0, diff --git a/include/ck/utility/common_header.hpp b/include/ck/utility/common_header.hpp index c1bc937062..539263703b 100644 --- a/include/ck/utility/common_header.hpp +++ b/include/ck/utility/common_header.hpp @@ -28,6 +28,7 @@ #include "transpose_vectors.hpp" #include "inner_product.hpp" #include "element_wise_operation.hpp" +#include "thread_group.hpp" #include "debug.hpp" #include "amd_buffer_addressing.hpp" diff --git a/include/ck/utility/get_id.hpp b/include/ck/utility/get_id.hpp index f742512d40..d1288a2274 100644 --- a/include/ck/utility/get_id.hpp +++ b/include/ck/utility/get_id.hpp @@ -3,11 +3,15 @@ namespace ck { -__device__ constexpr index_t get_wave_size() { return CK_GPU_WAVE_SIZE; } +__host__ __device__ constexpr index_t get_warp_size() +{ + // warpSize is defined by HIP + return warpSize; +} __device__ index_t get_thread_local_1d_id() { return threadIdx.x; } -__device__ index_t get_wave_local_1d_id() { return threadIdx.x / get_wave_size(); } +__device__ index_t get_warp_local_1d_id() { return threadIdx.x / get_warp_size(); } __device__ index_t get_block_1d_id() { return blockIdx.x; } diff --git a/include/ck/utility/thread_group.hpp b/include/ck/utility/thread_group.hpp new file mode 100644 index 0000000000..bd3563c5f1 --- /dev/null +++ b/include/ck/utility/thread_group.hpp @@ -0,0 +1,18 @@ +#pragma once +#include "get_id.hpp" + +namespace ck { + +template +struct ThisThreadBlock +{ + static constexpr index_t kNumThread_ = ThreadPerBlock; + + __device__ static constexpr index_t GetNumOfThread() { return kNumThread_; } + + __device__ static constexpr bool IsBelong() { return true; } + + __device__ static index_t GetThreadId() { return get_thread_local_1d_id(); } +}; + +} // namespace ck diff --git a/include/ck/utility/tuple.hpp b/include/ck/utility/tuple.hpp index 96cab4b99e..766a78240b 100644 --- a/include/ck/utility/tuple.hpp +++ b/include/ck/utility/tuple.hpp @@ -21,9 +21,9 @@ struct TupleElement { __host__ __device__ constexpr TupleElement() = default; - template >, TupleElement>::value, - bool>::type = false> + template < + typename T, + typename enable_if, TupleElement>::value, bool>::type = false> __host__ __device__ constexpr TupleElement(T&& v) : mData(std::forward(v)) { } @@ -60,7 +60,7 @@ struct TupleImpl, Xs...> : TupleElement, Xs> template >, TupleImpl>::value, + !is_same, TupleImpl>::value, bool>::type = false> __host__ __device__ constexpr TupleImpl(Y&& y) : TupleElement, Xs>(std::forward(y))... @@ -101,8 +101,7 @@ struct Tuple : detail::TupleImpl>, Tuple>::value, + typename enable_if, Tuple>::value, bool>::type = false> __host__ __device__ constexpr Tuple(Y&& y) : base(std::forward(y)) { diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp index 3601fafc28..1b49ca5740 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp @@ -1,6 +1,4 @@ -#ifndef REFERENCE_GEMM_HPP -#define REFERENCE_GEMM_HPP - +#pragma once #include #include #include "device_base.hpp" @@ -129,4 +127,3 @@ struct ReferenceGemm : public device::BaseOperator } // namespace host } // namespace tensor_operation } // namespace ck -#endif diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instance.cpp index 791d0c2810..de97b60a62 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instance.cpp @@ -1,6 +1,6 @@ #include #include "config.hpp" -#include "device_gemm_xdl_c_shuffle.hpp" +#include "device_gemm_xdl_cshuffle.hpp" #include "element_wise_operation.hpp" #include "device_operation_instance.hpp" @@ -20,26 +20,28 @@ using S = ck::Sequence; using PassThrough = ck::tensor_operation::element_wise::PassThrough; +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + // Compilation parameters for a[m, k] * b[n, k] = c[m, n] using device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances = std::tuple< // clang-format off - //#####################| AData| BData| CData| AccData| CShuffle| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Num| - //#####################| Type| Type| Type| Type| DataType| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Prefetch| - //#####################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| | - //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, F16, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8, 2>, - DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, F16, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8, 2>, - DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, F16, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8, 2>, - DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, F16, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8, 2>, - DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, F16, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8, 2>, - DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, F16, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8, 2>, - DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, F16, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8, 2>, - DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, F16, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8, 2>, - DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, F16, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8, 2>, - DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, F16, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8, 2>, - DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, F16, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8, 2>, - DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, F16, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8, 2>, - DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, F16, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8, 2> + //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> // clang-format on >; diff --git a/profiler/include/profile_gemm_impl.hpp b/profiler/include/profile_gemm_impl.hpp index f266188844..93262fe802 100644 --- a/profiler/include/profile_gemm_impl.hpp +++ b/profiler/include/profile_gemm_impl.hpp @@ -421,7 +421,7 @@ void profile_gemm_impl(int do_verification, std::size_t flop = std::size_t(2) * M * N * K; std::size_t num_btype = - sizeof(ADataType) * M * K + sizeof(BDataType) * K * M + sizeof(CDataType) * M * N; + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N; float tflops = static_cast(flop) / 1.E9 / ave_time; diff --git a/test/gemm/gemm_bf16.cpp b/test/gemm/gemm_bf16.cpp index 3f08acb1e6..5461088b02 100644 --- a/test/gemm/gemm_bf16.cpp +++ b/test/gemm/gemm_bf16.cpp @@ -15,7 +15,7 @@ #include "host_gemm.hpp" #include "device_tensor.hpp" #include "device_gemm_xdl.hpp" -#include "device_gemm_xdl_c_shuffle.hpp" +#include "device_gemm_xdl_cshuffle.hpp" #include "element_wise_operation.hpp" #include "reference_gemm.hpp" #include "gemm_specialization.hpp" diff --git a/test/gemm/gemm_fp16.cpp b/test/gemm/gemm_fp16.cpp index d7669bb242..aeffeafd3e 100644 --- a/test/gemm/gemm_fp16.cpp +++ b/test/gemm/gemm_fp16.cpp @@ -13,7 +13,7 @@ #include "host_gemm.hpp" #include "device_tensor.hpp" #include "device_gemm_xdl.hpp" -#include "device_gemm_xdl_c_shuffle.hpp" +#include "device_gemm_xdl_cshuffle.hpp" #include "element_wise_operation.hpp" #include "gemm_specialization.hpp" diff --git a/test/gemm/gemm_fp32.cpp b/test/gemm/gemm_fp32.cpp index 6c86085f3b..10b5175c37 100644 --- a/test/gemm/gemm_fp32.cpp +++ b/test/gemm/gemm_fp32.cpp @@ -15,7 +15,7 @@ #include "host_gemm.hpp" #include "device_tensor.hpp" #include "device_gemm_xdl.hpp" -#include "device_gemm_xdl_c_shuffle.hpp" +#include "device_gemm_xdl_cshuffle.hpp" #include "element_wise_operation.hpp" #include "reference_gemm.hpp" #include "gemm_specialization.hpp" diff --git a/test/gemm/gemm_int8.cpp b/test/gemm/gemm_int8.cpp index 864fca8df4..870881dd76 100644 --- a/test/gemm/gemm_int8.cpp +++ b/test/gemm/gemm_int8.cpp @@ -15,7 +15,7 @@ #include "host_gemm.hpp" #include "device_tensor.hpp" #include "device_gemm_xdl.hpp" -#include "device_gemm_xdl_c_shuffle.hpp" +#include "device_gemm_xdl_cshuffle.hpp" #include "element_wise_operation.hpp" #include "reference_gemm.hpp" #include "gemm_specialization.hpp"