mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
Add instances of grouped convolution 3d forward with a ConvScale element-wise op for bf8@bf8->fp8 (#1326)
We are adding more instances of grouped convolution 3d forward with a ConvScale element-wise operation. This commit handles bf8@bf8->fp8 data types combination. * Included an example. * Added instances. * Added a client example. --------- Co-authored-by: Rostyslav Geyyer <rosty.geyyer@amd.com> Co-authored-by: Bartłomiej Kocot <barkocot@amd.com>
This commit is contained in:
committed by
GitHub
parent
fa129c1a5d
commit
05b10e0e5a
@@ -40,9 +40,14 @@ add_executable(client_conv3d_fwd_convinvscale_fp8
|
||||
grouped_convnd_fwd_convinvscale/conv3d_fwd_convinvscale_fp8.cpp)
|
||||
target_link_libraries(client_conv3d_fwd_convinvscale_fp8 PRIVATE composable_kernel::device_conv_operations)
|
||||
# Fwd convscale
|
||||
add_executable(client_conv3d_fwd_convscale_fp8
|
||||
add_executable(client_conv3d_fwd_convscale_fp8
|
||||
grouped_convnd_fwd_convscale/conv3d_fwd_convscale_fp8.cpp)
|
||||
target_link_libraries(client_conv3d_fwd_convscale_fp8 PRIVATE composable_kernel::device_conv_operations)
|
||||
|
||||
add_executable(client_conv3d_fwd_convscale_bf8
|
||||
grouped_convnd_fwd_convscale/conv3d_fwd_convscale_bf8.cpp)
|
||||
target_link_libraries(client_conv3d_fwd_convscale_bf8 PRIVATE composable_kernel::device_conv_operations)
|
||||
|
||||
add_executable(client_conv3d_fwd_convscale_fp8_bf8
|
||||
grouped_convnd_fwd_convscale/conv3d_fwd_convscale_fp8_bf8.cpp)
|
||||
target_link_libraries(client_conv3d_fwd_convscale_fp8_bf8 PRIVATE composable_kernel::device_conv_operations)
|
||||
|
||||
@@ -0,0 +1,50 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
|
||||
using InDataType = ck::bf8_t;
|
||||
using WeiDataType = ck::bf8_t;
|
||||
using CShuffleDataType = float;
|
||||
using OutDataType = ck::f8_t;
|
||||
using AComputeDataType = InDataType;
|
||||
using BComputeDataType = AComputeDataType;
|
||||
|
||||
using InLayout = ck::tensor_layout::convolution::NDHWGC;
|
||||
using WeiLayout = ck::tensor_layout::convolution::GKZYXC;
|
||||
using OutLayout = ck::tensor_layout::convolution::NDHWGK;
|
||||
|
||||
static constexpr ck::index_t NumDimSpatial = 3;
|
||||
static constexpr ck::index_t G = 1;
|
||||
static constexpr ck::index_t N = 64;
|
||||
static constexpr ck::index_t K = 128;
|
||||
static constexpr ck::index_t C = 64;
|
||||
static constexpr ck::index_t Z = 3;
|
||||
static constexpr ck::index_t Y = 3;
|
||||
static constexpr ck::index_t X = 3;
|
||||
static constexpr ck::index_t Di = 28;
|
||||
static constexpr ck::index_t Hi = 28;
|
||||
static constexpr ck::index_t Wi = 3;
|
||||
static constexpr ck::index_t Do = 28;
|
||||
static constexpr ck::index_t Ho = 28;
|
||||
static constexpr ck::index_t Wo = 3;
|
||||
|
||||
int main()
|
||||
{
|
||||
return run_grouped_conv_fwd_convscale<NumDimSpatial,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout,
|
||||
3,
|
||||
AComputeDataType,
|
||||
BComputeDataType>(
|
||||
{N, Di, Hi, Wi, G, C}, {G, K, Z, Y, X, C}, {N, Do, Ho, Wo, G, K})
|
||||
? EXIT_SUCCESS
|
||||
: EXIT_FAILURE;
|
||||
}
|
||||
@@ -6,46 +6,36 @@ if (DTYPES)
|
||||
add_definitions(-DDTYPES)
|
||||
if (DTYPES MATCHES "int8")
|
||||
add_definitions(-DCK_ENABLE_INT8)
|
||||
if(NOT DEFINED ${CK_ENABLE_INT8})
|
||||
set(CK_ENABLE_INT8 "ON")
|
||||
endif()
|
||||
set(CK_ENABLE_INT8 "ON")
|
||||
endif()
|
||||
if (DTYPES MATCHES "fp8")
|
||||
add_definitions(-DCK_ENABLE_FP8)
|
||||
if(NOT DEFINED ${CK_ENABLE_FP8})
|
||||
set(CK_ENABLE_FP8 "ON")
|
||||
endif()
|
||||
set(CK_ENABLE_FP8 "ON")
|
||||
endif()
|
||||
if (DTYPES MATCHES "bf8")
|
||||
add_definitions(-DCK_ENABLE_BF8)
|
||||
set(CK_ENABLE_BF8 "ON")
|
||||
endif()
|
||||
if (DTYPES MATCHES "fp16")
|
||||
add_definitions(-DCK_ENABLE_FP16)
|
||||
if(NOT DEFINED ${CK_ENABLE_FP16})
|
||||
set(CK_ENABLE_FP16 "ON")
|
||||
endif()
|
||||
set(CK_ENABLE_FP16 "ON")
|
||||
endif()
|
||||
if (DTYPES MATCHES "fp32")
|
||||
add_definitions(-DCK_ENABLE_FP32)
|
||||
if(NOT DEFINED ${CK_ENABLE_FP32})
|
||||
set(CK_ENABLE_FP32 "ON")
|
||||
endif()
|
||||
set(CK_ENABLE_FP32 "ON")
|
||||
endif()
|
||||
if (DTYPES MATCHES "fp64")
|
||||
add_definitions(-DCK_ENABLE_FP64)
|
||||
if(NOT DEFINED ${CK_ENABLE_FP64})
|
||||
set(CK_ENABLE_FP64 "ON")
|
||||
endif()
|
||||
set(CK_ENABLE_FP64 "ON")
|
||||
endif()
|
||||
if (DTYPES MATCHES "bf16")
|
||||
add_definitions(-DCK_ENABLE_BF16)
|
||||
if(NOT DEFINED ${CK_ENABLE_BF16})
|
||||
set(CK_ENABLE_BF16 "ON")
|
||||
endif()
|
||||
set(CK_ENABLE_BF16 "ON")
|
||||
endif()
|
||||
message("DTYPES macro set to ${DTYPES}")
|
||||
else()
|
||||
add_definitions(-DCK_ENABLE_INT8 -DCK_ENABLE_FP8 -DCK_ENABLE_FP16 -DCK_ENABLE_FP32 -DCK_ENABLE_FP64 -DCK_ENABLE_BF16)
|
||||
if(NOT DEFINED ${CK_ENABLE_ALL_DTYPES})
|
||||
set(CK_ENABLE_ALL_DTYPES "ON")
|
||||
endif()
|
||||
add_definitions(-DCK_ENABLE_INT8 -DCK_ENABLE_FP8 -DCK_ENABLE_BF8 -DCK_ENABLE_FP16 -DCK_ENABLE_FP32 -DCK_ENABLE_FP64 -DCK_ENABLE_BF16)
|
||||
set(CK_ENABLE_ALL_DTYPES "ON")
|
||||
endif()
|
||||
|
||||
if (GPU_TARGETS)
|
||||
@@ -73,7 +63,8 @@ message(STATUS "Build with HIP ${hip_VERSION}")
|
||||
# add all example subdir
|
||||
file(GLOB dir_list LIST_DIRECTORIES true *)
|
||||
FOREACH(subdir ${dir_list})
|
||||
IF(IS_DIRECTORY "${subdir}" AND (NOT "${subdir}" MATCHES "build"))
|
||||
IF(IS_DIRECTORY "${subdir}" AND (NOT "${subdir}" MATCHES "build")
|
||||
AND (NOT "${subdir}" MATCHES ".vscode"))
|
||||
add_subdirectory(${subdir})
|
||||
ENDIF()
|
||||
ENDFOREACH()
|
||||
|
||||
Reference in New Issue
Block a user