mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
Conv bwd data multiple d (#404)
* init commit of convnd bwd data * begin compiling example * have a first version that produce a right result * refine device level launch kernel code * add more instances in example and get right results * clang-format * format example file * add more instances * fix instances * adding conv_bwd_data multile_d * adding conv_bwd_data multile_d * adding conv_bwd multiple d * adding conv_bwd multiple d * adding conv_bwd multiple d * refactor * refactor * adding conv bwd data multiple d * adding conv bwd data multiple d * adding conv bwd data multiple d * adding conv bwd data multiple d * adding conv bwd data multiple d * adding conv bwd data multiple d * adding conv bwd data multiple d * refactor * update conv fwd's bias impl * refactor * reorg file * clean up cmake * clean * clean * clean Co-authored-by: Chao Liu <lc.roy86@gmail.com> Co-authored-by: Chao Liu <chao.liu2@amd.com>
This commit is contained in:
@@ -6,9 +6,10 @@ find_package(composable_kernel 1.0.0 COMPONENTS device_operations)
|
||||
find_package(hip REQUIRED PATHS /opt/rocm)
|
||||
message(STATUS "Build with HIP ${hip_VERSION}")
|
||||
|
||||
add_subdirectory(01_gemm)
|
||||
add_subdirectory(02_gemm_add_add_fastgelu)
|
||||
add_subdirectory(03_gemm_layernorm)
|
||||
add_subdirectory(04_contraction)
|
||||
add_subdirectory(05_layernorm)
|
||||
add_subdirectory(06_softmax)
|
||||
# add all example subdir
|
||||
file(GLOB dir_list LIST_DIRECTORIES true *)
|
||||
FOREACH(subdir ${dir_list})
|
||||
IF(IS_DIRECTORY "${subdir}")
|
||||
add_subdirectory(${subdir})
|
||||
ENDIF()
|
||||
ENDFOREACH()
|
||||
|
||||
@@ -1,258 +0,0 @@
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_batched_gemm_e_permute_xdl.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using ADataType = F16;
|
||||
using BDataType = F16;
|
||||
using AccDataType = F32;
|
||||
using CShuffleDataType = F16;
|
||||
using EDataType = F16;
|
||||
|
||||
using ALayout = Row;
|
||||
using BLayout = Col;
|
||||
using ELayout = Row;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
using CDEElementOp = PassThrough;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
|
||||
using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmEPermuteXdl
|
||||
// clang-format off
|
||||
//######| ALayout| BLayout| ELayout| AData| BData| AccData| CShuffle| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//######| | | | Type| Type| Type| DataType| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
//######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< ALayout, BLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
|
||||
// clang-format on
|
||||
|
||||
using ReferenceBatchedGemmInstance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
|
||||
BDataType,
|
||||
EDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CDEElementOp>;
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = false;
|
||||
|
||||
const int M = 256;
|
||||
const int N = 128;
|
||||
const int K = 64;
|
||||
|
||||
const int stride_A = K;
|
||||
const int stride_B = K;
|
||||
|
||||
const int batch_stride_A = M * K;
|
||||
const int batch_stride_B = K * N;
|
||||
|
||||
const int G0 = 16;
|
||||
const int G1 = 8;
|
||||
|
||||
const int batch_count = G0 * G1;
|
||||
|
||||
// output layout - [G0, M, G1, N]
|
||||
const int stride_G0 = M * G1 * N;
|
||||
const int stride_G1 = N;
|
||||
const int stride_M = G1 * N;
|
||||
const int stride_N = 1;
|
||||
|
||||
if(argc == 4)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("arg1: verification (0=no, 1=yes)\n");
|
||||
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
|
||||
printf("arg3: time kernel (0=n0, 1=yes)\n");
|
||||
exit(0);
|
||||
}
|
||||
|
||||
// GEMM shape
|
||||
ck::tensor_operation::device::BatchedGemmEPermuteDesc batched_gemm_e_permute_desc{
|
||||
G0, G1, M, N, stride_G0, stride_G1, stride_M, stride_N};
|
||||
|
||||
auto f_host_tensor_descriptor = [](std::size_t batch_count_,
|
||||
std::size_t row,
|
||||
std::size_t col,
|
||||
std::size_t stride,
|
||||
std::size_t batch_stride,
|
||||
auto layout) {
|
||||
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
return HostTensorDescriptor(std::vector<std::size_t>({batch_count_, row, col}),
|
||||
std::vector<std::size_t>({batch_stride, stride, 1}));
|
||||
}
|
||||
else
|
||||
{
|
||||
return HostTensorDescriptor(std::vector<std::size_t>({batch_count_, row, col}),
|
||||
std::vector<std::size_t>({batch_stride, 1, stride}));
|
||||
}
|
||||
};
|
||||
|
||||
Tensor<ADataType> a_g_m_k(
|
||||
f_host_tensor_descriptor(batch_count, M, K, stride_A, batch_stride_A, ALayout{}));
|
||||
Tensor<BDataType> b_g_k_n(
|
||||
f_host_tensor_descriptor(batch_count, K, N, stride_B, batch_stride_B, BLayout{}));
|
||||
|
||||
auto f_host_e_tensor_descriptor = [](std::size_t G0_,
|
||||
std::size_t G1_,
|
||||
std::size_t M_,
|
||||
std::size_t N_,
|
||||
std::size_t stride_G0_,
|
||||
std::size_t stride_G1_,
|
||||
std::size_t stride_M_,
|
||||
std::size_t stride_N_) {
|
||||
return HostTensorDescriptor(
|
||||
std::vector<std::size_t>({G0_, G1_, M_, N_}),
|
||||
std::vector<std::size_t>({stride_G0_, stride_G1_, stride_M_, stride_N_}));
|
||||
};
|
||||
|
||||
Tensor<EDataType> e_g0_g1_m_n_host_result(
|
||||
f_host_e_tensor_descriptor(G0, G1, M, N, stride_G0, stride_G1, stride_M, stride_N));
|
||||
|
||||
Tensor<EDataType> e_g0_g1_m_n_device_result(
|
||||
f_host_e_tensor_descriptor(G0, G1, M, N, stride_G0, stride_G1, stride_M, stride_N));
|
||||
|
||||
std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl;
|
||||
std::cout << "b_g_k_n: " << b_g_k_n.mDesc << std::endl;
|
||||
std::cout << "e_g0_g1_m_n: " << e_g0_g1_m_n_host_result.mDesc << std::endl;
|
||||
|
||||
switch(init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
a_g_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
|
||||
b_g_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
|
||||
break;
|
||||
default:
|
||||
a_g_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
b_g_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
|
||||
break;
|
||||
}
|
||||
|
||||
DeviceMem a_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b_device_buf(sizeof(BDataType) * b_g_k_n.mDesc.GetElementSpaceSize());
|
||||
DeviceMem e_device_buf(sizeof(EDataType) *
|
||||
e_g0_g1_m_n_device_result.mDesc.GetElementSpaceSize());
|
||||
|
||||
a_device_buf.ToDevice(a_g_m_k.mData.data());
|
||||
b_device_buf.ToDevice(b_g_k_n.mData.data());
|
||||
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b_element_op = BElementOp{};
|
||||
auto cde_element_op = CDEElementOp{};
|
||||
|
||||
auto gemm = DeviceGemmInstance{};
|
||||
auto invoker = gemm.MakeInvoker();
|
||||
|
||||
// do GEM
|
||||
auto argument = gemm.MakeArgument(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
|
||||
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
|
||||
static_cast<EDataType*>(e_device_buf.GetDeviceBuffer()),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
batch_stride_A,
|
||||
batch_stride_B,
|
||||
batched_gemm_e_permute_desc,
|
||||
batch_count,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op);
|
||||
|
||||
if(!gemm.IsSupportedArgument(argument))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! device_gemm with the specified compilation parameters does "
|
||||
"not support this GEMM problem");
|
||||
}
|
||||
|
||||
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
|
||||
|
||||
std::size_t flop = std::size_t(2) * batch_count * M * N * K;
|
||||
std::size_t num_btype = sizeof(ADataType) * batch_count * M * K +
|
||||
sizeof(BDataType) * batch_count * K * N +
|
||||
sizeof(EDataType) * batch_count * M * N;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
|
||||
<< gemm.GetTypeString() << std::endl;
|
||||
|
||||
bool pass = true;
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
e_device_buf.FromDevice(e_g0_g1_m_n_device_result.mData.data());
|
||||
|
||||
auto ref_batched_gemm = ReferenceBatchedGemmInstance{};
|
||||
auto ref_invoker = ref_batched_gemm.MakeInvoker();
|
||||
|
||||
Tensor<EDataType> c_g_m_n_host_result = HostTensorDescriptor(
|
||||
std::vector<std::size_t>({batch_count, M, N}), std::vector<std::size_t>({M * N, N, 1}));
|
||||
|
||||
auto ref_argument = ref_batched_gemm.MakeArgument(
|
||||
a_g_m_k, b_g_k_n, c_g_m_n_host_result, a_element_op, b_element_op, cde_element_op);
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
for(int g0 = 0; g0 < G0; g0++)
|
||||
{
|
||||
for(int g1 = 0; g1 < G1; g1++)
|
||||
{
|
||||
for(int m = 0; m < M; m++)
|
||||
{
|
||||
for(int n = 0; n < N; n++)
|
||||
{
|
||||
int g = g0 * G1 + g1;
|
||||
|
||||
e_g0_g1_m_n_host_result(g0, g1, m, n) = c_g_m_n_host_result(g, m, n);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pass = ck::utils::check_err(e_g0_g1_m_n_host_result.mData,
|
||||
e_g0_g1_m_n_device_result.mData,
|
||||
"Error: Incorrect results c");
|
||||
}
|
||||
|
||||
return pass ? 0 : 1;
|
||||
}
|
||||
@@ -137,7 +137,7 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
using InLayout = ctc::G_NW_C;
|
||||
using WeiLayout = ctc::G_K_X_C;
|
||||
using BiasLayout = ctc::G_NW_K;
|
||||
using BiasLayout = ctc::G_K;
|
||||
using ResidualLayout = ctc::G_NW_K;
|
||||
using OutLayout = ctc::G_NW_K;
|
||||
|
||||
@@ -220,7 +220,7 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
using InLayout = ctc::G_NHW_C;
|
||||
using WeiLayout = ctc::G_K_YX_C;
|
||||
using BiasLayout = ctc::G_NHW_K;
|
||||
using BiasLayout = ctc::G_K;
|
||||
using ResidualLayout = ctc::G_NHW_K;
|
||||
using OutLayout = ctc::G_NHW_K;
|
||||
|
||||
@@ -332,7 +332,7 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
using InLayout = ctc::G_NDHW_C;
|
||||
using WeiLayout = ctc::G_K_ZYX_C;
|
||||
using BiasLayout = ctc::G_NDHW_K;
|
||||
using BiasLayout = ctc::G_K;
|
||||
using ResidualLayout = ctc::G_NDHW_K;
|
||||
using OutLayout = ctc::G_NDHW_K;
|
||||
|
||||
|
||||
@@ -137,7 +137,7 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
using InLayout = ctc::G_NW_C;
|
||||
using WeiLayout = ctc::G_K_X_C;
|
||||
using BiasLayout = ctc::G_NW_K;
|
||||
using BiasLayout = ctc::G_K;
|
||||
using ResidualLayout = ctc::G_NW_K;
|
||||
using OutLayout = ctc::G_NW_K;
|
||||
|
||||
@@ -220,7 +220,7 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
using InLayout = ctc::G_NHW_C;
|
||||
using WeiLayout = ctc::G_K_YX_C;
|
||||
using BiasLayout = ctc::G_NHW_K;
|
||||
using BiasLayout = ctc::G_K;
|
||||
using ResidualLayout = ctc::G_NHW_K;
|
||||
using OutLayout = ctc::G_NHW_K;
|
||||
|
||||
@@ -332,7 +332,7 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
using InLayout = ctc::G_NDHW_C;
|
||||
using WeiLayout = ctc::G_K_ZYX_C;
|
||||
using BiasLayout = ctc::G_NDHW_K;
|
||||
using BiasLayout = ctc::G_K;
|
||||
using ResidualLayout = ctc::G_NDHW_K;
|
||||
using OutLayout = ctc::G_NDHW_K;
|
||||
|
||||
|
||||
@@ -137,7 +137,7 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
using InLayout = ctc::G_NW_C;
|
||||
using WeiLayout = ctc::G_K_X_C;
|
||||
using BiasLayout = ctc::G_NW_K;
|
||||
using BiasLayout = ctc::G_K;
|
||||
using ResidualLayout = ctc::G_NW_K;
|
||||
using OutLayout = ctc::G_NW_K;
|
||||
|
||||
@@ -220,7 +220,7 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
using InLayout = ctc::G_NHW_C;
|
||||
using WeiLayout = ctc::G_K_YX_C;
|
||||
using BiasLayout = ctc::G_NHW_K;
|
||||
using BiasLayout = ctc::G_K;
|
||||
using ResidualLayout = ctc::G_NHW_K;
|
||||
using OutLayout = ctc::G_NHW_K;
|
||||
|
||||
@@ -332,7 +332,7 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
using InLayout = ctc::G_NDHW_C;
|
||||
using WeiLayout = ctc::G_K_ZYX_C;
|
||||
using BiasLayout = ctc::G_NDHW_K;
|
||||
using BiasLayout = ctc::G_K;
|
||||
using ResidualLayout = ctc::G_NDHW_K;
|
||||
using OutLayout = ctc::G_NDHW_K;
|
||||
|
||||
|
||||
@@ -137,7 +137,7 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
using InLayout = ctc::G_NW_C;
|
||||
using WeiLayout = ctc::G_K_X_C;
|
||||
using BiasLayout = ctc::G_NW_K;
|
||||
using BiasLayout = ctc::G_K;
|
||||
using ResidualLayout = ctc::G_NW_K;
|
||||
using OutLayout = ctc::G_NW_K;
|
||||
|
||||
@@ -220,7 +220,7 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
using InLayout = ctc::G_NHW_C;
|
||||
using WeiLayout = ctc::G_K_YX_C;
|
||||
using BiasLayout = ctc::G_NHW_K;
|
||||
using BiasLayout = ctc::G_K;
|
||||
using ResidualLayout = ctc::G_NHW_K;
|
||||
using OutLayout = ctc::G_NHW_K;
|
||||
|
||||
@@ -332,7 +332,7 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
using InLayout = ctc::G_NDHW_C;
|
||||
using WeiLayout = ctc::G_K_ZYX_C;
|
||||
using BiasLayout = ctc::G_NDHW_K;
|
||||
using BiasLayout = ctc::G_K;
|
||||
using ResidualLayout = ctc::G_NDHW_K;
|
||||
using OutLayout = ctc::G_NDHW_K;
|
||||
|
||||
|
||||
@@ -137,7 +137,7 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
using InLayout = ctc::G_NW_C;
|
||||
using WeiLayout = ctc::G_K_X_C;
|
||||
using BiasLayout = ctc::G_NW_K;
|
||||
using BiasLayout = ctc::G_K;
|
||||
using ResidualLayout = ctc::G_NW_K;
|
||||
using OutLayout = ctc::G_NW_K;
|
||||
|
||||
@@ -220,7 +220,7 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
using InLayout = ctc::G_NHW_C;
|
||||
using WeiLayout = ctc::G_K_YX_C;
|
||||
using BiasLayout = ctc::G_NHW_K;
|
||||
using BiasLayout = ctc::G_K;
|
||||
using ResidualLayout = ctc::G_NHW_K;
|
||||
using OutLayout = ctc::G_NHW_K;
|
||||
|
||||
@@ -332,7 +332,7 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
using InLayout = ctc::G_NDHW_C;
|
||||
using WeiLayout = ctc::G_K_ZYX_C;
|
||||
using BiasLayout = ctc::G_NDHW_K;
|
||||
using BiasLayout = ctc::G_K;
|
||||
using ResidualLayout = ctc::G_NDHW_K;
|
||||
using OutLayout = ctc::G_NDHW_K;
|
||||
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
add_example_executable(example_grouped_conv_bwd_data_bias_relu_fp16 grouped_conv_bwd_data_bias_relu_fp16.cpp)
|
||||
@@ -0,0 +1,199 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/convolution_parameter.hpp"
|
||||
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp"
|
||||
|
||||
void print_helper_msg()
|
||||
{
|
||||
std::cout << "arg1: verification (0=no, 1=yes)\n"
|
||||
<< "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"
|
||||
<< "arg3: time kernel (0=no, 1=yes)\n"
|
||||
<< ck::utils::conv::get_conv_param_parser_helper_msg() << std::endl;
|
||||
}
|
||||
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename OutDataType,
|
||||
typename WeiDataType,
|
||||
typename BiasDataType,
|
||||
typename InDataType,
|
||||
typename OutElementOp,
|
||||
typename WeiElementOp,
|
||||
typename InElementOp,
|
||||
typename DeviceInstance>
|
||||
int run_conv_bwd_data_bias_relu(bool do_verification,
|
||||
int init_method,
|
||||
bool time_kernel,
|
||||
const ck::utils::conv::ConvParam& conv_param,
|
||||
const HostTensorDescriptor& out_g_n_k_wos_desc,
|
||||
const HostTensorDescriptor& wei_g_k_c_xs_desc,
|
||||
const HostTensorDescriptor& bias_g_n_c_wis_desc,
|
||||
const HostTensorDescriptor& in_g_n_c_wis_desc,
|
||||
const OutElementOp& out_element_op,
|
||||
const WeiElementOp& wei_element_op,
|
||||
const InElementOp& in_element_op)
|
||||
{
|
||||
Tensor<OutDataType> out(out_g_n_k_wos_desc);
|
||||
Tensor<WeiDataType> wei(wei_g_k_c_xs_desc);
|
||||
Tensor<BiasDataType> bias(bias_g_n_c_wis_desc);
|
||||
Tensor<InDataType> in_host(in_g_n_c_wis_desc);
|
||||
Tensor<InDataType> in_device(in_g_n_c_wis_desc);
|
||||
|
||||
std::cout << "out: " << out.mDesc << std::endl;
|
||||
std::cout << "wei: " << wei.mDesc << std::endl;
|
||||
std::cout << "bias: " << bias.mDesc << std::endl;
|
||||
std::cout << "in: " << in_host.mDesc << std::endl;
|
||||
|
||||
switch(init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
out.GenerateTensorValue(GeneratorTensor_2<OutDataType>{-5, 5});
|
||||
wei.GenerateTensorValue(GeneratorTensor_2<WeiDataType>{-5, 5});
|
||||
bias.GenerateTensorValue(GeneratorTensor_2<BiasDataType>{-5, 5});
|
||||
break;
|
||||
default:
|
||||
out.GenerateTensorValue(GeneratorTensor_3<OutDataType>{0.0, 1.0});
|
||||
wei.GenerateTensorValue(GeneratorTensor_3<WeiDataType>{-0.5, 0.5});
|
||||
bias.GenerateTensorValue(GeneratorTensor_3<BiasDataType>{0.0, 1.0});
|
||||
}
|
||||
|
||||
DeviceMem out_device_buf(sizeof(OutDataType) * out.mDesc.GetElementSpaceSize());
|
||||
DeviceMem wei_device_buf(sizeof(WeiDataType) * wei.mDesc.GetElementSpaceSize());
|
||||
DeviceMem bias_device_buf(sizeof(BiasDataType) * bias.mDesc.GetElementSpaceSize());
|
||||
DeviceMem in_device_buf(sizeof(InDataType) * in_device.mDesc.GetElementSpaceSize());
|
||||
|
||||
out_device_buf.ToDevice(out.mData.data());
|
||||
wei_device_buf.ToDevice(wei.mData.data());
|
||||
bias_device_buf.ToDevice(bias.mData.data());
|
||||
|
||||
// reset input to zero
|
||||
in_device_buf.SetZero();
|
||||
|
||||
std::array<ck::index_t, NDimSpatial + 3> a_g_n_k_wos_lengths{};
|
||||
std::array<ck::index_t, NDimSpatial + 3> a_g_n_k_wos_strides{};
|
||||
std::array<ck::index_t, NDimSpatial + 3> b_g_k_c_xs_lengths{};
|
||||
std::array<ck::index_t, NDimSpatial + 3> b_g_k_c_xs_strides{};
|
||||
std::array<ck::index_t, NDimSpatial + 3> d0_g_n_c_wis_lengths{};
|
||||
std::array<ck::index_t, NDimSpatial + 3> d0_g_n_c_wis_strides{};
|
||||
std::array<ck::index_t, NDimSpatial + 3> e_g_n_c_wis_lengths{};
|
||||
std::array<ck::index_t, NDimSpatial + 3> e_g_n_c_wis_strides{};
|
||||
std::array<ck::index_t, NDimSpatial> conv_filter_strides{};
|
||||
std::array<ck::index_t, NDimSpatial> conv_filter_dilations{};
|
||||
std::array<ck::index_t, NDimSpatial> input_left_pads{};
|
||||
std::array<ck::index_t, NDimSpatial> input_right_pads{};
|
||||
|
||||
auto copy = [](auto& x, auto& y) { std::copy(x.begin(), x.end(), y.begin()); };
|
||||
|
||||
copy(out_g_n_k_wos_desc.GetLengths(), a_g_n_k_wos_lengths);
|
||||
copy(out_g_n_k_wos_desc.GetStrides(), a_g_n_k_wos_strides);
|
||||
copy(wei_g_k_c_xs_desc.GetLengths(), b_g_k_c_xs_lengths);
|
||||
copy(wei_g_k_c_xs_desc.GetStrides(), b_g_k_c_xs_strides);
|
||||
copy(bias_g_n_c_wis_desc.GetLengths(), d0_g_n_c_wis_lengths);
|
||||
copy(bias_g_n_c_wis_desc.GetStrides(), d0_g_n_c_wis_strides);
|
||||
copy(in_g_n_c_wis_desc.GetLengths(), e_g_n_c_wis_lengths);
|
||||
copy(in_g_n_c_wis_desc.GetStrides(), e_g_n_c_wis_strides);
|
||||
copy(conv_param.conv_filter_strides_, conv_filter_strides);
|
||||
copy(conv_param.conv_filter_dilations_, conv_filter_dilations);
|
||||
copy(conv_param.input_left_pads_, input_left_pads);
|
||||
copy(conv_param.input_right_pads_, input_right_pads);
|
||||
|
||||
// do conv
|
||||
auto conv = DeviceInstance{};
|
||||
auto invoker = conv.MakeInvoker();
|
||||
auto argument = conv.MakeArgument(
|
||||
out_device_buf.GetDeviceBuffer(),
|
||||
wei_device_buf.GetDeviceBuffer(),
|
||||
std::array<const void*, 1>{bias_device_buf.GetDeviceBuffer()},
|
||||
in_device_buf.GetDeviceBuffer(),
|
||||
a_g_n_k_wos_lengths,
|
||||
a_g_n_k_wos_strides,
|
||||
b_g_k_c_xs_lengths,
|
||||
b_g_k_c_xs_strides,
|
||||
std::array<std::array<ck::index_t, NDimSpatial + 3>, 1>{d0_g_n_c_wis_lengths},
|
||||
std::array<std::array<ck::index_t, NDimSpatial + 3>, 1>{d0_g_n_c_wis_strides},
|
||||
e_g_n_c_wis_lengths,
|
||||
e_g_n_c_wis_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
out_element_op,
|
||||
wei_element_op,
|
||||
in_element_op);
|
||||
|
||||
if(!conv.IsSupportedArgument(argument))
|
||||
{
|
||||
printf("wrong! device_conv with the specified compilation parameters does "
|
||||
"not support this Conv problem\n");
|
||||
|
||||
return 1;
|
||||
}
|
||||
|
||||
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
|
||||
|
||||
std::size_t flop = conv_param.GetFlops();
|
||||
std::size_t num_btype = conv_param.GetByte<InDataType, WeiDataType, OutDataType>();
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
|
||||
<< std::endl;
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
// c doesn't physically exist, any layout is fine
|
||||
Tensor<float> c_host(in_g_n_c_wis_desc);
|
||||
|
||||
auto ref_conv = ck::tensor_operation::host::ReferenceConvBwdData<NDimSpatial,
|
||||
float,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
PassThrough,
|
||||
WeiElementOp,
|
||||
OutElementOp>();
|
||||
|
||||
auto ref_invoker = ref_conv.MakeInvoker();
|
||||
|
||||
auto ref_argument = ref_conv.MakeArgument(c_host,
|
||||
wei,
|
||||
out,
|
||||
conv_param.conv_filter_strides_,
|
||||
conv_param.conv_filter_dilations_,
|
||||
conv_param.input_left_pads_,
|
||||
conv_param.input_right_pads_,
|
||||
PassThrough{},
|
||||
wei_element_op,
|
||||
out_element_op);
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
// TODO: implement elementwise operation for host
|
||||
in_host.ForEach(
|
||||
[&](auto&, auto idx) { in_element_op(in_host(idx), c_host(idx), bias(idx)); });
|
||||
|
||||
in_device_buf.FromDevice(in_device.mData.data());
|
||||
|
||||
return ck::utils::check_err(in_device.mData, in_host.mData) ? 0 : 1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
@@ -0,0 +1,174 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "grouped_conv_bwd_data_bias_relu_common.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_data_multiple_d.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp"
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using OutDataType = ck::half_t;
|
||||
using WeiDataType = ck::half_t;
|
||||
using AccDataType = float;
|
||||
using CShuffleDataType = ck::half_t;
|
||||
using BiasDataType = ck::half_t; // bias
|
||||
using InDataType = ck::half_t;
|
||||
|
||||
using OutLayout = ck::tensor_layout::convolution::GNHWK;
|
||||
using WeiLayout = ck::tensor_layout::convolution::GKYXC;
|
||||
using BiasLayout = ck::tensor_layout::convolution::G_C;
|
||||
using InLayout = ck::tensor_layout::convolution::GNHWC;
|
||||
|
||||
using OutElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using WeiElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using CBiasInElementOp = ck::tensor_operation::element_wise::AddRelu;
|
||||
|
||||
static constexpr auto ConvBwdDataDefault =
|
||||
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Default;
|
||||
|
||||
template <ck::index_t NDimSpatial>
|
||||
using DeviceConvNdBwdDataInstance =
|
||||
ck::tensor_operation::device::DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<
|
||||
NDimSpatial,
|
||||
OutLayout,
|
||||
WeiLayout,
|
||||
ck::Tuple<BiasLayout>,
|
||||
InLayout,
|
||||
OutDataType,
|
||||
WeiDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
ck::Tuple<BiasDataType>,
|
||||
InDataType,
|
||||
OutElementOp,
|
||||
WeiElementOp,
|
||||
CBiasInElementOp,
|
||||
ConvBwdDataDefault,
|
||||
true, // DoPadGemmM
|
||||
true, // DoPadGemmN
|
||||
1,
|
||||
256,
|
||||
128,
|
||||
256,
|
||||
32,
|
||||
8,
|
||||
2,
|
||||
32,
|
||||
32,
|
||||
2,
|
||||
4,
|
||||
S<4, 64, 1>,
|
||||
S<1, 0, 2>,
|
||||
S<1, 0, 2>,
|
||||
2,
|
||||
8,
|
||||
8,
|
||||
1,
|
||||
S<4, 64, 1>,
|
||||
S<0, 2, 1>,
|
||||
S<0, 2, 1>,
|
||||
1,
|
||||
4,
|
||||
2,
|
||||
0,
|
||||
1,
|
||||
1,
|
||||
S<1, 32, 1, 8>,
|
||||
8>;
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
namespace ctc = ck::tensor_layout::convolution;
|
||||
|
||||
print_helper_msg();
|
||||
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = false;
|
||||
|
||||
ck::utils::conv::ConvParam conv_param{
|
||||
2, 2, 128, 256, 256, {3, 3}, {14, 14}, {2, 2}, {1, 1}, {1, 1}, {1, 1}};
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
// use default
|
||||
}
|
||||
else if(argc == 4)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
}
|
||||
else
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
const ck::index_t num_dim_spatial = std::stoi(argv[4]);
|
||||
|
||||
conv_param = ck::utils::conv::parse_conv_param(num_dim_spatial, 5, argv);
|
||||
}
|
||||
|
||||
const auto in_element_op = CBiasInElementOp{};
|
||||
const auto wei_element_op = WeiElementOp{};
|
||||
const auto out_element_op = OutElementOp{};
|
||||
|
||||
if(conv_param.num_dim_spatial_ == 2)
|
||||
{
|
||||
// output image: GNHWK
|
||||
const auto out_g_n_k_wos_desc =
|
||||
ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed<OutLayout>(
|
||||
conv_param);
|
||||
|
||||
// weight: GKYXC
|
||||
const auto wei_g_k_c_xs_desc =
|
||||
ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed<WeiLayout>(
|
||||
conv_param);
|
||||
|
||||
// input image bias: G_C
|
||||
const auto bias_g_n_c_wis_desc =
|
||||
HostTensorDescriptor({conv_param.G_,
|
||||
conv_param.N_,
|
||||
conv_param.C_,
|
||||
conv_param.input_spatial_lengths_[0],
|
||||
conv_param.input_spatial_lengths_[1]},
|
||||
{
|
||||
conv_param.C_, // g
|
||||
0, // n
|
||||
1, // c
|
||||
0, // hi
|
||||
0 // wi
|
||||
});
|
||||
|
||||
// input image: GNHWC
|
||||
const auto in_g_n_c_wis_desc =
|
||||
ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<InLayout>(
|
||||
conv_param);
|
||||
|
||||
using DeviceInstance = DeviceConvNdBwdDataInstance<2>;
|
||||
|
||||
run_conv_bwd_data_bias_relu<2,
|
||||
OutDataType,
|
||||
WeiDataType,
|
||||
BiasDataType,
|
||||
InDataType,
|
||||
OutElementOp,
|
||||
WeiElementOp,
|
||||
CBiasInElementOp,
|
||||
DeviceInstance>(do_verification,
|
||||
init_method,
|
||||
time_kernel,
|
||||
conv_param,
|
||||
out_g_n_k_wos_desc,
|
||||
wei_g_k_c_xs_desc,
|
||||
bias_g_n_c_wis_desc,
|
||||
in_g_n_c_wis_desc,
|
||||
wei_element_op,
|
||||
out_element_op,
|
||||
in_element_op);
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
@@ -21,36 +21,10 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME)
|
||||
add_dependencies(examples ${EXAMPLE_NAME})
|
||||
endfunction(add_example_executable_no_testing EXAMPLE_NAME)
|
||||
|
||||
add_subdirectory(01_gemm)
|
||||
add_subdirectory(02_gemm_bilinear)
|
||||
add_subdirectory(03_gemm_bias_relu)
|
||||
add_subdirectory(04_gemm_add_add_fastgelu)
|
||||
add_subdirectory(09_convnd_fwd)
|
||||
add_subdirectory(10_convnd_fwd_multiple_d_multiple_reduce)
|
||||
add_subdirectory(12_reduce)
|
||||
add_subdirectory(13_pool2d_fwd)
|
||||
add_subdirectory(14_gemm_xdl_requant_relu_requant)
|
||||
add_subdirectory(15_grouped_gemm)
|
||||
add_subdirectory(16_gemm_multi_d_multi_reduces)
|
||||
add_subdirectory(17_convnd_bwd_data)
|
||||
add_subdirectory(18_batched_gemm_reduce)
|
||||
add_subdirectory(19_binary_elementwise)
|
||||
add_subdirectory(20_convnd_bwd_weight)
|
||||
add_subdirectory(21_gemm_layernorm)
|
||||
add_subdirectory(22_cgemm)
|
||||
add_subdirectory(23_softmax)
|
||||
add_subdirectory(24_batched_gemm)
|
||||
add_subdirectory(25_gemm_bias_e_permute)
|
||||
add_subdirectory(26_contraction)
|
||||
add_subdirectory(27_layernorm)
|
||||
add_subdirectory(28_grouped_gemm_bias_e_permute)
|
||||
add_subdirectory(29_batched_gemm_bias_e_permute)
|
||||
add_subdirectory(30_grouped_convnd_fwd_bias_relu_add)
|
||||
add_subdirectory(31_batched_gemm_gemm)
|
||||
add_subdirectory(32_batched_gemm_scale_softmax_gemm)
|
||||
add_subdirectory(33_multiple_reduce)
|
||||
add_subdirectory(34_batchnorm)
|
||||
add_subdirectory(35_splitK_gemm)
|
||||
add_subdirectory(36_sparse_embedding)
|
||||
add_subdirectory(37_batched_gemm_add_add_relu_gemm_add)
|
||||
add_subdirectory(41_grouped_conv_conv_fwd)
|
||||
# add all example subdir
|
||||
file(GLOB dir_list LIST_DIRECTORIES true *)
|
||||
FOREACH(subdir ${dir_list})
|
||||
IF(IS_DIRECTORY "${subdir}")
|
||||
add_subdirectory(${subdir})
|
||||
ENDIF()
|
||||
ENDFOREACH()
|
||||
|
||||
@@ -549,10 +549,6 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
AGridDesc_M_K,
|
||||
BGridDesc_N_K,
|
||||
DsGridDesc_M_N,
|
||||
EGridDesc_M_N,
|
||||
NumGemmKPrefetchStage,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
@@ -586,12 +582,19 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle
|
||||
CDEBlockTransferScalarPerVector_NPerBlock,
|
||||
LoopSched>;
|
||||
|
||||
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(
|
||||
// desc for blockwise copy
|
||||
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(
|
||||
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))>;
|
||||
using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype(
|
||||
using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype(
|
||||
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>;
|
||||
using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
|
||||
GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DsGridDesc_M_N{}))>;
|
||||
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
|
||||
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))>;
|
||||
|
||||
using Block2ETileMap = typename GridwiseGemm::DefaultBlock2ETileMap;
|
||||
// block-to-e-tile map
|
||||
using Block2ETileMap =
|
||||
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))>;
|
||||
|
||||
// Argument
|
||||
struct Argument : public BaseArgument
|
||||
@@ -719,10 +722,9 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle
|
||||
// tensor descriptors for block/thread-wise copy
|
||||
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
|
||||
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
|
||||
typename GridwiseGemm::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock_;
|
||||
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock_;
|
||||
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_;
|
||||
|
||||
// block-to-e-tile map
|
||||
Block2ETileMap block_2_etile_map_;
|
||||
@@ -786,10 +788,10 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle
|
||||
CDEElementwiseOperation,
|
||||
DeviceOp::AGridDesc_AK0_M_AK1,
|
||||
DeviceOp::BGridDesc_BK0_N_BK1,
|
||||
typename GridwiseGemm::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
DeviceOp::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
ComputePtrOffsetOfStridedBatch,
|
||||
typename GridwiseGemm::DefaultBlock2ETileMap,
|
||||
DeviceOp::Block2ETileMap,
|
||||
has_main_loop>;
|
||||
|
||||
return launch_and_time_kernel(stream_config,
|
||||
|
||||
@@ -1,682 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_batched_gemm_e_permute.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
/*
|
||||
* \brief Wrapper function of GridwiseGemm::Run to realize BatchedGEMM.
|
||||
*
|
||||
* \tparam ComputePtrOffsetOfBatch Class that computes the base pointer offsets of A, B, C matrix
|
||||
* given the batch. For example, ComputePtrOffsetOfStridedBatch() computes the offsets of evenly
|
||||
* strided batched, but we can easily extend to other layouts. The returned offset can be either \p
|
||||
* index_t or \p long_index_t. If it returns \p long_index_t, we are not subject to the 2GB
|
||||
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
|
||||
* limitations.
|
||||
*
|
||||
* \tparam Block2ETileMap Block2ETileMap::CalculateBottomIndex() takes in id of a workgroup and
|
||||
* returns the 2D index of the tile that it computes. \see
|
||||
* GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3::Run().
|
||||
* \note Using \p ComputePtrOffsetOfBatch gives us the flexibility that 2 workgroups can compute 2
|
||||
* tiles from different matrices. Keep in mind that these 2 matrices can share the same grid
|
||||
* descriptor (like in BatchedGEMM), or use their own grid descriptors (in GroupedGemm). \link
|
||||
* device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp kernel_gemm_xdlops_v2r3_for_conv3d \endlink for \link
|
||||
* DeviceConv3d \endlink uses the same concept, but currently does NOT encapsulate the computing of
|
||||
* pointer offset into \p ComputePtrOffsetOfStridedBatch.
|
||||
*
|
||||
* \note \p Block2ETileMap allows customized mapping between a workgroup and the C-tile it computes.
|
||||
* Together with \p ComputePtrOffsetOfBatch, we can reuse GridwiseGemm (and GridwiseGemm fusion ) to
|
||||
* realize BatchedGemmCPermute and GroupedGemm (and the corresponding GEMM fusion).
|
||||
*
|
||||
*/
|
||||
template <typename GridwiseGemm,
|
||||
typename ABDataType,
|
||||
typename EDataType,
|
||||
typename AGridDesc_AK0_M_AK1,
|
||||
typename BGridDesc_BK0_N_BK1,
|
||||
typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation,
|
||||
typename ComputePtrOffsetOfBatch,
|
||||
typename Block2ETileMap,
|
||||
bool HasMainKBlockLoop>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_batched_gemm_e_permute_xdl(const ABDataType* __restrict__ p_a_grid,
|
||||
const ABDataType* __restrict__ p_b_grid,
|
||||
EDataType* __restrict__ p_e_grid,
|
||||
const index_t batch_count,
|
||||
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
|
||||
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
|
||||
const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const AElementwiseOperation a_element_op,
|
||||
const BElementwiseOperation b_element_op,
|
||||
const CDEElementwiseOperation cde_element_op,
|
||||
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
|
||||
const Block2ETileMap block_2_etile_map)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
|
||||
const index_t num_blocks_per_batch =
|
||||
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
|
||||
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
|
||||
|
||||
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
|
||||
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
|
||||
const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx)));
|
||||
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset,
|
||||
p_b_grid + b_batch_offset,
|
||||
ck::Tuple<>{},
|
||||
p_e_grid + e_batch_offset,
|
||||
p_shared,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op,
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
ck::Tuple<>{},
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
block_2_etile_map);
|
||||
#else
|
||||
ignore = p_a_grid;
|
||||
ignore = p_b_grid;
|
||||
ignore = p_e_grid;
|
||||
ignore = batch_count;
|
||||
ignore = a_grid_desc_ak0_m_ak1;
|
||||
ignore = b_grid_desc_bk0_n_bk1;
|
||||
ignore = e_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
ignore = a_element_op;
|
||||
ignore = b_element_op;
|
||||
ignore = cde_element_op;
|
||||
ignore = compute_ptr_offset_of_batch;
|
||||
ignore = block_2_etile_map;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename ELayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename EDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation,
|
||||
GemmSpecialization GemmSpec,
|
||||
index_t NumPrefetch,
|
||||
index_t BlockSize,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t AK1,
|
||||
index_t BK1,
|
||||
index_t MPerXDL,
|
||||
index_t NPerXDL,
|
||||
index_t MXdlPerWave,
|
||||
index_t NXdlPerWave,
|
||||
typename ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
index_t ABlockTransferSrcVectorDim,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t ABlockTransferDstScalarPerVector_K1,
|
||||
index_t ABlockLdsExtraM,
|
||||
typename BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
index_t BBlockTransferSrcVectorDim,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferDstScalarPerVector_K1,
|
||||
index_t BBlockLdsExtraN,
|
||||
index_t CShuffleMXdlPerWavePerShuffle,
|
||||
index_t CShuffleNXdlPerWavePerShuffle,
|
||||
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CDEBlockTransferScalarPerVector_NPerBlock,
|
||||
LoopScheduler LoopSched = make_default_loop_scheduler()>
|
||||
struct DeviceBatchedGemmEPermuteXdl : public DeviceBatchedGemmEPermute<ALayout,
|
||||
BLayout,
|
||||
ELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
EDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation>
|
||||
{
|
||||
using DeviceOp = DeviceBatchedGemmEPermuteXdl;
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
|
||||
static constexpr auto matrix_padder =
|
||||
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
|
||||
|
||||
static auto MakeAGridDescriptor_M_K(index_t MRaw, index_t KRaw, index_t StrideA)
|
||||
{
|
||||
const auto a_grid_desc_mraw_kraw = [&]() {
|
||||
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
|
||||
make_tuple(StrideA, I1));
|
||||
}
|
||||
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
|
||||
make_tuple(I1, StrideA));
|
||||
}
|
||||
}();
|
||||
|
||||
return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
|
||||
}
|
||||
|
||||
static auto MakeBGridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideB)
|
||||
{
|
||||
const auto b_grid_desc_nraw_kraw = [&]() {
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
|
||||
make_tuple(I1, StrideB));
|
||||
}
|
||||
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
|
||||
make_tuple(StrideB, I1));
|
||||
}
|
||||
}();
|
||||
|
||||
return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
|
||||
}
|
||||
|
||||
static auto
|
||||
MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t stride_M, index_t stride_N)
|
||||
{
|
||||
const auto e_grid_desc_mraw_nraw =
|
||||
make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), make_tuple(stride_M, stride_N));
|
||||
|
||||
return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw);
|
||||
}
|
||||
|
||||
static auto MakeEGridDescriptor_G0_G1_M_N(index_t G0,
|
||||
index_t G1,
|
||||
index_t MRaw,
|
||||
index_t NRaw,
|
||||
index_t stride_G0,
|
||||
index_t stride_G1,
|
||||
index_t stride_M,
|
||||
index_t stride_N)
|
||||
{
|
||||
const auto e_grid_desc_g0_g1_mraw_nraw = [&]() {
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(G0, G1, MRaw, NRaw),
|
||||
make_tuple(stride_G0, stride_G1, stride_M, stride_N));
|
||||
}();
|
||||
|
||||
const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
|
||||
const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock;
|
||||
|
||||
const auto MPad = M - MRaw;
|
||||
const auto NPad = N - NRaw;
|
||||
|
||||
if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
|
||||
GemmSpec == GemmSpecialization::MNKPadding)
|
||||
{
|
||||
// pad M and N
|
||||
return transform_tensor_descriptor(
|
||||
e_grid_desc_g0_g1_mraw_nraw,
|
||||
make_tuple(make_pass_through_transform(G0),
|
||||
make_pass_through_transform(G1),
|
||||
make_right_pad_transform(MRaw, MPad),
|
||||
make_right_pad_transform(NRaw, NPad)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
}
|
||||
else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
|
||||
GemmSpec == GemmSpecialization::MKPadding)
|
||||
{
|
||||
// pad M, but not N
|
||||
return transform_tensor_descriptor(
|
||||
e_grid_desc_g0_g1_mraw_nraw,
|
||||
make_tuple(make_pass_through_transform(G0),
|
||||
make_pass_through_transform(G1),
|
||||
make_right_pad_transform(MRaw, MPad),
|
||||
make_pass_through_transform(NRaw)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
}
|
||||
else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
|
||||
GemmSpec == GemmSpecialization::NKPadding)
|
||||
{
|
||||
// pad N, but not M
|
||||
return transform_tensor_descriptor(
|
||||
e_grid_desc_g0_g1_mraw_nraw,
|
||||
make_tuple(make_pass_through_transform(G0),
|
||||
make_pass_through_transform(G1),
|
||||
make_pass_through_transform(MRaw),
|
||||
make_right_pad_transform(NRaw, NPad)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
// not pad M or N
|
||||
return e_grid_desc_g0_g1_mraw_nraw;
|
||||
}
|
||||
}
|
||||
|
||||
using AGridDesc_M_K = decltype(MakeAGridDescriptor_M_K(1, 1, 1));
|
||||
using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K(1, 1, 1));
|
||||
using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N(1, 1, 1, 1));
|
||||
using EGridDesc_G0_G1_M_N = decltype(MakeEGridDescriptor_G0_G1_M_N(1, 1, 1, 1, 1, 1, 1, 1));
|
||||
|
||||
struct ComputePtrOffsetOfStridedBatch
|
||||
{
|
||||
ComputePtrOffsetOfStridedBatch(index_t Batchstride_A,
|
||||
index_t Batchstride_B,
|
||||
EGridDesc_G0_G1_M_N e_grid_desc_g0_g1_m_n)
|
||||
: Batchstride_A_(Batchstride_A),
|
||||
Batchstride_B_(Batchstride_B),
|
||||
e_grid_desc_g0_g1_m_n_(e_grid_desc_g0_g1_m_n)
|
||||
{
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(Batchstride_A_);
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(Batchstride_B_);
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const
|
||||
{
|
||||
const index_t G1 = e_grid_desc_g0_g1_m_n_.GetLength(I1);
|
||||
index_t b0 = g_idx / G1;
|
||||
index_t b1 = g_idx - b0 * G1; // g_idx % G1
|
||||
return e_grid_desc_g0_g1_m_n_.CalculateOffset(make_multi_index(b0, b1, 0, 0));
|
||||
}
|
||||
|
||||
private:
|
||||
index_t Batchstride_A_;
|
||||
index_t Batchstride_B_;
|
||||
EGridDesc_G0_G1_M_N e_grid_desc_g0_g1_m_n_;
|
||||
};
|
||||
|
||||
using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle<
|
||||
ADataType, // TODO: distinguish A/B datatype
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
ck::Tuple<>, // DsDataType,
|
||||
EDataType, // EDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
AGridDesc_M_K,
|
||||
BGridDesc_N_K,
|
||||
Tuple<>,
|
||||
EGridDesc_M_N,
|
||||
NumPrefetch,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
AK1,
|
||||
BK1,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MXdlPerWave,
|
||||
NXdlPerWave,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
false, // AThreadTransferSrcResetCoordinateAfterRun,
|
||||
ABlockLdsExtraM,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
false, // BThreadTransferSrcResetCoordinateAfterRun,
|
||||
BBlockLdsExtraN,
|
||||
CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CDEBlockTransferScalarPerVector_NPerBlock,
|
||||
LoopSched>;
|
||||
|
||||
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(
|
||||
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))>;
|
||||
using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype(
|
||||
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>;
|
||||
|
||||
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = decltype(
|
||||
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}));
|
||||
using Block2ETileMap = typename GridwiseGemm::DefaultBlock2ETileMap;
|
||||
|
||||
// Argument
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
Argument(const ADataType* p_a_grid,
|
||||
const BDataType* p_b_grid,
|
||||
EDataType* p_e_grid,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t stride_A,
|
||||
index_t stride_B,
|
||||
index_t batch_stride_A,
|
||||
index_t batch_stride_B,
|
||||
BatchedGemmEPermuteDesc batched_gemm_e_permute_desc,
|
||||
index_t BatchCount,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op)
|
||||
: p_a_grid_{p_a_grid},
|
||||
p_b_grid_{p_b_grid},
|
||||
p_e_grid_{p_e_grid},
|
||||
BatchCount_(BatchCount),
|
||||
a_grid_desc_m_k_{DeviceOp::MakeAGridDescriptor_M_K(M, K, stride_A)},
|
||||
b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K(K, N, stride_B)},
|
||||
e_grid_desc_m_n_{
|
||||
DeviceOp::MakeEGridDescriptor_M_N(batched_gemm_e_permute_desc.M_,
|
||||
batched_gemm_e_permute_desc.N_,
|
||||
batched_gemm_e_permute_desc.stride_M_,
|
||||
batched_gemm_e_permute_desc.stride_N_)},
|
||||
a_grid_desc_ak0_m_ak1_{
|
||||
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)},
|
||||
b_grid_desc_bk0_n_bk1_{
|
||||
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)},
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock{},
|
||||
e_grid_desc_g0_g1_m_n_{
|
||||
DeviceOp::MakeEGridDescriptor_G0_G1_M_N(batched_gemm_e_permute_desc.G0_,
|
||||
batched_gemm_e_permute_desc.G1_,
|
||||
batched_gemm_e_permute_desc.M_,
|
||||
batched_gemm_e_permute_desc.N_,
|
||||
batched_gemm_e_permute_desc.stride_G0_,
|
||||
batched_gemm_e_permute_desc.stride_G1_,
|
||||
batched_gemm_e_permute_desc.stride_M_,
|
||||
batched_gemm_e_permute_desc.stride_N_)},
|
||||
compute_ptr_offset_of_batch_{batch_stride_A, batch_stride_B, e_grid_desc_g0_g1_m_n_},
|
||||
block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)},
|
||||
a_element_op_{a_element_op},
|
||||
b_element_op_{b_element_op},
|
||||
cde_element_op_{cde_element_op}
|
||||
{
|
||||
if(GridwiseGemm::CheckValidity(a_grid_desc_m_k_,
|
||||
b_grid_desc_n_k_,
|
||||
ck::Tuple<>{},
|
||||
e_grid_desc_m_n_,
|
||||
block_2_etile_map_))
|
||||
{
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock =
|
||||
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
e_grid_desc_m_n_);
|
||||
}
|
||||
}
|
||||
|
||||
void Print() const
|
||||
{
|
||||
std::cout << "A[M, K]: " << a_grid_desc_m_k_ << std::endl;
|
||||
std::cout << "B[N, K]: " << b_grid_desc_n_k_ << std::endl;
|
||||
std::cout << "C[M, N]: " << e_grid_desc_m_n_ << std::endl;
|
||||
}
|
||||
|
||||
// private:
|
||||
// pointers
|
||||
const ADataType* p_a_grid_;
|
||||
const BDataType* p_b_grid_;
|
||||
EDataType* p_e_grid_;
|
||||
|
||||
// batch count
|
||||
index_t BatchCount_;
|
||||
|
||||
// tensor descriptors for problem definiton
|
||||
AGridDesc_M_K a_grid_desc_m_k_;
|
||||
BGridDesc_N_K b_grid_desc_n_k_;
|
||||
EGridDesc_M_N e_grid_desc_m_n_;
|
||||
|
||||
// tensor descriptors for block/thread-wise copy
|
||||
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
|
||||
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
|
||||
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
EGridDesc_G0_G1_M_N e_grid_desc_g0_g1_m_n_;
|
||||
|
||||
// for calculating Batch offset
|
||||
ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_;
|
||||
|
||||
// block-to-e-tile map
|
||||
Block2ETileMap block_2_etile_map_;
|
||||
|
||||
// element-wise op
|
||||
AElementwiseOperation a_element_op_;
|
||||
BElementwiseOperation b_element_op_;
|
||||
CDEElementwiseOperation cde_element_op_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
using Argument = DeviceOp::Argument;
|
||||
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_,
|
||||
arg.b_grid_desc_n_k_,
|
||||
ck::Tuple<>{},
|
||||
arg.e_grid_desc_m_n_,
|
||||
arg.block_2_etile_map_))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! GridwiseBatchedGemmCPermute_km_kn_m0m1n0n1_xdlops_v2r3 has invalid "
|
||||
"setting");
|
||||
}
|
||||
|
||||
const index_t grid_size =
|
||||
arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_) * arg.BatchCount_;
|
||||
|
||||
const auto K =
|
||||
arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
|
||||
|
||||
auto launch_kernel = [&](auto has_main_k_block_loop_) {
|
||||
const auto kernel = kernel_batched_gemm_e_permute_xdl<
|
||||
GridwiseGemm,
|
||||
ADataType, // TODO: distiguish A/B datatype
|
||||
EDataType,
|
||||
remove_reference_t<DeviceOp::AGridDesc_AK0_M_AK1>,
|
||||
remove_reference_t<DeviceOp::BGridDesc_BK0_N_BK1>,
|
||||
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
ComputePtrOffsetOfStridedBatch,
|
||||
remove_reference_t<Block2ETileMap>,
|
||||
has_main_k_block_loop_>;
|
||||
|
||||
return launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_e_grid_,
|
||||
arg.BatchCount_,
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.cde_element_op_,
|
||||
arg.compute_ptr_offset_of_batch_,
|
||||
arg.block_2_etile_map_);
|
||||
};
|
||||
|
||||
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
|
||||
{
|
||||
return launch_kernel(integral_constant<bool, true>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return launch_kernel(integral_constant<bool, false>{});
|
||||
}
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
static constexpr bool IsValidCompilationParameter()
|
||||
{
|
||||
// TODO: properly implement this check
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
return GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_,
|
||||
arg.b_grid_desc_n_k_,
|
||||
ck::Tuple<>{},
|
||||
arg.e_grid_desc_m_n_,
|
||||
arg.block_2_etile_map_);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
|
||||
static auto MakeArgument(const ADataType* p_a,
|
||||
const BDataType* p_b,
|
||||
EDataType* p_e,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t stride_A,
|
||||
index_t stride_B,
|
||||
index_t batch_stride_A,
|
||||
index_t batch_stride_B,
|
||||
BatchedGemmEPermuteDesc batched_gemm_e_permute_desc,
|
||||
index_t BatchCount,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op)
|
||||
{
|
||||
return Argument{p_a,
|
||||
p_b,
|
||||
p_e,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
batch_stride_A,
|
||||
batch_stride_B,
|
||||
batched_gemm_e_permute_desc,
|
||||
BatchCount,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
void* p_e,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t stride_A,
|
||||
index_t stride_B,
|
||||
index_t batch_stride_A,
|
||||
index_t batch_stride_B,
|
||||
BatchedGemmEPermuteDesc batched_gemm_e_permute_desc,
|
||||
index_t BatchCount,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op) override
|
||||
{
|
||||
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
|
||||
static_cast<const BDataType*>(p_b),
|
||||
static_cast<EDataType*>(p_e),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
batch_stride_A,
|
||||
batch_stride_B,
|
||||
batched_gemm_e_permute_desc,
|
||||
BatchCount,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceBatchedGemmEPermuteXdl"
|
||||
<< "<"
|
||||
<< BlockSize << ", "
|
||||
<< MPerBlock << ", "
|
||||
<< NPerBlock << ", "
|
||||
<< KPerBlock
|
||||
<< ">";
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -333,10 +333,6 @@ struct DeviceBatchedGemmMultiD_Xdl : public DeviceBatchedGemmMultiD<ALayout,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
AGridDesc_M_K,
|
||||
BGridDesc_N_K,
|
||||
DsGridDesc_M_N,
|
||||
EGridDesc_M_N,
|
||||
NumGemmKPrefetchStage,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
@@ -370,12 +366,19 @@ struct DeviceBatchedGemmMultiD_Xdl : public DeviceBatchedGemmMultiD<ALayout,
|
||||
CDEBlockTransferScalarPerVector_NPerBlock,
|
||||
LoopSched>;
|
||||
|
||||
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(
|
||||
// desc for blockwise copy
|
||||
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(
|
||||
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))>;
|
||||
using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype(
|
||||
using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype(
|
||||
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>;
|
||||
using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
|
||||
GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DsGridDesc_M_N{}))>;
|
||||
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
|
||||
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))>;
|
||||
|
||||
using Block2ETileMap = typename GridwiseGemm::DefaultBlock2ETileMap;
|
||||
// block-to-e-tile map
|
||||
using Block2ETileMap =
|
||||
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))>;
|
||||
|
||||
// Argument
|
||||
struct Argument : public BaseArgument
|
||||
@@ -478,10 +481,9 @@ struct DeviceBatchedGemmMultiD_Xdl : public DeviceBatchedGemmMultiD<ALayout,
|
||||
// tensor descriptors for block/thread-wise copy
|
||||
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
|
||||
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
|
||||
typename GridwiseGemm::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock_;
|
||||
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock_;
|
||||
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_;
|
||||
|
||||
// for calculating batch offset
|
||||
ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_;
|
||||
@@ -520,21 +522,21 @@ struct DeviceBatchedGemmMultiD_Xdl : public DeviceBatchedGemmMultiD<ALayout,
|
||||
auto launch_kernel = [&](auto has_main_k_block_loop) {
|
||||
constexpr bool has_main_loop = has_main_k_block_loop.value;
|
||||
|
||||
const auto kernel = kernel_batched_gemm_xdl<
|
||||
GridwiseGemm,
|
||||
ADataType, // TODO: distiguish A/B datatype
|
||||
typename GridwiseGemm::DsGridPointer,
|
||||
EDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
DeviceOp::AGridDesc_AK0_M_AK1,
|
||||
DeviceOp::BGridDesc_BK0_N_BK1,
|
||||
typename GridwiseGemm::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
ComputePtrOffsetOfStridedBatch,
|
||||
Block2ETileMap,
|
||||
has_main_loop>;
|
||||
const auto kernel =
|
||||
kernel_batched_gemm_xdl<GridwiseGemm,
|
||||
ADataType, // TODO: distiguish A/B datatype
|
||||
typename GridwiseGemm::DsGridPointer,
|
||||
EDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
DeviceOp::AGridDesc_AK0_M_AK1,
|
||||
DeviceOp::BGridDesc_BK0_N_BK1,
|
||||
DeviceOp::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
ComputePtrOffsetOfStridedBatch,
|
||||
Block2ETileMap,
|
||||
has_main_loop>;
|
||||
|
||||
return launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
|
||||
@@ -320,10 +320,6 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
AGridDesc_M_K,
|
||||
BGridDesc_N_K,
|
||||
DsGridDesc_M_N,
|
||||
EGridDesc_M_N,
|
||||
NumGemmKPrefetchStage,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
@@ -357,12 +353,19 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
|
||||
CDEBlockTransferScalarPerVector_NPerBlock,
|
||||
LoopSched>;
|
||||
|
||||
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(
|
||||
// desc for blockwise copy
|
||||
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(
|
||||
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))>;
|
||||
using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype(
|
||||
using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype(
|
||||
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>;
|
||||
using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
|
||||
GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DsGridDesc_M_N{}))>;
|
||||
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
|
||||
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))>;
|
||||
|
||||
using Block2ETileMap = typename GridwiseGemm::DefaultBlock2ETileMap;
|
||||
// block-to-e-tile map
|
||||
using Block2ETileMap =
|
||||
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))>;
|
||||
|
||||
// Argument
|
||||
struct Argument : public BaseArgument
|
||||
@@ -475,10 +478,9 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
|
||||
// tensor descriptors for block/thread-wise copy
|
||||
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
|
||||
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
|
||||
typename GridwiseGemm::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock_;
|
||||
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock_;
|
||||
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_;
|
||||
|
||||
// block-to-e-tile map
|
||||
Block2ETileMap block_2_etile_map_;
|
||||
@@ -535,9 +537,9 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
|
||||
CDEElementwiseOperation,
|
||||
DeviceOp::AGridDesc_AK0_M_AK1,
|
||||
DeviceOp::BGridDesc_BK0_N_BK1,
|
||||
typename GridwiseGemm::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename GridwiseGemm::DefaultBlock2ETileMap,
|
||||
DeviceOp::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
DeviceOp::Block2ETileMap,
|
||||
has_main_loop>;
|
||||
|
||||
return launch_and_time_kernel(stream_config,
|
||||
|
||||
@@ -237,10 +237,6 @@ struct DeviceGemmBiasEPermute_Xdl : public DeviceGemmBiasCPermute<AElementwiseOp
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
AGridDesc_M_K,
|
||||
BGridDesc_N_K,
|
||||
DsGridDesc_M_N,
|
||||
EGridDesc_M_N,
|
||||
NumGemmKPrefetchStage,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
|
||||
@@ -234,6 +234,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
|
||||
Number<NumDTensor>{});
|
||||
}
|
||||
|
||||
// desc for problem definition
|
||||
using AGridDesc_M_K = decltype(MakeAGridDescriptor_M_K(1, 1, 1));
|
||||
using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K(1, 1, 1));
|
||||
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}, {}))>;
|
||||
@@ -250,10 +251,6 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
AGridDesc_M_K,
|
||||
BGridDesc_N_K,
|
||||
DsGridDesc_M_N,
|
||||
EGridDesc_M_N,
|
||||
NumGemmKPrefetchStage,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
@@ -287,10 +284,19 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
|
||||
CDEBlockTransferScalarPerVector_NPerBlock,
|
||||
LoopSched>;
|
||||
|
||||
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(
|
||||
// desc for blockwise copy
|
||||
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(
|
||||
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))>;
|
||||
using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype(
|
||||
using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype(
|
||||
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>;
|
||||
using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
|
||||
GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DsGridDesc_M_N{}))>;
|
||||
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
|
||||
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))>;
|
||||
|
||||
// block-to-e-tile map
|
||||
using Block2ETileMap =
|
||||
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))>;
|
||||
|
||||
// Argument
|
||||
struct Argument : public BaseArgument
|
||||
@@ -383,13 +389,12 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
|
||||
// tensor descriptors for block/thread-wise copy
|
||||
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
|
||||
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
|
||||
typename GridwiseGemm::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock_;
|
||||
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock_;
|
||||
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_;
|
||||
|
||||
// block-to-e-tile map
|
||||
typename GridwiseGemm::DefaultBlock2ETileMap block_2_etile_map_;
|
||||
Block2ETileMap block_2_etile_map_;
|
||||
|
||||
// element-wise op
|
||||
AElementwiseOperation a_element_op_;
|
||||
@@ -432,9 +437,9 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
|
||||
CDEElementwiseOperation,
|
||||
DeviceOp::AGridDesc_AK0_M_AK1,
|
||||
DeviceOp::BGridDesc_BK0_N_BK1,
|
||||
typename GridwiseGemm::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename GridwiseGemm::DefaultBlock2ETileMap,
|
||||
DeviceOp::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
DeviceOp::Block2ETileMap,
|
||||
has_main_loop>;
|
||||
|
||||
return launch_and_time_kernel(stream_config,
|
||||
|
||||
@@ -365,10 +365,6 @@ struct DeviceGroupedContractionMultipleD_Xdl_CShuffle
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
AGridDesc_M_K,
|
||||
BGridDesc_N_K,
|
||||
DsGridDesc_M_N,
|
||||
EGridDesc_M_N,
|
||||
NumGemmKPrefetchStage,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
@@ -402,17 +398,21 @@ struct DeviceGroupedContractionMultipleD_Xdl_CShuffle
|
||||
CDEBlockTransferScalarPerVector_NPerBlock,
|
||||
LoopSched>;
|
||||
|
||||
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(
|
||||
// desc for blockwise copy
|
||||
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(
|
||||
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))>;
|
||||
using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype(
|
||||
using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype(
|
||||
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>;
|
||||
using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
|
||||
GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DsGridDesc_M_N{}))>;
|
||||
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
|
||||
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))>;
|
||||
|
||||
struct GroupedContractionBlock2ETileMap
|
||||
{
|
||||
static_assert(
|
||||
std::is_same<decltype(GridwiseGemm::MakeDefaultBlock2ETileMap(EGridDesc_M_N{})),
|
||||
typename GridwiseGemm::DefaultBlock2ETileMap>::value,
|
||||
"Wrong! Should be the same type name");
|
||||
// block-to-e-tile map
|
||||
using Block2ETileMap =
|
||||
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))>;
|
||||
|
||||
GroupedContractionBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n,
|
||||
ck::index_t BlockStart)
|
||||
@@ -441,7 +441,7 @@ struct DeviceGroupedContractionMultipleD_Xdl_CShuffle
|
||||
return default_block_2_etile_map_.CheckValidity(e_grid_desc_m_n);
|
||||
}
|
||||
|
||||
typename GridwiseGemm::DefaultBlock2ETileMap default_block_2_etile_map_;
|
||||
Block2ETileMap default_block_2_etile_map_;
|
||||
ck::index_t block_start_;
|
||||
};
|
||||
|
||||
@@ -456,10 +456,9 @@ struct DeviceGroupedContractionMultipleD_Xdl_CShuffle
|
||||
// tensor descriptors for block/thread-wise copy
|
||||
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
|
||||
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
|
||||
typename GridwiseGemm::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock_;
|
||||
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock_;
|
||||
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_;
|
||||
|
||||
// lock-to-e-tile map
|
||||
GroupedContractionBlock2ETileMap block_2_etile_map_;
|
||||
|
||||
@@ -0,0 +1,67 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
// Conv backward data multiple D:
|
||||
// input : output image A[G, N, K, Ho, Wo]
|
||||
// input : weight B[G, K, C, Y, X],
|
||||
// input : D0[G, N, K, Ho, Wo], D1[G, N, K, Ho, Wo], ...
|
||||
// output : input image E[G, N, C, Hi, Wi],
|
||||
// C = a_op(A) * b_op(B)
|
||||
// E = cde_op(C, D0, D1, ...)
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename EDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation>
|
||||
struct DeviceGroupedConvBwdDataMultipleD : public BaseOperator
|
||||
{
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
|
||||
static_assert(NumDTensor == DsLayout::Size(), "wrong! Inconsistent NumDTensor");
|
||||
|
||||
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(
|
||||
const void* p_a, // output image
|
||||
const void* p_b, // weight
|
||||
const std::array<const void*, NumDTensor>& p_ds, // bias
|
||||
void* p_e, // input image
|
||||
const std::array<index_t, NDimSpatial + 3>& a_g_n_k_wos_lengths, // output image
|
||||
const std::array<index_t, NDimSpatial + 3>& a_g_n_k_wos_strides, // output image
|
||||
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, // weight
|
||||
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides, // weight
|
||||
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>&
|
||||
ds_g_n_k_wos_lengths, // bias
|
||||
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>&
|
||||
ds_g_n_k_wos_strides, // bias
|
||||
const std::array<index_t, NDimSpatial + 3>& e_g_n_c_wis_lengths, // input image
|
||||
const std::array<index_t, NDimSpatial + 3>& e_g_n_c_wis_strides, // input image
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<index_t, NDimSpatial>& input_right_pads,
|
||||
const AElementwiseOperation& a_element_op,
|
||||
const BElementwiseOperation& b_element_op,
|
||||
const CDEElementwiseOperation& cde_element_op) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -34,11 +34,13 @@ struct DeviceGroupedConvFwdMultipleD : public BaseOperator
|
||||
{
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
|
||||
static_assert(NumDTensor == DsLayout::Size(), "wrong! Inconsistent NumDTensor");
|
||||
|
||||
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(
|
||||
const void* p_a,
|
||||
const void* p_b,
|
||||
const void* p_a, // input image
|
||||
const void* p_b, // weight
|
||||
const std::array<const void*, NumDTensor>& p_ds,
|
||||
void* p_e,
|
||||
void* p_e, // output image
|
||||
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
|
||||
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
|
||||
|
||||
@@ -117,7 +117,7 @@ __global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_batch_gemm_multiple_d_xdl_cshuffle(
|
||||
kernel_grouped_conv_fwd_multiple_d_xdl_cshuffle(
|
||||
const ABDataType* __restrict__ p_a_grid,
|
||||
const ABDataType* __restrict__ p_b_grid,
|
||||
DsPointer p_ds_grid,
|
||||
@@ -136,8 +136,7 @@ __global__ void
|
||||
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
|
||||
|
||||
#if 1
|
||||
// offset base pointer for each work-group
|
||||
const index_t num_blocks_per_batch =
|
||||
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
|
||||
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
|
||||
@@ -174,24 +173,6 @@ __global__ void
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
block_2_ctile_map);
|
||||
#else
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
|
||||
p_b_grid,
|
||||
p_ds_grid,
|
||||
p_e_grid,
|
||||
p_shared,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op,
|
||||
a_grid_desc_k0_m_k1,
|
||||
b_grid_desc_k0_n_k1,
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
block_2_ctile_map);
|
||||
#endif
|
||||
|
||||
#else
|
||||
ignore = p_a_grid;
|
||||
ignore = p_b_grid;
|
||||
@@ -378,6 +359,7 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
|
||||
Number<NumDTensor>{});
|
||||
}
|
||||
|
||||
// desc for problem definition
|
||||
using AGridDesc_M_K = remove_cvref_t<decltype(
|
||||
MakeAGridDescriptor_M_K<ALayout>({}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>;
|
||||
using BGridDesc_N_K = remove_cvref_t<decltype(MakeBGridDescriptor_N_K<BLayout>({}, {}))>;
|
||||
@@ -395,10 +377,6 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
AGridDesc_M_K,
|
||||
BGridDesc_N_K,
|
||||
DsGridDesc_M_N,
|
||||
EGridDesc_M_N,
|
||||
NumGemmKPrefetchStage,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
@@ -432,12 +410,19 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
|
||||
CDEBlockTransferScalarPerVector_NPerBlock,
|
||||
LoopSched>;
|
||||
|
||||
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(
|
||||
// desc for blockwise copy
|
||||
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(
|
||||
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))>;
|
||||
using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype(
|
||||
using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype(
|
||||
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>;
|
||||
using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
|
||||
GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DsGridDesc_M_N{}))>;
|
||||
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
|
||||
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))>;
|
||||
|
||||
using Block2ETileMap = typename GridwiseGemm::DefaultBlock2ETileMap;
|
||||
// block-to-e-tile map
|
||||
using Block2ETileMap =
|
||||
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))>;
|
||||
|
||||
// Argument
|
||||
struct Argument : public BaseArgument
|
||||
@@ -467,6 +452,7 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
|
||||
p_b_grid_{static_cast<const BDataType*>(p_b)},
|
||||
p_ds_grid_{},
|
||||
p_e_grid_{static_cast<EDataType*>(p_e)},
|
||||
num_group_{a_g_n_c_wis_lengths[0]},
|
||||
a_grid_desc_m_k_{DeviceOp::MakeAGridDescriptor_M_K<ALayout>(a_g_n_c_wis_lengths,
|
||||
a_g_n_c_wis_strides,
|
||||
b_g_k_c_xs_lengths,
|
||||
@@ -561,6 +547,7 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
|
||||
EDataType* p_e_grid_;
|
||||
|
||||
// tensor descriptors for problem definiton
|
||||
index_t num_group_;
|
||||
AGridDesc_M_K a_grid_desc_m_k_;
|
||||
BGridDesc_N_K b_grid_desc_n_k_;
|
||||
DsGridDesc_M_N ds_grid_desc_m_n_;
|
||||
@@ -569,14 +556,14 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
|
||||
// tensor descriptors for block/thread-wise copy
|
||||
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
|
||||
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
|
||||
typename GridwiseGemm::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock_;
|
||||
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock_;
|
||||
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_;
|
||||
|
||||
// block-to-e-tile map
|
||||
Block2ETileMap block_2_etile_map_;
|
||||
|
||||
// for computing batch offset
|
||||
ComputePtrOffsetOfStridedBatch<NumDTensor> compute_ptr_offset_of_batch_;
|
||||
|
||||
// element-wise op
|
||||
@@ -622,8 +609,7 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
|
||||
}
|
||||
|
||||
const index_t grid_size =
|
||||
arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_) *
|
||||
arg.a_g_n_c_wis_lengths_[0]; // Group count
|
||||
arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_) * arg.num_group_;
|
||||
|
||||
const auto K =
|
||||
arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
|
||||
@@ -631,7 +617,7 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
|
||||
auto launch_kernel = [&](auto has_main_k_block_loop) {
|
||||
constexpr bool has_main_loop = has_main_k_block_loop.value;
|
||||
|
||||
const auto kernel = kernel_batch_gemm_multiple_d_xdl_cshuffle<
|
||||
const auto kernel = kernel_grouped_conv_fwd_multiple_d_xdl_cshuffle<
|
||||
GridwiseGemm,
|
||||
ADataType, // TODO: distiguish A/B datatype
|
||||
typename GridwiseGemm::DsGridPointer,
|
||||
@@ -641,8 +627,8 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
|
||||
CDEElementwiseOperation,
|
||||
DeviceOp::AGridDesc_AK0_M_AK1,
|
||||
DeviceOp::BGridDesc_BK0_N_BK1,
|
||||
typename GridwiseGemm::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
DeviceOp::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
Block2ETileMap,
|
||||
ComputePtrOffsetOfStridedBatch<NumDTensor>,
|
||||
has_main_loop>;
|
||||
@@ -798,7 +784,8 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
|
||||
is_same_v<DLayout, ctc::G_NDHW_K> || is_same_v<DLayout, ctc::GNWK> ||
|
||||
is_same_v<DLayout, ctc::GNHWK> || is_same_v<DLayout, ctc::GNDHWK> ||
|
||||
is_same_v<DLayout, ctc::NWGK> || is_same_v<DLayout, ctc::NHWGK> ||
|
||||
is_same_v<DLayout, ctc::NDHWGK>)
|
||||
is_same_v<DLayout, ctc::NDHWGK> || is_same_v<DLayout, ctc::GK> ||
|
||||
is_same_v<DLayout, ctc::G_K>)
|
||||
{
|
||||
const index_t K = arg.ds_g_n_k_wos_lengths_[i][2];
|
||||
|
||||
|
||||
@@ -238,10 +238,6 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
AGridDesc_M_K,
|
||||
BGridDesc_N_K,
|
||||
DsGridDesc_M_N,
|
||||
EGridDesc_M_N,
|
||||
NumPrefetch, // NumGemmKPrefetchStage
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
@@ -275,19 +271,19 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
|
||||
CDEBlockTransferScalarPerVector_NPerBlock,
|
||||
LoopSched>;
|
||||
|
||||
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(
|
||||
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(
|
||||
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))>;
|
||||
using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype(
|
||||
using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype(
|
||||
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>;
|
||||
using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
|
||||
GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DsGridDesc_M_N{}))>;
|
||||
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
|
||||
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))>;
|
||||
|
||||
struct GroupedGemmBlock2ETileMap
|
||||
{
|
||||
using UnderlyingBlock2ETileMap = typename GridwiseGemm::DefaultBlock2ETileMap;
|
||||
|
||||
static_assert(
|
||||
std::is_same<decltype(GridwiseGemm::MakeDefaultBlock2ETileMap(EGridDesc_M_N{})),
|
||||
typename GridwiseGemm::DefaultBlock2ETileMap>::value,
|
||||
"Wrong! Should be the same type name");
|
||||
using Block2ETileMap =
|
||||
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))>;
|
||||
|
||||
GroupedGemmBlock2ETileMap()
|
||||
{
|
||||
@@ -321,7 +317,7 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
|
||||
return block_2_etile_map_.CheckValidity(e_grid_desc_m_n);
|
||||
}
|
||||
|
||||
typename GridwiseGemm::DefaultBlock2ETileMap block_2_etile_map_;
|
||||
Block2ETileMap block_2_etile_map_;
|
||||
ck::index_t BlockStart_;
|
||||
};
|
||||
|
||||
@@ -342,10 +338,9 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
|
||||
// tensor descriptors for block/thread-wise copy
|
||||
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
|
||||
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
|
||||
typename GridwiseGemm::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock_;
|
||||
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock_;
|
||||
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_;
|
||||
|
||||
// block-to-e-tile map
|
||||
GroupedGemmBlock2ETileMap block_2_etile_map_;
|
||||
@@ -440,7 +435,7 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
|
||||
block_2_etile_map))
|
||||
{
|
||||
// tensor descriptors for block/thread-wise copy
|
||||
typename GridwiseGemm::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
|
||||
static_for<0, NumDTensor, 1>{}([&](auto j) {
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -92,6 +92,12 @@ struct GNDHWC : public BaseTensorLayout
|
||||
static constexpr const char* name = "GNDHWC";
|
||||
};
|
||||
|
||||
// for input bias
|
||||
struct GC : public BaseTensorLayout
|
||||
{
|
||||
static constexpr const char* name = "GC";
|
||||
};
|
||||
|
||||
// input tensor
|
||||
// packed NWGC/NHWGC/NDHWGC
|
||||
struct NWGC : public BaseTensorLayout
|
||||
@@ -126,6 +132,12 @@ struct G_NDHW_C : public BaseTensorLayout
|
||||
static constexpr const char* name = "G_NDHW_C";
|
||||
};
|
||||
|
||||
// for input bias
|
||||
struct G_C : public BaseTensorLayout
|
||||
{
|
||||
static constexpr const char* name = "G_C";
|
||||
};
|
||||
|
||||
// weight tensor
|
||||
// packed KCX/KCYX/KCZYX
|
||||
struct KCX : public BaseTensorLayout
|
||||
@@ -296,6 +308,12 @@ struct GNDHWK : public BaseTensorLayout
|
||||
static constexpr const char* name = "GNDHWK";
|
||||
};
|
||||
|
||||
// for output bias
|
||||
struct GK : public BaseTensorLayout
|
||||
{
|
||||
static constexpr const char* name = "GK";
|
||||
};
|
||||
|
||||
// output tensor
|
||||
// packed NWGK/NHWGK/NDHWGK
|
||||
struct NWGK : public BaseTensorLayout
|
||||
@@ -330,6 +348,12 @@ struct G_NDHW_K : public BaseTensorLayout
|
||||
static constexpr const char* name = "G_NDHW_K";
|
||||
};
|
||||
|
||||
// for output bias
|
||||
struct G_K : public BaseTensorLayout
|
||||
{
|
||||
static constexpr const char* name = "G_K";
|
||||
};
|
||||
|
||||
// K-reduced output tensor (packed)
|
||||
struct GNW : public BaseTensorLayout
|
||||
{
|
||||
|
||||
@@ -35,10 +35,6 @@ template <typename ABDataType, // FIXME: don't assume A/B have same datatype
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation,
|
||||
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
|
||||
typename AGridDesc_M_K,
|
||||
typename BGridDesc_N_K,
|
||||
typename DsGridDesc_M_N,
|
||||
typename EGridDesc_M_N,
|
||||
index_t NumGemmKPrefetchStage,
|
||||
index_t BlockSize,
|
||||
index_t MPerBlock,
|
||||
@@ -166,6 +162,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
|
||||
}
|
||||
|
||||
// A desc for source in blockwise copy
|
||||
template <typename AGridDesc_M_K>
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeDefaultAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K& a_grid_desc_m_k)
|
||||
{
|
||||
@@ -182,6 +179,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
|
||||
}
|
||||
|
||||
// B desc for source in blockwise copy
|
||||
template <typename BGridDesc_N_K>
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeDefaultBGridDescriptor_BK0_N_BK1(const BGridDesc_N_K& b_grid_desc_n_k)
|
||||
{
|
||||
@@ -198,9 +196,9 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
|
||||
}
|
||||
|
||||
// E desc for destination in blockwise copy
|
||||
template <typename EGridDescriptor_M_N>
|
||||
__host__ __device__ static constexpr auto MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
const EGridDescriptor_M_N& e_grid_desc_m_n)
|
||||
template <typename EGridDesc_M_N>
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const EGridDesc_M_N& e_grid_desc_m_n)
|
||||
{
|
||||
const auto M = e_grid_desc_m_n.GetLength(I0);
|
||||
const auto N = e_grid_desc_m_n.GetLength(I1);
|
||||
@@ -219,10 +217,9 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
|
||||
}
|
||||
|
||||
// Ds desc for source in blockwise copy
|
||||
template <typename DsGridDescriptor_M_N>
|
||||
template <typename DsGridDesc_M_N>
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
const DsGridDescriptor_M_N& ds_grid_desc_m_n)
|
||||
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc_M_N& ds_grid_desc_m_n)
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
@@ -232,6 +229,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
|
||||
}
|
||||
|
||||
// return block_id to E matrix tile idx (m0, n0) mapping
|
||||
template <typename EGridDesc_M_N>
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeDefaultBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n)
|
||||
{
|
||||
@@ -240,7 +238,11 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
|
||||
}
|
||||
|
||||
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
|
||||
template <typename Block2ETileMap>
|
||||
template <typename AGridDesc_M_K,
|
||||
typename BGridDesc_N_K,
|
||||
typename DsGridDesc_M_N,
|
||||
typename EGridDesc_M_N,
|
||||
typename Block2ETileMap>
|
||||
__host__ __device__ static constexpr bool CheckValidity(const AGridDesc_M_K& a_grid_desc_m_k,
|
||||
const BGridDesc_N_K& b_grid_desc_n_k,
|
||||
const DsGridDesc_M_N& ds_grid_desc_m_n,
|
||||
@@ -314,23 +316,13 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
|
||||
return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
|
||||
}
|
||||
|
||||
using DefaultAGridDesc_AK0_M_AK1 =
|
||||
remove_cvref_t<decltype(MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))>;
|
||||
using DefaultBGridDesc_BK0_N_BK1 =
|
||||
remove_cvref_t<decltype(MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>;
|
||||
using EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
|
||||
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))>;
|
||||
using DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
|
||||
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DsGridDesc_M_N{}))>;
|
||||
|
||||
using DefaultBlock2ETileMap =
|
||||
remove_cvref_t<decltype(MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))>;
|
||||
|
||||
using DsGridPointer = decltype(MakeDsGridPointer());
|
||||
|
||||
template <bool HasMainKBlockLoop,
|
||||
typename AGridDesc_AK0_M_AK1,
|
||||
typename BGridDesc_BK0_N_BK1,
|
||||
typename DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename Block2ETileMap>
|
||||
__device__ static void Run(const ABDataType* __restrict__ p_a_grid,
|
||||
const ABDataType* __restrict__ p_b_grid,
|
||||
@@ -342,9 +334,9 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
|
||||
const CDEElementwiseOperation& cde_element_op,
|
||||
const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
|
||||
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
|
||||
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
|
||||
const DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
|
||||
const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const Block2ETileMap& block_2_etile_map)
|
||||
{
|
||||
|
||||
@@ -0,0 +1,583 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
|
||||
template <
|
||||
index_t NDimSpatial,
|
||||
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization ConvBwdDataSpecialization,
|
||||
index_t AK1,
|
||||
index_t BK1,
|
||||
index_t GemmMPerBlock,
|
||||
index_t GemmNPerBlock,
|
||||
bool DoPadGemmM,
|
||||
bool DoPadGemmN>
|
||||
struct TransformConvBwdDataToGemm_v1
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
|
||||
template <typename ALayout,
|
||||
typename std::enable_if<NDimSpatial == 2 &&
|
||||
is_same_v<ALayout, tensor_layout::convolution::GNHWK>,
|
||||
bool>::type = false>
|
||||
static auto MakeADescriptor_AK0_M_AK1(
|
||||
const std::array<index_t, NDimSpatial + 3>& out_g_n_k_wos_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& /* out_g_n_k_wos_strides */,
|
||||
const std::array<index_t, NDimSpatial + 3>& wei_g_k_c_xs_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& /* wei_g_k_c_xs_strides */,
|
||||
const std::array<index_t, NDimSpatial + 3>& in_g_n_c_wis_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& /* in_g_n_c_wis_strides */,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<index_t, NDimSpatial>& /* input_right_pads */,
|
||||
const std::array<index_t, NDimSpatial>& tildes)
|
||||
{
|
||||
index_t i_ytilde = tildes[0];
|
||||
index_t i_xtilde = tildes[1];
|
||||
|
||||
const index_t N = in_g_n_c_wis_lengths[1];
|
||||
const index_t K = wei_g_k_c_xs_lengths[1];
|
||||
|
||||
const index_t Hi = in_g_n_c_wis_lengths[3];
|
||||
const index_t Wi = in_g_n_c_wis_lengths[4];
|
||||
|
||||
const index_t Ho = out_g_n_k_wos_lengths[3];
|
||||
const index_t Wo = out_g_n_k_wos_lengths[4];
|
||||
|
||||
const index_t Y = wei_g_k_c_xs_lengths[3];
|
||||
const index_t X = wei_g_k_c_xs_lengths[4];
|
||||
|
||||
const index_t InLeftPadH = input_left_pads[0];
|
||||
const index_t InLeftPadW = input_left_pads[1];
|
||||
|
||||
const index_t ConvStrideH = conv_filter_strides[0];
|
||||
const index_t ConvStrideW = conv_filter_strides[1];
|
||||
|
||||
const index_t ConvDilationH = conv_filter_dilations[0];
|
||||
const index_t ConvDilationW = conv_filter_dilations[1];
|
||||
|
||||
const index_t AK0 = K / AK1;
|
||||
|
||||
// assume packed
|
||||
const auto out_n_ho_wo_k_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N, Ho, Wo, K));
|
||||
|
||||
if constexpr(ConvBwdDataSpecialization ==
|
||||
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
|
||||
Filter1x1Stride1Pad0)
|
||||
{
|
||||
// A: output tensor
|
||||
const auto out_gemmak0_gemmmraw_gemmak1_grid_desc = transform_tensor_descriptor(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)),
|
||||
make_tuple(make_pass_through_transform(N * Ho * Wo),
|
||||
make_unmerge_transform(make_tuple(AK0, AK1))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0, 2>{}));
|
||||
|
||||
const auto out_gemmak0_gemmm_gemmak1_grid_desc =
|
||||
ck::tensor_operation::device::PadTensorDescriptor(
|
||||
out_gemmak0_gemmmraw_gemmak1_grid_desc,
|
||||
make_tuple(AK0, GemmMPerBlock, AK1),
|
||||
Sequence<false, DoPadGemmM, false>{});
|
||||
|
||||
return out_gemmak0_gemmm_gemmak1_grid_desc;
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
|
||||
const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
|
||||
|
||||
const auto YTilde = ConvStrideH / GcdStrideDilationH;
|
||||
const auto XTilde = ConvStrideW / GcdStrideDilationW;
|
||||
|
||||
const auto YDot = math::integer_divide_ceil(Y, YTilde);
|
||||
const auto XDot = math::integer_divide_ceil(X, XTilde);
|
||||
|
||||
const auto HTilde =
|
||||
Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH);
|
||||
const auto WTilde =
|
||||
Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW);
|
||||
|
||||
// only work on HTilde and WTilde that contribute to non-padding area of input tensor
|
||||
const auto IHTildeSliceBegin = math::integer_divide_floor(
|
||||
math::max(I0, InLeftPadH - ConvDilationH * (YTilde - I1)), ConvStrideH);
|
||||
const auto IWTildeSliceBegin = math::integer_divide_floor(
|
||||
math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW);
|
||||
|
||||
const auto IHTildeSliceEnd = math::min(
|
||||
HTilde, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1);
|
||||
const auto IWTildeSliceEnd = math::min(
|
||||
WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1);
|
||||
|
||||
const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin;
|
||||
const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin;
|
||||
|
||||
// GemmK is different for each GEMM
|
||||
const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde);
|
||||
const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde);
|
||||
|
||||
// A: output tensor
|
||||
const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor(
|
||||
out_n_ho_wo_k_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pad_transform(Ho, I0, I0),
|
||||
make_pad_transform(Wo, I0, I0),
|
||||
make_pass_through_transform(K)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto out_n_ydot_htilde_xdot_wtilde_k_grid_desc = transform_tensor_descriptor(
|
||||
out_n_hop_wop_k_grid_desc,
|
||||
make_tuple(
|
||||
make_pass_through_transform(N),
|
||||
make_embed_transform(make_tuple(YDot, HTilde),
|
||||
make_tuple(-ConvDilationH / GcdStrideDilationH, I1)),
|
||||
make_embed_transform(make_tuple(XDot, WTilde),
|
||||
make_tuple(-ConvDilationW / GcdStrideDilationW, I1)),
|
||||
make_pass_through_transform(K)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
|
||||
|
||||
const auto out_n_ydotslice_htildeslice_xdotslice_wtildeslice_ak0_ak1_grid_desc =
|
||||
transform_tensor_descriptor(
|
||||
out_n_ydot_htilde_xdot_wtilde_k_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_slice_transform(YDot, I0, YDotSlice),
|
||||
make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
|
||||
make_slice_transform(XDot, I0, XDotSlice),
|
||||
make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
|
||||
make_unmerge_transform(make_tuple(AK0, AK1))),
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<1>{},
|
||||
Sequence<2>{},
|
||||
Sequence<3>{},
|
||||
Sequence<4>{},
|
||||
Sequence<5>{}),
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<1>{},
|
||||
Sequence<2>{},
|
||||
Sequence<3>{},
|
||||
Sequence<4>{},
|
||||
Sequence<5, 6>{}));
|
||||
|
||||
const auto out_gemmak0_gemmmraw_gemmak1_grid_desc = transform_tensor_descriptor(
|
||||
out_n_ydotslice_htildeslice_xdotslice_wtildeslice_ak0_ak1_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, AK0)),
|
||||
make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)),
|
||||
make_pass_through_transform(AK1)),
|
||||
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}, Sequence<6>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
|
||||
const auto out_gemmak0_gemmm_gemmak1_grid_desc =
|
||||
ck::tensor_operation::device::PadTensorDescriptor(
|
||||
out_gemmak0_gemmmraw_gemmak1_grid_desc,
|
||||
make_tuple(AK0, GemmMPerBlock, AK1),
|
||||
Sequence<false, DoPadGemmM, false>{});
|
||||
|
||||
return out_gemmak0_gemmm_gemmak1_grid_desc;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename BLayout,
|
||||
typename std::enable_if<NDimSpatial == 2 &&
|
||||
is_same_v<BLayout, tensor_layout::convolution::GKYXC>,
|
||||
bool>::type = false>
|
||||
static auto MakeBDescriptor_BK0_N_BK1(
|
||||
const std::array<index_t, NDimSpatial + 3>& out_g_n_k_wos_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& /* out_g_n_k_wos_strides */,
|
||||
const std::array<index_t, NDimSpatial + 3>& wei_g_k_c_xs_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& /* wei_g_k_c_xs_strides */,
|
||||
const std::array<index_t, NDimSpatial + 3>& in_g_n_c_wis_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& /* in_g_n_c_wis_strides */,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<index_t, NDimSpatial>& /* input_left_pads */,
|
||||
const std::array<index_t, NDimSpatial>& /* input_right_pads */,
|
||||
const std::array<index_t, NDimSpatial>& tildes)
|
||||
{
|
||||
index_t i_ytilde = tildes[0];
|
||||
index_t i_xtilde = tildes[1];
|
||||
|
||||
const index_t N = in_g_n_c_wis_lengths[1];
|
||||
const index_t K = wei_g_k_c_xs_lengths[1];
|
||||
const index_t C = wei_g_k_c_xs_lengths[2];
|
||||
|
||||
const index_t Ho = out_g_n_k_wos_lengths[3];
|
||||
const index_t Wo = out_g_n_k_wos_lengths[4];
|
||||
|
||||
const index_t Y = wei_g_k_c_xs_lengths[3];
|
||||
const index_t X = wei_g_k_c_xs_lengths[4];
|
||||
|
||||
const index_t ConvStrideH = conv_filter_strides[0];
|
||||
const index_t ConvStrideW = conv_filter_strides[1];
|
||||
|
||||
const index_t ConvDilationH = conv_filter_dilations[0];
|
||||
const index_t ConvDilationW = conv_filter_dilations[1];
|
||||
|
||||
const index_t BK0 = K / BK1;
|
||||
|
||||
// assume packed
|
||||
const auto wei_k_y_x_c_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(K, Y, X, C));
|
||||
|
||||
if constexpr(ConvBwdDataSpecialization ==
|
||||
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
|
||||
Filter1x1Stride1Pad0)
|
||||
{
|
||||
// B: weight tensor
|
||||
const auto wei_gemmbk0_gemmnraw_gemmbk1_grid_desc =
|
||||
transform_tensor_descriptor(make_naive_tensor_descriptor_packed(make_tuple(K, C)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
make_naive_tensor_descriptor(make_tuple(N * Ho * Wo, C), make_tuple(I0, I1));
|
||||
|
||||
const auto wei_gemmbk0_gemmn_gemmbk1_grid_desc =
|
||||
ck::tensor_operation::device::PadTensorDescriptor(
|
||||
wei_gemmbk0_gemmnraw_gemmbk1_grid_desc,
|
||||
make_tuple(BK0, GemmNPerBlock, BK1),
|
||||
Sequence<false, DoPadGemmN, false>{});
|
||||
|
||||
return wei_gemmbk0_gemmn_gemmbk1_grid_desc;
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
|
||||
const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
|
||||
|
||||
const auto YTilde = ConvStrideH / GcdStrideDilationH;
|
||||
const auto XTilde = ConvStrideW / GcdStrideDilationW;
|
||||
|
||||
const auto YDot = math::integer_divide_ceil(Y, YTilde);
|
||||
const auto XDot = math::integer_divide_ceil(X, XTilde);
|
||||
|
||||
// GemmK is different for each GEMM
|
||||
const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde);
|
||||
const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde);
|
||||
|
||||
// B weight tensor
|
||||
const auto wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc = transform_tensor_descriptor(
|
||||
wei_k_y_x_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(K),
|
||||
make_embed_transform(make_tuple(YDot, YTilde),
|
||||
make_tuple(ConvStrideH / GcdStrideDilationH, I1)),
|
||||
make_embed_transform(make_tuple(XDot, XTilde),
|
||||
make_tuple(ConvStrideW / GcdStrideDilationW, I1)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
|
||||
|
||||
const auto wei_bk0_bk1_ydotslice_xdotslice_c_grid_desc =
|
||||
transform_tensor_descriptor(wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
|
||||
make_slice_transform(YDot, I0, YDotSlice),
|
||||
make_slice_transform(XDot, I0, XDotSlice),
|
||||
make_freeze_transform(i_ytilde),
|
||||
make_freeze_transform(i_xtilde),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<1>{},
|
||||
Sequence<3>{},
|
||||
Sequence<2>{},
|
||||
Sequence<4>{},
|
||||
Sequence<5>{}),
|
||||
make_tuple(Sequence<0, 1>{},
|
||||
Sequence<2>{},
|
||||
Sequence<3>{},
|
||||
Sequence<>{},
|
||||
Sequence<>{},
|
||||
Sequence<4>{}));
|
||||
|
||||
const auto wei_gemmbk0_gemmnraw_gemmbk1_grid_desc = transform_tensor_descriptor(
|
||||
wei_bk0_bk1_ydotslice_xdotslice_c_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, BK0)),
|
||||
make_pass_through_transform(C),
|
||||
make_pass_through_transform(BK1)),
|
||||
make_tuple(Sequence<2, 3, 0>{}, Sequence<4>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
|
||||
const auto wei_gemmbk0_gemmn_gemmbk1_grid_desc =
|
||||
ck::tensor_operation::device::PadTensorDescriptor(
|
||||
wei_gemmbk0_gemmnraw_gemmbk1_grid_desc,
|
||||
make_tuple(
|
||||
wei_gemmbk0_gemmnraw_gemmbk1_grid_desc.GetLength(I0), GemmNPerBlock, BK1),
|
||||
Sequence<false, DoPadGemmN, false>{});
|
||||
|
||||
return wei_gemmbk0_gemmn_gemmbk1_grid_desc;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename CLayout,
|
||||
typename std::enable_if<NDimSpatial == 2 &&
|
||||
(is_same_v<CLayout, tensor_layout::convolution::GNHWC> ||
|
||||
is_same_v<CLayout, tensor_layout::convolution::NHWGC> ||
|
||||
is_same_v<CLayout, tensor_layout::convolution::G_NHW_C>),
|
||||
bool>::type = false>
|
||||
static auto
|
||||
MakeCDescriptor_M_N(const std::array<index_t, NDimSpatial + 3>& out_g_n_k_wos_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& /* out_g_n_k_wos_strides */,
|
||||
const std::array<index_t, NDimSpatial + 3>& wei_g_k_c_xs_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& /* wei_g_k_c_xs_strides */,
|
||||
const std::array<index_t, NDimSpatial + 3>& in_g_n_c_wis_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& in_g_n_c_wis_strides,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<index_t, NDimSpatial>& input_right_pads,
|
||||
const std::array<index_t, NDimSpatial>& tildes)
|
||||
{
|
||||
index_t i_ytilde = tildes[0];
|
||||
index_t i_xtilde = tildes[1];
|
||||
|
||||
const index_t N = in_g_n_c_wis_lengths[1];
|
||||
const index_t C = wei_g_k_c_xs_lengths[2];
|
||||
|
||||
const index_t Hi = in_g_n_c_wis_lengths[3];
|
||||
const index_t Wi = in_g_n_c_wis_lengths[4];
|
||||
|
||||
const index_t Ho = out_g_n_k_wos_lengths[3];
|
||||
const index_t Wo = out_g_n_k_wos_lengths[4];
|
||||
|
||||
const index_t Y = wei_g_k_c_xs_lengths[3];
|
||||
const index_t X = wei_g_k_c_xs_lengths[4];
|
||||
|
||||
const index_t InLeftPadH = input_left_pads[0];
|
||||
const index_t InLeftPadW = input_left_pads[1];
|
||||
|
||||
const index_t InRightPadH = input_right_pads[0];
|
||||
const index_t InRightPadW = input_right_pads[1];
|
||||
|
||||
const index_t ConvStrideH = conv_filter_strides[0];
|
||||
const index_t ConvStrideW = conv_filter_strides[1];
|
||||
|
||||
const index_t ConvDilationH = conv_filter_dilations[0];
|
||||
const index_t ConvDilationW = conv_filter_dilations[1];
|
||||
|
||||
// assume strided
|
||||
const auto in_n_hi_wi_c_grid_desc =
|
||||
make_naive_tensor_descriptor(make_tuple(N, Hi, Wi, C),
|
||||
make_tuple(in_g_n_c_wis_strides[1],
|
||||
in_g_n_c_wis_strides[3],
|
||||
in_g_n_c_wis_strides[4],
|
||||
in_g_n_c_wis_strides[2]));
|
||||
|
||||
if constexpr(ConvBwdDataSpecialization ==
|
||||
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
|
||||
Filter1x1Stride1Pad0)
|
||||
{
|
||||
// C: input tensor
|
||||
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hi_wi_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_embed_transform(make_tuple(I1, Ho), make_tuple(I1, ConvStrideH)),
|
||||
make_embed_transform(make_tuple(I1, Wo), make_tuple(I1, ConvStrideW)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
|
||||
|
||||
const auto in_gemmmraw_gemmnraw_grid_desc = transform_tensor_descriptor(
|
||||
in_n_y_ho_x_wo_c_grid_desc,
|
||||
make_tuple(make_freeze_transform(I0),
|
||||
make_freeze_transform(I0),
|
||||
make_merge_transform(make_tuple(N, Ho, Wo)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<1>{}, Sequence<3>{}, Sequence<0, 2, 4>{}, Sequence<5>{}),
|
||||
make_tuple(Sequence<>{}, Sequence<>{}, Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto in_gemmm_gemmn_grid_desc = ck::tensor_operation::device::PadTensorDescriptor(
|
||||
in_gemmmraw_gemmnraw_grid_desc,
|
||||
make_tuple(GemmMPerBlock, GemmNPerBlock),
|
||||
Sequence<DoPadGemmM, DoPadGemmN>{});
|
||||
|
||||
return in_gemmm_gemmn_grid_desc;
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
|
||||
const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
|
||||
|
||||
const auto YTilde = ConvStrideH / GcdStrideDilationH;
|
||||
const auto XTilde = ConvStrideW / GcdStrideDilationW;
|
||||
|
||||
const auto HTilde =
|
||||
Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH);
|
||||
const auto WTilde =
|
||||
Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW);
|
||||
|
||||
// only work on HTilde and WTilde that contribute to non-padding area of input tensor
|
||||
const auto IHTildeSliceBegin = math::integer_divide_floor(
|
||||
math::max(I0, InLeftPadH - ConvDilationH * (YTilde - I1)), ConvStrideH);
|
||||
const auto IWTildeSliceBegin = math::integer_divide_floor(
|
||||
math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW);
|
||||
|
||||
const auto IHTildeSliceEnd = math::min(
|
||||
HTilde, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1);
|
||||
const auto IWTildeSliceEnd = math::min(
|
||||
WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1);
|
||||
|
||||
const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin;
|
||||
const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin;
|
||||
|
||||
// C: input tensor
|
||||
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hi_wi_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pad_transform(Hi, InLeftPadH, InRightPadH),
|
||||
make_pad_transform(Wi, InLeftPadW, InRightPadW),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hip_wip_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_embed_transform(make_tuple(YTilde, HTilde),
|
||||
make_tuple(ConvDilationH, ConvStrideH)),
|
||||
make_embed_transform(make_tuple(XTilde, WTilde),
|
||||
make_tuple(ConvDilationW, ConvStrideW)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
|
||||
|
||||
const auto in_n_htildeslice_wtildeslice_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_freeze_transform(i_ytilde),
|
||||
make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
|
||||
make_freeze_transform(i_xtilde),
|
||||
make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<1>{},
|
||||
Sequence<2>{},
|
||||
Sequence<3>{},
|
||||
Sequence<4>{},
|
||||
Sequence<5>{}),
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<>{},
|
||||
Sequence<1>{},
|
||||
Sequence<>{},
|
||||
Sequence<2>{},
|
||||
Sequence<3>{}));
|
||||
|
||||
const auto in_gemmmraw_gemmnraw_grid_desc = transform_tensor_descriptor(
|
||||
in_n_htildeslice_wtildeslice_c_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto in_gemmm_gemmn_grid_desc = ck::tensor_operation::device::PadTensorDescriptor(
|
||||
in_gemmmraw_gemmnraw_grid_desc,
|
||||
make_tuple(GemmMPerBlock, GemmNPerBlock),
|
||||
Sequence<DoPadGemmM, DoPadGemmN>{});
|
||||
|
||||
return in_gemmm_gemmn_grid_desc;
|
||||
}
|
||||
}
|
||||
|
||||
// for input bias
|
||||
template <typename CLayout,
|
||||
typename std::enable_if<NDimSpatial == 2 &&
|
||||
(is_same_v<CLayout, tensor_layout::convolution::GC> ||
|
||||
is_same_v<CLayout, tensor_layout::convolution::G_C>),
|
||||
bool>::type = false>
|
||||
static auto
|
||||
MakeCDescriptor_M_N(const std::array<index_t, NDimSpatial + 3>& out_g_n_k_wos_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& /* out_g_n_k_wos_strides */,
|
||||
const std::array<index_t, NDimSpatial + 3>& wei_g_k_c_xs_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& /* wei_g_k_c_xs_strides */,
|
||||
const std::array<index_t, NDimSpatial + 3>& in_g_n_c_wis_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& /* in_g_n_c_wis_strides */,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<index_t, NDimSpatial>& /* input_right_pads */,
|
||||
const std::array<index_t, NDimSpatial>& /* tildes */)
|
||||
{
|
||||
const index_t N = in_g_n_c_wis_lengths[1];
|
||||
const index_t C = wei_g_k_c_xs_lengths[2];
|
||||
|
||||
const index_t Hi = in_g_n_c_wis_lengths[3];
|
||||
const index_t Wi = in_g_n_c_wis_lengths[4];
|
||||
|
||||
const index_t Ho = out_g_n_k_wos_lengths[3];
|
||||
const index_t Wo = out_g_n_k_wos_lengths[4];
|
||||
|
||||
const index_t Y = wei_g_k_c_xs_lengths[3];
|
||||
const index_t X = wei_g_k_c_xs_lengths[4];
|
||||
|
||||
const index_t InLeftPadH = input_left_pads[0];
|
||||
const index_t InLeftPadW = input_left_pads[1];
|
||||
|
||||
const index_t ConvStrideH = conv_filter_strides[0];
|
||||
const index_t ConvStrideW = conv_filter_strides[1];
|
||||
|
||||
const index_t ConvDilationH = conv_filter_dilations[0];
|
||||
const index_t ConvDilationW = conv_filter_dilations[1];
|
||||
|
||||
if constexpr(ConvBwdDataSpecialization ==
|
||||
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
|
||||
Filter1x1Stride1Pad0)
|
||||
{
|
||||
const auto in_gemmm_gemmn_grid_desc =
|
||||
make_naive_tensor_descriptor(make_tuple(N * Ho * Wo, C), make_tuple(I0, I1));
|
||||
|
||||
return in_gemmm_gemmn_grid_desc;
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
|
||||
const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
|
||||
|
||||
const auto YTilde = ConvStrideH / GcdStrideDilationH;
|
||||
const auto XTilde = ConvStrideW / GcdStrideDilationW;
|
||||
|
||||
const auto HTilde =
|
||||
Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH);
|
||||
const auto WTilde =
|
||||
Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW);
|
||||
|
||||
// only work on HTilde and WTilde that contribute to non-padding area of input tensor
|
||||
const auto IHTildeSliceBegin = math::integer_divide_floor(
|
||||
math::max(I0, InLeftPadH - ConvDilationH * (YTilde - I1)), ConvStrideH);
|
||||
const auto IWTildeSliceBegin = math::integer_divide_floor(
|
||||
math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW);
|
||||
|
||||
const auto IHTildeSliceEnd = math::min(
|
||||
HTilde, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1);
|
||||
const auto IWTildeSliceEnd = math::min(
|
||||
WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1);
|
||||
|
||||
const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin;
|
||||
const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin;
|
||||
|
||||
// bias tensor
|
||||
const auto in_gemmmraw_gemmnraw_grid_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N * HTildeSlice * WTildeSlice, C), make_tuple(I0, I1));
|
||||
|
||||
const auto in_gemmm_gemmn_grid_desc = ck::tensor_operation::device::PadTensorDescriptor(
|
||||
in_gemmmraw_gemmnraw_grid_desc,
|
||||
make_tuple(GemmMPerBlock, GemmNPerBlock),
|
||||
Sequence<DoPadGemmM, DoPadGemmN>{});
|
||||
|
||||
return in_gemmm_gemmn_grid_desc;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -16,6 +16,7 @@ namespace tensor_operation {
|
||||
template <index_t NDimSpatial, device::ConvolutionForwardSpecialization ConvForwardSpecialization>
|
||||
struct TransformConvFwdToGemm
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
|
||||
template <typename ALayout,
|
||||
@@ -864,6 +865,29 @@ struct TransformConvFwdToGemm
|
||||
|
||||
return out_gemmm_gemmn_desc;
|
||||
}
|
||||
|
||||
// for output bias
|
||||
template <typename CLayout,
|
||||
typename std::enable_if<is_same_v<CLayout, tensor_layout::convolution::GK> ||
|
||||
is_same_v<CLayout, tensor_layout::convolution::G_K>,
|
||||
bool>::type = false>
|
||||
static auto
|
||||
MakeCDescriptor_M_N(const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& /* c_g_n_k_wos_strides */)
|
||||
{
|
||||
const index_t N = c_g_n_k_wos_lengths[1];
|
||||
const index_t K = c_g_n_k_wos_lengths[2];
|
||||
|
||||
const index_t NHoWo = N * std::accumulate(c_g_n_k_wos_lengths.begin() + 3,
|
||||
c_g_n_k_wos_lengths.begin() + 3 + NDimSpatial,
|
||||
index_t{1},
|
||||
std::multiplies<index_t>());
|
||||
|
||||
const auto out_gemmm_gemmn_desc =
|
||||
make_naive_tensor_descriptor(make_tuple(NHoWo, K), make_tuple(I0, I1));
|
||||
|
||||
return out_gemmm_gemmn_desc;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace tensor_operation
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#ifndef CK_IGNORE_HPP
|
||||
#define CK_IGNORE_HPP
|
||||
#pragma once
|
||||
|
||||
// https://en.cppreference.com/w/cpp/utility/tuple/ignore
|
||||
|
||||
@@ -21,4 +20,3 @@ struct ignore_t
|
||||
inline constexpr detail::ignore_t ignore;
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
|
||||
Reference in New Issue
Block a user