mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
Add instances of grouped convolution 3d forward with a ConvScale element-wise op for bf8@bf8->fp8 (#1326)
We are adding more instances of grouped convolution 3d forward with a ConvScale element-wise operation. This commit handles bf8@bf8->fp8 data types combination. * Included an example. * Added instances. * Added a client example. --------- Co-authored-by: Rostyslav Geyyer <rosty.geyyer@amd.com> Co-authored-by: Bartłomiej Kocot <barkocot@amd.com>
This commit is contained in:
committed by
GitHub
parent
fa129c1a5d
commit
05b10e0e5a
@@ -40,9 +40,14 @@ add_executable(client_conv3d_fwd_convinvscale_fp8
|
||||
grouped_convnd_fwd_convinvscale/conv3d_fwd_convinvscale_fp8.cpp)
|
||||
target_link_libraries(client_conv3d_fwd_convinvscale_fp8 PRIVATE composable_kernel::device_conv_operations)
|
||||
# Fwd convscale
|
||||
add_executable(client_conv3d_fwd_convscale_fp8
|
||||
add_executable(client_conv3d_fwd_convscale_fp8
|
||||
grouped_convnd_fwd_convscale/conv3d_fwd_convscale_fp8.cpp)
|
||||
target_link_libraries(client_conv3d_fwd_convscale_fp8 PRIVATE composable_kernel::device_conv_operations)
|
||||
|
||||
add_executable(client_conv3d_fwd_convscale_bf8
|
||||
grouped_convnd_fwd_convscale/conv3d_fwd_convscale_bf8.cpp)
|
||||
target_link_libraries(client_conv3d_fwd_convscale_bf8 PRIVATE composable_kernel::device_conv_operations)
|
||||
|
||||
add_executable(client_conv3d_fwd_convscale_fp8_bf8
|
||||
grouped_convnd_fwd_convscale/conv3d_fwd_convscale_fp8_bf8.cpp)
|
||||
target_link_libraries(client_conv3d_fwd_convscale_fp8_bf8 PRIVATE composable_kernel::device_conv_operations)
|
||||
|
||||
@@ -0,0 +1,50 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
|
||||
using InDataType = ck::bf8_t;
|
||||
using WeiDataType = ck::bf8_t;
|
||||
using CShuffleDataType = float;
|
||||
using OutDataType = ck::f8_t;
|
||||
using AComputeDataType = InDataType;
|
||||
using BComputeDataType = AComputeDataType;
|
||||
|
||||
using InLayout = ck::tensor_layout::convolution::NDHWGC;
|
||||
using WeiLayout = ck::tensor_layout::convolution::GKZYXC;
|
||||
using OutLayout = ck::tensor_layout::convolution::NDHWGK;
|
||||
|
||||
static constexpr ck::index_t NumDimSpatial = 3;
|
||||
static constexpr ck::index_t G = 1;
|
||||
static constexpr ck::index_t N = 64;
|
||||
static constexpr ck::index_t K = 128;
|
||||
static constexpr ck::index_t C = 64;
|
||||
static constexpr ck::index_t Z = 3;
|
||||
static constexpr ck::index_t Y = 3;
|
||||
static constexpr ck::index_t X = 3;
|
||||
static constexpr ck::index_t Di = 28;
|
||||
static constexpr ck::index_t Hi = 28;
|
||||
static constexpr ck::index_t Wi = 3;
|
||||
static constexpr ck::index_t Do = 28;
|
||||
static constexpr ck::index_t Ho = 28;
|
||||
static constexpr ck::index_t Wo = 3;
|
||||
|
||||
int main()
|
||||
{
|
||||
return run_grouped_conv_fwd_convscale<NumDimSpatial,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout,
|
||||
3,
|
||||
AComputeDataType,
|
||||
BComputeDataType>(
|
||||
{N, Di, Hi, Wi, G, C}, {G, K, Z, Y, X, C}, {N, Do, Ho, Wo, G, K})
|
||||
? EXIT_SUCCESS
|
||||
: EXIT_FAILURE;
|
||||
}
|
||||
Reference in New Issue
Block a user