mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-01 04:07:56 +00:00
added draft of conv bwd data direct loads
This commit is contained in:
@@ -187,6 +187,7 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
|
||||
CThreadBuffer& c_thread_buf,
|
||||
index_t num_loop) const
|
||||
{
|
||||
if(threadIdx.x == 0) printf("intra\n");
|
||||
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
|
||||
a_thread_desc_.GetElementSpaceSize());
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
|
||||
@@ -212,6 +213,7 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
|
||||
index_t i = 0;
|
||||
do
|
||||
{
|
||||
if(threadIdx.x == 0) printf("hotloop: %d\n", i);
|
||||
// -------------------------------------------------------------------------------------------
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
|
||||
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
|
||||
@@ -280,6 +282,7 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
|
||||
// tail
|
||||
if constexpr(TailNum == TailNumber::Full)
|
||||
{
|
||||
if(threadIdx.x == 0) printf("tail\n");
|
||||
block_sync_lds();
|
||||
static_for<0, KRepeat, 1>{}([&](auto k) {
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
@@ -919,6 +922,7 @@ struct BlockwiseGemmXdlopsDirectLoad_pipeline_v1<BlockGemmPipelineScheduler::Int
|
||||
CThreadBuffer& c_thread_buf,
|
||||
index_t num_loop) const
|
||||
{
|
||||
if(threadIdx.x == 0) printf("v1 intra directload, num_loop: %d\n", num_loop);
|
||||
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
|
||||
a_thread_desc_.GetElementSpaceSize());
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
|
||||
@@ -942,6 +946,7 @@ struct BlockwiseGemmXdlopsDirectLoad_pipeline_v1<BlockGemmPipelineScheduler::Int
|
||||
index_t i = 0;
|
||||
do
|
||||
{
|
||||
if(threadIdx.x == 0) printf("has Main loop %d\n", i);
|
||||
static_for<0, KRepeat, 1>{}([&](auto k) {
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
|
||||
@@ -981,6 +986,14 @@ struct BlockwiseGemmXdlopsDirectLoad_pipeline_v1<BlockGemmPipelineScheduler::Int
|
||||
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
|
||||
if(threadIdx.x == 0) {
|
||||
printf("a: %f b: %f\n",
|
||||
static_cast<float>(a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, k0, ik))>{}]),
|
||||
static_cast<float>(b_thread_buf[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}]));
|
||||
}
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
@@ -1007,6 +1020,7 @@ struct BlockwiseGemmXdlopsDirectLoad_pipeline_v1<BlockGemmPipelineScheduler::Int
|
||||
// tail
|
||||
if constexpr(TailNum == TailNumber::Full)
|
||||
{
|
||||
if(threadIdx.x == 0) printf("Tail full\n");
|
||||
static_for<0, KRepeat, 1>{}([&](auto k) {
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
|
||||
@@ -1032,6 +1046,10 @@ struct BlockwiseGemmXdlopsDirectLoad_pipeline_v1<BlockGemmPipelineScheduler::Int
|
||||
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
|
||||
|
||||
if(threadIdx.x == 0) {
|
||||
printf("Repeat: (M N K): (%d, %d, %d)\n", m0.value, n0.value, k0.value);
|
||||
}
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
@@ -1039,6 +1057,17 @@ struct BlockwiseGemmXdlopsDirectLoad_pipeline_v1<BlockGemmPipelineScheduler::Int
|
||||
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
|
||||
if(threadIdx.x == 0) {
|
||||
printf("a: %f b: %f a_off: %d b_off: %d\n",
|
||||
static_cast<float>(a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, k0, ik))>{}]),
|
||||
static_cast<float>(b_thread_buf[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}]),
|
||||
a_thread_desc_.CalculateOffset(make_tuple(m0, I0, k0, ik)),
|
||||
b_thread_desc_.CalculateOffset(make_tuple(n0, I0, k0, ik))
|
||||
);
|
||||
}
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -151,7 +151,8 @@ template <typename ALayout,
|
||||
typename LDSTypeA = ADataType,
|
||||
typename LDSTypeB = BDataType,
|
||||
bool DoElementwiseBeforeCShuffle = false,
|
||||
bool DirectLoad = false>
|
||||
bool DirectLoad = false,
|
||||
bool LdsScalarLoadToVgpr = false>
|
||||
struct GridwiseGemmMultiD_xdl_cshuffle_v3
|
||||
: public GridwiseGemm_xdl_cshuffle_base<
|
||||
ALayout,
|
||||
@@ -197,7 +198,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
|
||||
CDEShuffleBlockTransferScalarPerVectors,
|
||||
ComputeTypeA,
|
||||
ComputeTypeB,
|
||||
BlkGemmPipelineVer == BlockGemmPipelineVersion::v4,
|
||||
BlkGemmPipelineVer == BlockGemmPipelineVersion::v1,
|
||||
DirectLoad>
|
||||
{
|
||||
static_assert((is_same_v<AElementwiseOperation, tensor_operation::element_wise::PassThrough> &&
|
||||
@@ -248,7 +249,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
|
||||
CDEShuffleBlockTransferScalarPerVectors,
|
||||
ComputeTypeA,
|
||||
ComputeTypeB,
|
||||
BlkGemmPipelineVer == BlockGemmPipelineVersion::v4,
|
||||
BlkGemmPipelineVer == BlockGemmPipelineVersion::v1,
|
||||
DirectLoad>;
|
||||
|
||||
using Base::AK0Number;
|
||||
@@ -355,7 +356,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
|
||||
}
|
||||
|
||||
template <typename GridDesc_K0_MN_K1_T, index_t K0Number, index_t K1Value>
|
||||
__host__ __device__ static auto TransformGrid(GridDesc_K0_MN_K1_T& desc)
|
||||
__host__ __device__ static auto TransformGrid(const GridDesc_K0_MN_K1_T& desc)
|
||||
{
|
||||
|
||||
if constexpr(!DirectLoad)
|
||||
@@ -695,7 +696,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
|
||||
}
|
||||
|
||||
template <typename DsGridDesc>
|
||||
__device__ static constexpr auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
__host__ __device__ static constexpr auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
const DsGridDesc& ds_grid_desc_m_n, index_t MBlock, index_t NBlock)
|
||||
{
|
||||
return generate_tuple(
|
||||
@@ -734,6 +735,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
|
||||
MBlock{CalculateMBlock(M_)},
|
||||
NBlock{CalculateNBlock(N_)}
|
||||
{
|
||||
Print();
|
||||
}
|
||||
|
||||
__host__ void Print() const
|
||||
@@ -884,7 +886,12 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
|
||||
make_tuple(BK0Number, Number<NPerBlock>{}, BK1Number),
|
||||
make_tuple(Number<NPerBlock + BBlockLdsExtraN>{} * BK1Number, BK1Number, I1));
|
||||
}
|
||||
else
|
||||
else if constexpr(DirectLoad && BBlockTransferSrcVectorDim == 1)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(BK0Number, Number<NPerBlock>{}, BK1Number),
|
||||
make_tuple(Number<NPerBlock * BK1Number>{}, I1, Number<NPerBlock>{}));
|
||||
} else
|
||||
{
|
||||
return Base::GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(DeviceArch{});
|
||||
}
|
||||
@@ -920,7 +927,8 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
|
||||
MXdlPerWave,
|
||||
NXdlPerWave,
|
||||
KPack,
|
||||
DirectLoad>())>;
|
||||
DirectLoad,
|
||||
LdsScalarLoadToVgpr>())>;
|
||||
|
||||
template <typename DeviceArch>
|
||||
__device__ static constexpr index_t GetSharedMemoryNumberOfByte(DeviceArch)
|
||||
@@ -1213,7 +1221,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
|
||||
}
|
||||
|
||||
template <typename CGridDesc>
|
||||
__device__ static constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
__device__ __host__ static constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock)
|
||||
{
|
||||
const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
|
||||
@@ -1320,7 +1328,9 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
|
||||
const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1,
|
||||
const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1,
|
||||
const DsGridDesc_M_N& ds_grid_desc_m_n,
|
||||
const CGridDesc_M_N& c_grid_desc_m_n)
|
||||
const CGridDesc_M_N& c_grid_desc_m_n,
|
||||
const index_t k_batch = 1,
|
||||
const index_t k_idx = 0)
|
||||
{
|
||||
|
||||
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
@@ -1346,6 +1356,13 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
|
||||
const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
|
||||
const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]);
|
||||
|
||||
//const index_t n_block_data_idx_on_grid =__builtin_amdgcn_readfirstlane(k_id * KPerBlock);
|
||||
|
||||
const index_t num_ak0_per_block =
|
||||
__builtin_amdgcn_readfirstlane(a_grid_desc_ak0_m_ak1.GetLength(I0) / k_batch);
|
||||
const index_t num_bk0_per_block =
|
||||
__builtin_amdgcn_readfirstlane(b_grid_desc_bk0_n_bk1.GetLength(I0) / k_batch);
|
||||
|
||||
// HACK: this force m/n_block_data_idx_on_grid into SGPR
|
||||
const index_t m_block_data_idx_on_grid =
|
||||
__builtin_amdgcn_readfirstlane(block_m_id * MPerBlock);
|
||||
@@ -1364,6 +1381,20 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
|
||||
constexpr auto b_block_desc_bk0_n_bk1 =
|
||||
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(get_device_arch());
|
||||
|
||||
if(threadIdx.x == 0) {
|
||||
if(is_same<tensor_layout::gemm::RowMajor, ALayout>::value) {
|
||||
printf("A RowMajor\n");
|
||||
} else if(is_same<tensor_layout::gemm::ColumnMajor, ALayout>::value) {
|
||||
printf("A Colmajor\n");
|
||||
}
|
||||
|
||||
if(is_same<tensor_layout::gemm::RowMajor, BLayout>::value) {
|
||||
printf("B RowMajor\n");
|
||||
} else if(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value) {
|
||||
printf("B Colmajor\n");
|
||||
}
|
||||
}
|
||||
|
||||
auto get_a_blockwise_copy = [&]() {
|
||||
if constexpr(DirectLoad)
|
||||
{
|
||||
@@ -1381,7 +1412,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
|
||||
2,
|
||||
ABlockTransferSrcScalarPerVector>(
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
make_multi_index(0, m_block_data_idx_on_grid, 0),
|
||||
make_multi_index(num_ak0_per_block * k_idx, m_block_data_idx_on_grid, 0),
|
||||
a_block_desc_ak0_m_ak1,
|
||||
make_multi_index(0, 0, 0));
|
||||
}
|
||||
@@ -1411,7 +1442,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
|
||||
true,
|
||||
BlockwiseGemmPipe::GlobalBufferNum>(
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
make_multi_index(0, m_block_data_idx_on_grid, 0),
|
||||
make_multi_index(num_ak0_per_block * k_idx, m_block_data_idx_on_grid, 0),
|
||||
a_element_op,
|
||||
a_block_desc_ak0_m_ak1,
|
||||
make_multi_index(0, 0, 0),
|
||||
@@ -1434,10 +1465,10 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
|
||||
decltype(b_block_desc_bk0_n_bk1),
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
2,
|
||||
1, // enforced earlier
|
||||
BBlockTransferSrcScalarPerVector>(
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
make_multi_index(0, n_block_data_idx_on_grid, 0),
|
||||
make_multi_index(num_bk0_per_block * k_idx, n_block_data_idx_on_grid, 0),
|
||||
b_block_desc_bk0_n_bk1,
|
||||
make_multi_index(0, 0, 0));
|
||||
}
|
||||
@@ -1467,7 +1498,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
|
||||
true,
|
||||
BlockwiseGemmPipe::GlobalBufferNum>(
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
make_multi_index(0, n_block_data_idx_on_grid, 0),
|
||||
make_multi_index(num_bk0_per_block * k_idx, n_block_data_idx_on_grid, 0),
|
||||
b_element_op,
|
||||
b_block_desc_bk0_n_bk1,
|
||||
make_multi_index(0, 0, 0),
|
||||
@@ -1486,6 +1517,10 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
|
||||
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<LDSTypeA*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
|
||||
|
||||
if(threadIdx.x == 0) {
|
||||
printf("a size aligned: %ld, a size: %ld b size: %ld\n", a_block_space_size_aligned.value, a_block_desc_ak0_m_ak1.GetElementSpaceSize().value, b_block_desc_bk0_n_bk1.GetElementSpaceSize().value);
|
||||
}
|
||||
|
||||
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<LDSTypeB*>(p_shared) +
|
||||
a_block_space_size_aligned * sizeof(LDSTypeA) / sizeof(LDSTypeB),
|
||||
@@ -1501,7 +1536,11 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
|
||||
|
||||
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
|
||||
(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
|
||||
KPerBlock);
|
||||
(KPerBlock * k_batch));
|
||||
|
||||
// if(threadIdx.x == 0) {
|
||||
// printf("num_k block main loop: %d\n m_block_data_idx_on_grid: %d\n n_block_data_idx_on_grid: %d\n", num_k_block_main_loop, m_block_data_idx_on_grid, n_block_data_idx_on_grid);
|
||||
// }
|
||||
|
||||
blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(a_grid_desc_ak0_m_ak1,
|
||||
a_block_desc_ak0_m_ak1,
|
||||
@@ -1825,6 +1864,10 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
|
||||
(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
|
||||
KPerBlock);
|
||||
|
||||
if(threadIdx.x == 0) {
|
||||
printf("num_k block main loop: %d\n m_block_data_idx_on_grid: %d\n n_block_data_idx_on_grid: %d\n", num_k_block_main_loop, m_block_data_idx_on_grid, n_block_data_idx_on_grid);
|
||||
}
|
||||
|
||||
blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(a_grid_desc_ak0_m_ak1,
|
||||
a_block_desc_ak0_m_ak1,
|
||||
a_blockwise_copy,
|
||||
|
||||
@@ -0,0 +1,90 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v3.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
using BF8 = ck::bf8_t;
|
||||
using F8 = ck::f8_t;
|
||||
|
||||
using Empty_Tuple = ck::Tuple<>;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using namespace ck::tensor_layout::convolution;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
static constexpr auto ConvBwdDataDefault = ConvolutionBackwardDataSpecialization::Default;
|
||||
|
||||
static constexpr auto ConvBwdDataFilter1x1Stride1Pad0 =
|
||||
ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0;
|
||||
|
||||
// template <index_t NDimSpatial,
|
||||
// typename ALayout,
|
||||
// typename BLayout,
|
||||
// typename DsLayout,
|
||||
// typename ELayout,
|
||||
// ConvolutionBackwardDataSpecialization ConvSpec>
|
||||
// using device_grouped_conv_bwd_data_xdl_v3_f16_16_16_instances = std::tuple<
|
||||
// // clang-format off
|
||||
// // ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle | CShuffle | CDEBlockTransfer| CDEBlockTransfer|
|
||||
// // ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat | _MBlock_MPerBlock| ScalarPerVector|
|
||||
// // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| |
|
||||
// // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
// DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 64, 64, 32, 64, 8, 8, 16, 16, 4, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 4>, S<8,8,8>>,
|
||||
// DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 64, 64, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 4>, S<8,8,8>>
|
||||
// // clang-format on
|
||||
// >;
|
||||
|
||||
template <index_t NDimSpatial,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
ConvolutionBackwardDataSpecialization ConvSpec>
|
||||
using device_grouped_conv_bwd_data_xdl_v3_f16_instances = std::tuple<
|
||||
// clang-format off
|
||||
// ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle | CShuffle | CDEBlockTransfer| CDEBlockTransfer|
|
||||
// ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat | _MBlock_MPerBlock| ScalarPerVector|
|
||||
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| |
|
||||
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
// generic instance
|
||||
//DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 1, S<4, 2, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 16, 1, 4>, S<2,2,2>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, 1, 1, true>
|
||||
// DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, S<1,1,1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, 1, 1, false>,
|
||||
// DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 128, 16, 32, 64, 8, 8, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, S<2,2,2>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, 1, 1, true>,
|
||||
//DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 128, 16, 32, 64, 8, 8, 16, 16, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 8>, S<1,1,1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, 1, 1, false>
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 256, 64, 64, 64, 8, 8, 16, 16, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, S<2,2,2>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, 1, 1, true>
|
||||
// DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 256, 64, 64, 64, 8, 8, 16, 16, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, S<2,2,2>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, 1, 1, true>,
|
||||
// DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 256, 64, 64, 64, 8, 8, 16, 16, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 4>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, S<2,2,2>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, 1, 1, true>
|
||||
|
||||
|
||||
//DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 128, 16, 32, 64, 8, 8, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 2, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, S<1,1,1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, 1, 1, true>,
|
||||
//DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 128, 16, 32, 64, 8, 8, 16, 16, 1, 1, S<8, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, S<1,1,1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, 1, 1, true>
|
||||
// DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 64, 64, 32, 64, 8, 8, 16, 16, 4, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 4>, S<8,8,8>>,
|
||||
// DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 64, 64, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 4>, S<8,8,8>>,
|
||||
// DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 128, 128, 64, 32, 8, 8, 16, 16, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 8>, S<8,8,8>>,
|
||||
// DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 128, 32, 64, 32, 8, 8, 16, 16, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 8>, S<8,8,8>>,
|
||||
// DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 128, 128, 64, 64, 8, 8, 16, 16, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, S<8,8,8>>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -79,7 +79,7 @@ struct DeviceOperationInstanceFactory<
|
||||
is_same_v<OutDataType, F16> && is_same_v<ComputeTypeA, F16> &&
|
||||
is_same_v<ComputeTypeB, F16>)
|
||||
{
|
||||
add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f16_instances(op_ptrs);
|
||||
// add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f16_instances(op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
@@ -87,7 +87,7 @@ struct DeviceOperationInstanceFactory<
|
||||
is_same_v<OutDataType, F32> && is_same_v<ComputeTypeA, F32> &&
|
||||
is_same_v<ComputeTypeB, F32>)
|
||||
{
|
||||
add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f32_instances(op_ptrs);
|
||||
// add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f32_instances(op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
@@ -95,8 +95,8 @@ struct DeviceOperationInstanceFactory<
|
||||
is_same_v<OutDataType, BF16> && is_same_v<ComputeTypeA, BF16> &&
|
||||
is_same_v<ComputeTypeB, BF16>)
|
||||
{
|
||||
add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_bf16_instances(
|
||||
op_ptrs);
|
||||
// add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_bf16_instances(
|
||||
// op_ptrs);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
@@ -108,11 +108,12 @@ struct DeviceOperationInstanceFactory<
|
||||
is_same_v<OutDataType, F16> && is_same_v<ComputeTypeA, F16> &&
|
||||
is_same_v<ComputeTypeB, F16>)
|
||||
{
|
||||
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_instances(op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_16_16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_optimized_loads_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_data_xdl_v3_nhwgk_gkyxc_nhwgc_f16_instances(op_ptrs);
|
||||
// add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_instances(op_ptrs);
|
||||
// add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_16_16_instances(
|
||||
// op_ptrs);
|
||||
// add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_optimized_loads_instances(
|
||||
// op_ptrs);
|
||||
}
|
||||
#endif
|
||||
if constexpr(is_same_v<InDataType, F32> && is_same_v<WeiDataType, F32> &&
|
||||
@@ -123,23 +124,23 @@ struct DeviceOperationInstanceFactory<
|
||||
#ifdef CK_ENABLE_TF32
|
||||
if constexpr(is_same_v<ComputeTypeA, TF32>)
|
||||
{
|
||||
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_tf32_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_tf32_16_16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_tf32_optimized_loads_instances(
|
||||
op_ptrs);
|
||||
// add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_tf32_instances(
|
||||
// op_ptrs);
|
||||
// add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_tf32_16_16_instances(
|
||||
// op_ptrs);
|
||||
// add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_tf32_optimized_loads_instances(
|
||||
// op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
if constexpr(is_same_v<ComputeTypeA, F32>)
|
||||
{
|
||||
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_16_16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_optimized_loads_instances(
|
||||
op_ptrs);
|
||||
// add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_instances(
|
||||
// op_ptrs);
|
||||
// add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_16_16_instances(
|
||||
// op_ptrs);
|
||||
// add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_optimized_loads_instances(
|
||||
// op_ptrs);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
@@ -148,12 +149,12 @@ struct DeviceOperationInstanceFactory<
|
||||
is_same_v<OutDataType, BF16> && is_same_v<ComputeTypeA, BF16> &&
|
||||
is_same_v<ComputeTypeB, BF16>)
|
||||
{
|
||||
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_16_16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_optimized_loads_instances(
|
||||
op_ptrs);
|
||||
// add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_instances(
|
||||
// op_ptrs);
|
||||
// add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_16_16_instances(
|
||||
// op_ptrs);
|
||||
// add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_optimized_loads_instances(
|
||||
// op_ptrs);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
@@ -165,7 +166,7 @@ struct DeviceOperationInstanceFactory<
|
||||
is_same_v<OutDataType, F16> && is_same_v<ComputeTypeA, F16> &&
|
||||
is_same_v<ComputeTypeB, F16>)
|
||||
{
|
||||
add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkyxc_ngchw_f16_instances(op_ptrs);
|
||||
// add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkyxc_ngchw_f16_instances(op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
@@ -173,7 +174,7 @@ struct DeviceOperationInstanceFactory<
|
||||
is_same_v<OutDataType, F32> && is_same_v<ComputeTypeA, F32> &&
|
||||
is_same_v<ComputeTypeB, F32>)
|
||||
{
|
||||
add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkyxc_ngchw_f32_instances(op_ptrs);
|
||||
// add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkyxc_ngchw_f32_instances(op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
@@ -181,8 +182,8 @@ struct DeviceOperationInstanceFactory<
|
||||
is_same_v<OutDataType, BF16> && is_same_v<ComputeTypeA, BF16> &&
|
||||
is_same_v<ComputeTypeB, BF16>)
|
||||
{
|
||||
add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkyxc_ngchw_bf16_instances(
|
||||
op_ptrs);
|
||||
// add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkyxc_ngchw_bf16_instances(
|
||||
// op_ptrs);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
@@ -194,11 +195,11 @@ struct DeviceOperationInstanceFactory<
|
||||
is_same_v<OutDataType, F16> && is_same_v<ComputeTypeA, F16> &&
|
||||
is_same_v<ComputeTypeB, F16>)
|
||||
{
|
||||
add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f16_instances(op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f16_16_16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f16_vec_transpose_instances(
|
||||
op_ptrs);
|
||||
// add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f16_instances(op_ptrs);
|
||||
// add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f16_16_16_instances(
|
||||
// op_ptrs);
|
||||
// add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f16_vec_transpose_instances(
|
||||
// op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
@@ -206,11 +207,11 @@ struct DeviceOperationInstanceFactory<
|
||||
is_same_v<OutDataType, F32> && is_same_v<ComputeTypeA, F32> &&
|
||||
is_same_v<ComputeTypeB, F32>)
|
||||
{
|
||||
add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f32_instances(op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f32_16_16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f32_vec_transpose_instances(
|
||||
op_ptrs);
|
||||
// add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f32_instances(op_ptrs);
|
||||
// add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f32_16_16_instances(
|
||||
// op_ptrs);
|
||||
// add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f32_vec_transpose_instances(
|
||||
// op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
@@ -218,12 +219,12 @@ struct DeviceOperationInstanceFactory<
|
||||
is_same_v<OutDataType, BF16> && is_same_v<ComputeTypeA, BF16> &&
|
||||
is_same_v<ComputeTypeB, BF16>)
|
||||
{
|
||||
add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_bf16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_bf16_16_16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_bf16_vec_transpose_instances(
|
||||
op_ptrs);
|
||||
// add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_bf16_instances(
|
||||
// op_ptrs);
|
||||
// add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_bf16_16_16_instances(
|
||||
// op_ptrs);
|
||||
// add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_bf16_vec_transpose_instances(
|
||||
// op_ptrs);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
@@ -238,8 +239,8 @@ struct DeviceOperationInstanceFactory<
|
||||
is_same_v<OutDataType, F16> && is_same_v<ComputeTypeA, F16> &&
|
||||
is_same_v<ComputeTypeB, F16>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f16_instances(
|
||||
op_ptrs);
|
||||
// add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f16_instances(
|
||||
// op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
@@ -247,8 +248,8 @@ struct DeviceOperationInstanceFactory<
|
||||
is_same_v<OutDataType, F32> && is_same_v<ComputeTypeA, F32> &&
|
||||
is_same_v<ComputeTypeB, F32>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f32_instances(
|
||||
op_ptrs);
|
||||
// add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f32_instances(
|
||||
// op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
@@ -256,8 +257,8 @@ struct DeviceOperationInstanceFactory<
|
||||
is_same_v<OutDataType, BF16> && is_same_v<ComputeTypeA, BF16> &&
|
||||
is_same_v<ComputeTypeB, BF16>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_bf16_instances(
|
||||
op_ptrs);
|
||||
// add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_bf16_instances(
|
||||
// op_ptrs);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
@@ -269,12 +270,12 @@ struct DeviceOperationInstanceFactory<
|
||||
is_same_v<OutDataType, F16> && is_same_v<ComputeTypeA, F16> &&
|
||||
is_same_v<ComputeTypeB, F16>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_16_16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_optimized_loads_instances(
|
||||
op_ptrs);
|
||||
// add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_instances(
|
||||
// op_ptrs);
|
||||
// add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_16_16_instances(
|
||||
// op_ptrs);
|
||||
// add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_optimized_loads_instances(
|
||||
// op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#if defined CK_ENABLE_FP16 && defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
|
||||
@@ -282,8 +283,8 @@ struct DeviceOperationInstanceFactory<
|
||||
is_same_v<OutDataType, F16> && is_same_v<ComputeTypeA, bf8_t> &&
|
||||
is_same_v<ComputeTypeB, f8_t>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_input_f16_comp_bf8f8_instances(
|
||||
op_ptrs);
|
||||
// add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_input_f16_comp_bf8f8_instances(
|
||||
// op_ptrs);
|
||||
}
|
||||
#endif
|
||||
if constexpr(is_same_v<InDataType, F32> && is_same_v<WeiDataType, F32> &&
|
||||
@@ -294,23 +295,23 @@ struct DeviceOperationInstanceFactory<
|
||||
#ifdef CK_ENABLE_FP32
|
||||
if constexpr(is_same_v<ComputeTypeA, F32>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_16_16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_optimized_loads_instances(
|
||||
op_ptrs);
|
||||
// add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_instances(
|
||||
// op_ptrs);
|
||||
// add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_16_16_instances(
|
||||
// op_ptrs);
|
||||
// add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_optimized_loads_instances(
|
||||
// op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_TF32
|
||||
if constexpr(is_same_v<ComputeTypeA, TF32>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_tf32_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_tf32_16_16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_tf32_optimized_loads_instances(
|
||||
op_ptrs);
|
||||
// add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_tf32_instances(
|
||||
// op_ptrs);
|
||||
// add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_tf32_16_16_instances(
|
||||
// op_ptrs);
|
||||
// add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_tf32_optimized_loads_instances(
|
||||
// op_ptrs);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
@@ -319,12 +320,12 @@ struct DeviceOperationInstanceFactory<
|
||||
is_same_v<OutDataType, BF16> && is_same_v<ComputeTypeA, BF16> &&
|
||||
is_same_v<ComputeTypeB, BF16>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_16_16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_optimized_loads_instances(
|
||||
op_ptrs);
|
||||
// add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_instances(
|
||||
// op_ptrs);
|
||||
// add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_16_16_instances(
|
||||
// op_ptrs);
|
||||
// add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_optimized_loads_instances(
|
||||
// op_ptrs);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
@@ -336,8 +337,8 @@ struct DeviceOperationInstanceFactory<
|
||||
is_same_v<OutDataType, F16> && is_same_v<ComputeTypeA, F16> &&
|
||||
is_same_v<ComputeTypeB, F16>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkzyxc_ngcdhw_f16_instances(
|
||||
op_ptrs);
|
||||
// add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkzyxc_ngcdhw_f16_instances(
|
||||
// op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
@@ -345,8 +346,8 @@ struct DeviceOperationInstanceFactory<
|
||||
is_same_v<OutDataType, F32> && is_same_v<ComputeTypeA, F32> &&
|
||||
is_same_v<ComputeTypeB, F32>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkzyxc_ngcdhw_f32_instances(
|
||||
op_ptrs);
|
||||
// add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkzyxc_ngcdhw_f32_instances(
|
||||
// op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
@@ -354,8 +355,8 @@ struct DeviceOperationInstanceFactory<
|
||||
is_same_v<OutDataType, BF16> && is_same_v<ComputeTypeA, BF16> &&
|
||||
is_same_v<ComputeTypeB, BF16>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkzyxc_ngcdhw_bf16_instances(
|
||||
op_ptrs);
|
||||
// add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkzyxc_ngcdhw_bf16_instances(
|
||||
// op_ptrs);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
@@ -367,12 +368,12 @@ struct DeviceOperationInstanceFactory<
|
||||
is_same_v<OutDataType, F16> && is_same_v<ComputeTypeA, F16> &&
|
||||
is_same_v<ComputeTypeB, F16>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_f16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_f16_16_16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_f16_vec_transpose_instances(
|
||||
op_ptrs);
|
||||
// add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_f16_instances(
|
||||
// op_ptrs);
|
||||
// add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_f16_16_16_instances(
|
||||
// op_ptrs);
|
||||
// add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_f16_vec_transpose_instances(
|
||||
// op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
@@ -380,12 +381,12 @@ struct DeviceOperationInstanceFactory<
|
||||
is_same_v<OutDataType, F32> && is_same_v<ComputeTypeA, F32> &&
|
||||
is_same_v<ComputeTypeB, F32>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_f32_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_f32_16_16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_f32_vec_transpose_instances(
|
||||
op_ptrs);
|
||||
// add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_f32_instances(
|
||||
// op_ptrs);
|
||||
// add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_f32_16_16_instances(
|
||||
// op_ptrs);
|
||||
// add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_f32_vec_transpose_instances(
|
||||
// op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
@@ -393,12 +394,12 @@ struct DeviceOperationInstanceFactory<
|
||||
is_same_v<OutDataType, BF16> && is_same_v<ComputeTypeA, BF16> &&
|
||||
is_same_v<ComputeTypeB, BF16>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_bf16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_bf16_16_16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_bf16_vec_transpose_instances(
|
||||
op_ptrs);
|
||||
// add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_bf16_instances(
|
||||
// op_ptrs);
|
||||
// add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_bf16_16_16_instances(
|
||||
// op_ptrs);
|
||||
// add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_bf16_vec_transpose_instances(
|
||||
// op_ptrs);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
@@ -416,10 +417,10 @@ struct DeviceOperationInstanceFactory<
|
||||
is_same_v<OutDataType, F16> && is_same_v<ComputeTypeA, F16> &&
|
||||
is_same_v<ComputeTypeB, F16>)
|
||||
{
|
||||
add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_f16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_f16_1x1s1p0_instances(
|
||||
op_ptrs);
|
||||
// add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_f16_instances(
|
||||
// op_ptrs);
|
||||
// add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_f16_1x1s1p0_instances(
|
||||
// op_ptrs);
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -428,9 +429,9 @@ struct DeviceOperationInstanceFactory<
|
||||
is_same_v<OutDataType, int8_t> && is_same_v<ComputeTypeA, int8_t> &&
|
||||
is_same_v<ComputeTypeB, int8_t>)
|
||||
{
|
||||
add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_i8_instances(op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_i8_1x1s1p0_instances(
|
||||
op_ptrs);
|
||||
// add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_i8_instances(op_ptrs);
|
||||
// add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_i8_1x1s1p0_instances(
|
||||
// op_ptrs);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
@@ -442,14 +443,14 @@ struct DeviceOperationInstanceFactory<
|
||||
is_same_v<OutDataType, F16> && is_same_v<ComputeTypeA, F16> &&
|
||||
is_same_v<ComputeTypeB, F16>)
|
||||
{
|
||||
add_device_grouped_conv2d_bwd_data_wmma_v3_nhwgk_gkyxc_nhwgc_f16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_data_wmma_v3_nhwgk_gkyxc_nhwgc_f16_16_16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_f16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_f16_1x1s1p0_instances(
|
||||
op_ptrs);
|
||||
// add_device_grouped_conv2d_bwd_data_wmma_v3_nhwgk_gkyxc_nhwgc_f16_instances(
|
||||
// op_ptrs);
|
||||
// add_device_grouped_conv2d_bwd_data_wmma_v3_nhwgk_gkyxc_nhwgc_f16_16_16_instances(
|
||||
// op_ptrs);
|
||||
// add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_f16_instances(
|
||||
// op_ptrs);
|
||||
// add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_f16_1x1s1p0_instances(
|
||||
// op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
@@ -457,10 +458,10 @@ struct DeviceOperationInstanceFactory<
|
||||
is_same_v<OutDataType, BF16> && is_same_v<ComputeTypeA, BF16> &&
|
||||
is_same_v<ComputeTypeB, BF16>)
|
||||
{
|
||||
add_device_grouped_conv2d_bwd_data_wmma_v3_nhwgk_gkyxc_nhwgc_bf16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_data_wmma_v3_nhwgk_gkyxc_nhwgc_bf16_16_16_instances(
|
||||
op_ptrs);
|
||||
// add_device_grouped_conv2d_bwd_data_wmma_v3_nhwgk_gkyxc_nhwgc_bf16_instances(
|
||||
// op_ptrs);
|
||||
// add_device_grouped_conv2d_bwd_data_wmma_v3_nhwgk_gkyxc_nhwgc_bf16_16_16_instances(
|
||||
// op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_INT8
|
||||
@@ -468,9 +469,9 @@ struct DeviceOperationInstanceFactory<
|
||||
is_same_v<OutDataType, int8_t> && is_same_v<ComputeTypeA, int8_t> &&
|
||||
is_same_v<ComputeTypeB, int8_t>)
|
||||
{
|
||||
add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_i8_instances(op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_i8_1x1s1p0_instances(
|
||||
op_ptrs);
|
||||
// add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_i8_instances(op_ptrs);
|
||||
// add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_i8_1x1s1p0_instances(
|
||||
// op_ptrs);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
@@ -485,10 +486,10 @@ struct DeviceOperationInstanceFactory<
|
||||
is_same_v<OutDataType, F16> && is_same_v<ComputeTypeA, F16> &&
|
||||
is_same_v<ComputeTypeB, F16>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_f16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_f16_1x1s1p0_instances(
|
||||
op_ptrs);
|
||||
// add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_f16_instances(
|
||||
// op_ptrs);
|
||||
// add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_f16_1x1s1p0_instances(
|
||||
// op_ptrs);
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -497,10 +498,10 @@ struct DeviceOperationInstanceFactory<
|
||||
is_same_v<OutDataType, int8_t> && is_same_v<ComputeTypeA, int8_t> &&
|
||||
is_same_v<ComputeTypeB, int8_t>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_i8_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_i8_1x1s1p0_instances(
|
||||
op_ptrs);
|
||||
// add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_i8_instances(
|
||||
// op_ptrs);
|
||||
// add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_i8_1x1s1p0_instances(
|
||||
// op_ptrs);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
@@ -512,14 +513,14 @@ struct DeviceOperationInstanceFactory<
|
||||
is_same_v<OutDataType, F16> && is_same_v<ComputeTypeA, F16> &&
|
||||
is_same_v<ComputeTypeB, F16>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_f16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_f16_1x1s1p0_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_bwd_data_wmma_v3_ndhwgk_gkzyxc_ndhwgc_f16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_bwd_data_wmma_v3_ndhwgk_gkzyxc_ndhwgc_f16_16_16_instances(
|
||||
op_ptrs);
|
||||
// add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_f16_instances(
|
||||
// op_ptrs);
|
||||
// add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_f16_1x1s1p0_instances(
|
||||
// op_ptrs);
|
||||
// add_device_grouped_conv3d_bwd_data_wmma_v3_ndhwgk_gkzyxc_ndhwgc_f16_instances(
|
||||
// op_ptrs);
|
||||
// add_device_grouped_conv3d_bwd_data_wmma_v3_ndhwgk_gkzyxc_ndhwgc_f16_16_16_instances(
|
||||
// op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
@@ -527,10 +528,10 @@ struct DeviceOperationInstanceFactory<
|
||||
is_same_v<OutDataType, BF16> && is_same_v<ComputeTypeA, BF16> &&
|
||||
is_same_v<ComputeTypeB, BF16>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_data_wmma_v3_ndhwgk_gkzyxc_ndhwgc_bf16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_bwd_data_wmma_v3_ndhwgk_gkzyxc_ndhwgc_bf16_16_16_instances(
|
||||
op_ptrs);
|
||||
// add_device_grouped_conv3d_bwd_data_wmma_v3_ndhwgk_gkzyxc_ndhwgc_bf16_instances(
|
||||
// op_ptrs);
|
||||
// add_device_grouped_conv3d_bwd_data_wmma_v3_ndhwgk_gkzyxc_ndhwgc_bf16_16_16_instances(
|
||||
// op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_INT8
|
||||
@@ -538,10 +539,10 @@ struct DeviceOperationInstanceFactory<
|
||||
is_same_v<OutDataType, int8_t> && is_same_v<ComputeTypeA, int8_t> &&
|
||||
is_same_v<ComputeTypeB, int8_t>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_i8_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_i8_1x1s1p0_instances(
|
||||
op_ptrs);
|
||||
// add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_i8_instances(
|
||||
// op_ptrs);
|
||||
// add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_i8_1x1s1p0_instances(
|
||||
// op_ptrs);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -56,6 +56,20 @@ void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_bf16_instances(
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_v3_nhwgk_gkyxc_nhwgc_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
NHWGK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGC,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
NHWGK,
|
||||
|
||||
@@ -32,6 +32,7 @@ add_instance_library(
|
||||
xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_f16_vec_transpose_instance.cpp
|
||||
xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_bf16_vec_transpose_instance.cpp
|
||||
xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_f32_vec_transpose_instance.cpp
|
||||
xdl/device_grouped_conv2d_bwd_data_xdl_v3_nhwgc_gkyxc_nhwgk_f16_instance.cpp
|
||||
|
||||
wmma/device_grouped_conv2d_bwd_data_wmma_gnhwc_gkyxc_gnhwk_f16_1x1s1p0_instance.cpp
|
||||
wmma/device_grouped_conv2d_bwd_data_wmma_nhwgc_gkyxc_nhwgk_f16_1x1s1p0_instance.cpp
|
||||
|
||||
@@ -0,0 +1,49 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_v3_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_v3_nhwgk_gkyxc_nhwgc_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
NHWGK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGC,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
// 1. Default
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_data_xdl_v3_f16_instances<2,
|
||||
NHWGK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGC,
|
||||
ConvBwdDataDefault>{});
|
||||
// 2. Filter1x1Stride1Pad0
|
||||
// add_device_operation_instances(
|
||||
// instances,
|
||||
// device_grouped_conv_bwd_data_xdl_f16_instances<2,
|
||||
// NHWGK,
|
||||
// GKYXC,
|
||||
// Empty_Tuple,
|
||||
// NHWGC,
|
||||
// ConvBwdDataFilter1x1Stride1Pad0>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -121,6 +121,38 @@ bool profile_grouped_conv_bwd_data_impl(int do_verification,
|
||||
out.GenerateTensorValue(GeneratorTensor_3<OutDataType>{0.0, 1.0});
|
||||
wei.GenerateTensorValue(GeneratorTensor_3<WeiDataType>{-0.5, 0.5});
|
||||
break;
|
||||
case 3:
|
||||
out.GenerateTensorValue(GeneratorTensor_1<OutDataType>{1});
|
||||
wei.GenerateTensorValue(GeneratorTensor_1<WeiDataType>{2});
|
||||
break;
|
||||
case 4:
|
||||
out.GenerateTensorValue(GeneratorTensor_1<OutDataType>{2});
|
||||
wei.GenerateTensorValue(GeneratorTensor_1<WeiDataType>{1});
|
||||
break;
|
||||
case 5:
|
||||
out.GenerateTensorValue(GeneratorTensor_3<OutDataType>{0.0, 1.0});
|
||||
wei.GenerateTensorValue(GeneratorTensor_1<WeiDataType>{1});
|
||||
break;
|
||||
case 6:
|
||||
out.GenerateTensorValue(GeneratorTensor_1<OutDataType>{1});
|
||||
wei.GenerateTensorValue(GeneratorTensor_3<WeiDataType>{0.0, 1.0});
|
||||
break;
|
||||
case 7:
|
||||
out.GenerateTensorValue(GeneratorTensor_3<OutDataType>{0.0, 1.0});
|
||||
wei.GenerateTensorValue(GeneratorTensor_3<WeiDataType>{0.0, 1.0});
|
||||
break;
|
||||
case 8:
|
||||
out.GenerateTensorValue(GeneratorTensor_Sequential<OutDataType, 2>{});
|
||||
wei.GenerateTensorValue(GeneratorTensor_1<WeiDataType>{1});
|
||||
break;
|
||||
case 9:
|
||||
out.GenerateTensorValue(GeneratorTensor_1<OutDataType>{1});
|
||||
wei.GenerateTensorValue(GeneratorTensor_Sequential<WeiDataType, 1>{});
|
||||
break;
|
||||
case 10:
|
||||
out.GenerateTensorValue(GeneratorTensor_Sequential<OutDataType, 2>{});
|
||||
wei.GenerateTensorValue(GeneratorTensor_Sequential<WeiDataType, 1>{});
|
||||
break;
|
||||
default:
|
||||
out.GenerateTensorValue(GeneratorTensor_1<OutDataType>{1});
|
||||
wei.GenerateTensorValue(GeneratorTensor_1<WeiDataType>{1});
|
||||
@@ -210,6 +242,7 @@ bool profile_grouped_conv_bwd_data_impl(int do_verification,
|
||||
// workspace_sz will be equal to 0 for other layout than NGCHW
|
||||
const std::size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get());
|
||||
DeviceMem workspace_dev(workspace_sz);
|
||||
// printf("run impl\n");
|
||||
op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace_dev.GetDeviceBuffer());
|
||||
|
||||
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
|
||||
@@ -224,8 +257,10 @@ bool profile_grouped_conv_bwd_data_impl(int do_verification,
|
||||
|
||||
auto invoker_ptr = op_ptr->MakeInvokerPointer();
|
||||
|
||||
// printf("prerun\n");
|
||||
float avg_time =
|
||||
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
|
||||
// printf("post run\n");
|
||||
|
||||
std::size_t flop = conv_param.GetFlops();
|
||||
std::size_t num_btype = conv_param.GetByte<InDataType, WeiDataType, OutDataType>();
|
||||
|
||||
@@ -14,118 +14,118 @@ message(STATUS "CK_PROFILER_OP_FILTER: ${CK_PROFILER_OP_FILTER}")
|
||||
message(STATUS "CK_PROFILER_INSTANCE_FILTER: ${CK_PROFILER_INSTANCE_FILTER}")
|
||||
|
||||
set(PROFILER_OPS
|
||||
profile_gemm.cpp
|
||||
profile_reduce.cpp
|
||||
profile_groupnorm_bwd_data.cpp
|
||||
profile_groupnorm_fwd.cpp
|
||||
profile_layernorm_bwd_data.cpp
|
||||
profile_layernorm_bwd_gamma_beta.cpp
|
||||
profile_groupnorm_bwd_gamma_beta.cpp
|
||||
profile_layernorm_fwd.cpp
|
||||
profile_max_pool2d_fwd.cpp
|
||||
profile_pool3d_fwd.cpp
|
||||
profile_avg_pool3d_bwd.cpp
|
||||
profile_max_pool3d_bwd.cpp
|
||||
profile_avg_pool2d_bwd.cpp
|
||||
profile_max_pool2d_bwd.cpp
|
||||
profile_softmax.cpp
|
||||
profile_batchnorm_fwd.cpp
|
||||
profile_batchnorm_bwd.cpp
|
||||
profile_batchnorm_infer.cpp
|
||||
profile_conv_tensor_rearrange.cpp
|
||||
profile_transpose.cpp
|
||||
profile_permute_scale.cpp
|
||||
profile_gemm_quantization.cpp
|
||||
# profile_gemm.cpp
|
||||
# profile_reduce.cpp
|
||||
# profile_groupnorm_bwd_data.cpp
|
||||
# profile_groupnorm_fwd.cpp
|
||||
# profile_layernorm_bwd_data.cpp
|
||||
# profile_layernorm_bwd_gamma_beta.cpp
|
||||
# profile_groupnorm_bwd_gamma_beta.cpp
|
||||
# profile_layernorm_fwd.cpp
|
||||
# profile_max_pool2d_fwd.cpp
|
||||
# profile_pool3d_fwd.cpp
|
||||
# profile_avg_pool3d_bwd.cpp
|
||||
# profile_max_pool3d_bwd.cpp
|
||||
# profile_avg_pool2d_bwd.cpp
|
||||
# profile_max_pool2d_bwd.cpp
|
||||
# profile_softmax.cpp
|
||||
# profile_batchnorm_fwd.cpp
|
||||
# profile_batchnorm_bwd.cpp
|
||||
# profile_batchnorm_infer.cpp
|
||||
# profile_conv_tensor_rearrange.cpp
|
||||
# profile_transpose.cpp
|
||||
# profile_permute_scale.cpp
|
||||
# profile_gemm_quantization.cpp
|
||||
)
|
||||
|
||||
if(SUPPORTED_GPU_TARGETS MATCHES "gfx9")
|
||||
if(DTYPES MATCHES "fp32" OR DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES)
|
||||
list(APPEND PROFILER_OPS profile_contraction_bilinear.cpp)
|
||||
list(APPEND PROFILER_OPS profile_contraction_scale.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_contraction_bilinear.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_contraction_scale.cpp)
|
||||
endif()
|
||||
if(CK_EXPERIMENTAL_BUILDER)
|
||||
list(APPEND PROFILER_OPS profile_grouped_conv_fwd_tile.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_grouped_conv_fwd_tile.cpp)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if(SUPPORTED_GPU_TARGETS MATCHES "gfx9|gfx1[12]")
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
list(APPEND PROFILER_OPS profile_gemm_reduce.cpp)
|
||||
list(APPEND PROFILER_OPS profile_batched_gemm_add_relu_gemm_add.cpp)
|
||||
list(APPEND PROFILER_OPS profile_gemm_add.cpp)
|
||||
list(APPEND PROFILER_OPS profile_grouped_gemm.cpp)
|
||||
list(APPEND PROFILER_OPS profile_gemm_streamk.cpp)
|
||||
list(APPEND PROFILER_OPS profile_gemm_add_relu.cpp)
|
||||
list(APPEND PROFILER_OPS profile_gemm_add_relu_add_layernorm.cpp)
|
||||
list(APPEND PROFILER_OPS profile_grouped_gemm_fixed_nk.cpp)
|
||||
list(APPEND PROFILER_OPS profile_grouped_gemm_fastgelu.cpp)
|
||||
list(APPEND PROFILER_OPS profile_grouped_gemm_tile_loop.cpp)
|
||||
list(APPEND PROFILER_OPS profile_grouped_gemm_multiply_tile_loop.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_gemm_reduce.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_batched_gemm_add_relu_gemm_add.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_gemm_add.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_grouped_gemm.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_gemm_streamk.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_gemm_add_relu.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_gemm_add_relu_add_layernorm.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_grouped_gemm_fixed_nk.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_grouped_gemm_fastgelu.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_grouped_gemm_tile_loop.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_grouped_gemm_multiply_tile_loop.cpp)
|
||||
endif()
|
||||
if(SUPPORTED_GPU_TARGETS MATCHES "gfx9[45]|gfx12")
|
||||
list(APPEND PROFILER_OPS profile_gemm_multiply_multiply_wp.cpp)
|
||||
list(APPEND PROFILER_OPS profile_gemm_ab_scale.cpp)
|
||||
list(APPEND PROFILER_OPS profile_gemm_blockscale_wp.cpp)
|
||||
list(APPEND PROFILER_OPS profile_gemm_universal_preshuffle.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_gemm_multiply_multiply_wp.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_gemm_ab_scale.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_gemm_blockscale_wp.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_gemm_universal_preshuffle.cpp)
|
||||
endif()
|
||||
if(SUPPORTED_GPU_TARGETS MATCHES "gfx95")
|
||||
list(APPEND PROFILER_OPS profile_gemm_mx.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_gemm_mx.cpp)
|
||||
endif()
|
||||
list(APPEND PROFILER_OPS profile_batched_gemm_reduce.cpp)
|
||||
list(APPEND PROFILER_OPS profile_gemm_add_multiply.cpp)
|
||||
list(APPEND PROFILER_OPS profile_gemm_add.cpp)
|
||||
list(APPEND PROFILER_OPS profile_gemm_bias_add_reduce.cpp)
|
||||
list(APPEND PROFILER_OPS profile_gemm_splitk.cpp)
|
||||
list(APPEND PROFILER_OPS profile_gemm_universal_batched.cpp)
|
||||
list(APPEND PROFILER_OPS profile_gemm_universal_streamk.cpp)
|
||||
list(APPEND PROFILER_OPS profile_conv_fwd_bias_relu.cpp)
|
||||
list(APPEND PROFILER_OPS profile_conv_fwd_bias_relu_add.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_batched_gemm_reduce.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_gemm_add_multiply.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_gemm_add.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_gemm_bias_add_reduce.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_gemm_splitk.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_gemm_universal_batched.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_gemm_universal_streamk.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_conv_fwd_bias_relu.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_conv_fwd_bias_relu_add.cpp)
|
||||
list(APPEND PROFILER_OPS profile_conv_bwd_data.cpp)
|
||||
list(APPEND PROFILER_OPS profile_conv_fwd.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_conv_fwd.cpp)
|
||||
endif()
|
||||
|
||||
if((SUPPORTED_GPU_TARGETS MATCHES "gfx9" AND (DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)) OR
|
||||
(SUPPORTED_GPU_TARGETS MATCHES "gfx1[12]"))
|
||||
list(APPEND PROFILER_OPS profile_gemm_bilinear.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_gemm_bilinear.cpp)
|
||||
endif()
|
||||
if(SUPPORTED_GPU_TARGETS MATCHES "gfx(9[45]|1[12])")
|
||||
list(APPEND PROFILER_OPS profile_gemm_multiply_multiply.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_gemm_multiply_mkultiply.cpp)
|
||||
endif()
|
||||
|
||||
if(SUPPORTED_GPU_TARGETS MATCHES "gfx9|gfx1[12]")
|
||||
list(APPEND PROFILER_OPS profile_gemm_universal.cpp)
|
||||
list(APPEND PROFILER_OPS profile_batched_gemm.cpp)
|
||||
list(APPEND PROFILER_OPS profile_batched_gemm_b_scale.cpp)
|
||||
list(APPEND PROFILER_OPS profile_gemm_b_scale.cpp)
|
||||
list(APPEND PROFILER_OPS profile_gemm_universal_reduce.cpp)
|
||||
list(APPEND PROFILER_OPS profile_grouped_conv_fwd.cpp)
|
||||
list(APPEND PROFILER_OPS profile_grouped_conv_fwd_bias_clamp.cpp)
|
||||
list(APPEND PROFILER_OPS profile_grouped_conv_fwd_bias_bnorm_clamp.cpp)
|
||||
list(APPEND PROFILER_OPS profile_grouped_conv_fwd_clamp.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_gemm_universal.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_batched_gemm.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_batched_gemm_b_scale.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_gemm_b_scale.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_gemm_universal_reduce.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_grouped_conv_fwd.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_grouped_conv_fwd_bias_clamp.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_grouped_conv_fwd_bias_bnorm_clamp.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_grouped_conv_fwd_clamp.cpp)
|
||||
list(APPEND PROFILER_OPS profile_grouped_conv_bwd_data.cpp)
|
||||
list(APPEND PROFILER_OPS profile_grouped_conv_fwd_bilinear.cpp)
|
||||
list(APPEND PROFILER_OPS profile_grouped_conv_bwd_weight.cpp)
|
||||
list(APPEND PROFILER_OPS profile_grouped_conv_fwd_outelementop.cpp)
|
||||
list(APPEND PROFILER_OPS profile_gemm_multi_abd.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_grouped_conv_fwd_bilinear.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_grouped_conv_bwd_weight.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_grouped_conv_fwd_outelementop.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_gemm_multi_abd.cpp)
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
list(APPEND PROFILER_OPS profile_gemm_add_multiply.cpp)
|
||||
list(APPEND PROFILER_OPS profile_gemm_multiply_add.cpp)
|
||||
list(APPEND PROFILER_OPS profile_gemm_add_silu.cpp)
|
||||
list(APPEND PROFILER_OPS profile_gemm_fastgelu.cpp)
|
||||
list(APPEND PROFILER_OPS profile_gemm_add_fastgelu.cpp)
|
||||
list(APPEND PROFILER_OPS profile_gemm_add_add_fastgelu.cpp)
|
||||
list(APPEND PROFILER_SOURCES profile_gemm_add.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_gemm_add_multiply.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_gemm_multiply_add.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_gemm_add_silu.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_gemm_fastgelu.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_gemm_add_fastgelu.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_gemm_add_add_fastgelu.cpp)
|
||||
# list(APPEND PROFILER_SOURCES profile_gemm_add.cpp)
|
||||
endif()
|
||||
list(APPEND PROFILER_OPS profile_batched_gemm_gemm.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_batched_gemm_gemm.cpp)
|
||||
endif()
|
||||
|
||||
if(DL_KERNELS)
|
||||
list(APPEND PROFILER_OPS profile_batched_gemm_multi_d.cpp)
|
||||
list(APPEND PROFILER_OPS profile_grouped_conv_bwd_weight.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_batched_gemm_multi_d.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_grouped_conv_bwd_weight.cpp)
|
||||
endif()
|
||||
|
||||
if(CK_ENABLE_INT8)
|
||||
list(APPEND PROFILER_OPS profile_gemm_quantization.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_gemm_quantization.cpp)
|
||||
endif()
|
||||
|
||||
set(PROFILER_SOURCES profiler.cpp)
|
||||
@@ -152,131 +152,131 @@ endif()
|
||||
|
||||
|
||||
set(DEVICE_INSTANCES "")
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_normalization_fwd_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_normalization_bwd_data_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_normalization_bwd_gamma_beta_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_softmax_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_reduce_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_batchnorm_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_pool2d_fwd_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_pool3d_fwd_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_avg_pool2d_bwd_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_avg_pool3d_bwd_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_max_pool_bwd_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_image_to_column_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_column_to_image_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_transpose_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_permute_scale_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_gemm_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_normalization_fwd_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_normalization_bwd_data_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_normalization_bwd_gamma_beta_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_softmax_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_reduce_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_batchnorm_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_pool2d_fwd_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_pool3d_fwd_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_avg_pool2d_bwd_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_avg_pool3d_bwd_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_max_pool_bwd_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_image_to_column_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_column_to_image_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_transpose_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_permute_scale_instance)
|
||||
|
||||
if(SUPPORTED_GPU_TARGETS MATCHES "gfx9|gfx1[12]")
|
||||
if(DTYPES MATCHES "fp32" OR DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES)
|
||||
list(APPEND DEVICE_INSTANCES device_contraction_bilinear_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_contraction_scale_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_contraction_bilinear_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_contraction_scale_instance)
|
||||
endif()
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_add_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_batched_gemm_gemm_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_add_add_fastgelu_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_fastgelu_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_batched_gemm_add_relu_gemm_add_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_gemm_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_streamk_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_add_relu_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_add_relu_add_layernorm_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_gemm_fixed_nk_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_gemm_fastgelu_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_gemm_tile_loop_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_gemm_add_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_batched_gemm_gemm_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_gemm_add_add_fastgelu_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_gemm_fastgelu_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_batched_gemm_add_relu_gemm_add_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_grouped_gemm_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_gemm_streamk_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_gemm_add_relu_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_gemm_add_relu_add_layernorm_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_grouped_gemm_fixed_nk_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_grouped_gemm_fastgelu_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_grouped_gemm_tile_loop_instance)
|
||||
endif()
|
||||
list(APPEND DEVICE_INSTANCES device_batched_gemm_reduce_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_batched_gemm_reduce_instance)
|
||||
if(SUPPORTED_GPU_TARGETS MATCHES "gfx9[45]|gfx12")
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_multiply_multiply_wp_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_universal_preshuffle_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_gemm_multiply_multiply_wp_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_gemm_universal_preshuffle_instance)
|
||||
endif()
|
||||
if(SUPPORTED_GPU_TARGETS MATCHES "gfx9[45]|gfx1[12]")
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_ab_scale_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_blockscale_wp_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_gemm_ab_scale_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_gemm_blockscale_wp_instance)
|
||||
endif()
|
||||
if(SUPPORTED_GPU_TARGETS MATCHES "gfx95")
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_mx_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_gemm_mx_instance)
|
||||
endif()
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_splitk_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_universal_batched_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_universal_streamk_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_add_multiply_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_add_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_reduce_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_bias_add_reduce_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_conv2d_fwd_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_conv2d_fwd_bias_relu_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_conv2d_fwd_bias_relu_add_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_gemm_splitk_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_gemm_universal_batched_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_gemm_universal_streamk_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_gemm_add_multiply_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_gemm_add_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_gemm_reduce_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_gemm_bias_add_reduce_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_conv2d_fwd_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_conv2d_fwd_bias_relu_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_conv2d_fwd_bias_relu_add_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_conv1d_bwd_data_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_conv3d_bwd_data_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_conv2d_bwd_data_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_convscale_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_convinvscale_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_convscale_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_convinvscale_instance)
|
||||
endif()
|
||||
|
||||
if((SUPPORTED_GPU_TARGETS MATCHES "gfx9" AND (DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)) OR
|
||||
(SUPPORTED_GPU_TARGETS MATCHES "gfx1[12]" ))
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_bilinear_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_gemm_bilinear_instance)
|
||||
endif()
|
||||
if(SUPPORTED_GPU_TARGETS MATCHES "gfx(9[45]|1[12])")
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_multiply_multiply_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_gemm_multiply_multiply_instance)
|
||||
endif()
|
||||
|
||||
if(SUPPORTED_GPU_TARGETS MATCHES "gfx9|gfx1[12]")
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_universal_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_batched_gemm_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_b_scale_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_universal_reduce_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_batched_gemm_b_scale_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_gemm_universal_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_batched_gemm_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_gemm_b_scale_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_gemm_universal_reduce_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_batched_gemm_b_scale_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_conv2d_bwd_data_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_bwd_data_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_conv1d_fwd_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_conv2d_fwd_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_conv2d_fwd_clamp_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_clamp_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_scale_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_conv2d_fwd_bias_clamp_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_bias_clamp_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_conv2d_fwd_bias_bnorm_clamp_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_bias_bnorm_clamp_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_bilinear_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_add_relu_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_multi_abd_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_grouped_conv1d_fwd_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_grouped_conv2d_fwd_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_grouped_conv2d_fwd_clamp_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_clamp_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_scale_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_grouped_conv2d_fwd_bias_clamp_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_bias_clamp_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_grouped_conv2d_fwd_bias_bnorm_clamp_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_bias_bnorm_clamp_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_bilinear_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_gemm_add_relu_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_gemm_multi_abd_instance)
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_add_multiply_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_multiply_add_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_add_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_add_silu_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_fastgelu_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_add_fastgelu_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_add_add_fastgelu_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_gemm_add_multiply_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_gemm_multiply_add_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_gemm_add_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_gemm_add_silu_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_gemm_fastgelu_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_gemm_add_fastgelu_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_gemm_add_add_fastgelu_instance)
|
||||
endif()
|
||||
list(APPEND DEVICE_INSTANCES device_batched_gemm_gemm_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_conv1d_bwd_weight_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_conv2d_bwd_weight_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_convnd_bwd_weight_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_bwd_weight_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_batched_gemm_gemm_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_grouped_conv1d_bwd_weight_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_grouped_conv2d_bwd_weight_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_grouped_convnd_bwd_weight_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_grouped_conv3d_bwd_weight_instance)
|
||||
endif()
|
||||
|
||||
if(SUPPORTED_GPU_TARGETS MATCHES "gfx9")
|
||||
if(CK_EXPERIMENTAL_BUILDER)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_conv_fwd_tile_instances)
|
||||
# list(APPEND DEVICE_INSTANCES device_grouped_conv_fwd_tile_instances)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if(DL_KERNELS)
|
||||
list(APPEND DEVICE_INSTANCES device_batched_gemm_multi_d_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_conv1d_bwd_weight_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_conv2d_bwd_weight_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_bwd_weight_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_batched_gemm_multi_d_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_grouped_conv1d_bwd_weight_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_grouped_conv2d_bwd_weight_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_grouped_conv3d_bwd_weight_instance)
|
||||
endif()
|
||||
|
||||
if(CK_ENABLE_INT8)
|
||||
list(APPEND DEVICE_INSTANCES device_quantization_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_quantization_instance)
|
||||
endif()
|
||||
|
||||
set(PROFILER_LIBS utility getopt::getopt)
|
||||
|
||||
Reference in New Issue
Block a user