added draft of conv bwd data direct loads

This commit is contained in:
Jakub Piasecki
2026-01-29 20:11:26 +00:00
parent f16d9100e4
commit 6f27bf1d7f
10 changed files with 2635 additions and 327 deletions

View File

@@ -187,6 +187,7 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
CThreadBuffer& c_thread_buf,
index_t num_loop) const
{
if(threadIdx.x == 0) printf("intra\n");
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
@@ -212,6 +213,7 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
index_t i = 0;
do
{
if(threadIdx.x == 0) printf("hotloop: %d\n", i);
// -------------------------------------------------------------------------------------------
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
@@ -280,6 +282,7 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
// tail
if constexpr(TailNum == TailNumber::Full)
{
if(threadIdx.x == 0) printf("tail\n");
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
@@ -919,6 +922,7 @@ struct BlockwiseGemmXdlopsDirectLoad_pipeline_v1<BlockGemmPipelineScheduler::Int
CThreadBuffer& c_thread_buf,
index_t num_loop) const
{
if(threadIdx.x == 0) printf("v1 intra directload, num_loop: %d\n", num_loop);
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
@@ -942,6 +946,7 @@ struct BlockwiseGemmXdlopsDirectLoad_pipeline_v1<BlockGemmPipelineScheduler::Int
index_t i = 0;
do
{
if(threadIdx.x == 0) printf("has Main loop %d\n", i);
static_for<0, KRepeat, 1>{}([&](auto k) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
@@ -981,6 +986,14 @@ struct BlockwiseGemmXdlopsDirectLoad_pipeline_v1<BlockGemmPipelineScheduler::Int
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
if(threadIdx.x == 0) {
printf("a: %f b: %f\n",
static_cast<float>(a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}]),
static_cast<float>(b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}]));
}
});
using mfma_input_type =
@@ -1007,6 +1020,7 @@ struct BlockwiseGemmXdlopsDirectLoad_pipeline_v1<BlockGemmPipelineScheduler::Int
// tail
if constexpr(TailNum == TailNumber::Full)
{
if(threadIdx.x == 0) printf("Tail full\n");
static_for<0, KRepeat, 1>{}([&](auto k) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
@@ -1032,6 +1046,10 @@ struct BlockwiseGemmXdlopsDirectLoad_pipeline_v1<BlockGemmPipelineScheduler::Int
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
if(threadIdx.x == 0) {
printf("Repeat: (M N K): (%d, %d, %d)\n", m0.value, n0.value, k0.value);
}
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
@@ -1039,6 +1057,17 @@ struct BlockwiseGemmXdlopsDirectLoad_pipeline_v1<BlockGemmPipelineScheduler::Int
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
if(threadIdx.x == 0) {
printf("a: %f b: %f a_off: %d b_off: %d\n",
static_cast<float>(a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}]),
static_cast<float>(b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}]),
a_thread_desc_.CalculateOffset(make_tuple(m0, I0, k0, ik)),
b_thread_desc_.CalculateOffset(make_tuple(n0, I0, k0, ik))
);
}
});
using mfma_input_type =

View File

