mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
Multi-kernel CGEMM (#230)
* Reference CGEMM + test stub * Format. * Incomplete simple implementation * Library instances * Sketch of tests * Test fixes. * Example added * Cosmetics * Add elementwise operation kernel and example * Add comment * Add template argument of dim . Prepare to support multiple dimension * Rename example * Support 1 dimension * Add static assert * Add comment * Second auxiliary buffer added * Extract pad * Remove redundant argument * Support any dimension for elementwise operation * Remove line * Let it be the multiple number of CU * Move thread per block to the parameter of constructor * Consuming binary ops to do A+B / A-B * Fix + cosmetics + bf16 test commented out temporarily * Format * Enabling bf16 test * Revert "Enabling bf16 test" This reverts commitf497e2ba44. * Fix + test reenabled * fix build * Revert "fix build" This reverts commitd73102384b. * post PR #235 merge fix * amend * Single workspace for cgemm + helper * Perf calc fix * Review remarks: static_cast * Review remarks: binary ops templated * Cleaning * Removal of instances and their tests * Review remarks from aosew addressed * Review remark: unnecessary attribute * Post-merge fixes * Restrict 4gemm to PassThrough + bug fix * Review remarks * update licence * change cgemm example to fp16 Co-authored-by: rocking <chunylai@amd.com> Co-authored-by: Chao Liu <chao.liu2@amd.com> Co-authored-by: Anthony Chang <ac.chang@outlook.com>
This commit is contained in:
@@ -1,3 +1,28 @@
|
||||
/*******************************************************************************
|
||||
*
|
||||
* MIT License
|
||||
*
|
||||
* Copyright (c) 2022 Advanced Micro Devices, Inc.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in all
|
||||
* copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
* SOFTWARE.
|
||||
*
|
||||
*******************************************************************************/
|
||||
#include <iostream>
|
||||
#include <cstdlib>
|
||||
#include "check_err.hpp"
|
||||
@@ -17,7 +42,8 @@ using ABDataType = F16;
|
||||
using CDataType = F16;
|
||||
using EltwiseComputeDataType = F32;
|
||||
|
||||
using Add = ck::tensor_operation::binary_element_wise::Add;
|
||||
using Add = ck::tensor_operation::binary_element_wise::
|
||||
Add<EltwiseComputeDataType, EltwiseComputeDataType, EltwiseComputeDataType>;
|
||||
|
||||
using DeviceElementwiseAddInstance =
|
||||
ck::tensor_operation::device::DeviceBinaryElementwise<ABDataType,
|
||||
@@ -46,19 +72,19 @@ void host_broadcast2D(
|
||||
{
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
ComputeDataType Amn = static_cast<ComputeDataType>(A(m, n));
|
||||
ComputeDataType Amn = ck::type_convert<ComputeDataType>(A(m, n));
|
||||
ComputeDataType Cmn = 0;
|
||||
if constexpr(broadcastDim == 0)
|
||||
{
|
||||
ComputeDataType Bn = static_cast<ComputeDataType>(B(n));
|
||||
ComputeDataType Bn = ck::type_convert<ComputeDataType>(B(n));
|
||||
functor(Cmn, Amn, Bn);
|
||||
}
|
||||
else
|
||||
{
|
||||
ComputeDataType Bm = static_cast<ComputeDataType>(B(m));
|
||||
ComputeDataType Bm = ck::type_convert<ComputeDataType>(B(m));
|
||||
functor(Cmn, Amn, Bm);
|
||||
}
|
||||
C(m, n) = static_cast<ctype>(Cmn);
|
||||
C(m, n) = ck::type_convert<ctype>(Cmn);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -17,7 +17,8 @@ using ABDataType = F16;
|
||||
using CDataType = F16;
|
||||
using EltwiseComputeDataType = F32;
|
||||
|
||||
using Add = ck::tensor_operation::binary_element_wise::Add;
|
||||
using Add = ck::tensor_operation::binary_element_wise::
|
||||
Add<EltwiseComputeDataType, EltwiseComputeDataType, EltwiseComputeDataType>;
|
||||
|
||||
using DeviceElementwiseAddInstance =
|
||||
ck::tensor_operation::device::DeviceBinaryElementwise<ABDataType,
|
||||
@@ -48,11 +49,11 @@ void host_broadcast3D_am_bmnk(HostTensorC& C,
|
||||
for(std::size_t n = 0; n < shape[1]; ++n)
|
||||
for(std::size_t k = 0; k < shape[2]; ++k)
|
||||
{
|
||||
ComputeDataType a_val = static_cast<ComputeDataType>(A(m));
|
||||
ComputeDataType b_val = static_cast<ComputeDataType>(B(m, n, k));
|
||||
ComputeDataType a_val = ck::type_convert<ComputeDataType>(A(m));
|
||||
ComputeDataType b_val = ck::type_convert<ComputeDataType>(B(m, n, k));
|
||||
ComputeDataType c_val = 0;
|
||||
functor(c_val, a_val, b_val);
|
||||
C(m, n, k) = static_cast<ctype>(c_val);
|
||||
C(m, n, k) = ck::type_convert<ctype>(c_val);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,3 +1,28 @@
|
||||
/*******************************************************************************
|
||||
*
|
||||
* MIT License
|
||||
*
|
||||
* Copyright (c) 2022 Advanced Micro Devices, Inc.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in all
|
||||
* copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
* SOFTWARE.
|
||||
*
|
||||
*******************************************************************************/
|
||||
#include <iostream>
|
||||
#include <cstdlib>
|
||||
#include "check_err.hpp"
|
||||
@@ -17,7 +42,8 @@ using ABDataType = F16;
|
||||
using CDataType = F16;
|
||||
using EltwiseComputeDataType = F32;
|
||||
|
||||
using Add = ck::tensor_operation::binary_element_wise::Add;
|
||||
using Add = ck::tensor_operation::binary_element_wise::
|
||||
Add<EltwiseComputeDataType, EltwiseComputeDataType, EltwiseComputeDataType>;
|
||||
|
||||
using DeviceElementwiseAddInstance =
|
||||
ck::tensor_operation::device::DeviceBinaryElementwise<ABDataType,
|
||||
@@ -43,11 +69,11 @@ void host_elementwise1D(
|
||||
|
||||
for(int m = 0; m < M; ++m)
|
||||
{
|
||||
ComputeDataType Am = static_cast<ComputeDataType>(A(m));
|
||||
ComputeDataType Bm = static_cast<ComputeDataType>(B(m));
|
||||
ComputeDataType Am = ck::type_convert<ComputeDataType>(A(m));
|
||||
ComputeDataType Bm = ck::type_convert<ComputeDataType>(B(m));
|
||||
ComputeDataType Cm = 0;
|
||||
functor(Cm, Am, Bm);
|
||||
C(m) = static_cast<ctype>(Cm);
|
||||
C(m) = ck::type_convert<ctype>(Cm);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,3 +1,28 @@
|
||||
/*******************************************************************************
|
||||
*
|
||||
* MIT License
|
||||
*
|
||||
* Copyright (c) 2020 Advanced Micro Devices, Inc.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in all
|
||||
* copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
* SOFTWARE.
|
||||
*
|
||||
*******************************************************************************/
|
||||
#include <iostream>
|
||||
#include <cstdlib>
|
||||
#include "check_err.hpp"
|
||||
@@ -17,7 +42,8 @@ using ABDataType = F16;
|
||||
using CDataType = F16;
|
||||
using EltwiseComputeDataType = F32;
|
||||
|
||||
using Add = ck::tensor_operation::binary_element_wise::Add;
|
||||
using Add = ck::tensor_operation::binary_element_wise::
|
||||
Add<EltwiseComputeDataType, EltwiseComputeDataType, EltwiseComputeDataType>;
|
||||
|
||||
using DeviceElementwiseAddInstance =
|
||||
ck::tensor_operation::device::DeviceBinaryElementwise<ABDataType,
|
||||
@@ -49,11 +75,11 @@ void host_elementwise4D(HostTensorC& C,
|
||||
for(std::size_t h = 0; h < shape[2]; ++h)
|
||||
for(std::size_t w = 0; w < shape[3]; ++w)
|
||||
{
|
||||
ComputeDataType a_val = static_cast<ComputeDataType>(A(n, c, h, w));
|
||||
ComputeDataType b_val = static_cast<ComputeDataType>(B(n, c, h, w));
|
||||
ComputeDataType a_val = ck::type_convert<ComputeDataType>(A(n, c, h, w));
|
||||
ComputeDataType b_val = ck::type_convert<ComputeDataType>(B(n, c, h, w));
|
||||
ComputeDataType c_val = 0;
|
||||
functor(c_val, a_val, b_val);
|
||||
C(n, c, h, w) = static_cast<ctype>(c_val);
|
||||
C(n, c, h, w) = ck::type_convert<ctype>(c_val);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
1
example/22_cgemm/CMakeLists.txt
Normal file
1
example/22_cgemm/CMakeLists.txt
Normal file
@@ -0,0 +1 @@
|
||||
add_example_executable(example_cgemm_xdl_fp16 cgemm_xdl_fp16.cpp)
|
||||
302
example/22_cgemm/cgemm_xdl_fp16.cpp
Normal file
302
example/22_cgemm/cgemm_xdl_fp16.cpp
Normal file
@@ -0,0 +1,302 @@
|
||||
/*******************************************************************************
|
||||
*
|
||||
* MIT License
|
||||
*
|
||||
* Copyright (c) 2022 Advanced Micro Devices, Inc.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in all
|
||||
* copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
* SOFTWARE.
|
||||
*
|
||||
*******************************************************************************/
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
#include <stdlib.h>
|
||||
#include <half.hpp>
|
||||
|
||||
#include "check_err.hpp"
|
||||
#include "config.hpp"
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "host_tensor_generator.hpp"
|
||||
#include "device_tensor.hpp"
|
||||
#include "device_cgemm_4gemm_xdl_cshuffle.hpp"
|
||||
#include "element_wise_operation.hpp"
|
||||
#include "reference_cgemm.hpp"
|
||||
#include "gemm_specialization.hpp"
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using ADataType = F16;
|
||||
using BDataType = F16;
|
||||
using CDataType = F16;
|
||||
using AccDataType = F32;
|
||||
|
||||
using ALayout = ck::tensor_layout::gemm::RowMajor;
|
||||
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
|
||||
using CLayout = ck::tensor_layout::gemm::RowMajor;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
|
||||
// clang-format off
|
||||
using DeviceCGemmInstance = ck::tensor_operation::device::DeviceCGemm_4Gemm_Xdl_CShuffle
|
||||
<ALayout, // typename ALayout
|
||||
BLayout, // typename BLayout
|
||||
CLayout, // typename CLayout
|
||||
ADataType, // typename ADataType
|
||||
BDataType, // typename BDataType
|
||||
CDataType, // typename CDataType
|
||||
AccDataType, // typename GemmAccDataType
|
||||
CDataType, // typename CShuffleDataType
|
||||
PassThrough, // typename AElementwiseOperation
|
||||
PassThrough, // typename BElementwiseOperation
|
||||
PassThrough, // typename CElementwiseOperation
|
||||
GemmDefault, // GemmSpecialization GemmSpec
|
||||
1, // index_t NumGemmKPrefetchStage
|
||||
256, // index_t BlockSize
|
||||
256, // index_t MPerBlock
|
||||
128, // index_t NPerBlock
|
||||
32, // index_t KPerBlock
|
||||
8, // index_t AK1
|
||||
8, // index_t BK1
|
||||
32, // index_t MPerXDL
|
||||
32, // index_t NPerXDL
|
||||
4, // index_t MXdlPerWave
|
||||
2, // index_t NXdlPerWave
|
||||
S<4, 64, 1>, // typename ABlockTransferThreadClusterLengths_AK0_M_AK1
|
||||
S<1, 0, 2>, // typename ABlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // typename ABlockTransferSrcAccessOrder
|
||||
2, // index_t ABlockTransferSrcVectorDim
|
||||
8, // index_t ABlockTransferSrcScalarPerVector
|
||||
8, // index_t ABlockTransferDstScalarPerVector_AK1
|
||||
1, // index_t ABlockLdsExtraM
|
||||
S<4, 64, 1>, // typename BBlockTransferThreadClusterLengths_BK0_N_BK1
|
||||
S<1, 0, 2>, // typename BBlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // typename BBlockTransferSrcAccessOrder
|
||||
2, // index_t BBlockTransferSrcVectorDim
|
||||
8, // index_t BBlockTransferSrcScalarPerVector
|
||||
8, // index_t BBlockTransferDstScalarPerVector_BK1
|
||||
1, // index_t BBlockLdsExtraN
|
||||
1, // index_t CShuffleMXdlPerWavePerShuffle
|
||||
1, // index_t CShuffleNXdlPerWavePerShuffle
|
||||
S<1, 32, 1, 8>, // typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
8>; // index_t CShuffleBlockTransferScalarPerVector_NPerBlock
|
||||
// clang-format on
|
||||
|
||||
using ReferenceCGemmInstance = ck::tensor_operation::host::
|
||||
ReferenceCGemm<ADataType, BDataType, CDataType, PassThrough, PassThrough, PassThrough>;
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = false;
|
||||
|
||||
// CGEMM shape
|
||||
ck::index_t M = 3840;
|
||||
ck::index_t N = 4096;
|
||||
ck::index_t K = 4096;
|
||||
|
||||
ck::index_t StrideA = 4096;
|
||||
ck::index_t StrideB = 4096;
|
||||
ck::index_t StrideC = 4096;
|
||||
|
||||
if(argc == 4)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
}
|
||||
else if(argc == 10)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
|
||||
M = std::stoi(argv[4]);
|
||||
N = std::stoi(argv[5]);
|
||||
K = std::stoi(argv[6]);
|
||||
|
||||
StrideA = std::stoi(argv[7]);
|
||||
StrideB = std::stoi(argv[8]);
|
||||
StrideC = std::stoi(argv[9]);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("arg1: verification (0=no, 1=yes)\n");
|
||||
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
|
||||
printf("arg3: run kernel # of times (>1)\n");
|
||||
printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n");
|
||||
exit(0);
|
||||
}
|
||||
|
||||
auto f_host_tensor_descriptor =
|
||||
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
|
||||
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
|
||||
std::vector<std::size_t>({stride, 1}));
|
||||
}
|
||||
else
|
||||
{
|
||||
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
|
||||
std::vector<std::size_t>({1, stride}));
|
||||
}
|
||||
};
|
||||
|
||||
Tensor<ADataType> a_m_k_real(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
|
||||
Tensor<ADataType> a_m_k_imag(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
|
||||
Tensor<BDataType> b_k_n_real(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
|
||||
Tensor<BDataType> b_k_n_imag(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
|
||||
Tensor<CDataType> c_m_n_real_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
|
||||
Tensor<CDataType> c_m_n_imag_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
|
||||
|
||||
std::cout << "a_m_k_real: " << a_m_k_real.mDesc << std::endl;
|
||||
std::cout << "a_m_k_imag: " << a_m_k_imag.mDesc << std::endl;
|
||||
std::cout << "b_k_n_real: " << b_k_n_real.mDesc << std::endl;
|
||||
std::cout << "b_k_n_imag: " << b_k_n_imag.mDesc << std::endl;
|
||||
std::cout << "c_m_n_real: " << c_m_n_real_device_result.mDesc << std::endl;
|
||||
std::cout << "c_m_n_imag: " << c_m_n_imag_device_result.mDesc << std::endl;
|
||||
|
||||
switch(init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
a_m_k_real.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
|
||||
a_m_k_imag.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
|
||||
b_k_n_real.GenerateTensorValue(GeneratorTensor_2<BDataType>{-2, 2});
|
||||
b_k_n_imag.GenerateTensorValue(GeneratorTensor_2<BDataType>{-2, 2});
|
||||
break;
|
||||
default:
|
||||
a_m_k_real.GenerateTensorValue(GeneratorTensor_3<ADataType>{-0.5, 0.5});
|
||||
a_m_k_imag.GenerateTensorValue(GeneratorTensor_3<ADataType>{-0.5, 0.5});
|
||||
b_k_n_real.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
|
||||
b_k_n_imag.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
|
||||
}
|
||||
|
||||
auto cgemm = DeviceCGemmInstance{};
|
||||
|
||||
DeviceMem a_m_k_real_device_buf(sizeof(ADataType) * a_m_k_real.mDesc.GetElementSpace());
|
||||
DeviceMem a_m_k_imag_device_buf(sizeof(ADataType) * a_m_k_imag.mDesc.GetElementSpace());
|
||||
DeviceMem b_k_n_real_device_buf(sizeof(BDataType) * b_k_n_real.mDesc.GetElementSpace());
|
||||
DeviceMem b_k_n_imag_device_buf(sizeof(BDataType) * b_k_n_imag.mDesc.GetElementSpace());
|
||||
DeviceMem c_m_n_real_device_buf(sizeof(CDataType) *
|
||||
c_m_n_real_device_result.mDesc.GetElementSpace());
|
||||
DeviceMem c_m_n_imag_device_buf(sizeof(CDataType) *
|
||||
c_m_n_imag_device_result.mDesc.GetElementSpace());
|
||||
DeviceMem workspace_device_buf(cgemm.GetWorkspaceSize(M, N, K, StrideA, StrideB, StrideC));
|
||||
|
||||
a_m_k_real_device_buf.ToDevice(a_m_k_real.mData.data());
|
||||
a_m_k_imag_device_buf.ToDevice(a_m_k_imag.mData.data());
|
||||
b_k_n_real_device_buf.ToDevice(b_k_n_real.mData.data());
|
||||
b_k_n_imag_device_buf.ToDevice(b_k_n_imag.mData.data());
|
||||
|
||||
auto a_element_op = PassThrough{};
|
||||
auto b_element_op = PassThrough{};
|
||||
auto c_element_op = PassThrough{};
|
||||
|
||||
// do GEMM
|
||||
auto invoker = cgemm.MakeInvoker();
|
||||
auto argument =
|
||||
cgemm.MakeArgument(static_cast<ADataType*>(a_m_k_real_device_buf.GetDeviceBuffer()),
|
||||
static_cast<ADataType*>(a_m_k_imag_device_buf.GetDeviceBuffer()),
|
||||
static_cast<BDataType*>(b_k_n_real_device_buf.GetDeviceBuffer()),
|
||||
static_cast<BDataType*>(b_k_n_imag_device_buf.GetDeviceBuffer()),
|
||||
static_cast<CDataType*>(c_m_n_real_device_buf.GetDeviceBuffer()),
|
||||
static_cast<CDataType*>(c_m_n_imag_device_buf.GetDeviceBuffer()),
|
||||
static_cast<CDataType*>(workspace_device_buf.GetDeviceBuffer()),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideC,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op);
|
||||
|
||||
if(!cgemm.IsSupportedArgument(argument))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! device_cgemm with the specified compilation parameters does "
|
||||
"not support this CGEMM problem");
|
||||
}
|
||||
|
||||
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
|
||||
|
||||
std::size_t flop = std::size_t(8) * M * N * K;
|
||||
std::size_t num_btype =
|
||||
std::size_t(2) *
|
||||
(sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N);
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
|
||||
<< cgemm.GetTypeString() << std::endl;
|
||||
|
||||
c_m_n_real_device_buf.FromDevice(c_m_n_real_device_result.mData.data());
|
||||
c_m_n_imag_device_buf.FromDevice(c_m_n_imag_device_result.mData.data());
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
Tensor<CDataType> c_m_n_real_host_result(
|
||||
f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
|
||||
Tensor<CDataType> c_m_n_imag_host_result(
|
||||
f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
|
||||
|
||||
auto ref_cgemm = ReferenceCGemmInstance{};
|
||||
auto ref_invoker = ref_cgemm.MakeInvoker();
|
||||
|
||||
auto ref_argument = ref_cgemm.MakeArgument(a_m_k_real,
|
||||
a_m_k_imag,
|
||||
b_k_n_real,
|
||||
b_k_n_imag,
|
||||
c_m_n_real_host_result,
|
||||
c_m_n_imag_host_result,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op);
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
ck::utils::check_err(c_m_n_real_device_result.mData,
|
||||
c_m_n_real_host_result.mData,
|
||||
"Verification error: incorrect results in real part!",
|
||||
1e-2f,
|
||||
1e-1f);
|
||||
ck::utils::check_err(c_m_n_imag_device_result.mData,
|
||||
c_m_n_imag_host_result.mData,
|
||||
"Verification error: incorrect results in imaginary part!",
|
||||
1e-2f,
|
||||
1e-1f);
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
@@ -48,10 +48,11 @@ add_subdirectory(11_conv2d_bwd_weight)
|
||||
add_subdirectory(12_reduce)
|
||||
add_subdirectory(13_pool2d_fwd)
|
||||
add_subdirectory(14_gemm_xdl_requant_relu_requant)
|
||||
add_subdirectory(17_convnd_bwd_data_xdl)
|
||||
add_subdirectory(15_grouped_gemm)
|
||||
add_subdirectory(16_gemm_reduce)
|
||||
add_subdirectory(17_convnd_bwd_data_xdl)
|
||||
add_subdirectory(18_batched_gemm_reduce)
|
||||
add_subdirectory(19_binary_elementwise)
|
||||
add_subdirectory(20_convnd_bwd_weight_xdl)
|
||||
add_subdirectory(21_gemm_layernorm)
|
||||
add_subdirectory(22_cgemm)
|
||||
|
||||
73
include/ck/tensor_operation/gpu/device/device_cgemm.hpp
Normal file
73
include/ck/tensor_operation/gpu/device/device_cgemm.hpp
Normal file
@@ -0,0 +1,73 @@
|
||||
/*******************************************************************************
|
||||
*
|
||||
* MIT License
|
||||
*
|
||||
* Copyright (c) 2022 Advanced Micro Devices, Inc.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in all
|
||||
* copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
* SOFTWARE.
|
||||
*
|
||||
*******************************************************************************/
|
||||
#pragma once
|
||||
#include "device_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation>
|
||||
struct DeviceCGemm : public BaseOperator
|
||||
{
|
||||
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a_real,
|
||||
const void* p_a_imag,
|
||||
const void* p_b_real,
|
||||
const void* p_b_imag,
|
||||
void* p_c_real,
|
||||
void* p_c_imag,
|
||||
void* p_workspace,
|
||||
ck::index_t M,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t StrideA,
|
||||
ck::index_t StrideB,
|
||||
ck::index_t StrideC,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op,
|
||||
ck::index_t KBatch = 1) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
virtual std::size_t GetWorkspaceSize(index_t MRaw,
|
||||
index_t NRaw,
|
||||
index_t KRaw,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
index_t StrideC) = 0;
|
||||
};
|
||||
|
||||
template <typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation>
|
||||
using DeviceCGemmPtr = std::unique_ptr<
|
||||
DeviceCGemm<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>>;
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,974 @@
|
||||
/*******************************************************************************
|
||||
*
|
||||
* MIT License
|
||||
*
|
||||
* Copyright (c) 2022 Advanced Micro Devices, Inc.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in all
|
||||
* copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
* SOFTWARE.
|
||||
*
|
||||
*******************************************************************************/
|
||||
#pragma once
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include "device.hpp"
|
||||
#include "device_gemm.hpp"
|
||||
#include "device_cgemm.hpp"
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_layout.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "gridwise_gemm_xdl_cshuffle_v1.hpp"
|
||||
#include "binary_element_wise_operation.hpp"
|
||||
#include "gridwise_binary_elementwise_1d.hpp"
|
||||
#include "tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
typename GemmAccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
GemmSpecialization GemmSpec,
|
||||
index_t NumGemmKPrefetchStage,
|
||||
index_t BlockSize,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t AK1,
|
||||
index_t BK1,
|
||||
index_t MPerXDL,
|
||||
index_t NPerXDL,
|
||||
index_t MXdlPerWave,
|
||||
index_t NXdlPerWave,
|
||||
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
index_t ABlockTransferSrcVectorDim,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t ABlockTransferDstScalarPerVector_AK1,
|
||||
bool ABlockLdsExtraM,
|
||||
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
index_t BBlockTransferSrcVectorDim,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferDstScalarPerVector_BK1,
|
||||
bool BBlockLdsExtraN,
|
||||
index_t CShuffleMXdlPerWavePerShuffle,
|
||||
index_t CShuffleNXdlPerWavePerShuffle,
|
||||
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
LoopScheduler LoopSched = make_default_loop_scheduler(),
|
||||
enable_if_t<
|
||||
is_same_v<AElementwiseOperation, ck::tensor_operation::element_wise::PassThrough> &&
|
||||
is_same_v<BElementwiseOperation, ck::tensor_operation::element_wise::PassThrough> &&
|
||||
is_same_v<CElementwiseOperation, ck::tensor_operation::element_wise::PassThrough>,
|
||||
bool> = false>
|
||||
struct DeviceCGemm_4Gemm_Xdl_CShuffle
|
||||
: public DeviceCGemm<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>
|
||||
{
|
||||
using DeviceOp = DeviceCGemm_4Gemm_Xdl_CShuffle;
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
|
||||
static constexpr auto MPerThread = Number<4>{};
|
||||
static constexpr auto AScalarPerVector = Number<4>{};
|
||||
static constexpr auto BScalarPerVector = Number<4>{};
|
||||
static constexpr auto CScalarPerVector = Number<4>{};
|
||||
|
||||
template <typename Desc_M>
|
||||
static auto PadDescriptor_M_1d(Desc_M desc_m, index_t gridSize, index_t blockSize)
|
||||
{
|
||||
const auto M = desc_m.GetLength(I0);
|
||||
const index_t loop_step = gridSize * blockSize * MPerThread;
|
||||
const auto pad = math::integer_least_multiple(M, loop_step) - M;
|
||||
const auto desc_m_pad =
|
||||
transform_tensor_descriptor(desc_m,
|
||||
make_tuple(make_right_pad_transform(M, pad)),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
return desc_m_pad;
|
||||
}
|
||||
|
||||
static auto MakeDescriptor_M(const std::vector<index_t>& lengths,
|
||||
const std::vector<index_t>& strides,
|
||||
index_t gridSize,
|
||||
index_t blockSize)
|
||||
{
|
||||
auto tupleOfShape = generate_tuple([&](auto I) { return lengths[I]; }, Number<2>{});
|
||||
auto tupleOfStride = generate_tuple([&](auto I) { return strides[I]; }, Number<2>{});
|
||||
|
||||
// nd desc - [s0, s1, s2, ...]
|
||||
const auto desc = make_naive_tensor_descriptor(tupleOfShape, tupleOfStride);
|
||||
const auto desc_m = transform_tensor_descriptor(
|
||||
desc,
|
||||
make_tuple(make_merge_transform(tupleOfShape)),
|
||||
make_tuple(generate_sequence_v2([&](auto I) { return I; }, Number<2>{})),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
return PadDescriptor_M_1d(desc_m, gridSize, blockSize);
|
||||
}
|
||||
|
||||
static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA)
|
||||
{
|
||||
const auto a_grid_desc_mraw_kraw = [&]() {
|
||||
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
|
||||
make_tuple(StrideA, I1));
|
||||
}
|
||||
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
|
||||
make_tuple(I1, StrideA));
|
||||
}
|
||||
}();
|
||||
|
||||
const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
|
||||
const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock;
|
||||
|
||||
const auto MPad = M - MRaw;
|
||||
const auto KPad = K - KRaw;
|
||||
|
||||
if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
|
||||
GemmSpec == GemmSpecialization::MNKPadding)
|
||||
{
|
||||
// pad both M and K
|
||||
assert(K % AK1 == 0);
|
||||
|
||||
const auto AK0 = K / AK1;
|
||||
|
||||
const auto a_grid_desc_m_k =
|
||||
transform_tensor_descriptor(a_grid_desc_mraw_kraw,
|
||||
make_tuple(make_right_pad_transform(MRaw, MPad),
|
||||
make_right_pad_transform(KRaw, KPad)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto a_grid_desc_ak0_m_ak1 =
|
||||
transform_tensor_descriptor(a_grid_desc_m_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
|
||||
make_pass_through_transform(M)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
return a_grid_desc_ak0_m_ak1;
|
||||
}
|
||||
else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
|
||||
GemmSpec == GemmSpecialization::MNPadding)
|
||||
{
|
||||
// pad M, but not K
|
||||
assert(KRaw % AK1 == 0);
|
||||
|
||||
const auto AK0 = KRaw / AK1;
|
||||
|
||||
const auto a_grid_desc_ak0_m_ak1 =
|
||||
transform_tensor_descriptor(a_grid_desc_mraw_kraw,
|
||||
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
|
||||
make_right_pad_transform(MRaw, MPad)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
return a_grid_desc_ak0_m_ak1;
|
||||
}
|
||||
else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
|
||||
GemmSpec == GemmSpecialization::NKPadding)
|
||||
{
|
||||
// pad K, but not M
|
||||
assert(K % AK1 == 0);
|
||||
|
||||
const auto AK0 = K / AK1;
|
||||
|
||||
const auto a_grid_desc_m_k = transform_tensor_descriptor(
|
||||
a_grid_desc_mraw_kraw,
|
||||
make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(KRaw, KPad)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto a_grid_desc_ak0_m_ak1 =
|
||||
transform_tensor_descriptor(a_grid_desc_m_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
|
||||
make_pass_through_transform(MRaw)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
return a_grid_desc_ak0_m_ak1;
|
||||
}
|
||||
else
|
||||
{
|
||||
// not pad M or K
|
||||
assert(KRaw % AK1 == 0);
|
||||
|
||||
const auto AK0 = KRaw / AK1;
|
||||
|
||||
const auto a_grid_desc_ak0_m_ak1 =
|
||||
transform_tensor_descriptor(a_grid_desc_mraw_kraw,
|
||||
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
|
||||
make_pass_through_transform(MRaw)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
return a_grid_desc_ak0_m_ak1;
|
||||
}
|
||||
}
|
||||
|
||||
static auto MakeBGridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB)
|
||||
{
|
||||
const auto b_grid_desc_nraw_kraw = [&]() {
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
|
||||
make_tuple(I1, StrideB));
|
||||
}
|
||||
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
|
||||
make_tuple(StrideB, I1));
|
||||
}
|
||||
}();
|
||||
|
||||
const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock;
|
||||
const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock;
|
||||
|
||||
const auto NPad = N - NRaw;
|
||||
const auto KPad = K - KRaw;
|
||||
|
||||
if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
|
||||
GemmSpec == GemmSpecialization::MNKPadding)
|
||||
{
|
||||
// pad both N and K
|
||||
assert(K % BK1 == 0);
|
||||
|
||||
const auto BK0 = K / BK1;
|
||||
|
||||
const auto b_grid_desc_n_k =
|
||||
transform_tensor_descriptor(b_grid_desc_nraw_kraw,
|
||||
make_tuple(make_right_pad_transform(NRaw, NPad),
|
||||
make_right_pad_transform(KRaw, KPad)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto b_grid_desc_bk0_n_bk1 =
|
||||
transform_tensor_descriptor(b_grid_desc_n_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
|
||||
make_pass_through_transform(N)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
return b_grid_desc_bk0_n_bk1;
|
||||
}
|
||||
else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
|
||||
GemmSpec == GemmSpecialization::MNPadding)
|
||||
{
|
||||
// pad N, but not K
|
||||
assert(KRaw % BK1 == 0);
|
||||
|
||||
const auto BK0 = KRaw / BK1;
|
||||
|
||||
const auto b_grid_desc_bk0_n_bk1 =
|
||||
transform_tensor_descriptor(b_grid_desc_nraw_kraw,
|
||||
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
|
||||
make_right_pad_transform(NRaw, NPad)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
return b_grid_desc_bk0_n_bk1;
|
||||
}
|
||||
else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
|
||||
GemmSpec == GemmSpecialization::MKPadding)
|
||||
{
|
||||
// pad K, but not N
|
||||
assert(K % BK1 == 0);
|
||||
|
||||
const auto BK0 = K / BK1;
|
||||
|
||||
const auto b_grid_desc_n_k = transform_tensor_descriptor(
|
||||
b_grid_desc_nraw_kraw,
|
||||
make_tuple(make_pass_through_transform(NRaw), make_right_pad_transform(KRaw, KPad)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto b_grid_desc_bk0_n_bk1 =
|
||||
transform_tensor_descriptor(b_grid_desc_n_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
|
||||
make_pass_through_transform(NRaw)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
return b_grid_desc_bk0_n_bk1;
|
||||
}
|
||||
else
|
||||
{
|
||||
// not pad N or K
|
||||
assert(KRaw % BK1 == 0);
|
||||
|
||||
const auto BK0 = KRaw / BK1;
|
||||
|
||||
const auto b_grid_desc_bk0_n_bk1 =
|
||||
transform_tensor_descriptor(b_grid_desc_nraw_kraw,
|
||||
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
|
||||
make_pass_through_transform(NRaw)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
return b_grid_desc_bk0_n_bk1;
|
||||
}
|
||||
}
|
||||
|
||||
static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideC)
|
||||
{
|
||||
const auto c_grid_desc_mraw_nraw = [&]() {
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
|
||||
make_tuple(StrideC, I1));
|
||||
}
|
||||
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, CLayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
|
||||
make_tuple(I1, StrideC));
|
||||
}
|
||||
}();
|
||||
|
||||
const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
|
||||
const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock;
|
||||
|
||||
const auto MPad = M - MRaw;
|
||||
const auto NPad = N - NRaw;
|
||||
|
||||
if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
|
||||
GemmSpec == GemmSpecialization::MNKPadding)
|
||||
{
|
||||
// pad M and N
|
||||
return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
|
||||
make_tuple(make_right_pad_transform(MRaw, MPad),
|
||||
make_right_pad_transform(NRaw, NPad)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
|
||||
GemmSpec == GemmSpecialization::MKPadding)
|
||||
{
|
||||
// pad M, but not N
|
||||
return transform_tensor_descriptor(
|
||||
c_grid_desc_mraw_nraw,
|
||||
make_tuple(make_right_pad_transform(MRaw, MPad), make_pass_through_transform(NRaw)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
|
||||
GemmSpec == GemmSpecialization::NKPadding)
|
||||
{
|
||||
// pad N, but not M
|
||||
return transform_tensor_descriptor(
|
||||
c_grid_desc_mraw_nraw,
|
||||
make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(NRaw, NPad)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
// not pad M or N
|
||||
return c_grid_desc_mraw_nraw;
|
||||
}
|
||||
}
|
||||
|
||||
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1));
|
||||
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1));
|
||||
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
|
||||
using CGridDesc_M = decltype(MakeDescriptor_M({1, 1}, {1, 1}, 1, 1));
|
||||
|
||||
// GridwiseGemm
|
||||
using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1<
|
||||
ADataType, // TODO: distinguish A/B datatype
|
||||
GemmAccDataType,
|
||||
CShuffleDataType,
|
||||
CDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
AGridDesc_AK0_M_AK1,
|
||||
BGridDesc_BK0_N_BK1,
|
||||
CGridDesc_M_N,
|
||||
NumGemmKPrefetchStage,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
AK1,
|
||||
BK1,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MXdlPerWave,
|
||||
NXdlPerWave,
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_AK1,
|
||||
false,
|
||||
ABlockLdsExtraM,
|
||||
BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_BK1,
|
||||
false,
|
||||
BBlockLdsExtraN,
|
||||
CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
LoopSched>;
|
||||
|
||||
// Argument
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
Argument(const ADataType* p_a_grid_real,
|
||||
const ADataType* p_a_grid_imag,
|
||||
const BDataType* p_b_grid_real,
|
||||
const BDataType* p_b_grid_imag,
|
||||
CDataType* p_c_grid_real,
|
||||
CDataType* p_c_grid_imag,
|
||||
CDataType* p_workspace,
|
||||
index_t MRaw,
|
||||
index_t NRaw,
|
||||
index_t KRaw,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
index_t StrideC,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op)
|
||||
: p_a_grid_real_{p_a_grid_real},
|
||||
p_a_grid_imag_{p_a_grid_imag},
|
||||
p_b_grid_real_{p_b_grid_real},
|
||||
p_b_grid_imag_{p_b_grid_imag},
|
||||
p_c_grid_real_{p_c_grid_real},
|
||||
p_c_grid_imag_{p_c_grid_imag},
|
||||
p_aux_grid_{p_workspace},
|
||||
a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KRaw, StrideA)},
|
||||
b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)},
|
||||
c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC)},
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock_{},
|
||||
block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)},
|
||||
a_element_op_{a_element_op},
|
||||
b_element_op_{b_element_op},
|
||||
c_element_op_{c_element_op}
|
||||
{
|
||||
if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_,
|
||||
b_grid_desc_bk0_n_bk1_,
|
||||
c_grid_desc_m_n_,
|
||||
block_2_ctile_map_))
|
||||
{
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock_ =
|
||||
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
c_grid_desc_m_n_);
|
||||
}
|
||||
|
||||
const index_t grid_size = block_2_ctile_map_.CalculateGridSize(c_grid_desc_m_n_);
|
||||
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
|
||||
{
|
||||
c_grid_desc_m_ =
|
||||
DeviceOp::MakeDescriptor_M({MRaw, NRaw}, {StrideC, I1}, grid_size, BlockSize);
|
||||
}
|
||||
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, CLayout>::value)
|
||||
{
|
||||
c_grid_desc_m_ =
|
||||
DeviceOp::MakeDescriptor_M({MRaw, NRaw}, {I1, StrideC}, grid_size, BlockSize);
|
||||
}
|
||||
|
||||
p_aux_2_grid_ = p_workspace + c_grid_desc_m_n_.GetElementSpaceSize();
|
||||
}
|
||||
|
||||
// private:
|
||||
const ADataType* p_a_grid_real_;
|
||||
const ADataType* p_a_grid_imag_;
|
||||
const BDataType* p_b_grid_real_;
|
||||
const BDataType* p_b_grid_imag_;
|
||||
CDataType* p_c_grid_real_;
|
||||
CDataType* p_c_grid_imag_;
|
||||
CDataType* p_aux_grid_;
|
||||
CDataType* p_aux_2_grid_;
|
||||
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
|
||||
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
|
||||
CGridDesc_M_N c_grid_desc_m_n_;
|
||||
CGridDesc_M c_grid_desc_m_;
|
||||
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock_;
|
||||
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
|
||||
AElementwiseOperation a_element_op_;
|
||||
BElementwiseOperation b_element_op_;
|
||||
CElementwiseOperation c_element_op_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
using Argument = DeviceOp::Argument;
|
||||
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.c_grid_desc_m_n_,
|
||||
arg.block_2_ctile_map_))
|
||||
{
|
||||
throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
|
||||
}
|
||||
|
||||
const index_t grid_size =
|
||||
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_);
|
||||
|
||||
const auto K =
|
||||
arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
using Add =
|
||||
ck::tensor_operation::binary_element_wise::Add<CDataType, CDataType, CDataType>;
|
||||
using Substract = ck::tensor_operation::binary_element_wise::
|
||||
Substract<CDataType, CDataType, CDataType>;
|
||||
using GridwiseBinAdd = GridwiseBinaryElementwise_1D<CDataType,
|
||||
CDataType,
|
||||
CDataType,
|
||||
CDataType,
|
||||
CGridDesc_M,
|
||||
CGridDesc_M,
|
||||
CGridDesc_M,
|
||||
Add,
|
||||
MPerThread,
|
||||
AScalarPerVector,
|
||||
BScalarPerVector,
|
||||
CScalarPerVector>;
|
||||
using GridwiseBinSubstract = GridwiseBinaryElementwise_1D<CDataType,
|
||||
CDataType,
|
||||
CDataType,
|
||||
CDataType,
|
||||
CGridDesc_M,
|
||||
CGridDesc_M,
|
||||
CGridDesc_M,
|
||||
Substract,
|
||||
MPerThread,
|
||||
AScalarPerVector,
|
||||
BScalarPerVector,
|
||||
CScalarPerVector>;
|
||||
const auto add_kernel = kernel_binary_elementwise_1d<GridwiseBinAdd,
|
||||
CDataType,
|
||||
CDataType,
|
||||
CDataType,
|
||||
CGridDesc_M,
|
||||
CGridDesc_M,
|
||||
CGridDesc_M,
|
||||
Add>;
|
||||
const auto substract_kernel = kernel_binary_elementwise_1d<GridwiseBinSubstract,
|
||||
CDataType,
|
||||
CDataType,
|
||||
CDataType,
|
||||
CGridDesc_M,
|
||||
CGridDesc_M,
|
||||
CGridDesc_M,
|
||||
Substract>;
|
||||
|
||||
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdl_cshuffle_v1<
|
||||
GridwiseGemm,
|
||||
ADataType, // TODO: distiguish A/B datatype
|
||||
CDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
DeviceOp::AGridDesc_AK0_M_AK1,
|
||||
DeviceOp::BGridDesc_BK0_N_BK1,
|
||||
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename GridwiseGemm::DefaultBlock2CTileMap,
|
||||
true>;
|
||||
|
||||
ave_time +=
|
||||
launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_real_,
|
||||
arg.p_b_grid_real_,
|
||||
arg.p_aux_grid_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.block_2_ctile_map_);
|
||||
|
||||
ave_time +=
|
||||
launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_imag_,
|
||||
arg.p_b_grid_imag_,
|
||||
arg.p_aux_2_grid_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.block_2_ctile_map_);
|
||||
|
||||
// c_real = aux - aux_2
|
||||
ave_time += launch_and_time_kernel(stream_config,
|
||||
substract_kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_aux_grid_,
|
||||
arg.p_aux_2_grid_,
|
||||
arg.p_c_grid_real_,
|
||||
arg.c_grid_desc_m_,
|
||||
arg.c_grid_desc_m_,
|
||||
arg.c_grid_desc_m_,
|
||||
Substract{});
|
||||
|
||||
ave_time +=
|
||||
launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_real_,
|
||||
arg.p_b_grid_imag_,
|
||||
arg.p_aux_grid_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.block_2_ctile_map_);
|
||||
|
||||
ave_time +=
|
||||
launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_imag_,
|
||||
arg.p_b_grid_real_,
|
||||
arg.p_aux_2_grid_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.block_2_ctile_map_);
|
||||
|
||||
// c_imag = aux + aux_2
|
||||
ave_time += launch_and_time_kernel(stream_config,
|
||||
add_kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_aux_grid_,
|
||||
arg.p_aux_2_grid_,
|
||||
arg.p_c_grid_imag_,
|
||||
arg.c_grid_desc_m_,
|
||||
arg.c_grid_desc_m_,
|
||||
arg.c_grid_desc_m_,
|
||||
Add{});
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdl_cshuffle_v1<
|
||||
GridwiseGemm,
|
||||
ADataType, // TODO: distiguish A/B datatype
|
||||
CDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
DeviceOp::AGridDesc_AK0_M_AK1,
|
||||
DeviceOp::BGridDesc_BK0_N_BK1,
|
||||
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename GridwiseGemm::DefaultBlock2CTileMap,
|
||||
false>;
|
||||
|
||||
ave_time +=
|
||||
launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_real_,
|
||||
arg.p_b_grid_real_,
|
||||
arg.p_aux_grid_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.block_2_ctile_map_);
|
||||
|
||||
ave_time +=
|
||||
launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_imag_,
|
||||
arg.p_b_grid_imag_,
|
||||
arg.p_aux_2_grid_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.block_2_ctile_map_);
|
||||
|
||||
// c_real = aux - aux_2
|
||||
ave_time += launch_and_time_kernel(stream_config,
|
||||
substract_kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_aux_grid_,
|
||||
arg.p_aux_2_grid_,
|
||||
arg.p_c_grid_real_,
|
||||
arg.c_grid_desc_m_,
|
||||
arg.c_grid_desc_m_,
|
||||
arg.c_grid_desc_m_,
|
||||
Substract{});
|
||||
|
||||
ave_time +=
|
||||
launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_real_,
|
||||
arg.p_b_grid_imag_,
|
||||
arg.p_aux_grid_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.block_2_ctile_map_);
|
||||
|
||||
ave_time +=
|
||||
launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_imag_,
|
||||
arg.p_b_grid_real_,
|
||||
arg.p_aux_2_grid_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.block_2_ctile_map_);
|
||||
|
||||
// c_imag = aux + aux_2
|
||||
ave_time += launch_and_time_kernel(stream_config,
|
||||
add_kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_aux_grid_,
|
||||
arg.p_aux_2_grid_,
|
||||
arg.p_c_grid_imag_,
|
||||
arg.c_grid_desc_m_,
|
||||
arg.c_grid_desc_m_,
|
||||
arg.c_grid_desc_m_,
|
||||
Add{});
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
static constexpr bool IsValidCompilationParameter()
|
||||
{
|
||||
// TODO: properly implement this check
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.c_grid_desc_m_n_,
|
||||
arg.block_2_ctile_map_);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
|
||||
static auto MakeArgument(const ADataType* p_a_real,
|
||||
const ADataType* p_a_imag,
|
||||
const BDataType* p_b_real,
|
||||
const BDataType* p_b_imag,
|
||||
CDataType* p_c_real,
|
||||
CDataType* p_c_imag,
|
||||
CDataType* p_workspace,
|
||||
index_t MRaw,
|
||||
index_t NRaw,
|
||||
index_t KRaw,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
index_t StrideC,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op)
|
||||
{
|
||||
return Argument{p_a_real,
|
||||
p_a_imag,
|
||||
p_b_real,
|
||||
p_b_imag,
|
||||
p_c_real,
|
||||
p_c_imag,
|
||||
p_workspace,
|
||||
MRaw,
|
||||
NRaw,
|
||||
KRaw,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideC,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a_real,
|
||||
const void* p_a_imag,
|
||||
const void* p_b_real,
|
||||
const void* p_b_imag,
|
||||
void* p_c_real,
|
||||
void* p_c_imag,
|
||||
void* p_workspace,
|
||||
index_t MRaw,
|
||||
index_t NRaw,
|
||||
index_t KRaw,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
index_t StrideC,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op,
|
||||
index_t /* KBatch */ = 1) override
|
||||
{
|
||||
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a_real),
|
||||
static_cast<const ADataType*>(p_a_imag),
|
||||
static_cast<const BDataType*>(p_b_real),
|
||||
static_cast<const BDataType*>(p_b_imag),
|
||||
static_cast<CDataType*>(p_c_real),
|
||||
static_cast<CDataType*>(p_c_imag),
|
||||
static_cast<CDataType*>(p_workspace),
|
||||
MRaw,
|
||||
NRaw,
|
||||
KRaw,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideC,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceCGemm_4Gemm_Xdl_CShuffle"
|
||||
<< "<"
|
||||
<< BlockSize << ", "
|
||||
<< MPerBlock << ", "
|
||||
<< NPerBlock << ", "
|
||||
<< KPerBlock << ", "
|
||||
<< AK1 << ", "
|
||||
<< BK1
|
||||
<< ">";
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
|
||||
std::size_t GetWorkspaceSize(index_t MRaw,
|
||||
index_t NRaw,
|
||||
[[maybe_unused]] index_t KRaw,
|
||||
[[maybe_unused]] index_t StrideA,
|
||||
[[maybe_unused]] index_t StrideB,
|
||||
index_t StrideC) override
|
||||
{
|
||||
const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC);
|
||||
|
||||
return 2 * sizeof(CDataType) * c_grid_desc_m_n.GetElementSpaceSize();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -60,8 +60,8 @@ template <
|
||||
index_t CThreadTransferDstScalarPerVector,
|
||||
enable_if_t<
|
||||
is_same_v<AElementwiseOperation, ck::tensor_operation::element_wise::PassThrough> &&
|
||||
is_same_v<AElementwiseOperation, ck::tensor_operation::element_wise::PassThrough> &&
|
||||
is_same_v<AElementwiseOperation, ck::tensor_operation::element_wise::PassThrough>,
|
||||
is_same_v<BElementwiseOperation, ck::tensor_operation::element_wise::PassThrough> &&
|
||||
is_same_v<CElementwiseOperation, ck::tensor_operation::element_wise::PassThrough>,
|
||||
bool> = false>
|
||||
struct DeviceGemmDl
|
||||
: public DeviceGemm<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>
|
||||
|
||||
@@ -1,3 +1,28 @@
|
||||
/*******************************************************************************
|
||||
*
|
||||
* MIT License
|
||||
*
|
||||
* Copyright (c) 2022 Advanced Micro Devices, Inc.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in all
|
||||
* copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
* SOFTWARE.
|
||||
*
|
||||
*******************************************************************************/
|
||||
#pragma once
|
||||
#include "data_type.hpp"
|
||||
|
||||
@@ -5,14 +30,22 @@ namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace binary_element_wise {
|
||||
|
||||
struct Add
|
||||
template <typename Y, typename X1, typename X2>
|
||||
struct Add;
|
||||
|
||||
template <>
|
||||
struct Add<double, double, double>
|
||||
{
|
||||
__host__ __device__ constexpr void
|
||||
operator()(double& dst, const double& src1, const double& src2) const
|
||||
{
|
||||
dst = src1 + src2;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Add<float, float, float>
|
||||
{
|
||||
__host__ __device__ constexpr void
|
||||
operator()(float& dst, const float& src1, const float& src2) const
|
||||
{
|
||||
@@ -20,6 +53,75 @@ struct Add
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Add<half_t, half_t, half_t>
|
||||
{
|
||||
__host__ __device__ constexpr void
|
||||
operator()(half_t& dst, const half_t& src1, const half_t& src2) const
|
||||
{
|
||||
dst = src1 + src2;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Add<bhalf_t, bhalf_t, bhalf_t>
|
||||
{
|
||||
__host__ __device__ constexpr void
|
||||
operator()(bhalf_t& dst, const bhalf_t& src1, const bhalf_t& src2) const
|
||||
{
|
||||
const float x1 = ck::type_convert<float>(src1);
|
||||
const float x2 = ck::type_convert<float>(src2);
|
||||
const float y = x1 + x2;
|
||||
dst = ck::type_convert<bhalf_t>(y);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Y, typename X1, typename X2>
|
||||
struct Substract;
|
||||
|
||||
template <>
|
||||
struct Substract<double, double, double>
|
||||
{
|
||||
__host__ __device__ constexpr void
|
||||
operator()(double& dst, const double& src1, const double& src2) const
|
||||
{
|
||||
dst = src1 - src2;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Substract<float, float, float>
|
||||
{
|
||||
__host__ __device__ constexpr void
|
||||
operator()(float& dst, const float& src1, const float& src2) const
|
||||
{
|
||||
dst = src1 - src2;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Substract<half_t, half_t, half_t>
|
||||
{
|
||||
__host__ __device__ constexpr void
|
||||
operator()(half_t& dst, const half_t& src1, const half_t& src2) const
|
||||
{
|
||||
dst = src1 - src2;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Substract<bhalf_t, bhalf_t, bhalf_t>
|
||||
{
|
||||
__host__ __device__ constexpr void
|
||||
operator()(bhalf_t& dst, const bhalf_t& src1, const bhalf_t& src2) const
|
||||
{
|
||||
const float x1 = ck::type_convert<float>(src1);
|
||||
const float x2 = ck::type_convert<float>(src2);
|
||||
const float y = x1 - x2;
|
||||
dst = ck::type_convert<bhalf_t>(y);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace binary_element_wise
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
|
||||
@@ -0,0 +1,203 @@
|
||||
/*******************************************************************************
|
||||
*
|
||||
* MIT License
|
||||
*
|
||||
* Copyright (c) 2022 Advanced Micro Devices, Inc.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in all
|
||||
* copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
* SOFTWARE.
|
||||
*
|
||||
*******************************************************************************/
|
||||
#pragma once
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include "device_base.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace host {
|
||||
|
||||
// FIXME: support arbitrary elementwise operation for A/B/C
|
||||
template <
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
enable_if_t<
|
||||
is_same_v<AElementwiseOperation, ck::tensor_operation::element_wise::PassThrough> &&
|
||||
is_same_v<BElementwiseOperation, ck::tensor_operation::element_wise::PassThrough> &&
|
||||
is_same_v<CElementwiseOperation, ck::tensor_operation::element_wise::PassThrough>,
|
||||
bool> = false>
|
||||
struct ReferenceCGemm : public device::BaseOperator
|
||||
{
|
||||
// Argument
|
||||
struct Argument : public device::BaseArgument
|
||||
{
|
||||
Argument(const Tensor<ADataType>& a_m_k_real,
|
||||
const Tensor<ADataType>& a_m_k_imag,
|
||||
const Tensor<BDataType>& b_k_n_real,
|
||||
const Tensor<BDataType>& b_k_n_imag,
|
||||
Tensor<CDataType>& c_m_n_real,
|
||||
Tensor<CDataType>& c_m_n_imag,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op)
|
||||
: a_m_k_real_{a_m_k_real},
|
||||
a_m_k_imag_{a_m_k_imag},
|
||||
b_k_n_real_{b_k_n_real},
|
||||
b_k_n_imag_{b_k_n_imag},
|
||||
c_m_n_real_{c_m_n_real},
|
||||
c_m_n_imag_{c_m_n_imag},
|
||||
a_element_op_{a_element_op},
|
||||
b_element_op_{b_element_op},
|
||||
c_element_op_{c_element_op}
|
||||
{
|
||||
}
|
||||
|
||||
const Tensor<ADataType>& a_m_k_real_;
|
||||
const Tensor<ADataType>& a_m_k_imag_;
|
||||
const Tensor<BDataType>& b_k_n_real_;
|
||||
const Tensor<BDataType>& b_k_n_imag_;
|
||||
Tensor<CDataType>& c_m_n_real_;
|
||||
Tensor<CDataType>& c_m_n_imag_;
|
||||
|
||||
AElementwiseOperation a_element_op_;
|
||||
BElementwiseOperation b_element_op_;
|
||||
CElementwiseOperation c_element_op_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
struct Invoker : public device::BaseInvoker
|
||||
{
|
||||
using Argument = ReferenceCGemm::Argument;
|
||||
|
||||
float Run(const Argument& arg)
|
||||
{
|
||||
const std::size_t K = arg.a_m_k_real_.mDesc.GetLengths()[1];
|
||||
|
||||
if(K != arg.a_m_k_imag_.mDesc.GetLengths()[1])
|
||||
{
|
||||
throw std::runtime_error("wrong! Incompatible real and imag sizes in CGEMM");
|
||||
}
|
||||
|
||||
auto f_mk_kn_mn_real = [&](auto m, auto n) {
|
||||
float v_c_real = 0;
|
||||
|
||||
for(std::size_t k = 0; k < K; ++k)
|
||||
{
|
||||
float v_a_real = ck::type_convert<float>(arg.a_m_k_real_(m, k));
|
||||
float v_a_imag = ck::type_convert<float>(arg.a_m_k_imag_(m, k));
|
||||
float v_b_real = ck::type_convert<float>(arg.b_k_n_real_(k, n));
|
||||
float v_b_imag = ck::type_convert<float>(arg.b_k_n_imag_(k, n));
|
||||
|
||||
v_c_real += v_a_real * v_b_real - v_a_imag * v_b_imag;
|
||||
}
|
||||
|
||||
arg.c_m_n_real_(m, n) = v_c_real;
|
||||
};
|
||||
|
||||
auto f_mk_kn_mn_imag = [&](auto m, auto n) {
|
||||
float v_c_imag = 0;
|
||||
|
||||
for(std::size_t k = 0; k < K; ++k)
|
||||
{
|
||||
float v_a_real = ck::type_convert<float>(arg.a_m_k_real_(m, k));
|
||||
float v_a_imag = ck::type_convert<float>(arg.a_m_k_imag_(m, k));
|
||||
float v_b_real = ck::type_convert<float>(arg.b_k_n_real_(k, n));
|
||||
float v_b_imag = ck::type_convert<float>(arg.b_k_n_imag_(k, n));
|
||||
|
||||
v_c_imag += v_a_real * v_b_imag + v_a_imag * v_b_real;
|
||||
}
|
||||
|
||||
arg.c_m_n_imag_(m, n) = v_c_imag;
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_mk_kn_mn_real,
|
||||
arg.c_m_n_real_.mDesc.GetLengths()[0],
|
||||
arg.c_m_n_real_.mDesc.GetLengths()[1])(
|
||||
std::thread::hardware_concurrency());
|
||||
make_ParallelTensorFunctor(f_mk_kn_mn_imag,
|
||||
arg.c_m_n_imag_.mDesc.GetLengths()[0],
|
||||
arg.c_m_n_imag_.mDesc.GetLengths()[1])(
|
||||
std::thread::hardware_concurrency());
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
float Run(const device::BaseArgument* p_arg,
|
||||
const StreamConfig& /* stream_config */ = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
};
|
||||
|
||||
static constexpr bool IsValidCompilationParameter()
|
||||
{
|
||||
// TODO: properly implement this check
|
||||
return true;
|
||||
}
|
||||
|
||||
bool IsSupportedArgument(const device::BaseArgument*) override { return true; }
|
||||
|
||||
static auto MakeArgument(const Tensor<ADataType>& a_m_k_real,
|
||||
const Tensor<ADataType>& a_m_k_imag,
|
||||
const Tensor<BDataType>& b_k_n_real,
|
||||
const Tensor<BDataType>& b_k_n_imag,
|
||||
Tensor<CDataType>& c_m_n_real,
|
||||
Tensor<CDataType>& c_m_n_imag,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op)
|
||||
{
|
||||
return Argument{a_m_k_real,
|
||||
a_m_k_imag,
|
||||
b_k_n_real,
|
||||
b_k_n_imag,
|
||||
c_m_n_real,
|
||||
c_m_n_imag,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
virtual std::unique_ptr<device::BaseInvoker> MakeInvokerPointer()
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "ReferenceCGemm"
|
||||
<< std::endl;
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace host
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
Reference in New Issue
Block a user