mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
Add instances for conv_scale with bf8@fp8->fp8 (#1231)
* Add instances
* Add example
* Add profiler mode
* Add client example
[ROCm/composable_kernel commit: bbefc12a26]
This commit is contained in:
@@ -20,6 +20,9 @@ endif()
|
||||
if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "bf8") OR NOT DEFINED DTYPES)
|
||||
add_executable(client_conv3d_fwd_fp8_bf8 conv3d_fwd_fp8_bf8.cpp)
|
||||
target_link_libraries(client_conv3d_fwd_fp8_bf8 PRIVATE composable_kernel::device_conv_operations)
|
||||
|
||||
add_executable(client_conv3d_fwd_bf8_fp8 conv3d_fwd_bf8_fp8.cpp)
|
||||
target_link_libraries(client_conv3d_fwd_bf8_fp8 PRIVATE composable_kernel::device_conv_operations)
|
||||
endif()
|
||||
|
||||
if((DTYPES MATCHES "fp32") OR NOT DEFINED DTYPES)
|
||||
|
||||
50
client_example/16_convnd_fwd/conv3d_fwd_bf8_fp8.cpp
Normal file
50
client_example/16_convnd_fwd/conv3d_fwd_bf8_fp8.cpp
Normal file
@@ -0,0 +1,50 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
|
||||
using InDataType = ck::bf8_t;
|
||||
using WeiDataType = ck::f8_t;
|
||||
using OutDataType = ck::f8_t;
|
||||
|
||||
using InLayout = ck::tensor_layout::convolution::NDHWGC;
|
||||
using WeiLayout = ck::tensor_layout::convolution::GKZYXC;
|
||||
using OutLayout = ck::tensor_layout::convolution::NDHWGK;
|
||||
|
||||
using AComputeType = ck::bf8_t;
|
||||
using BComputeType = ck::f8_t;
|
||||
|
||||
static constexpr ck::index_t NumDimSpatial = 3;
|
||||
static constexpr ck::index_t G = 1;
|
||||
static constexpr ck::index_t N = 64;
|
||||
static constexpr ck::index_t K = 128;
|
||||
static constexpr ck::index_t C = 64;
|
||||
static constexpr ck::index_t Z = 3;
|
||||
static constexpr ck::index_t Y = 3;
|
||||
static constexpr ck::index_t X = 3;
|
||||
static constexpr ck::index_t Di = 28;
|
||||
static constexpr ck::index_t Hi = 28;
|
||||
static constexpr ck::index_t Wi = 3;
|
||||
static constexpr ck::index_t Do = 28;
|
||||
static constexpr ck::index_t Ho = 28;
|
||||
static constexpr ck::index_t Wo = 3;
|
||||
|
||||
int main()
|
||||
{
|
||||
return run_grouped_conv_fwd<NumDimSpatial,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout,
|
||||
3,
|
||||
AComputeType,
|
||||
BComputeType>(
|
||||
{N, Di, Hi, Wi, G, C}, {G, K, Z, Y, X, C}, {N, Do, Ho, Wo, G, K})
|
||||
? EXIT_SUCCESS
|
||||
: EXIT_FAILURE;
|
||||
}
|
||||
@@ -7,6 +7,7 @@ add_example_executable(example_convnd_fwd_xdl_fp64 convnd_fwd_xdl_fp64.cpp)
|
||||
add_example_executable(example_convnd_fwd_xdl_bf8 convnd_fwd_xdl_bf8.cpp)
|
||||
add_example_executable(example_convnd_fwd_xdl_fp16_comp_fp8 convnd_fwd_xdl_fp16_comp_fp8.cpp)
|
||||
add_example_executable(example_convnd_fwd_xdl_fp8_bf8 convnd_fwd_xdl_fp8_bf8.cpp)
|
||||
add_example_executable(example_convnd_fwd_xdl_bf8_fp8 convnd_fwd_xdl_bf8_fp8.cpp)
|
||||
add_example_executable(example_convnd_fwd_dl_fp16 convnd_fwd_dl_fp16.cpp)
|
||||
add_example_executable(example_convnd_fwd_dl_fp32 convnd_fwd_dl_fp32.cpp)
|
||||
add_example_executable(example_convnd_fwd_dl_int8 convnd_fwd_dl_int8.cpp)
|
||||
|
||||
83
example/09_convnd_fwd/convnd_fwd_xdl_bf8_fp8.cpp
Normal file
83
example/09_convnd_fwd/convnd_fwd_xdl_bf8_fp8.cpp
Normal file
@@ -0,0 +1,83 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "convnd_fwd_common.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp"
|
||||
|
||||
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
|
||||
|
||||
using InDataType = ck::bf8_t;
|
||||
using WeiDataType = ck::f8_t;
|
||||
using AccDataType = float;
|
||||
using CShuffleDataType = ck::f8_t;
|
||||
using OutDataType = ck::f8_t;
|
||||
using AComputeType = ck::bf8_t;
|
||||
using BComputeType = ck::f8_t;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using InElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using WeiElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using OutElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
static constexpr auto ConvSpec =
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
|
||||
|
||||
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
|
||||
template <ck::index_t NDimSpatial, typename InLayout, typename WeiLayout, typename OutLayout>
|
||||
using DeviceGroupedConvNDFwdInstance =
|
||||
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<
|
||||
NDimSpatial,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
ck::Tuple<>,
|
||||
OutLayout,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
ck::Tuple<>,
|
||||
OutDataType,
|
||||
InElementOp,
|
||||
WeiElementOp,
|
||||
OutElementOp,
|
||||
ConvSpec, // ConvForwardSpecialization
|
||||
GemmSpec, // GemmSpecialization
|
||||
1, //
|
||||
256, // BlockSize
|
||||
128, // MPerBlock
|
||||
256, // NPerBlock
|
||||
32, // KPerBlock
|
||||
8, // AK1
|
||||
8, // BK1
|
||||
32, // MPerXdl
|
||||
32, // NPerXdl
|
||||
2, // MXdlPerWave
|
||||
4, // NXdlPerWave
|
||||
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
|
||||
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
|
||||
2, // ABlockTransferSrcVectorDim
|
||||
8, // ABlockTransferSrcScalarPerVector
|
||||
8, // ABlockTransferDstScalarPerVector_AK1
|
||||
1, // ABlockLdsExtraM
|
||||
S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
|
||||
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
|
||||
2, // BBlockTransferSrcVectorDim
|
||||
8, // BBlockTransferSrcScalarPerVector
|
||||
8, // BBlockTransferDstScalarPerVector_BK1
|
||||
1, // BBlockLdsExtraN
|
||||
1,
|
||||
1,
|
||||
S<1, 32, 1, 8>,
|
||||
8,
|
||||
AComputeType,
|
||||
BComputeType>;
|
||||
|
||||
#include "run_convnd_fwd_example.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return run_convnd_fwd_example(argc, argv) ? 0 : 1; }
|
||||
@@ -326,6 +326,42 @@ using device_grouped_conv_fwd_xdl_f8_bf8_instances = std::tuple<
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <index_t NDimSpatial,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
ConvolutionForwardSpecialization ConvSpec>
|
||||
using device_grouped_conv_fwd_xdl_bf8_f8_instances = std::tuple<
|
||||
// clang-format off
|
||||
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|AComputeType|BComputeType|
|
||||
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | |
|
||||
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | |
|
||||
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
#if(defined(CK_ENABLE_FP8) && defined(CK_ENABLE_BF8))
|
||||
// generic instance
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, BF8, F8, F32, F8, DsLayout, F8, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, BF8, F8>,
|
||||
// instances for small conv.K and conv.C
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, BF8, F8, F32, F8, DsLayout, F8, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, BF8, F8>,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, BF8, F8, F32, F8, DsLayout, F8, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BF8, F8>,
|
||||
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, BF8, F8, F32, F8, DsLayout, F8, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BF8, F8>,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, BF8, F8, F32, F8, DsLayout, F8, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BF8, F8>,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, BF8, F8, F32, F8, DsLayout, F8, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BF8, F8>,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, BF8, F8, F32, F8, DsLayout, F8, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BF8, F8>,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, BF8, F8, F32, F8, DsLayout, F8, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, BF8, F8, F32, F8, DsLayout, F8, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BF8, F8>,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, BF8, F8, F32, F8, DsLayout, F8, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, BF8, F8>,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, BF8, F8, F32, F8, DsLayout, F8, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BF8, F8>,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, BF8, F8, F32, F8, DsLayout, F8, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BF8, F8>,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, BF8, F8, F32, F8, DsLayout, F8, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, BF8, F8, F32, F8, DsLayout, F8, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BF8, F8>,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, BF8, F8, F32, F8, DsLayout, F8, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, BF8, F8>,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, BF8, F8, F32, F8, DsLayout, F8, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, BF8, F8>
|
||||
#endif
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
@@ -301,6 +301,14 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f8_bf8_instances(op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#if(defined(CK_ENABLE_FP8) && defined(CK_ENABLE_BF8))
|
||||
if constexpr(is_same_v<InDataType, ck::bf8_t> && is_same_v<WeiDataType, ck::f8_t> &&
|
||||
is_same_v<OutDataType, ck::f8_t> && is_same_v<AComputeType, ck::bf8_t> &&
|
||||
is_same_v<BComputeType, ck::f8_t>)
|
||||
{
|
||||
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf8_f8_instances(op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
|
||||
is_same_v<OutDataType, half_t> && is_same_v<AComputeType, half_t> &&
|
||||
|
||||
@@ -369,6 +369,24 @@ void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f8_bf8_instances(
|
||||
BF8>>>& instances);
|
||||
#endif
|
||||
|
||||
#if(defined(CK_ENABLE_FP8) && defined(CK_ENABLE_BF8))
|
||||
void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf8_f8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGK,
|
||||
BF8,
|
||||
F8,
|
||||
Empty_Tuple,
|
||||
F8,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BF8,
|
||||
F8>>>& instances);
|
||||
#endif
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
@@ -44,6 +44,8 @@ endif()
|
||||
if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "bf8") OR NOT DEFINED DTYPES)
|
||||
list(APPEND GROUPED_CONV3D_FWD
|
||||
xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_fp8_bf8_instance.cpp)
|
||||
list(APPEND GROUPED_CONV3D_FWD
|
||||
xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf8_fp8_instance.cpp)
|
||||
endif()
|
||||
|
||||
add_instance_library(device_grouped_conv3d_fwd_instance ${GROUPED_CONV3D_FWD})
|
||||
|
||||
@@ -0,0 +1,54 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf8_f8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGK,
|
||||
BF8,
|
||||
F8,
|
||||
Empty_Tuple,
|
||||
F8,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BF8,
|
||||
F8>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_bf8_f8_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGK,
|
||||
ConvFwdDefault>{});
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_bf8_f8_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGK,
|
||||
ConvFwd1x1P0>{});
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_bf8_f8_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGK,
|
||||
ConvFwd1x1S1P0>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -26,6 +26,7 @@ enum struct ConvDataType
|
||||
F8_F8_F8, // 4
|
||||
BF8_BF8_F8, // 5
|
||||
F8_BF8_F8, // 6
|
||||
BF8_F8_F8, // 7
|
||||
};
|
||||
|
||||
#define OP_NAME "grouped_conv_fwd"
|
||||
@@ -42,7 +43,8 @@ static void print_helper_msg()
|
||||
<< " 3: Input int8, Weight int8, Output int8\n"
|
||||
<< " 4: Input fp8, Weight fp8, Output fp8\n"
|
||||
<< " 5: Input bf8, Weight bf8, Output fp8\n"
|
||||
<< " 6: Input fp8, Weight bf8, Output fp8)\n"
|
||||
<< " 6: Input fp8, Weight bf8, Output fp8\n"
|
||||
<< " 7: Input bf8, Weight fp8, Output fp8)\n"
|
||||
<< "arg3: tensor layout (0: Input[G, N, Hi, Wi, C], Weight[G, K, Y, X, C], Output[G, N, Ho, Wo, K]\n"
|
||||
<< " 1: Input[N, Hi, Wi, G, C], Weight[G, K, Y, X, C], Output[N, Ho, Wo, G, K])\n"
|
||||
<< "arg4: verification (0: no, 1: yes)\n"
|
||||
@@ -281,6 +283,10 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
|
||||
{
|
||||
return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F8{}, BF8{}, F8{}, F8{}, BF8{});
|
||||
}
|
||||
else if(data_type == ConvDataType::BF8_F8_F8)
|
||||
{
|
||||
return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, BF8{}, F8{}, F8{}, BF8{}, F8{});
|
||||
}
|
||||
}
|
||||
|
||||
std::cout << "this data_type & layout is not implemented" << std::endl;
|
||||
|
||||
Reference in New Issue
Block a user