@@ -151,7 +151,8 @@ template <typename ALayout,
typename LDSTypeA = ADataType,
typename LDSTypeB = BDataType,
bool DoElementwiseBeforeCShuffle = false,
bool DirectLoad = false>
bool DirectLoad = false,
bool LdsScalarLoadToVgpr = false>
struct GridwiseGemmMultiD_xdl_cshuffle_v3
: public GridwiseGemm_xdl_cshuffle_base<
ALayout,
@@ -197,7 +198,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
CDEShuffleBlockTransferScalarPerVectors,
ComputeTypeA,
ComputeTypeB,
BlkGemmPipelineVer == BlockGemmPipelineVersion::v4,
BlkGemmPipelineVer == BlockGemmPipelineVersion::v1,
DirectLoad>
{
static_assert((is_same_v<AElementwiseOperation, tensor_operation::element_wise::PassThrough> &&
@@ -248,7 +249,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
CDEShuffleBlockTransferScalarPerVectors,
ComputeTypeA,
ComputeTypeB,
BlkGemmPipelineVer == BlockGemmPipelineVersion::v4,
BlkGemmPipelineVer == BlockGemmPipelineVersion::v1,
DirectLoad>;
using Base::AK0Number;
@@ -355,7 +356,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
}
template <typename GridDesc_K0_MN_K1_T, index_t K0Number, index_t K1Value>
__host__ __device__ static auto TransformGrid(GridDesc_K0_MN_K1_T& desc)
__host__ __device__ static auto TransformGrid(const GridDesc_K0_MN_K1_T& desc)
{
if constexpr(!DirectLoad)
@@ -695,7 +696,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
}
template <typename DsGridDesc>
__device__ static constexpr auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
__host__ __device__ static constexpr auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
const DsGridDesc& ds_grid_desc_m_n, index_t MBlock, index_t NBlock)
{
return generate_tuple(
@@ -734,6 +735,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
MBlock{CalculateMBlock(M_)},
NBlock{CalculateNBlock(N_)}
{
Print();
}
__host__ void Print() const
@@ -884,7 +886,12 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
make_tuple(BK0Number, Number<NPerBlock>{}, BK1Number),
make_tuple(Number<NPerBlock + BBlockLdsExtraN>{} * BK1Number, BK1Number, I1));
}
else
else if constexpr(DirectLoad && BBlockTransferSrcVectorDim == 1)
{
return make_naive_tensor_descriptor(
make_tuple(BK0Number, Number<NPerBlock>{}, BK1Number),
make_tuple(Number<NPerBlock * BK1Number>{}, I1, Number<NPerBlock>{}));
} else
{
return Base::GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(DeviceArch{});
}
@@ -920,7 +927,8 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
MXdlPerWave,
NXdlPerWave,
KPack,
DirectLoad>())>;
DirectLoad,
LdsScalarLoadToVgpr>())>;
template <typename DeviceArch>
__device__ static constexpr index_t GetSharedMemoryNumberOfByte(DeviceArch)
@@ -1213,7 +1221,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
}
template <typename CGridDesc>
__device__ static constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
__device__ __host__ static constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock)
{
const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
@@ -1320,7 +1328,9 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1,
const DsGridDesc_M_N& ds_grid_desc_m_n,
const CGridDesc_M_N& c_grid_desc_m_n)
const CGridDesc_M_N& c_grid_desc_m_n,
const index_t k_batch = 1,
const index_t k_idx = 0)
{
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
@@ -1346,6 +1356,13 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]);
//const index_t n_block_data_idx_on_grid =__builtin_amdgcn_readfirstlane(k_id * KPerBlock);
const index_t num_ak0_per_block =
__builtin_amdgcn_readfirstlane(a_grid_desc_ak0_m_ak1.GetLength(I0) / k_batch);
const index_t num_bk0_per_block =
__builtin_amdgcn_readfirstlane(b_grid_desc_bk0_n_bk1.GetLength(I0) / k_batch);
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const index_t m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_m_id * MPerBlock);
@@ -1364,6 +1381,20 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
constexpr auto b_block_desc_bk0_n_bk1 =
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(get_device_arch());
if(threadIdx.x == 0) {
if(is_same<tensor_layout::gemm::RowMajor, ALayout>::value) {
printf("A RowMajor\n");
} else if(is_same<tensor_layout::gemm::ColumnMajor, ALayout>::value) {
printf("A Colmajor\n");
}
if(is_same<tensor_layout::gemm::RowMajor, BLayout>::value) {
printf("B RowMajor\n");
} else if(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value) {
printf("B Colmajor\n");
}
}
auto get_a_blockwise_copy = [&]() {
if constexpr(DirectLoad)
{
@@ -1381,7 +1412,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
2,
ABlockTransferSrcScalarPerVector>(
a_grid_desc_ak0_m_ak1,
make_multi_index(0, m_block_data_idx_on_grid, 0),
make_multi_index(num_ak0_per_block * k_idx, m_block_data_idx_on_grid, 0),
a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0));
}
@@ -1411,7 +1442,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
true,
BlockwiseGemmPipe::GlobalBufferNum>(
a_grid_desc_ak0_m_ak1,
make_multi_index(0, m_block_data_idx_on_grid, 0),
make_multi_index(num_ak0_per_block * k_idx, m_block_data_idx_on_grid, 0),
a_element_op,
a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0),
@@ -1434,10 +1465,10 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
decltype(b_block_desc_bk0_n_bk1),
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
2,
1, // enforced earlier
BBlockTransferSrcScalarPerVector>(
b_grid_desc_bk0_n_bk1,
make_multi_index(0, n_block_data_idx_on_grid, 0),
make_multi_index(num_bk0_per_block * k_idx, n_block_data_idx_on_grid, 0),
b_block_desc_bk0_n_bk1,
make_multi_index(0, 0, 0));
}
@@ -1467,7 +1498,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
true,
BlockwiseGemmPipe::GlobalBufferNum>(
b_grid_desc_bk0_n_bk1,
make_multi_index(0, n_block_data_idx_on_grid, 0),
make_multi_index(num_bk0_per_block * k_idx, n_block_data_idx_on_grid, 0),
b_element_op,
b_block_desc_bk0_n_bk1,
make_multi_index(0, 0, 0),
@@ -1486,6 +1517,10 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<LDSTypeA*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
if(threadIdx.x == 0) {
printf("a size aligned: %ld, a size: %ld b size: %ld\n", a_block_space_size_aligned.value, a_block_desc_ak0_m_ak1.GetElementSpaceSize().value, b_block_desc_bk0_n_bk1.GetElementSpaceSize().value);
}
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<LDSTypeB*>(p_shared) +
a_block_space_size_aligned * sizeof(LDSTypeA) / sizeof(LDSTypeB),
@@ -1501,7 +1536,11 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
KPerBlock);
(KPerBlock * k_batch));
// if(threadIdx.x == 0) {
// printf("num_k block main loop: %d\n m_block_data_idx_on_grid: %d\n n_block_data_idx_on_grid: %d\n", num_k_block_main_loop, m_block_data_idx_on_grid, n_block_data_idx_on_grid);
// }
blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(a_grid_desc_ak0_m_ak1,
a_block_desc_ak0_m_ak1,
@@ -1825,6 +1864,10 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
KPerBlock);
if(threadIdx.x == 0) {
printf("num_k block main loop: %d\n m_block_data_idx_on_grid: %d\n n_block_data_idx_on_grid: %d\n", num_k_block_main_loop, m_block_data_idx_on_grid, n_block_data_idx_on_grid);
}
blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(a_grid_desc_ak0_m_ak1,
a_block_desc_ak0_m_ak1,
a_blockwise_copy,

View File

@@ -0,0 +1,90 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v3.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using BF16 = ck::bhalf_t;
using F16 = ck::half_t;
using F32 = float;
using BF8 = ck::bf8_t;
using F8 = ck::f8_t;
using Empty_Tuple = ck::Tuple<>;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using namespace ck::tensor_layout::convolution;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto ConvBwdDataDefault = ConvolutionBackwardDataSpecialization::Default;
static constexpr auto ConvBwdDataFilter1x1Stride1Pad0 =
ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0;
// template <index_t NDimSpatial,
// typename ALayout,
// typename BLayout,
// typename DsLayout,
// typename ELayout,
// ConvolutionBackwardDataSpecialization ConvSpec>
// using device_grouped_conv_bwd_data_xdl_v3_f16_16_16_instances = std::tuple<
// // clang-format off
// // ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle | CShuffle | CDEBlockTransfer| CDEBlockTransfer|
// // ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat | _MBlock_MPerBlock| ScalarPerVector|
// // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| |
// // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 64, 64, 32, 64, 8, 8, 16, 16, 4, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 4>, S<8,8,8>>,
// DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 64, 64, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 4>, S<8,8,8>>
// // clang-format on
// >;
template <index_t NDimSpatial,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
ConvolutionBackwardDataSpecialization ConvSpec>
using device_grouped_conv_bwd_data_xdl_v3_f16_instances = std::tuple<
// clang-format off
// ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle | CShuffle | CDEBlockTransfer| CDEBlockTransfer|
// ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat | _MBlock_MPerBlock| ScalarPerVector|
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| |
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// generic instance
//DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 1, S<4, 2, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 16, 1, 4>, S<2,2,2>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, 1, 1, true>
// DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, S<1,1,1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, 1, 1, false>,
// DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 128, 16, 32, 64, 8, 8, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, S<2,2,2>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, 1, 1, true>,
//DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 128, 16, 32, 64, 8, 8, 16, 16, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 8>, S<1,1,1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, 1, 1, false>
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 256, 64, 64, 64, 8, 8, 16, 16, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, S<2,2,2>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, 1, 1, true>
// DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 256, 64, 64, 64, 8, 8, 16, 16, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, S<2,2,2>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, 1, 1, true>,
// DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 256, 64, 64, 64, 8, 8, 16, 16, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 4>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, S<2,2,2>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, 1, 1, true>
//DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 128, 16, 32, 64, 8, 8, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 2, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, S<1,1,1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, 1, 1, true>,
//DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 128, 16, 32, 64, 8, 8, 16, 16, 1, 1, S<8, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, S<1,1,1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, 1, 1, true>
// DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 64, 64, 32, 64, 8, 8, 16, 16, 4, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 4>, S<8,8,8>>,
// DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 64, 64, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 4>, S<8,8,8>>,
// DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 128, 128, 64, 32, 8, 8, 16, 16, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 8>, S<8,8,8>>,
// DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 128, 32, 64, 32, 8, 8, 16, 16, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 8>, S<8,8,8>>,
// DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 128, 128, 64, 64, 8, 8, 16, 16, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, S<8,8,8>>
// clang-format on
>;
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -79,7 +79,7 @@ struct DeviceOperationInstanceFactory<
is_same_v<OutDataType, F16> && is_same_v<ComputeTypeA, F16> &&
is_same_v<ComputeTypeB, F16>)
{
add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f16_instances(op_ptrs);
// add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f16_instances(op_ptrs);
}
#endif
#ifdef CK_ENABLE_FP32
@@ -87,7 +87,7 @@ struct DeviceOperationInstanceFactory<
is_same_v<OutDataType, F32> && is_same_v<ComputeTypeA, F32> &&
is_same_v<ComputeTypeB, F32>)
{
add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f32_instances(op_ptrs);
// add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f32_instances(op_ptrs);
}
#endif
#ifdef CK_ENABLE_BF16
@@ -95,8 +95,8 @@ struct DeviceOperationInstanceFactory<
is_same_v<OutDataType, BF16> && is_same_v<ComputeTypeA, BF16> &&
is_same_v<ComputeTypeB, BF16>)
{
add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_bf16_instances(
op_ptrs);
// add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_bf16_instances(
// op_ptrs);
}
#endif
}
@@ -108,11 +108,12 @@ struct DeviceOperationInstanceFactory<
is_same_v<OutDataType, F16> && is_same_v<ComputeTypeA, F16> &&
is_same_v<ComputeTypeB, F16>)
{
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_instances(op_ptrs);
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_16_16_instances(
op_ptrs);
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_optimized_loads_instances(
op_ptrs);
add_device_grouped_conv2d_bwd_data_xdl_v3_nhwgk_gkyxc_nhwgc_f16_instances(op_ptrs);
// add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_instances(op_ptrs);
// add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_16_16_instances(
// op_ptrs);
// add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_optimized_loads_instances(
// op_ptrs);
}
#endif
if constexpr(is_same_v<InDataType, F32> && is_same_v<WeiDataType, F32> &&
@@ -123,23 +124,23 @@ struct DeviceOperationInstanceFactory<
#ifdef CK_ENABLE_TF32
if constexpr(is_same_v<ComputeTypeA, TF32>)
{
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_tf32_instances(
op_ptrs);
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_tf32_16_16_instances(
op_ptrs);
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_tf32_optimized_loads_instances(
op_ptrs);
// add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_tf32_instances(
// op_ptrs);
// add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_tf32_16_16_instances(
// op_ptrs);
// add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_tf32_optimized_loads_instances(
// op_ptrs);
}
#endif
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<ComputeTypeA, F32>)
{
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_instances(
op_ptrs);
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_16_16_instances(
op_ptrs);
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_optimized_loads_instances(
op_ptrs);
// add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_instances(
// op_ptrs);
// add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_16_16_instances(
// op_ptrs);
// add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_optimized_loads_instances(
// op_ptrs);
}
#endif
}
@@ -148,12 +149,12 @@ struct DeviceOperationInstanceFactory<
is_same_v<OutDataType, BF16> && is_same_v<ComputeTypeA, BF16> &&
is_same_v<ComputeTypeB, BF16>)
{
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_instances(
op_ptrs);
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_16_16_instances(
op_ptrs);
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_optimized_loads_instances(
op_ptrs);
// add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_instances(
// op_ptrs);
// add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_16_16_instances(
// op_ptrs);
// add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_optimized_loads_instances(
// op_ptrs);
}
#endif
}
@@ -165,7 +166,7 @@ struct DeviceOperationInstanceFactory<
is_same_v<OutDataType, F16> && is_same_v<ComputeTypeA, F16> &&
is_same_v<ComputeTypeB, F16>)
{
add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkyxc_ngchw_f16_instances(op_ptrs);
// add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkyxc_ngchw_f16_instances(op_ptrs);
}
#endif
#ifdef CK_ENABLE_FP32
@@ -173,7 +174,7 @@ struct DeviceOperationInstanceFactory<
is_same_v<OutDataType, F32> && is_same_v<ComputeTypeA, F32> &&
is_same_v<ComputeTypeB, F32>)
{
add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkyxc_ngchw_f32_instances(op_ptrs);
// add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkyxc_ngchw_f32_instances(op_ptrs);
}
#endif
#ifdef CK_ENABLE_BF16
@@ -181,8 +182,8 @@ struct DeviceOperationInstanceFactory<
is_same_v<OutDataType, BF16> && is_same_v<ComputeTypeA, BF16> &&
is_same_v<ComputeTypeB, BF16>)
{
add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkyxc_ngchw_bf16_instances(
op_ptrs);
// add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkyxc_ngchw_bf16_instances(
// op_ptrs);
}
#endif
}
@@ -194,11 +195,11 @@ struct DeviceOperationInstanceFactory<
is_same_v<OutDataType, F16> && is_same_v<ComputeTypeA, F16> &&
is_same_v<ComputeTypeB, F16>)
{
add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f16_instances(op_ptrs);
add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f16_16_16_instances(
op_ptrs);
add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f16_vec_transpose_instances(
op_ptrs);
// add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f16_instances(op_ptrs);
// add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f16_16_16_instances(
// op_ptrs);
// add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f16_vec_transpose_instances(
// op_ptrs);
}
#endif
#ifdef CK_ENABLE_FP32
@@ -206,11 +207,11 @@ struct DeviceOperationInstanceFactory<
is_same_v<OutDataType, F32> && is_same_v<ComputeTypeA, F32> &&
is_same_v<ComputeTypeB, F32>)
{
add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f32_instances(op_ptrs);
add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f32_16_16_instances(
op_ptrs);
add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f32_vec_transpose_instances(
op_ptrs);
// add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f32_instances(op_ptrs);
// add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f32_16_16_instances(
// op_ptrs);
// add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f32_vec_transpose_instances(
// op_ptrs);
}
#endif
#ifdef CK_ENABLE_BF16
@@ -218,12 +219,12 @@ struct DeviceOperationInstanceFactory<
is_same_v<OutDataType, BF16> && is_same_v<ComputeTypeA, BF16> &&
is_same_v<ComputeTypeB, BF16>)
{
add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_bf16_instances(
op_ptrs);
add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_bf16_16_16_instances(
op_ptrs);
add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_bf16_vec_transpose_instances(
op_ptrs);
// add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_bf16_instances(
// op_ptrs);
// add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_bf16_16_16_instances(
// op_ptrs);
// add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_bf16_vec_transpose_instances(
// op_ptrs);
}
#endif
}
@@ -238,8 +239,8 @@ struct DeviceOperationInstanceFactory<
is_same_v<OutDataType, F16> && is_same_v<ComputeTypeA, F16> &&
is_same_v<ComputeTypeB, F16>)
{
add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f16_instances(
op_ptrs);
// add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f16_instances(
// op_ptrs);
}
#endif
#ifdef CK_ENABLE_FP32
@@ -247,8 +248,8 @@ struct DeviceOperationInstanceFactory<
is_same_v<OutDataType, F32> && is_same_v<ComputeTypeA, F32> &&
is_same_v<ComputeTypeB, F32>)
{
add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f32_instances(
op_ptrs);
// add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f32_instances(
// op_ptrs);
}
#endif
#ifdef CK_ENABLE_BF16
@@ -256,8 +257,8 @@ struct DeviceOperationInstanceFactory<
is_same_v<OutDataType, BF16> && is_same_v<ComputeTypeA, BF16> &&
is_same_v<ComputeTypeB, BF16>)
{
add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_bf16_instances(
op_ptrs);
// add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_bf16_instances(
// op_ptrs);
}
#endif
}
@@ -269,12 +270,12 @@ struct DeviceOperationInstanceFactory<
is_same_v<OutDataType, F16> && is_same_v<ComputeTypeA, F16> &&
is_same_v<ComputeTypeB, F16>)
{
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_instances(
op_ptrs);
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_16_16_instances(
op_ptrs);
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_optimized_loads_instances(
op_ptrs);
// add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_instances(
// op_ptrs);
// add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_16_16_instances(
// op_ptrs);
// add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_optimized_loads_instances(
// op_ptrs);
}
#endif
#if defined CK_ENABLE_FP16 && defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
@@ -282,8 +283,8 @@ struct DeviceOperationInstanceFactory<
is_same_v<OutDataType, F16> && is_same_v<ComputeTypeA, bf8_t> &&
is_same_v<ComputeTypeB, f8_t>)
{
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_input_f16_comp_bf8f8_instances(
op_ptrs);
// add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_input_f16_comp_bf8f8_instances(
// op_ptrs);
}
#endif
if constexpr(is_same_v<InDataType, F32> && is_same_v<WeiDataType, F32> &&
@@ -294,23 +295,23 @@ struct DeviceOperationInstanceFactory<
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<ComputeTypeA, F32>)
{
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_instances(
op_ptrs);
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_16_16_instances(
op_ptrs);
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_optimized_loads_instances(
op_ptrs);
// add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_instances(
// op_ptrs);
// add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_16_16_instances(
// op_ptrs);
// add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_optimized_loads_instances(
// op_ptrs);
}
#endif
#ifdef CK_ENABLE_TF32
if constexpr(is_same_v<ComputeTypeA, TF32>)
{
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_tf32_instances(
op_ptrs);
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_tf32_16_16_instances(
op_ptrs);
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_tf32_optimized_loads_instances(
op_ptrs);
// add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_tf32_instances(
// op_ptrs);
// add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_tf32_16_16_instances(
// op_ptrs);
// add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_tf32_optimized_loads_instances(
// op_ptrs);
}
#endif
}
@@ -319,12 +320,12 @@ struct DeviceOperationInstanceFactory<
is_same_v<OutDataType, BF16> && is_same_v<ComputeTypeA, BF16> &&
is_same_v<ComputeTypeB, BF16>)
{
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_instances(
op_ptrs);
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_16_16_instances(
op_ptrs);
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_optimized_loads_instances(
op_ptrs);
// add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_instances(
// op_ptrs);
// add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_16_16_instances(
// op_ptrs);
// add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_optimized_loads_instances(
// op_ptrs);
}
#endif
}
@@ -336,8 +337,8 @@ struct DeviceOperationInstanceFactory<
is_same_v<OutDataType, F16> && is_same_v<ComputeTypeA, F16> &&
is_same_v<ComputeTypeB, F16>)
{
add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkzyxc_ngcdhw_f16_instances(
op_ptrs);
// add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkzyxc_ngcdhw_f16_instances(
// op_ptrs);
}
#endif
#ifdef CK_ENABLE_FP32
@@ -345,8 +346,8 @@ struct DeviceOperationInstanceFactory<
is_same_v<OutDataType, F32> && is_same_v<ComputeTypeA, F32> &&
is_same_v<ComputeTypeB, F32>)
{
add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkzyxc_ngcdhw_f32_instances(
op_ptrs);
// add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkzyxc_ngcdhw_f32_instances(
// op_ptrs);
}
#endif
#ifdef CK_ENABLE_BF16
@@ -354,8 +355,8 @@ struct DeviceOperationInstanceFactory<
is_same_v<OutDataType, BF16> && is_same_v<ComputeTypeA, BF16> &&
is_same_v<ComputeTypeB, BF16>)
{
add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkzyxc_ngcdhw_bf16_instances(
op_ptrs);
// add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkzyxc_ngcdhw_bf16_instances(
// op_ptrs);
}
#endif
}
@@ -367,12 +368,12 @@ struct DeviceOperationInstanceFactory<
is_same_v<OutDataType, F16> && is_same_v<ComputeTypeA, F16> &&
is_same_v<ComputeTypeB, F16>)
{
add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_f16_instances(
op_ptrs);
add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_f16_16_16_instances(
op_ptrs);
add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_f16_vec_transpose_instances(
op_ptrs);
// add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_f16_instances(
// op_ptrs);
// add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_f16_16_16_instances(
// op_ptrs);
// add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_f16_vec_transpose_instances(
// op_ptrs);
}
#endif
#ifdef CK_ENABLE_FP32
@@ -380,12 +381,12 @@ struct DeviceOperationInstanceFactory<
is_same_v<OutDataType, F32> && is_same_v<ComputeTypeA, F32> &&
is_same_v<ComputeTypeB, F32>)
{
add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_f32_instances(
op_ptrs);
add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_f32_16_16_instances(
op_ptrs);
add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_f32_vec_transpose_instances(
op_ptrs);
// add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_f32_instances(
// op_ptrs);
// add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_f32_16_16_instances(
// op_ptrs);
// add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_f32_vec_transpose_instances(
// op_ptrs);
}
#endif
#ifdef CK_ENABLE_BF16
@@ -393,12 +394,12 @@ struct DeviceOperationInstanceFactory<
is_same_v<OutDataType, BF16> && is_same_v<ComputeTypeA, BF16> &&
is_same_v<ComputeTypeB, BF16>)
{
add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_bf16_instances(
op_ptrs);
add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_bf16_16_16_instances(
op_ptrs);
add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_bf16_vec_transpose_instances(
op_ptrs);
// add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_bf16_instances(
// op_ptrs);
// add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_bf16_16_16_instances(
// op_ptrs);
// add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_bf16_vec_transpose_instances(
// op_ptrs);
}
#endif
}
@@ -416,10 +417,10 @@ struct DeviceOperationInstanceFactory<
is_same_v<OutDataType, F16> && is_same_v<ComputeTypeA, F16> &&
is_same_v<ComputeTypeB, F16>)
{
add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_f16_instances(
op_ptrs);
add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_f16_1x1s1p0_instances(
op_ptrs);
// add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_f16_instances(
// op_ptrs);
// add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_f16_1x1s1p0_instances(
// op_ptrs);
}
#endif
@@ -428,9 +429,9 @@ struct DeviceOperationInstanceFactory<
is_same_v<OutDataType, int8_t> && is_same_v<ComputeTypeA, int8_t> &&
is_same_v<ComputeTypeB, int8_t>)
{
add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_i8_instances(op_ptrs);
add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_i8_1x1s1p0_instances(
op_ptrs);
// add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_i8_instances(op_ptrs);
// add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_i8_1x1s1p0_instances(
// op_ptrs);
}
#endif
}
@@ -442,14 +443,14 @@ struct DeviceOperationInstanceFactory<
is_same_v<OutDataType, F16> && is_same_v<ComputeTypeA, F16> &&
is_same_v<ComputeTypeB, F16>)
{
add_device_grouped_conv2d_bwd_data_wmma_v3_nhwgk_gkyxc_nhwgc_f16_instances(
op_ptrs);
add_device_grouped_conv2d_bwd_data_wmma_v3_nhwgk_gkyxc_nhwgc_f16_16_16_instances(
op_ptrs);
add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_f16_instances(
op_ptrs);
add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_f16_1x1s1p0_instances(
op_ptrs);
// add_device_grouped_conv2d_bwd_data_wmma_v3_nhwgk_gkyxc_nhwgc_f16_instances(
// op_ptrs);
// add_device_grouped_conv2d_bwd_data_wmma_v3_nhwgk_gkyxc_nhwgc_f16_16_16_instances(
// op_ptrs);
// add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_f16_instances(
// op_ptrs);
// add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_f16_1x1s1p0_instances(
// op_ptrs);
}
#endif
#ifdef CK_ENABLE_BF16
@@ -457,10 +458,10 @@ struct DeviceOperationInstanceFactory<
is_same_v<OutDataType, BF16> && is_same_v<ComputeTypeA, BF16> &&
is_same_v<ComputeTypeB, BF16>)
{
add_device_grouped_conv2d_bwd_data_wmma_v3_nhwgk_gkyxc_nhwgc_bf16_instances(
op_ptrs);
add_device_grouped_conv2d_bwd_data_wmma_v3_nhwgk_gkyxc_nhwgc_bf16_16_16_instances(
op_ptrs);
// add_device_grouped_conv2d_bwd_data_wmma_v3_nhwgk_gkyxc_nhwgc_bf16_instances(
// op_ptrs);
// add_device_grouped_conv2d_bwd_data_wmma_v3_nhwgk_gkyxc_nhwgc_bf16_16_16_instances(
// op_ptrs);
}
#endif
#ifdef CK_ENABLE_INT8
@@ -468,9 +469,9 @@ struct DeviceOperationInstanceFactory<
is_same_v<OutDataType, int8_t> && is_same_v<ComputeTypeA, int8_t> &&
is_same_v<ComputeTypeB, int8_t>)
{
add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_i8_instances(op_ptrs);
add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_i8_1x1s1p0_instances(
op_ptrs);
// add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_i8_instances(op_ptrs);
// add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_i8_1x1s1p0_instances(
// op_ptrs);
}
#endif
}
@@ -485,10 +486,10 @@ struct DeviceOperationInstanceFactory<
is_same_v<OutDataType, F16> && is_same_v<ComputeTypeA, F16> &&
is_same_v<ComputeTypeB, F16>)
{
add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_f16_instances(
op_ptrs);
add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_f16_1x1s1p0_instances(
op_ptrs);
// add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_f16_instances(
// op_ptrs);
// add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_f16_1x1s1p0_instances(
// op_ptrs);
}
#endif
@@ -497,10 +498,10 @@ struct DeviceOperationInstanceFactory<
is_same_v<OutDataType, int8_t> && is_same_v<ComputeTypeA, int8_t> &&
is_same_v<ComputeTypeB, int8_t>)
{
add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_i8_instances(
op_ptrs);
add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_i8_1x1s1p0_instances(
op_ptrs);
// add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_i8_instances(
// op_ptrs);
// add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_i8_1x1s1p0_instances(
// op_ptrs);
}
#endif
}
@@ -512,14 +513,14 @@ struct DeviceOperationInstanceFactory<
is_same_v<OutDataType, F16> && is_same_v<ComputeTypeA, F16> &&
is_same_v<ComputeTypeB, F16>)
{
add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_f16_instances(
op_ptrs);
add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_f16_1x1s1p0_instances(
op_ptrs);
add_device_grouped_conv3d_bwd_data_wmma_v3_ndhwgk_gkzyxc_ndhwgc_f16_instances(
op_ptrs);
add_device_grouped_conv3d_bwd_data_wmma_v3_ndhwgk_gkzyxc_ndhwgc_f16_16_16_instances(
op_ptrs);
// add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_f16_instances(
// op_ptrs);
// add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_f16_1x1s1p0_instances(
// op_ptrs);
// add_device_grouped_conv3d_bwd_data_wmma_v3_ndhwgk_gkzyxc_ndhwgc_f16_instances(
// op_ptrs);
// add_device_grouped_conv3d_bwd_data_wmma_v3_ndhwgk_gkzyxc_ndhwgc_f16_16_16_instances(
// op_ptrs);
}
#endif
#ifdef CK_ENABLE_BF16
@@ -527,10 +528,10 @@ struct DeviceOperationInstanceFactory<
is_same_v<OutDataType, BF16> && is_same_v<ComputeTypeA, BF16> &&
is_same_v<ComputeTypeB, BF16>)
{
add_device_grouped_conv3d_bwd_data_wmma_v3_ndhwgk_gkzyxc_ndhwgc_bf16_instances(
op_ptrs);
add_device_grouped_conv3d_bwd_data_wmma_v3_ndhwgk_gkzyxc_ndhwgc_bf16_16_16_instances(
op_ptrs);
// add_device_grouped_conv3d_bwd_data_wmma_v3_ndhwgk_gkzyxc_ndhwgc_bf16_instances(
// op_ptrs);
// add_device_grouped_conv3d_bwd_data_wmma_v3_ndhwgk_gkzyxc_ndhwgc_bf16_16_16_instances(
// op_ptrs);
}
#endif
#ifdef CK_ENABLE_INT8
@@ -538,10 +539,10 @@ struct DeviceOperationInstanceFactory<
is_same_v<OutDataType, int8_t> && is_same_v<ComputeTypeA, int8_t> &&
is_same_v<ComputeTypeB, int8_t>)
{
add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_i8_instances(
op_ptrs);
add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_i8_1x1s1p0_instances(
op_ptrs);
// add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_i8_instances(
// op_ptrs);
// add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_i8_1x1s1p0_instances(
// op_ptrs);
}
#endif
}

