mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
Add instances for conv_scale with fp8@bf8->fp8 (#1220)
* Update device op api to support BComputeType
* Add example
* Add instances
* Add profiler mode
* Add client example
* Update copyright year
* Add BComputeType check
* Fix compute types
[ROCm/composable_kernel commit: a61e73bc56]
This commit is contained in:
@@ -17,6 +17,11 @@ if((DTYPES MATCHES "bf8") OR NOT DEFINED DTYPES)
|
||||
target_link_libraries(client_conv3d_fwd_bf8 PRIVATE composable_kernel::device_conv_operations)
|
||||
endif()
|
||||
|
||||
if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "bf8") OR NOT DEFINED DTYPES)
|
||||
add_executable(client_conv3d_fwd_fp8_bf8 conv3d_fwd_fp8_bf8.cpp)
|
||||
target_link_libraries(client_conv3d_fwd_fp8_bf8 PRIVATE composable_kernel::device_conv_operations)
|
||||
endif()
|
||||
|
||||
if((DTYPES MATCHES "fp32") OR NOT DEFINED DTYPES)
|
||||
add_executable(client_conv3d_fwd_fp32 conv3d_fwd_fp32.cpp)
|
||||
target_link_libraries(client_conv3d_fwd_fp32 PRIVATE composable_kernel::device_conv_operations)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <cstdlib>
|
||||
#include <iomanip>
|
||||
@@ -95,7 +95,8 @@ template <ck::index_t NumDimSpatial,
|
||||
typename WeiLayout,
|
||||
typename OutLayout,
|
||||
ck::index_t NumNonSpatialDim = 3,
|
||||
typename ComputeType = InDataType>
|
||||
typename AComputeType = InDataType,
|
||||
typename BComputeType = AComputeType>
|
||||
bool run_grouped_conv_fwd(std::array<ck::index_t, NumDimSpatial + NumNonSpatialDim> in_lengths,
|
||||
std::array<ck::index_t, NumDimSpatial + NumNonSpatialDim> wei_lengths,
|
||||
std::array<ck::index_t, NumDimSpatial + NumNonSpatialDim> out_lengths)
|
||||
@@ -186,7 +187,8 @@ bool run_grouped_conv_fwd(std::array<ck::index_t, NumDimSpatial + NumNonSpatialD
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
ComputeType>;
|
||||
AComputeType,
|
||||
BComputeType>;
|
||||
// get device op instances
|
||||
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
|
||||
DeviceOp>::GetInstances();
|
||||
|
||||
50
client_example/16_convnd_fwd/conv3d_fwd_fp8_bf8.cpp
Normal file
50
client_example/16_convnd_fwd/conv3d_fwd_fp8_bf8.cpp
Normal file
@@ -0,0 +1,50 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
|
||||
using InDataType = ck::f8_t;
|
||||
using WeiDataType = ck::bf8_t;
|
||||
using OutDataType = ck::f8_t;
|
||||
|
||||
using InLayout = ck::tensor_layout::convolution::NDHWGC;
|
||||
using WeiLayout = ck::tensor_layout::convolution::GKZYXC;
|
||||
using OutLayout = ck::tensor_layout::convolution::NDHWGK;
|
||||
|
||||
using AComputeType = ck::f8_t;
|
||||
using BComputeType = ck::bf8_t;
|
||||
|
||||
static constexpr ck::index_t NumDimSpatial = 3;
|
||||
static constexpr ck::index_t G = 1;
|
||||
static constexpr ck::index_t N = 64;
|
||||
static constexpr ck::index_t K = 128;
|
||||
static constexpr ck::index_t C = 64;
|
||||
static constexpr ck::index_t Z = 3;
|
||||
static constexpr ck::index_t Y = 3;
|
||||
static constexpr ck::index_t X = 3;
|
||||
static constexpr ck::index_t Di = 28;
|
||||
static constexpr ck::index_t Hi = 28;
|
||||
static constexpr ck::index_t Wi = 3;
|
||||
static constexpr ck::index_t Do = 28;
|
||||
static constexpr ck::index_t Ho = 28;
|
||||
static constexpr ck::index_t Wo = 3;
|
||||
|
||||
int main()
|
||||
{
|
||||
return run_grouped_conv_fwd<NumDimSpatial,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout,
|
||||
3,
|
||||
AComputeType,
|
||||
BComputeType>(
|
||||
{N, Di, Hi, Wi, G, C}, {G, K, Z, Y, X, C}, {N, Do, Ho, Wo, G, K})
|
||||
? EXIT_SUCCESS
|
||||
: EXIT_FAILURE;
|
||||
}
|
||||
@@ -5,6 +5,7 @@ add_example_executable(example_convnd_fwd_xdl_int8 convnd_fwd_xdl_int8.cpp)
|
||||
add_example_executable(example_convnd_fwd_xdl_fp8 convnd_fwd_xdl_fp8.cpp)
|
||||
add_example_executable(example_convnd_fwd_xdl_fp64 convnd_fwd_xdl_fp64.cpp)
|
||||
add_example_executable(example_convnd_fwd_xdl_bf8 convnd_fwd_xdl_bf8.cpp)
|
||||
add_example_executable(example_convnd_fwd_xdl_fp8_bf8 convnd_fwd_xdl_fp8_bf8.cpp)
|
||||
add_example_executable(example_convnd_fwd_dl_fp16 convnd_fwd_dl_fp16.cpp)
|
||||
add_example_executable(example_convnd_fwd_dl_fp32 convnd_fwd_dl_fp32.cpp)
|
||||
add_example_executable(example_convnd_fwd_dl_int8 convnd_fwd_dl_int8.cpp)
|
||||
|
||||
83
example/09_convnd_fwd/convnd_fwd_xdl_fp8_bf8.cpp
Normal file
83
example/09_convnd_fwd/convnd_fwd_xdl_fp8_bf8.cpp
Normal file
@@ -0,0 +1,83 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "convnd_fwd_common.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp"
|
||||
|
||||
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
|
||||
|
||||
using InDataType = ck::f8_t;
|
||||
using WeiDataType = ck::bf8_t;
|
||||
using AccDataType = float;
|
||||
using CShuffleDataType = ck::f8_t;
|
||||
using OutDataType = ck::f8_t;
|
||||
using AComputeType = ck::f8_t;
|
||||
using BComputeType = ck::bf8_t;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using InElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using WeiElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using OutElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
static constexpr auto ConvSpec =
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
|
||||
|
||||
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
|
||||
template <ck::index_t NDimSpatial, typename InLayout, typename WeiLayout, typename OutLayout>
|
||||
using DeviceGroupedConvNDFwdInstance =
|
||||
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<
|
||||
NDimSpatial,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
ck::Tuple<>,
|
||||
OutLayout,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
ck::Tuple<>,
|
||||
OutDataType,
|
||||
InElementOp,
|
||||
WeiElementOp,
|
||||
OutElementOp,
|
||||
ConvSpec, // ConvForwardSpecialization
|
||||
GemmSpec, // GemmSpecialization
|
||||
1, //
|
||||
256, // BlockSize
|
||||
128, // MPerBlock
|
||||
256, // NPerBlock
|
||||
32, // KPerBlock
|
||||
8, // AK1
|
||||
8, // BK1
|
||||
32, // MPerXdl
|
||||
32, // NPerXdl
|
||||
2, // MXdlPerWave
|
||||
4, // NXdlPerWave
|
||||
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
|
||||
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
|
||||
2, // ABlockTransferSrcVectorDim
|
||||
8, // ABlockTransferSrcScalarPerVector
|
||||
8, // ABlockTransferDstScalarPerVector_AK1
|
||||
1, // ABlockLdsExtraM
|
||||
S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
|
||||
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
|
||||
2, // BBlockTransferSrcVectorDim
|
||||
8, // BBlockTransferSrcScalarPerVector
|
||||
8, // BBlockTransferDstScalarPerVector_BK1
|
||||
1, // BBlockLdsExtraN
|
||||
1,
|
||||
1,
|
||||
S<1, 32, 1, 8>,
|
||||
8,
|
||||
AComputeType,
|
||||
BComputeType>;
|
||||
|
||||
#include "run_convnd_fwd_example.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return run_convnd_fwd_example(argc, argv) ? 0 : 1; }
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -40,7 +40,8 @@ using is_tuple = decltype(std::declval<T&>().IsTuple());
|
||||
* \tparam AElementwiseOperation A elementwise operation.
|
||||
* \tparam BElementwiseOperation B elementwise operation.
|
||||
* \tparam CDEElementwiseOperation CDE elementwise operation.
|
||||
* \tparam ComputeType Compute data type (default: ADataType, first if tuple passed).
|
||||
* \tparam AComputeType Compute data type for A tensor (default: ADataType, first if tuple passed).
|
||||
* \tparam BComputeType Compute data type for B tensor (default: AComputeType).
|
||||
*/
|
||||
template <index_t NDimSpatial,
|
||||
typename ALayout,
|
||||
@@ -54,12 +55,13 @@ template <index_t NDimSpatial,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation,
|
||||
typename ComputeType =
|
||||
typename AComputeType =
|
||||
decltype(UnpackDataType<is_detected<is_tuple, ADataType>::value,
|
||||
Number<0>,
|
||||
ADataType>())> // ComputeType is InputType by default (first
|
||||
ADataType>()), // AComputeType is InputType by default (first
|
||||
// in tuple for MultiAB), unpack if tuple was
|
||||
// passed
|
||||
typename BComputeType = AComputeType>
|
||||
struct DeviceGroupedConvFwdMultipleABD : public BaseOperator
|
||||
{
|
||||
static constexpr bool isMultiA = is_detected<is_tuple, ADataType>::value;
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -254,13 +254,14 @@ template <index_t NDimSpatial,
|
||||
index_t CShuffleNXdlPerWavePerShuffle,
|
||||
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CDEBlockTransferScalarPerVector_NPerBlock,
|
||||
typename ComputeDataType =
|
||||
typename AComputeDataType =
|
||||
decltype(UnpackDataType<is_detected<is_tuple, ADataType>::value,
|
||||
Number<0>,
|
||||
ADataType>()), // ComputeType is InputType by default (first
|
||||
// in tuple for MultiAB), unpack if tuple was
|
||||
// passed
|
||||
LoopScheduler LoopSched = make_default_loop_scheduler()>
|
||||
typename BComputeDataType = AComputeDataType,
|
||||
LoopScheduler LoopSched = make_default_loop_scheduler()>
|
||||
struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
: public DeviceGroupedConvFwdMultipleABD<NDimSpatial,
|
||||
ALayout,
|
||||
@@ -274,7 +275,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
ComputeDataType>
|
||||
AComputeDataType,
|
||||
BComputeDataType>
|
||||
{
|
||||
using DeviceOp = DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle;
|
||||
|
||||
@@ -386,7 +388,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
using GemmBDataType = std::conditional_t<!isMultiB && isMultiA, Tuple<BDataType>, BDataType>;
|
||||
|
||||
#define GridwiseGemmTemplateParameters \
|
||||
GemmADataType, GemmBDataType, ComputeDataType, AccDataType, CShuffleDataType, DsDataType, \
|
||||
GemmADataType, GemmBDataType, AComputeDataType, AccDataType, CShuffleDataType, DsDataType, \
|
||||
EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, \
|
||||
InMemoryDataOperationEnum::Set, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, \
|
||||
KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, \
|
||||
@@ -399,7 +401,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, \
|
||||
CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \
|
||||
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \
|
||||
CDEBlockTransferScalarPerVector_NPerBlock, LoopSched
|
||||
CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, \
|
||||
BComputeDataType
|
||||
// Use appropriate gridwise gemm
|
||||
using GridwiseGemm =
|
||||
std::conditional_t<isMultiA || isMultiB,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -75,13 +75,14 @@ template <index_t NDimSpatial,
|
||||
index_t CShuffleNXdlPerWavePerShuffle,
|
||||
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CDEBlockTransferScalarPerVector_NPerBlock,
|
||||
typename ComputeDataType =
|
||||
typename AComputeDataType =
|
||||
decltype(UnpackDataType<is_detected<is_tuple, ADataType>::value,
|
||||
Number<0>,
|
||||
ADataType>()), // ComputeType is InputType by default (first
|
||||
// in tuple for MultiAB), unpack if tuple was
|
||||
// passed
|
||||
LoopScheduler LoopSched = make_default_loop_scheduler()>
|
||||
typename BComputeDataType = AComputeDataType,
|
||||
LoopScheduler LoopSched = make_default_loop_scheduler()>
|
||||
using DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle = DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<
|
||||
NDimSpatial,
|
||||
ALayout,
|
||||
@@ -128,7 +129,8 @@ using DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle = DeviceGroupedConvFwdMultipl
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CDEBlockTransferScalarPerVector_NPerBlock,
|
||||
ComputeDataType,
|
||||
AComputeDataType,
|
||||
BComputeDataType,
|
||||
LoopSched>;
|
||||
|
||||
} // namespace device
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -30,7 +30,7 @@ namespace ck {
|
||||
// D0, D1, ... and E have the same layout
|
||||
template <typename AsDataType,
|
||||
typename BsDataType,
|
||||
typename ComputeDataType_,
|
||||
typename AComputeDataType_,
|
||||
typename AccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename DsDataType,
|
||||
@@ -71,7 +71,8 @@ template <typename AsDataType,
|
||||
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CDEShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
LoopScheduler LoopSched,
|
||||
PipelineVersion PipelineVer = PipelineVersion::v1>
|
||||
PipelineVersion PipelineVer = PipelineVersion::v1,
|
||||
typename BComputeDataType_ = AComputeDataType_>
|
||||
struct GridwiseGemmMultipleABD_xdl_cshuffle
|
||||
{
|
||||
static constexpr index_t NumATensor = AsDataType::Size();
|
||||
@@ -101,10 +102,13 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
|
||||
decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
|
||||
|
||||
#if CK_WORKAROUND_DENORM_FIX
|
||||
using ComputeDataType =
|
||||
conditional_t<is_same_v<ComputeDataType_, ck::half_t>, ck::bhalf_t, ComputeDataType_>;
|
||||
using AComputeDataType =
|
||||
conditional_t<is_same_v<AComputeDataType_, ck::half_t>, ck::bhalf_t, AComputeDataType_>;
|
||||
using BComputeDataType =
|
||||
conditional_t<is_same_v<BComputeDataType_, ck::half_t>, ck::bhalf_t, BComputeDataType_>;
|
||||
#else
|
||||
using ComputeDataType = ComputeDataType_;
|
||||
using AComputeDataType = AComputeDataType_;
|
||||
using BComputeDataType = BComputeDataType_;
|
||||
#endif
|
||||
|
||||
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
|
||||
@@ -195,8 +199,8 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
|
||||
constexpr auto c_block_size =
|
||||
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
|
||||
|
||||
return math::max((a_block_space_size_aligned + b_block_space_size_aligned) *
|
||||
sizeof(ComputeDataType),
|
||||
return math::max(a_block_space_size_aligned * sizeof(AComputeDataType) +
|
||||
b_block_space_size_aligned * sizeof(BComputeDataType),
|
||||
c_block_size * sizeof(CShuffleDataType));
|
||||
}
|
||||
|
||||
@@ -597,7 +601,7 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
|
||||
auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v7r2<
|
||||
ThisThreadBlock,
|
||||
AsDataType,
|
||||
Tuple<ComputeDataType>,
|
||||
Tuple<AComputeDataType>,
|
||||
decltype(as_grid_desc_ak0_m_ak1),
|
||||
decltype(tie(a_block_desc_ak0_m_ak1)),
|
||||
AElementwiseOperation,
|
||||
@@ -628,7 +632,7 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
|
||||
auto b_blockwise_copy = ThreadGroupTensorSliceTransfer_v7r2<
|
||||
ThisThreadBlock,
|
||||
BsDataType,
|
||||
Tuple<ComputeDataType>,
|
||||
Tuple<BComputeDataType>,
|
||||
decltype(bs_grid_desc_bk0_n_bk1),
|
||||
decltype(tie(b_block_desc_bk0_n_bk1)),
|
||||
BElementwiseOperation,
|
||||
@@ -656,14 +660,15 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
|
||||
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
|
||||
// register
|
||||
// sanity check
|
||||
constexpr index_t KPack =
|
||||
math::max(math::lcm(AK1, BK1),
|
||||
MfmaSelector<ComputeDataType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
|
||||
constexpr index_t KPack = math::max(
|
||||
math::lcm(AK1, BK1),
|
||||
MfmaSelector<AComputeDataType, MPerXdl, NPerXdl, BComputeDataType>::selected_mfma
|
||||
.k_per_blk);
|
||||
|
||||
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
|
||||
BlockSize,
|
||||
ComputeDataType, // ComputeDataType for A
|
||||
ComputeDataType, // ComputeDataType for B
|
||||
AComputeDataType,
|
||||
BComputeDataType,
|
||||
AccDataType,
|
||||
decltype(a_block_desc_ak0_m_ak1),
|
||||
decltype(b_block_desc_bk0_n_bk1),
|
||||
@@ -681,10 +686,10 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
|
||||
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<ComputeDataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
|
||||
static_cast<AComputeDataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
|
||||
|
||||
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<ComputeDataType*>(p_shared) + a_block_space_size_aligned,
|
||||
static_cast<BComputeDataType*>(p_shared) + a_block_space_size_aligned,
|
||||
b_block_desc_bk0_n_bk1.GetElementSpaceSize());
|
||||
|
||||
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0);
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -73,7 +73,7 @@ template <typename ADataType,
|
||||
index_t CDEShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
LoopScheduler LoopSched,
|
||||
PipelineVersion PipelineVer = PipelineVersion::v1,
|
||||
typename BComputeDataType = AComputeDataType_>
|
||||
typename BComputeDataType_ = AComputeDataType_>
|
||||
struct GridwiseGemmMultipleD_xdl_cshuffle
|
||||
{
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
@@ -103,8 +103,11 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
|
||||
#if CK_WORKAROUND_DENORM_FIX
|
||||
using AComputeDataType =
|
||||
conditional_t<is_same_v<AComputeDataType_, ck::half_t>, ck::bhalf_t, AComputeDataType_>;
|
||||
using BComputeDataType =
|
||||
conditional_t<is_same_v<BComputeDataType_, ck::half_t>, ck::bhalf_t, BComputeDataType_>;
|
||||
#else
|
||||
using AComputeDataType = AComputeDataType_;
|
||||
using BComputeDataType = BComputeDataType_;
|
||||
#endif
|
||||
|
||||
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
|
||||
|
||||
@@ -290,6 +290,42 @@ using device_grouped_conv_fwd_xdl_bf8_instances = std::tuple<
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <index_t NDimSpatial,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
ConvolutionForwardSpecialization ConvSpec>
|
||||
using device_grouped_conv_fwd_xdl_f8_bf8_instances = std::tuple<
|
||||
// clang-format off
|
||||
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|AComputeType|BComputeType|
|
||||
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | |
|
||||
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | |
|
||||
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
#if(defined(CK_ENABLE_FP8) && defined(CK_ENABLE_BF8))
|
||||
// generic instance
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F8, BF8, F32, F8, DsLayout, F8, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, F8, BF8>,
|
||||
// instances for small conv.K and conv.C
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F8, BF8, F32, F8, DsLayout, F8, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, F8, BF8>,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F8, BF8, F32, F8, DsLayout, F8, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8, BF8>,
|
||||
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F8, BF8, F32, F8, DsLayout, F8, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8, BF8>,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F8, BF8, F32, F8, DsLayout, F8, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8, BF8>,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F8, BF8, F32, F8, DsLayout, F8, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, F8, BF8>,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F8, BF8, F32, F8, DsLayout, F8, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8, BF8>,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F8, BF8, F32, F8, DsLayout, F8, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, F8, BF8>,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F8, BF8, F32, F8, DsLayout, F8, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, F8, BF8>,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F8, BF8, F32, F8, DsLayout, F8, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, F8, BF8>,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F8, BF8, F32, F8, DsLayout, F8, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8, BF8>,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F8, BF8, F32, F8, DsLayout, F8, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8, BF8>,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F8, BF8, F32, F8, DsLayout, F8, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, F8, BF8>,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F8, BF8, F32, F8, DsLayout, F8, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, F8, BF8>,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F8, BF8, F32, F8, DsLayout, F8, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, F8, BF8>,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F8, BF8, F32, F8, DsLayout, F8, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, F8, BF8>
|
||||
#endif
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
@@ -34,7 +34,8 @@ template <ck::index_t NumDimSpatial,
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename OutDataType,
|
||||
typename ComputeType>
|
||||
typename AComputeType,
|
||||
typename BComputeType>
|
||||
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD<
|
||||
NumDimSpatial,
|
||||
InLayout,
|
||||
@@ -48,7 +49,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ComputeType>>
|
||||
AComputeType,
|
||||
BComputeType>>
|
||||
{
|
||||
using DeviceOp =
|
||||
DeviceGroupedConvFwdMultipleABD<NumDimSpatial,
|
||||
@@ -63,7 +65,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ComputeType>;
|
||||
AComputeType,
|
||||
BComputeType>;
|
||||
|
||||
static auto GetInstances()
|
||||
{
|
||||
@@ -75,14 +78,16 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
{
|
||||
#ifdef CK_ENABLE_FP32
|
||||
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, float> && is_same_v<ComputeType, float>)
|
||||
is_same_v<OutDataType, float> && is_same_v<AComputeType, float> &&
|
||||
is_same_v<BComputeType, float>)
|
||||
{
|
||||
add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instances(op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
|
||||
is_same_v<OutDataType, half_t> && is_same_v<ComputeType, half_t>)
|
||||
is_same_v<OutDataType, half_t> && is_same_v<AComputeType, half_t> &&
|
||||
is_same_v<BComputeType, half_t>)
|
||||
{
|
||||
add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instances(op_ptrs);
|
||||
}
|
||||
@@ -94,14 +99,16 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
|
||||
#ifdef CK_ENABLE_FP32
|
||||
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, float> && is_same_v<ComputeType, float>)
|
||||
is_same_v<OutDataType, float> && is_same_v<AComputeType, float> &&
|
||||
is_same_v<BComputeType, float>)
|
||||
{
|
||||
add_device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f32_instances(op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
|
||||
is_same_v<OutDataType, half_t> && is_same_v<ComputeType, half_t>)
|
||||
is_same_v<OutDataType, half_t> && is_same_v<AComputeType, half_t> &&
|
||||
is_same_v<BComputeType, half_t>)
|
||||
{
|
||||
add_device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f16_instances(op_ptrs);
|
||||
}
|
||||
@@ -115,14 +122,16 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
{
|
||||
#ifdef CK_ENABLE_FP32
|
||||
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, float> && is_same_v<ComputeType, float>)
|
||||
is_same_v<OutDataType, float> && is_same_v<AComputeType, float> &&
|
||||
is_same_v<BComputeType, float>)
|
||||
{
|
||||
add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f32_instances(op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
|
||||
is_same_v<OutDataType, half_t> && is_same_v<ComputeType, half_t>)
|
||||
is_same_v<OutDataType, half_t> && is_same_v<AComputeType, half_t> &&
|
||||
is_same_v<BComputeType, half_t>)
|
||||
{
|
||||
add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f16_instances(op_ptrs);
|
||||
}
|
||||
@@ -130,14 +139,17 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
#ifdef CK_ENABLE_BF16
|
||||
if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
|
||||
is_same_v<WeiDataType, ck::bhalf_t> &&
|
||||
is_same_v<OutDataType, ck::bhalf_t> && is_same_v<ComputeType, bhalf_t>)
|
||||
is_same_v<OutDataType, ck::bhalf_t> &&
|
||||
is_same_v<AComputeType, ck::bhalf_t> &&
|
||||
is_same_v<BComputeType, ck::bhalf_t>)
|
||||
{
|
||||
add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_bf16_instances(op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_INT8
|
||||
if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
|
||||
is_same_v<OutDataType, int8_t> && is_same_v<ComputeType, int8_t>)
|
||||
is_same_v<OutDataType, int8_t> && is_same_v<AComputeType, int8_t> &&
|
||||
is_same_v<BComputeType, int8_t>)
|
||||
{
|
||||
add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_int8_instances(op_ptrs);
|
||||
}
|
||||
@@ -149,14 +161,16 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
{
|
||||
#ifdef CK_ENABLE_FP32
|
||||
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, float> && is_same_v<ComputeType, float>)
|
||||
is_same_v<OutDataType, float> && is_same_v<AComputeType, float> &&
|
||||
is_same_v<BComputeType, float>)
|
||||
{
|
||||
add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instances(op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
|
||||
is_same_v<OutDataType, half_t> && is_same_v<ComputeType, half_t>)
|
||||
is_same_v<OutDataType, half_t> && is_same_v<AComputeType, half_t> &&
|
||||
is_same_v<BComputeType, half_t>)
|
||||
{
|
||||
add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instances(op_ptrs);
|
||||
}
|
||||
@@ -164,7 +178,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
#ifdef CK_ENABLE_BF16
|
||||
if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
|
||||
is_same_v<WeiDataType, ck::bhalf_t> &&
|
||||
is_same_v<OutDataType, ck::bhalf_t> && is_same_v<ComputeType, bhalf_t>)
|
||||
is_same_v<OutDataType, ck::bhalf_t> &&
|
||||
is_same_v<AComputeType, ck::bhalf_t> &&
|
||||
is_same_v<BComputeType, ck::bhalf_t>)
|
||||
{
|
||||
add_device_grouped_conv1d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instances(op_ptrs);
|
||||
}
|
||||
@@ -176,14 +192,16 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
{
|
||||
#ifdef CK_ENABLE_FP32
|
||||
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, float> && is_same_v<ComputeType, float>)
|
||||
is_same_v<OutDataType, float> && is_same_v<AComputeType, float> &&
|
||||
is_same_v<BComputeType, float>)
|
||||
{
|
||||
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instances(op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
|
||||
is_same_v<OutDataType, half_t> && is_same_v<ComputeType, half_t>)
|
||||
is_same_v<OutDataType, half_t> && is_same_v<AComputeType, half_t> &&
|
||||
is_same_v<BComputeType, half_t>)
|
||||
{
|
||||
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instances(op_ptrs);
|
||||
}
|
||||
@@ -191,7 +209,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
#ifdef CK_ENABLE_BF16
|
||||
if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
|
||||
is_same_v<WeiDataType, ck::bhalf_t> &&
|
||||
is_same_v<OutDataType, ck::bhalf_t> && is_same_v<ComputeType, bhalf_t>)
|
||||
is_same_v<OutDataType, ck::bhalf_t> &&
|
||||
is_same_v<AComputeType, ck::bhalf_t> &&
|
||||
is_same_v<BComputeType, ck::bhalf_t>)
|
||||
{
|
||||
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instances(op_ptrs);
|
||||
}
|
||||
@@ -203,14 +223,16 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
{
|
||||
#ifdef CK_ENABLE_FP32
|
||||
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, float> && is_same_v<ComputeType, float>)
|
||||
is_same_v<OutDataType, float> && is_same_v<AComputeType, float> &&
|
||||
is_same_v<BComputeType, float>)
|
||||
{
|
||||
add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f32_instances(op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
|
||||
is_same_v<OutDataType, half_t> && is_same_v<ComputeType, half_t>)
|
||||
is_same_v<OutDataType, half_t> && is_same_v<AComputeType, half_t> &&
|
||||
is_same_v<BComputeType, half_t>)
|
||||
{
|
||||
add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f16_instances(op_ptrs);
|
||||
}
|
||||
@@ -218,14 +240,17 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
#ifdef CK_ENABLE_BF16
|
||||
if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
|
||||
is_same_v<WeiDataType, ck::bhalf_t> &&
|
||||
is_same_v<OutDataType, ck::bhalf_t> && is_same_v<ComputeType, bhalf_t>)
|
||||
is_same_v<OutDataType, ck::bhalf_t> &&
|
||||
is_same_v<AComputeType, ck::bhalf_t> &&
|
||||
is_same_v<BComputeType, ck::bhalf_t>)
|
||||
{
|
||||
add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_bf16_instances(op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_INT8
|
||||
if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
|
||||
is_same_v<OutDataType, int8_t> && is_same_v<ComputeType, int8_t>)
|
||||
is_same_v<OutDataType, int8_t> && is_same_v<AComputeType, int8_t> &&
|
||||
is_same_v<BComputeType, int8_t>)
|
||||
{
|
||||
add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_int8_instances(op_ptrs);
|
||||
}
|
||||
@@ -237,7 +262,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
{
|
||||
#ifdef CK_ENABLE_FP32
|
||||
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, float> && is_same_v<ComputeType, float>)
|
||||
is_same_v<OutDataType, float> && is_same_v<AComputeType, float> &&
|
||||
is_same_v<BComputeType, float>)
|
||||
{
|
||||
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(op_ptrs);
|
||||
}
|
||||
@@ -245,28 +271,40 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
|
||||
#ifdef CK_ENABLE_FP8
|
||||
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
|
||||
is_same_v<OutDataType, half_t> && is_same_v<ComputeType, ck::f8_t>)
|
||||
is_same_v<OutDataType, half_t> && is_same_v<AComputeType, ck::f8_t> &&
|
||||
is_same_v<BComputeType, ck::f8_t>)
|
||||
{
|
||||
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_f8_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
|
||||
if constexpr(is_same_v<InDataType, ck::f8_t> && is_same_v<WeiDataType, ck::f8_t> &&
|
||||
is_same_v<OutDataType, ck::f8_t> && is_same_v<ComputeType, ck::f8_t>)
|
||||
is_same_v<OutDataType, ck::f8_t> && is_same_v<AComputeType, ck::f8_t> &&
|
||||
is_same_v<BComputeType, ck::f8_t>)
|
||||
{
|
||||
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f8_instances(op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF8
|
||||
if constexpr(is_same_v<InDataType, ck::bf8_t> && is_same_v<WeiDataType, ck::bf8_t> &&
|
||||
is_same_v<OutDataType, ck::f8_t> && is_same_v<ComputeType, ck::bf8_t>)
|
||||
is_same_v<OutDataType, ck::f8_t> && is_same_v<AComputeType, ck::bf8_t> &&
|
||||
is_same_v<BComputeType, ck::bf8_t>)
|
||||
{
|
||||
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf8_instances(op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#if(defined(CK_ENABLE_FP8) && defined(CK_ENABLE_BF8))
|
||||
if constexpr(is_same_v<InDataType, ck::f8_t> && is_same_v<WeiDataType, ck::bf8_t> &&
|
||||
is_same_v<OutDataType, ck::f8_t> && is_same_v<AComputeType, ck::f8_t> &&
|
||||
is_same_v<BComputeType, ck::bf8_t>)
|
||||
{
|
||||
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f8_bf8_instances(op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
|
||||
is_same_v<OutDataType, half_t> && is_same_v<ComputeType, half_t>)
|
||||
is_same_v<OutDataType, half_t> && is_same_v<AComputeType, half_t> &&
|
||||
is_same_v<BComputeType, half_t>)
|
||||
{
|
||||
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances(op_ptrs);
|
||||
}
|
||||
@@ -274,14 +312,17 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
#ifdef CK_ENABLE_BF16
|
||||
if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
|
||||
is_same_v<WeiDataType, ck::bhalf_t> &&
|
||||
is_same_v<OutDataType, ck::bhalf_t> && is_same_v<ComputeType, bhalf_t>)
|
||||
is_same_v<OutDataType, ck::bhalf_t> &&
|
||||
is_same_v<AComputeType, ck::bhalf_t> &&
|
||||
is_same_v<BComputeType, ck::bhalf_t>)
|
||||
{
|
||||
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances(op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_INT8
|
||||
if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
|
||||
is_same_v<OutDataType, int8_t> && is_same_v<ComputeType, int8_t>)
|
||||
is_same_v<OutDataType, int8_t> && is_same_v<AComputeType, int8_t> &&
|
||||
is_same_v<BComputeType, int8_t>)
|
||||
{
|
||||
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_int8_instances(op_ptrs);
|
||||
}
|
||||
@@ -295,7 +336,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
{
|
||||
#ifdef CK_ENABLE_FP16
|
||||
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
|
||||
is_same_v<OutDataType, half_t> && is_same_v<ComputeType, half_t>)
|
||||
is_same_v<OutDataType, half_t> && is_same_v<AComputeType, half_t> &&
|
||||
is_same_v<BComputeType, half_t>)
|
||||
{
|
||||
add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_instances(op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_1x1p0_instances(op_ptrs);
|
||||
@@ -305,7 +347,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
#endif
|
||||
#ifdef CK_ENABLE_INT8
|
||||
if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
|
||||
is_same_v<OutDataType, int8_t> && is_same_v<ComputeType, int8_t>)
|
||||
is_same_v<OutDataType, int8_t> && is_same_v<AComputeType, int8_t> &&
|
||||
is_same_v<BComputeType, int8_t>)
|
||||
{
|
||||
add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_instances(op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_1x1p0_instances(op_ptrs);
|
||||
@@ -320,7 +363,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
{
|
||||
#ifdef CK_ENABLE_INT8
|
||||
if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
|
||||
is_same_v<OutDataType, int8_t> && is_same_v<ComputeType, int8_t>)
|
||||
is_same_v<OutDataType, int8_t> && is_same_v<AComputeType, int8_t> &&
|
||||
is_same_v<BComputeType, int8_t>)
|
||||
{
|
||||
add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_instances(op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_1x1p0_instances(op_ptrs);
|
||||
@@ -335,7 +379,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
{
|
||||
#ifdef CK_ENABLE_FP16
|
||||
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
|
||||
is_same_v<OutDataType, half_t> && is_same_v<ComputeType, half_t>)
|
||||
is_same_v<OutDataType, half_t> && is_same_v<AComputeType, half_t> &&
|
||||
is_same_v<BComputeType, half_t>)
|
||||
{
|
||||
add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_instances(op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1p0_instances(
|
||||
@@ -347,7 +392,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
#endif
|
||||
#ifdef CK_ENABLE_INT8
|
||||
if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
|
||||
is_same_v<OutDataType, int8_t> && is_same_v<ComputeType, int8_t>)
|
||||
is_same_v<OutDataType, int8_t> && is_same_v<AComputeType, int8_t> &&
|
||||
is_same_v<BComputeType, int8_t>)
|
||||
{
|
||||
add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_instances(op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1p0_instances(op_ptrs);
|
||||
@@ -363,7 +409,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
{
|
||||
#ifdef CK_ENABLE_FP16
|
||||
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
|
||||
is_same_v<OutDataType, half_t> && is_same_v<ComputeType, half_t>)
|
||||
is_same_v<OutDataType, half_t> && is_same_v<AComputeType, half_t> &&
|
||||
is_same_v<BComputeType, half_t>)
|
||||
{
|
||||
add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instances(op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1p0_instances(
|
||||
@@ -375,7 +422,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
#endif
|
||||
#ifdef CK_ENABLE_INT8
|
||||
if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
|
||||
is_same_v<OutDataType, int8_t> && is_same_v<ComputeType, int8_t>)
|
||||
is_same_v<OutDataType, int8_t> && is_same_v<AComputeType, int8_t> &&
|
||||
is_same_v<BComputeType, int8_t>)
|
||||
{
|
||||
add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_instances(op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1p0_instances(op_ptrs);
|
||||
|
||||
@@ -351,6 +351,24 @@ void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf8_instances(
|
||||
BF8>>>& instances);
|
||||
#endif
|
||||
|
||||
#if(defined(CK_ENABLE_FP8) && defined(CK_ENABLE_BF8))
|
||||
void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f8_bf8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGK,
|
||||
F8,
|
||||
BF8,
|
||||
Empty_Tuple,
|
||||
F8,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
F8,
|
||||
BF8>>>& instances);
|
||||
#endif
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
@@ -41,4 +41,9 @@ if(DTYPES MATCHES "bf8" OR NOT DEFINED DTYPES)
|
||||
xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf8_instance.cpp)
|
||||
endif()
|
||||
|
||||
if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "bf8") OR NOT DEFINED DTYPES)
|
||||
list(APPEND GROUPED_CONV3D_FWD
|
||||
xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_fp8_bf8_instance.cpp)
|
||||
endif()
|
||||
|
||||
add_instance_library(device_grouped_conv3d_fwd_instance ${GROUPED_CONV3D_FWD})
|
||||
|
||||
@@ -0,0 +1,54 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f8_bf8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGK,
|
||||
F8,
|
||||
BF8,
|
||||
Empty_Tuple,
|
||||
F8,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
F8,
|
||||
BF8>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_f8_bf8_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGK,
|
||||
ConvFwdDefault>{});
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_f8_bf8_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGK,
|
||||
ConvFwd1x1P0>{});
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_f8_bf8_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGK,
|
||||
ConvFwd1x1S1P0>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -31,7 +31,9 @@ template <ck::index_t NDimSpatial,
|
||||
typename OutLayout,
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename OutDataType>
|
||||
typename OutDataType,
|
||||
typename AComputeType = InDataType,
|
||||
typename BComputeType = AComputeType>
|
||||
bool profile_grouped_conv_fwd_impl(int do_verification,
|
||||
int init_method,
|
||||
bool do_log,
|
||||
@@ -209,7 +211,9 @@ bool profile_grouped_conv_fwd_impl(int do_verification,
|
||||
OutDataType,
|
||||
InElementOp,
|
||||
WeiElementOp,
|
||||
OutElementOp>;
|
||||
OutElementOp,
|
||||
AComputeType,
|
||||
BComputeType>;
|
||||
|
||||
// get device op instances
|
||||
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
|
||||
|
||||
@@ -25,6 +25,7 @@ enum struct ConvDataType
|
||||
INT8_INT8_INT8, // 3
|
||||
F8_F8_F8, // 4
|
||||
BF8_BF8_F8, // 5
|
||||
F8_BF8_F8, // 6
|
||||
};
|
||||
|
||||
#define OP_NAME "grouped_conv_fwd"
|
||||
@@ -40,7 +41,8 @@ static void print_helper_msg()
|
||||
<< " 2: Input bf16, Weight bf16, Output bf16\n"
|
||||
<< " 3: Input int8, Weight int8, Output int8\n"
|
||||
<< " 4: Input fp8, Weight fp8, Output fp8\n"
|
||||
<< " 5: Input bf8, Weight bf8, Output fp8)\n"
|
||||
<< " 5: Input bf8, Weight bf8, Output fp8\n"
|
||||
<< " 6: Input fp8, Weight bf8, Output fp8)\n"
|
||||
<< "arg3: tensor layout (0: Input[G, N, Hi, Wi, C], Weight[G, K, Y, X, C], Output[G, N, Ho, Wo, K]\n"
|
||||
<< " 1: Input[N, Hi, Wi, G, C], Weight[G, K, Y, X, C], Output[N, Ho, Wo, G, K])\n"
|
||||
<< "arg4: verification (0: no, 1: yes)\n"
|
||||
@@ -118,7 +120,9 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
|
||||
auto out_layout,
|
||||
auto in_type,
|
||||
auto wei_type,
|
||||
auto out_type) {
|
||||
auto out_type,
|
||||
auto a_compute_type,
|
||||
auto b_compute_type) {
|
||||
constexpr ck::index_t NDimSpatial = num_dim_spatial_tmp.value;
|
||||
|
||||
using InLayout = decltype(in_layout);
|
||||
@@ -129,13 +133,18 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
|
||||
using WeiDataType = decltype(wei_type);
|
||||
using OutDataType = decltype(out_type);
|
||||
|
||||
using AComputeType = decltype(a_compute_type);
|
||||
using BComputeType = decltype(b_compute_type);
|
||||
|
||||
bool pass = ck::profiler::profile_grouped_conv_fwd_impl<NDimSpatial,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType>(
|
||||
OutDataType,
|
||||
AComputeType,
|
||||
BComputeType>(
|
||||
do_verification, init_method, do_log, time_kernel, params);
|
||||
|
||||
return pass ? 0 : 1;
|
||||
@@ -146,57 +155,59 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
|
||||
{
|
||||
if(data_type == ConvDataType::F32_F32_F32)
|
||||
{
|
||||
return profile(I1, GNWC{}, GKXC{}, GNWK{}, F32{}, F32{}, F32{});
|
||||
return profile(I1, GNWC{}, GKXC{}, GNWK{}, F32{}, F32{}, F32{}, F32{}, F32{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F16_F16_F16)
|
||||
{
|
||||
return profile(I1, GNWC{}, GKXC{}, GNWK{}, F16{}, F16{}, F16{});
|
||||
return profile(I1, GNWC{}, GKXC{}, GNWK{}, F16{}, F16{}, F16{}, F16{}, F16{});
|
||||
}
|
||||
else if(data_type == ConvDataType::BF16_BF16_BF16)
|
||||
{
|
||||
return profile(I1, GNWC{}, GKXC{}, GNWK{}, BF16{}, BF16{}, BF16{});
|
||||
return profile(I1, GNWC{}, GKXC{}, GNWK{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{});
|
||||
}
|
||||
else if(data_type == ConvDataType::INT8_INT8_INT8)
|
||||
{
|
||||
return profile(I1, GNWC{}, GKXC{}, GNWK{}, INT8{}, INT8{}, INT8{});
|
||||
return profile(I1, GNWC{}, GKXC{}, GNWK{}, INT8{}, INT8{}, INT8{}, INT8{}, INT8{});
|
||||
}
|
||||
}
|
||||
else if(num_dim_spatial == 2 && layout == ConvLayout::GNHWC_GKYXC_GNHWK)
|
||||
{
|
||||
if(data_type == ConvDataType::F32_F32_F32)
|
||||
{
|
||||
return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, F32{}, F32{}, F32{});
|
||||
return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, F32{}, F32{}, F32{}, F32{}, F32{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F16_F16_F16)
|
||||
{
|
||||
return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, F16{}, F16{}, F16{});
|
||||
return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, F16{}, F16{}, F16{}, F16{}, F16{});
|
||||
}
|
||||
else if(data_type == ConvDataType::BF16_BF16_BF16)
|
||||
{
|
||||
return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, BF16{}, BF16{}, BF16{});
|
||||
return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{});
|
||||
}
|
||||
else if(data_type == ConvDataType::INT8_INT8_INT8)
|
||||
{
|
||||
return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, INT8{}, INT8{}, INT8{});
|
||||
return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, INT8{}, INT8{}, INT8{}, INT8{}, INT8{});
|
||||
}
|
||||
}
|
||||
else if(num_dim_spatial == 3 && layout == ConvLayout::GNHWC_GKYXC_GNHWK)
|
||||
{
|
||||
if(data_type == ConvDataType::F32_F32_F32)
|
||||
{
|
||||
return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, F32{}, F32{}, F32{});
|
||||
return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, F32{}, F32{}, F32{}, F32{}, F32{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F16_F16_F16)
|
||||
{
|
||||
return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, F16{}, F16{}, F16{});
|
||||
return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, F16{}, F16{}, F16{}, F16{}, F16{});
|
||||
}
|
||||
else if(data_type == ConvDataType::BF16_BF16_BF16)
|
||||
{
|
||||
return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, BF16{}, BF16{}, BF16{});
|
||||
return profile(
|
||||
I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{});
|
||||
}
|
||||
else if(data_type == ConvDataType::INT8_INT8_INT8)
|
||||
{
|
||||
return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, INT8{}, INT8{}, INT8{});
|
||||
return profile(
|
||||
I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, INT8{}, INT8{}, INT8{}, INT8{}, INT8{});
|
||||
}
|
||||
}
|
||||
// NHWGC_GKYXC_NHWGK
|
||||
@@ -204,65 +215,71 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
|
||||
{
|
||||
if(data_type == ConvDataType::F32_F32_F32)
|
||||
{
|
||||
return profile(I1, NWGC{}, GKXC{}, NWGK{}, F32{}, F32{}, F32{});
|
||||
return profile(I1, NWGC{}, GKXC{}, NWGK{}, F32{}, F32{}, F32{}, F32{}, F32{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F16_F16_F16)
|
||||
{
|
||||
return profile(I1, NWGC{}, GKXC{}, NWGK{}, F16{}, F16{}, F16{});
|
||||
return profile(I1, NWGC{}, GKXC{}, NWGK{}, F16{}, F16{}, F16{}, F16{}, F16{});
|
||||
}
|
||||
else if(data_type == ConvDataType::BF16_BF16_BF16)
|
||||
{
|
||||
return profile(I1, NWGC{}, GKXC{}, NWGK{}, BF16{}, BF16{}, BF16{});
|
||||
return profile(I1, NWGC{}, GKXC{}, NWGK{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{});
|
||||
}
|
||||
else if(data_type == ConvDataType::INT8_INT8_INT8)
|
||||
{
|
||||
return profile(I1, NWGC{}, GKXC{}, NWGK{}, INT8{}, INT8{}, INT8{});
|
||||
return profile(I1, NWGC{}, GKXC{}, NWGK{}, INT8{}, INT8{}, INT8{}, INT8{}, INT8{});
|
||||
}
|
||||
}
|
||||
else if(num_dim_spatial == 2 && layout == ConvLayout::NHWGC_GKYXC_NHWGK)
|
||||
{
|
||||
if(data_type == ConvDataType::F32_F32_F32)
|
||||
{
|
||||
return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F32{}, F32{}, F32{});
|
||||
return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F32{}, F32{}, F32{}, F32{}, F32{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F16_F16_F16)
|
||||
{
|
||||
return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F16{}, F16{}, F16{});
|
||||
return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F16{}, F16{}, F16{}, F16{}, F16{});
|
||||
}
|
||||
else if(data_type == ConvDataType::BF16_BF16_BF16)
|
||||
{
|
||||
return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, BF16{}, BF16{}, BF16{});
|
||||
return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{});
|
||||
}
|
||||
else if(data_type == ConvDataType::INT8_INT8_INT8)
|
||||
{
|
||||
return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, INT8{}, INT8{}, INT8{});
|
||||
return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, INT8{}, INT8{}, INT8{}, INT8{}, INT8{});
|
||||
}
|
||||
}
|
||||
else if(num_dim_spatial == 3 && layout == ConvLayout::NHWGC_GKYXC_NHWGK)
|
||||
{
|
||||
if(data_type == ConvDataType::F32_F32_F32)
|
||||
{
|
||||
return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F32{}, F32{}, F32{});
|
||||
return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F32{}, F32{}, F32{}, F32{}, F32{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F16_F16_F16)
|
||||
{
|
||||
return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F16{}, F16{}, F16{});
|
||||
return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F16{}, F16{}, F16{}, F16{}, F16{});
|
||||
}
|
||||
else if(data_type == ConvDataType::BF16_BF16_BF16)
|
||||
{
|
||||
return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, BF16{}, BF16{}, BF16{});
|
||||
return profile(
|
||||
I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{});
|
||||
}
|
||||
else if(data_type == ConvDataType::INT8_INT8_INT8)
|
||||
{
|
||||
return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, INT8{}, INT8{}, INT8{});
|
||||
return profile(
|
||||
I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, INT8{}, INT8{}, INT8{}, INT8{}, INT8{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F8_F8_F8)
|
||||
{
|
||||
return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F8{}, F8{}, F8{});
|
||||
return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F8{}, F8{}, F8{}, F8{}, F8{});
|
||||
}
|
||||
else if(data_type == ConvDataType::BF8_BF8_F8)
|
||||
{
|
||||
return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, BF8{}, BF8{}, F8{});
|
||||
return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, BF8{}, BF8{}, F8{}, BF8{}, BF8{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F8_BF8_F8)
|
||||
{
|
||||
return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F8{}, BF8{}, F8{}, F8{}, BF8{});
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user