mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-17 11:30:02 +00:00
3d grouped conv fwd with input/output fp16 and comp fp8 (#931)
* add f8 comp instance
* fixed
* fixed comments
* rename
* fixed dtype
* format
* fixed CI
* fixed ci
* add missing ComputeType
* fixed cit
* fixed
* Update cmake-ck-dev.sh
---------
Co-authored-by: Jing Zhang <jizha@amd.com>
[ROCm/composable_kernel commit: e921e1f08d]
This commit is contained in:
@@ -1,5 +1,15 @@
|
||||
add_executable(client_conv3d_fwd_fp16 conv3d_fwd_fp16.cpp)
|
||||
add_executable(client_conv3d_fwd_fp32 conv3d_fwd_fp32.cpp)
|
||||
if((DTYPES MATCHES "fp16") OR NOT DEFINED DTYPES)
|
||||
add_executable(client_conv3d_fwd_fp16 conv3d_fwd_fp16.cpp)
|
||||
target_link_libraries(client_conv3d_fwd_fp16 PRIVATE composable_kernel::device_operations)
|
||||
|
||||
target_link_libraries(client_conv3d_fwd_fp16 PRIVATE composable_kernel::device_operations)
|
||||
target_link_libraries(client_conv3d_fwd_fp32 PRIVATE composable_kernel::device_operations)
|
||||
endif()
|
||||
|
||||
if((DTYPES MATCHES "fp8") OR NOT DEFINED DTYPES)
|
||||
add_executable(client_conv3d_fwd_fp16_comp_fp8 conv3d_fwd_fp16_comp_fp8.cpp)
|
||||
target_link_libraries(client_conv3d_fwd_fp16_comp_fp8 PRIVATE composable_kernel::device_operations)
|
||||
endif()
|
||||
|
||||
if((DTYPES MATCHES "fp32") OR NOT DEFINED DTYPES)
|
||||
add_executable(client_conv3d_fwd_fp32 conv3d_fwd_fp32.cpp)
|
||||
target_link_libraries(client_conv3d_fwd_fp32 PRIVATE composable_kernel::device_operations)
|
||||
endif()
|
||||
|
||||
@@ -94,7 +94,8 @@ template <ck::index_t NumDimSpatial,
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
typename OutLayout,
|
||||
ck::index_t NumNonSpatialDim = 3>
|
||||
ck::index_t NumNonSpatialDim = 3,
|
||||
typename ComputeType = InDataType>
|
||||
bool run_grouped_conv_fwd(std::array<ck::index_t, NumDimSpatial + NumNonSpatialDim> in_lengths,
|
||||
std::array<ck::index_t, NumDimSpatial + NumNonSpatialDim> wei_lengths,
|
||||
std::array<ck::index_t, NumDimSpatial + NumNonSpatialDim> out_lengths)
|
||||
@@ -184,7 +185,8 @@ bool run_grouped_conv_fwd(std::array<ck::index_t, NumDimSpatial + NumNonSpatialD
|
||||
OutDataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>;
|
||||
PassThrough,
|
||||
ComputeType>;
|
||||
// get device op instances
|
||||
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
|
||||
DeviceOp>::GetInstances();
|
||||
|
||||
46
client_example/16_convnd_fwd/conv3d_fwd_fp16_comp_fp8.cpp
Normal file
46
client_example/16_convnd_fwd/conv3d_fwd_fp16_comp_fp8.cpp
Normal file
@@ -0,0 +1,46 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
|
||||
using InDataType = ck::half_t;
|
||||
using WeiDataType = ck::half_t;
|
||||
using OutDataType = ck::half_t;
|
||||
|
||||
using InLayout = ck::tensor_layout::convolution::NDHWGC;
|
||||
using WeiLayout = ck::tensor_layout::convolution::GKZYXC;
|
||||
using OutLayout = ck::tensor_layout::convolution::NDHWGK;
|
||||
|
||||
static constexpr ck::index_t NumDimSpatial = 3;
|
||||
static constexpr ck::index_t G = 1;
|
||||
static constexpr ck::index_t N = 64;
|
||||
static constexpr ck::index_t K = 128;
|
||||
static constexpr ck::index_t C = 64;
|
||||
static constexpr ck::index_t Z = 3;
|
||||
static constexpr ck::index_t Y = 3;
|
||||
static constexpr ck::index_t X = 3;
|
||||
static constexpr ck::index_t Di = 28;
|
||||
static constexpr ck::index_t Hi = 28;
|
||||
static constexpr ck::index_t Wi = 3;
|
||||
static constexpr ck::index_t Do = 28;
|
||||
static constexpr ck::index_t Ho = 28;
|
||||
static constexpr ck::index_t Wo = 3;
|
||||
|
||||
int main()
|
||||
{
|
||||
return run_grouped_conv_fwd<NumDimSpatial,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout,
|
||||
3,
|
||||
ck::f8_t>(
|
||||
{N, Di, Hi, Wi, G, C}, {G, K, Z, Y, X, C}, {N, Do, Ho, Wo, G, K})
|
||||
? EXIT_SUCCESS
|
||||
: EXIT_FAILURE;
|
||||
}
|
||||
Reference in New Issue
Block a user