View File

@@ -56,6 +56,20 @@ void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_bf16_instances(
#endif
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv2d_bwd_data_xdl_v3_nhwgk_gkyxc_nhwgc_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
NHWGK,
GKYXC,
Empty_Tuple,
NHWGC,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
NHWGK,

View File

@@ -32,6 +32,7 @@ add_instance_library(
xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_f16_vec_transpose_instance.cpp
xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_bf16_vec_transpose_instance.cpp
xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_f32_vec_transpose_instance.cpp
xdl/device_grouped_conv2d_bwd_data_xdl_v3_nhwgc_gkyxc_nhwgk_f16_instance.cpp
wmma/device_grouped_conv2d_bwd_data_wmma_gnhwc_gkyxc_gnhwk_f16_1x1s1p0_instance.cpp
wmma/device_grouped_conv2d_bwd_data_wmma_nhwgc_gkyxc_nhwgk_f16_1x1s1p0_instance.cpp

View File

@@ -0,0 +1,49 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_v3_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv2d_bwd_data_xdl_v3_nhwgk_gkyxc_nhwgc_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
NHWGK,
GKYXC,
Empty_Tuple,
NHWGC,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
// 1. Default
add_device_operation_instances(
instances,
device_grouped_conv_bwd_data_xdl_v3_f16_instances<2,
NHWGK,
GKYXC,
Empty_Tuple,
NHWGC,
ConvBwdDataDefault>{});
// 2. Filter1x1Stride1Pad0
// add_device_operation_instances(
// instances,
// device_grouped_conv_bwd_data_xdl_f16_instances<2,
// NHWGK,
// GKYXC,
// Empty_Tuple,
// NHWGC,
// ConvBwdDataFilter1x1Stride1Pad0>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -121,6 +121,38 @@ bool profile_grouped_conv_bwd_data_impl(int do_verification,
out.GenerateTensorValue(GeneratorTensor_3<OutDataType>{0.0, 1.0});
wei.GenerateTensorValue(GeneratorTensor_3<WeiDataType>{-0.5, 0.5});
break;
case 3:
out.GenerateTensorValue(GeneratorTensor_1<OutDataType>{1});
wei.GenerateTensorValue(GeneratorTensor_1<WeiDataType>{2});
break;
case 4:
out.GenerateTensorValue(GeneratorTensor_1<OutDataType>{2});
wei.GenerateTensorValue(GeneratorTensor_1<WeiDataType>{1});
break;
case 5:
out.GenerateTensorValue(GeneratorTensor_3<OutDataType>{0.0, 1.0});
wei.GenerateTensorValue(GeneratorTensor_1<WeiDataType>{1});
break;
case 6:
out.GenerateTensorValue(GeneratorTensor_1<OutDataType>{1});
wei.GenerateTensorValue(GeneratorTensor_3<WeiDataType>{0.0, 1.0});
break;
case 7:
out.GenerateTensorValue(GeneratorTensor_3<OutDataType>{0.0, 1.0});
wei.GenerateTensorValue(GeneratorTensor_3<WeiDataType>{0.0, 1.0});
break;
case 8:
out.GenerateTensorValue(GeneratorTensor_Sequential<OutDataType, 2>{});
wei.GenerateTensorValue(GeneratorTensor_1<WeiDataType>{1});
break;
case 9:
out.GenerateTensorValue(GeneratorTensor_1<OutDataType>{1});
wei.GenerateTensorValue(GeneratorTensor_Sequential<WeiDataType, 1>{});
break;
case 10:
out.GenerateTensorValue(GeneratorTensor_Sequential<OutDataType, 2>{});
wei.GenerateTensorValue(GeneratorTensor_Sequential<WeiDataType, 1>{});
break;
default:
out.GenerateTensorValue(GeneratorTensor_1<OutDataType>{1});
wei.GenerateTensorValue(GeneratorTensor_1<WeiDataType>{1});
@@ -210,6 +242,7 @@ bool profile_grouped_conv_bwd_data_impl(int do_verification,
// workspace_sz will be equal to 0 for other layout than NGCHW
const std::size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get());
DeviceMem workspace_dev(workspace_sz);
// printf("run impl\n");
op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace_dev.GetDeviceBuffer());
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
@@ -224,8 +257,10 @@ bool profile_grouped_conv_bwd_data_impl(int do_verification,
auto invoker_ptr = op_ptr->MakeInvokerPointer();
// printf("prerun\n");
float avg_time =
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
// printf("post run\n");
std::size_t flop = conv_param.GetFlops();
std::size_t num_btype = conv_param.GetByte<InDataType, WeiDataType, OutDataType>();

View File

@@ -14,118 +14,118 @@ message(STATUS "CK_PROFILER_OP_FILTER: ${CK_PROFILER_OP_FILTER}")
message(STATUS "CK_PROFILER_INSTANCE_FILTER: ${CK_PROFILER_INSTANCE_FILTER}")
set(PROFILER_OPS
profile_gemm.cpp
profile_reduce.cpp
profile_groupnorm_bwd_data.cpp
profile_groupnorm_fwd.cpp
profile_layernorm_bwd_data.cpp
profile_layernorm_bwd_gamma_beta.cpp
profile_groupnorm_bwd_gamma_beta.cpp
profile_layernorm_fwd.cpp
profile_max_pool2d_fwd.cpp
profile_pool3d_fwd.cpp
profile_avg_pool3d_bwd.cpp
profile_max_pool3d_bwd.cpp
profile_avg_pool2d_bwd.cpp
profile_max_pool2d_bwd.cpp
profile_softmax.cpp
profile_batchnorm_fwd.cpp
profile_batchnorm_bwd.cpp
profile_batchnorm_infer.cpp
profile_conv_tensor_rearrange.cpp
profile_transpose.cpp
profile_permute_scale.cpp
profile_gemm_quantization.cpp
# profile_gemm.cpp
# profile_reduce.cpp
# profile_groupnorm_bwd_data.cpp
# profile_groupnorm_fwd.cpp
# profile_layernorm_bwd_data.cpp
# profile_layernorm_bwd_gamma_beta.cpp
# profile_groupnorm_bwd_gamma_beta.cpp
# profile_layernorm_fwd.cpp
# profile_max_pool2d_fwd.cpp
# profile_pool3d_fwd.cpp
# profile_avg_pool3d_bwd.cpp
# profile_max_pool3d_bwd.cpp
# profile_avg_pool2d_bwd.cpp
# profile_max_pool2d_bwd.cpp
# profile_softmax.cpp
# profile_batchnorm_fwd.cpp
# profile_batchnorm_bwd.cpp
# profile_batchnorm_infer.cpp
# profile_conv_tensor_rearrange.cpp
# profile_transpose.cpp
# profile_permute_scale.cpp
# profile_gemm_quantization.cpp
)
if(SUPPORTED_GPU_TARGETS MATCHES "gfx9")
if(DTYPES MATCHES "fp32" OR DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES)
list(APPEND PROFILER_OPS profile_contraction_bilinear.cpp)
list(APPEND PROFILER_OPS profile_contraction_scale.cpp)
# list(APPEND PROFILER_OPS profile_contraction_bilinear.cpp)
# list(APPEND PROFILER_OPS profile_contraction_scale.cpp)
endif()
if(CK_EXPERIMENTAL_BUILDER)
list(APPEND PROFILER_OPS profile_grouped_conv_fwd_tile.cpp)
# list(APPEND PROFILER_OPS profile_grouped_conv_fwd_tile.cpp)
endif()
endif()
if(SUPPORTED_GPU_TARGETS MATCHES "gfx9|gfx1[12]")
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
list(APPEND PROFILER_OPS profile_gemm_reduce.cpp)
list(APPEND PROFILER_OPS profile_batched_gemm_add_relu_gemm_add.cpp)
list(APPEND PROFILER_OPS profile_gemm_add.cpp)
list(APPEND PROFILER_OPS profile_grouped_gemm.cpp)
list(APPEND PROFILER_OPS profile_gemm_streamk.cpp)
list(APPEND PROFILER_OPS profile_gemm_add_relu.cpp)
list(APPEND PROFILER_OPS profile_gemm_add_relu_add_layernorm.cpp)
list(APPEND PROFILER_OPS profile_grouped_gemm_fixed_nk.cpp)
list(APPEND PROFILER_OPS profile_grouped_gemm_fastgelu.cpp)
list(APPEND PROFILER_OPS profile_grouped_gemm_tile_loop.cpp)
list(APPEND PROFILER_OPS profile_grouped_gemm_multiply_tile_loop.cpp)
# list(APPEND PROFILER_OPS profile_gemm_reduce.cpp)
# list(APPEND PROFILER_OPS profile_batched_gemm_add_relu_gemm_add.cpp)
# list(APPEND PROFILER_OPS profile_gemm_add.cpp)
# list(APPEND PROFILER_OPS profile_grouped_gemm.cpp)
# list(APPEND PROFILER_OPS profile_gemm_streamk.cpp)
# list(APPEND PROFILER_OPS profile_gemm_add_relu.cpp)
# list(APPEND PROFILER_OPS profile_gemm_add_relu_add_layernorm.cpp)
# list(APPEND PROFILER_OPS profile_grouped_gemm_fixed_nk.cpp)
# list(APPEND PROFILER_OPS profile_grouped_gemm_fastgelu.cpp)
# list(APPEND PROFILER_OPS profile_grouped_gemm_tile_loop.cpp)
# list(APPEND PROFILER_OPS profile_grouped_gemm_multiply_tile_loop.cpp)
endif()
if(SUPPORTED_GPU_TARGETS MATCHES "gfx9[45]|gfx12")
list(APPEND PROFILER_OPS profile_gemm_multiply_multiply_wp.cpp)
list(APPEND PROFILER_OPS profile_gemm_ab_scale.cpp)
list(APPEND PROFILER_OPS profile_gemm_blockscale_wp.cpp)
list(APPEND PROFILER_OPS profile_gemm_universal_preshuffle.cpp)
# list(APPEND PROFILER_OPS profile_gemm_multiply_multiply_wp.cpp)
# list(APPEND PROFILER_OPS profile_gemm_ab_scale.cpp)
# list(APPEND PROFILER_OPS profile_gemm_blockscale_wp.cpp)
# list(APPEND PROFILER_OPS profile_gemm_universal_preshuffle.cpp)
endif()
if(SUPPORTED_GPU_TARGETS MATCHES "gfx95")
list(APPEND PROFILER_OPS profile_gemm_mx.cpp)
# list(APPEND PROFILER_OPS profile_gemm_mx.cpp)
endif()
list(APPEND PROFILER_OPS profile_batched_gemm_reduce.cpp)
list(APPEND PROFILER_OPS profile_gemm_add_multiply.cpp)
list(APPEND PROFILER_OPS profile_gemm_add.cpp)
list(APPEND PROFILER_OPS profile_gemm_bias_add_reduce.cpp)
list(APPEND PROFILER_OPS profile_gemm_splitk.cpp)
list(APPEND PROFILER_OPS profile_gemm_universal_batched.cpp)
list(APPEND PROFILER_OPS profile_gemm_universal_streamk.cpp)
list(APPEND PROFILER_OPS profile_conv_fwd_bias_relu.cpp)
list(APPEND PROFILER_OPS profile_conv_fwd_bias_relu_add.cpp)
# list(APPEND PROFILER_OPS profile_batched_gemm_reduce.cpp)
# list(APPEND PROFILER_OPS profile_gemm_add_multiply.cpp)
# list(APPEND PROFILER_OPS profile_gemm_add.cpp)
# list(APPEND PROFILER_OPS profile_gemm_bias_add_reduce.cpp)
# list(APPEND PROFILER_OPS profile_gemm_splitk.cpp)
# list(APPEND PROFILER_OPS profile_gemm_universal_batched.cpp)
# list(APPEND PROFILER_OPS profile_gemm_universal_streamk.cpp)
# list(APPEND PROFILER_OPS profile_conv_fwd_bias_relu.cpp)
# list(APPEND PROFILER_OPS profile_conv_fwd_bias_relu_add.cpp)
list(APPEND PROFILER_OPS profile_conv_bwd_data.cpp)
list(APPEND PROFILER_OPS profile_conv_fwd.cpp)
# list(APPEND PROFILER_OPS profile_conv_fwd.cpp)
endif()
if((SUPPORTED_GPU_TARGETS MATCHES "gfx9" AND (DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)) OR
(SUPPORTED_GPU_TARGETS MATCHES "gfx1[12]"))
list(APPEND PROFILER_OPS profile_gemm_bilinear.cpp)
# list(APPEND PROFILER_OPS profile_gemm_bilinear.cpp)
endif()
if(SUPPORTED_GPU_TARGETS MATCHES "gfx(9[45]|1[12])")
list(APPEND PROFILER_OPS profile_gemm_multiply_multiply.cpp)
# list(APPEND PROFILER_OPS profile_gemm_multiply_mkultiply.cpp)
endif()
if(SUPPORTED_GPU_TARGETS MATCHES "gfx9|gfx1[12]")
list(APPEND PROFILER_OPS profile_gemm_universal.cpp)
list(APPEND PROFILER_OPS profile_batched_gemm.cpp)
list(APPEND PROFILER_OPS profile_batched_gemm_b_scale.cpp)
list(APPEND PROFILER_OPS profile_gemm_b_scale.cpp)
list(APPEND PROFILER_OPS profile_gemm_universal_reduce.cpp)
list(APPEND PROFILER_OPS profile_grouped_conv_fwd.cpp)
list(APPEND PROFILER_OPS profile_grouped_conv_fwd_bias_clamp.cpp)
list(APPEND PROFILER_OPS profile_grouped_conv_fwd_bias_bnorm_clamp.cpp)
list(APPEND PROFILER_OPS profile_grouped_conv_fwd_clamp.cpp)
# list(APPEND PROFILER_OPS profile_gemm_universal.cpp)
# list(APPEND PROFILER_OPS profile_batched_gemm.cpp)
# list(APPEND PROFILER_OPS profile_batched_gemm_b_scale.cpp)
# list(APPEND PROFILER_OPS profile_gemm_b_scale.cpp)
# list(APPEND PROFILER_OPS profile_gemm_universal_reduce.cpp)
# list(APPEND PROFILER_OPS profile_grouped_conv_fwd.cpp)
# list(APPEND PROFILER_OPS profile_grouped_conv_fwd_bias_clamp.cpp)
# list(APPEND PROFILER_OPS profile_grouped_conv_fwd_bias_bnorm_clamp.cpp)
# list(APPEND PROFILER_OPS profile_grouped_conv_fwd_clamp.cpp)
list(APPEND PROFILER_OPS profile_grouped_conv_bwd_data.cpp)
list(APPEND PROFILER_OPS profile_grouped_conv_fwd_bilinear.cpp)
list(APPEND PROFILER_OPS profile_grouped_conv_bwd_weight.cpp)
list(APPEND PROFILER_OPS profile_grouped_conv_fwd_outelementop.cpp)
list(APPEND PROFILER_OPS profile_gemm_multi_abd.cpp)
# list(APPEND PROFILER_OPS profile_grouped_conv_fwd_bilinear.cpp)
# list(APPEND PROFILER_OPS profile_grouped_conv_bwd_weight.cpp)
# list(APPEND PROFILER_OPS profile_grouped_conv_fwd_outelementop.cpp)
# list(APPEND PROFILER_OPS profile_gemm_multi_abd.cpp)
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
list(APPEND PROFILER_OPS profile_gemm_add_multiply.cpp)
list(APPEND PROFILER_OPS profile_gemm_multiply_add.cpp)
list(APPEND PROFILER_OPS profile_gemm_add_silu.cpp)
list(APPEND PROFILER_OPS profile_gemm_fastgelu.cpp)
list(APPEND PROFILER_OPS profile_gemm_add_fastgelu.cpp)
list(APPEND PROFILER_OPS profile_gemm_add_add_fastgelu.cpp)
list(APPEND PROFILER_SOURCES profile_gemm_add.cpp)
# list(APPEND PROFILER_OPS profile_gemm_add_multiply.cpp)
# list(APPEND PROFILER_OPS profile_gemm_multiply_add.cpp)
# list(APPEND PROFILER_OPS profile_gemm_add_silu.cpp)
# list(APPEND PROFILER_OPS profile_gemm_fastgelu.cpp)
# list(APPEND PROFILER_OPS profile_gemm_add_fastgelu.cpp)
# list(APPEND PROFILER_OPS profile_gemm_add_add_fastgelu.cpp)
# list(APPEND PROFILER_SOURCES profile_gemm_add.cpp)
endif()
list(APPEND PROFILER_OPS profile_batched_gemm_gemm.cpp)
# list(APPEND PROFILER_OPS profile_batched_gemm_gemm.cpp)
endif()
if(DL_KERNELS)
list(APPEND PROFILER_OPS profile_batched_gemm_multi_d.cpp)
list(APPEND PROFILER_OPS profile_grouped_conv_bwd_weight.cpp)
# list(APPEND PROFILER_OPS profile_batched_gemm_multi_d.cpp)
# list(APPEND PROFILER_OPS profile_grouped_conv_bwd_weight.cpp)
endif()
if(CK_ENABLE_INT8)
list(APPEND PROFILER_OPS profile_gemm_quantization.cpp)
# list(APPEND PROFILER_OPS profile_gemm_quantization.cpp)
endif()
set(PROFILER_SOURCES profiler.cpp)
@@ -152,131 +152,131 @@ endif()
set(DEVICE_INSTANCES "")
list(APPEND DEVICE_INSTANCES device_gemm_instance)
list(APPEND DEVICE_INSTANCES device_normalization_fwd_instance)
list(APPEND DEVICE_INSTANCES device_normalization_bwd_data_instance)
list(APPEND DEVICE_INSTANCES device_normalization_bwd_gamma_beta_instance)
list(APPEND DEVICE_INSTANCES device_softmax_instance)
list(APPEND DEVICE_INSTANCES device_reduce_instance)
list(APPEND DEVICE_INSTANCES device_batchnorm_instance)
list(APPEND DEVICE_INSTANCES device_pool2d_fwd_instance)
list(APPEND DEVICE_INSTANCES device_pool3d_fwd_instance)
list(APPEND DEVICE_INSTANCES device_avg_pool2d_bwd_instance)
list(APPEND DEVICE_INSTANCES device_avg_pool3d_bwd_instance)
list(APPEND DEVICE_INSTANCES device_max_pool_bwd_instance)
list(APPEND DEVICE_INSTANCES device_image_to_column_instance)
list(APPEND DEVICE_INSTANCES device_column_to_image_instance)
list(APPEND DEVICE_INSTANCES device_transpose_instance)
list(APPEND DEVICE_INSTANCES device_permute_scale_instance)
# list(APPEND DEVICE_INSTANCES device_gemm_instance)
# list(APPEND DEVICE_INSTANCES device_normalization_fwd_instance)
# list(APPEND DEVICE_INSTANCES device_normalization_bwd_data_instance)
# list(APPEND DEVICE_INSTANCES device_normalization_bwd_gamma_beta_instance)
# list(APPEND DEVICE_INSTANCES device_softmax_instance)
# list(APPEND DEVICE_INSTANCES device_reduce_instance)
# list(APPEND DEVICE_INSTANCES device_batchnorm_instance)
# list(APPEND DEVICE_INSTANCES device_pool2d_fwd_instance)
# list(APPEND DEVICE_INSTANCES device_pool3d_fwd_instance)
# list(APPEND DEVICE_INSTANCES device_avg_pool2d_bwd_instance)
# list(APPEND DEVICE_INSTANCES device_avg_pool3d_bwd_instance)
# list(APPEND DEVICE_INSTANCES device_max_pool_bwd_instance)
# list(APPEND DEVICE_INSTANCES device_image_to_column_instance)
# list(APPEND DEVICE_INSTANCES device_column_to_image_instance)
# list(APPEND DEVICE_INSTANCES device_transpose_instance)
# list(APPEND DEVICE_INSTANCES device_permute_scale_instance)
if(SUPPORTED_GPU_TARGETS MATCHES "gfx9|gfx1[12]")
if(DTYPES MATCHES "fp32" OR DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES)
list(APPEND DEVICE_INSTANCES device_contraction_bilinear_instance)
list(APPEND DEVICE_INSTANCES device_contraction_scale_instance)
# list(APPEND DEVICE_INSTANCES device_contraction_bilinear_instance)
# list(APPEND DEVICE_INSTANCES device_contraction_scale_instance)
endif()
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
list(APPEND DEVICE_INSTANCES device_gemm_add_instance)
list(APPEND DEVICE_INSTANCES device_batched_gemm_gemm_instance)
list(APPEND DEVICE_INSTANCES device_gemm_add_add_fastgelu_instance)
list(APPEND DEVICE_INSTANCES device_gemm_fastgelu_instance)
list(APPEND DEVICE_INSTANCES device_batched_gemm_add_relu_gemm_add_instance)
list(APPEND DEVICE_INSTANCES device_grouped_gemm_instance)
list(APPEND DEVICE_INSTANCES device_gemm_streamk_instance)
list(APPEND DEVICE_INSTANCES device_gemm_add_relu_instance)
list(APPEND DEVICE_INSTANCES device_gemm_add_relu_add_layernorm_instance)
list(APPEND DEVICE_INSTANCES device_grouped_gemm_fixed_nk_instance)
list(APPEND DEVICE_INSTANCES device_grouped_gemm_fastgelu_instance)
list(APPEND DEVICE_INSTANCES device_grouped_gemm_tile_loop_instance)
# list(APPEND DEVICE_INSTANCES device_gemm_add_instance)
# list(APPEND DEVICE_INSTANCES device_batched_gemm_gemm_instance)
# list(APPEND DEVICE_INSTANCES device_gemm_add_add_fastgelu_instance)
# list(APPEND DEVICE_INSTANCES device_gemm_fastgelu_instance)
# list(APPEND DEVICE_INSTANCES device_batched_gemm_add_relu_gemm_add_instance)
# list(APPEND DEVICE_INSTANCES device_grouped_gemm_instance)
# list(APPEND DEVICE_INSTANCES device_gemm_streamk_instance)
# list(APPEND DEVICE_INSTANCES device_gemm_add_relu_instance)
# list(APPEND DEVICE_INSTANCES device_gemm_add_relu_add_layernorm_instance)
# list(APPEND DEVICE_INSTANCES device_grouped_gemm_fixed_nk_instance)
# list(APPEND DEVICE_INSTANCES device_grouped_gemm_fastgelu_instance)
# list(APPEND DEVICE_INSTANCES device_grouped_gemm_tile_loop_instance)
endif()
list(APPEND DEVICE_INSTANCES device_batched_gemm_reduce_instance)
# list(APPEND DEVICE_INSTANCES device_batched_gemm_reduce_instance)
if(SUPPORTED_GPU_TARGETS MATCHES "gfx9[45]|gfx12")
list(APPEND DEVICE_INSTANCES device_gemm_multiply_multiply_wp_instance)
list(APPEND DEVICE_INSTANCES device_gemm_universal_preshuffle_instance)
# list(APPEND DEVICE_INSTANCES device_gemm_multiply_multiply_wp_instance)
# list(APPEND DEVICE_INSTANCES device_gemm_universal_preshuffle_instance)
endif()
if(SUPPORTED_GPU_TARGETS MATCHES "gfx9[45]|gfx1[12]")
list(APPEND DEVICE_INSTANCES device_gemm_ab_scale_instance)
list(APPEND DEVICE_INSTANCES device_gemm_blockscale_wp_instance)
# list(APPEND DEVICE_INSTANCES device_gemm_ab_scale_instance)
# list(APPEND DEVICE_INSTANCES device_gemm_blockscale_wp_instance)
endif()
if(SUPPORTED_GPU_TARGETS MATCHES "gfx95")
list(APPEND DEVICE_INSTANCES device_gemm_mx_instance)
# list(APPEND DEVICE_INSTANCES device_gemm_mx_instance)
endif()
list(APPEND DEVICE_INSTANCES device_gemm_splitk_instance)
list(APPEND DEVICE_INSTANCES device_gemm_universal_batched_instance)
list(APPEND DEVICE_INSTANCES device_gemm_universal_streamk_instance)
list(APPEND DEVICE_INSTANCES device_gemm_add_multiply_instance)
list(APPEND DEVICE_INSTANCES device_gemm_add_instance)
list(APPEND DEVICE_INSTANCES device_gemm_reduce_instance)
list(APPEND DEVICE_INSTANCES device_gemm_bias_add_reduce_instance)
list(APPEND DEVICE_INSTANCES device_conv2d_fwd_instance)
list(APPEND DEVICE_INSTANCES device_conv2d_fwd_bias_relu_instance)
list(APPEND DEVICE_INSTANCES device_conv2d_fwd_bias_relu_add_instance)
# list(APPEND DEVICE_INSTANCES device_gemm_splitk_instance)
# list(APPEND DEVICE_INSTANCES device_gemm_universal_batched_instance)
# list(APPEND DEVICE_INSTANCES device_gemm_universal_streamk_instance)
# list(APPEND DEVICE_INSTANCES device_gemm_add_multiply_instance)
# list(APPEND DEVICE_INSTANCES device_gemm_add_instance)
# list(APPEND DEVICE_INSTANCES device_gemm_reduce_instance)
# list(APPEND DEVICE_INSTANCES device_gemm_bias_add_reduce_instance)
# list(APPEND DEVICE_INSTANCES device_conv2d_fwd_instance)
# list(APPEND DEVICE_INSTANCES device_conv2d_fwd_bias_relu_instance)
# list(APPEND DEVICE_INSTANCES device_conv2d_fwd_bias_relu_add_instance)
list(APPEND DEVICE_INSTANCES device_conv1d_bwd_data_instance)
list(APPEND DEVICE_INSTANCES device_conv3d_bwd_data_instance)
list(APPEND DEVICE_INSTANCES device_conv2d_bwd_data_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_convscale_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_convinvscale_instance)
# list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_convscale_instance)
# list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_convinvscale_instance)
endif()
if((SUPPORTED_GPU_TARGETS MATCHES "gfx9" AND (DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)) OR
(SUPPORTED_GPU_TARGETS MATCHES "gfx1[12]" ))
list(APPEND DEVICE_INSTANCES device_gemm_bilinear_instance)
# list(APPEND DEVICE_INSTANCES device_gemm_bilinear_instance)
endif()
if(SUPPORTED_GPU_TARGETS MATCHES "gfx(9[45]|1[12])")
list(APPEND DEVICE_INSTANCES device_gemm_multiply_multiply_instance)
# list(APPEND DEVICE_INSTANCES device_gemm_multiply_multiply_instance)
endif()
if(SUPPORTED_GPU_TARGETS MATCHES "gfx9|gfx1[12]")
list(APPEND DEVICE_INSTANCES device_gemm_universal_instance)
list(APPEND DEVICE_INSTANCES device_batched_gemm_instance)
list(APPEND DEVICE_INSTANCES device_gemm_b_scale_instance)
list(APPEND DEVICE_INSTANCES device_gemm_universal_reduce_instance)
list(APPEND DEVICE_INSTANCES device_batched_gemm_b_scale_instance)
# list(APPEND DEVICE_INSTANCES device_gemm_universal_instance)
# list(APPEND DEVICE_INSTANCES device_batched_gemm_instance)
# list(APPEND DEVICE_INSTANCES device_gemm_b_scale_instance)
# list(APPEND DEVICE_INSTANCES device_gemm_universal_reduce_instance)
# list(APPEND DEVICE_INSTANCES device_batched_gemm_b_scale_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv2d_bwd_data_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_bwd_data_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv1d_fwd_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv2d_fwd_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv2d_fwd_clamp_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_clamp_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_scale_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv2d_fwd_bias_clamp_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_bias_clamp_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv2d_fwd_bias_bnorm_clamp_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_bias_bnorm_clamp_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_bilinear_instance)
list(APPEND DEVICE_INSTANCES device_gemm_add_relu_instance)
list(APPEND DEVICE_INSTANCES device_gemm_multi_abd_instance)
# list(APPEND DEVICE_INSTANCES device_grouped_conv1d_fwd_instance)
# list(APPEND DEVICE_INSTANCES device_grouped_conv2d_fwd_instance)
# list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_instance)
# list(APPEND DEVICE_INSTANCES device_grouped_conv2d_fwd_clamp_instance)
# list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_clamp_instance)
# list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_scale_instance)
# list(APPEND DEVICE_INSTANCES device_grouped_conv2d_fwd_bias_clamp_instance)
# list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_bias_clamp_instance)
# list(APPEND DEVICE_INSTANCES device_grouped_conv2d_fwd_bias_bnorm_clamp_instance)
# list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_bias_bnorm_clamp_instance)
# list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_bilinear_instance)
# list(APPEND DEVICE_INSTANCES device_gemm_add_relu_instance)
# list(APPEND DEVICE_INSTANCES device_gemm_multi_abd_instance)
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
list(APPEND DEVICE_INSTANCES device_gemm_add_multiply_instance)
list(APPEND DEVICE_INSTANCES device_gemm_multiply_add_instance)
list(APPEND DEVICE_INSTANCES device_gemm_add_instance)
list(APPEND DEVICE_INSTANCES device_gemm_add_silu_instance)
list(APPEND DEVICE_INSTANCES device_gemm_fastgelu_instance)
list(APPEND DEVICE_INSTANCES device_gemm_add_fastgelu_instance)
list(APPEND DEVICE_INSTANCES device_gemm_add_add_fastgelu_instance)
# list(APPEND DEVICE_INSTANCES device_gemm_add_multiply_instance)
# list(APPEND DEVICE_INSTANCES device_gemm_multiply_add_instance)
# list(APPEND DEVICE_INSTANCES device_gemm_add_instance)
# list(APPEND DEVICE_INSTANCES device_gemm_add_silu_instance)
# list(APPEND DEVICE_INSTANCES device_gemm_fastgelu_instance)
# list(APPEND DEVICE_INSTANCES device_gemm_add_fastgelu_instance)
# list(APPEND DEVICE_INSTANCES device_gemm_add_add_fastgelu_instance)
endif()
list(APPEND DEVICE_INSTANCES device_batched_gemm_gemm_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv1d_bwd_weight_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv2d_bwd_weight_instance)
list(APPEND DEVICE_INSTANCES device_grouped_convnd_bwd_weight_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_bwd_weight_instance)
# list(APPEND DEVICE_INSTANCES device_batched_gemm_gemm_instance)
# list(APPEND DEVICE_INSTANCES device_grouped_conv1d_bwd_weight_instance)
# list(APPEND DEVICE_INSTANCES device_grouped_conv2d_bwd_weight_instance)
# list(APPEND DEVICE_INSTANCES device_grouped_convnd_bwd_weight_instance)
# list(APPEND DEVICE_INSTANCES device_grouped_conv3d_bwd_weight_instance)
endif()
if(SUPPORTED_GPU_TARGETS MATCHES "gfx9")
if(CK_EXPERIMENTAL_BUILDER)
list(APPEND DEVICE_INSTANCES device_grouped_conv_fwd_tile_instances)
# list(APPEND DEVICE_INSTANCES device_grouped_conv_fwd_tile_instances)
endif()
endif()
if(DL_KERNELS)
list(APPEND DEVICE_INSTANCES device_batched_gemm_multi_d_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv1d_bwd_weight_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv2d_bwd_weight_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_bwd_weight_instance)
# list(APPEND DEVICE_INSTANCES device_batched_gemm_multi_d_instance)
# list(APPEND DEVICE_INSTANCES device_grouped_conv1d_bwd_weight_instance)
# list(APPEND DEVICE_INSTANCES device_grouped_conv2d_bwd_weight_instance)
# list(APPEND DEVICE_INSTANCES device_grouped_conv3d_bwd_weight_instance)
endif()
if(CK_ENABLE_INT8)
list(APPEND DEVICE_INSTANCES device_quantization_instance)
# list(APPEND DEVICE_INSTANCES device_quantization_instance)
endif()
set(PROFILER_LIBS utility getopt::getopt)