Add instances for conv_scale with bf8@fp8->fp8 (#1231)

* Add instances

* Add example

* Add profiler mode

* Add client example
This commit is contained in:
Rostyslav Geyyer
2024-04-11 10:35:00 -05:00
committed by GitHub
parent b2735caf46
commit bbefc12a26
10 changed files with 262 additions and 1 deletions

View File

@@ -20,6 +20,9 @@ endif()
if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "bf8") OR NOT DEFINED DTYPES)
add_executable(client_conv3d_fwd_fp8_bf8 conv3d_fwd_fp8_bf8.cpp)
target_link_libraries(client_conv3d_fwd_fp8_bf8 PRIVATE composable_kernel::device_conv_operations)
add_executable(client_conv3d_fwd_bf8_fp8 conv3d_fwd_bf8_fp8.cpp)
target_link_libraries(client_conv3d_fwd_bf8_fp8 PRIVATE composable_kernel::device_conv_operations)
endif()
if((DTYPES MATCHES "fp32") OR NOT DEFINED DTYPES)

View File

@@ -0,0 +1,50 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
using InDataType = ck::bf8_t;
using WeiDataType = ck::f8_t;
using OutDataType = ck::f8_t;
using InLayout = ck::tensor_layout::convolution::NDHWGC;
using WeiLayout = ck::tensor_layout::convolution::GKZYXC;
using OutLayout = ck::tensor_layout::convolution::NDHWGK;
using AComputeType = ck::bf8_t;
using BComputeType = ck::f8_t;
static constexpr ck::index_t NumDimSpatial = 3;
static constexpr ck::index_t G = 1;
static constexpr ck::index_t N = 64;
static constexpr ck::index_t K = 128;
static constexpr ck::index_t C = 64;
static constexpr ck::index_t Z = 3;
static constexpr ck::index_t Y = 3;
static constexpr ck::index_t X = 3;
static constexpr ck::index_t Di = 28;
static constexpr ck::index_t Hi = 28;
static constexpr ck::index_t Wi = 3;
static constexpr ck::index_t Do = 28;
static constexpr ck::index_t Ho = 28;
static constexpr ck::index_t Wo = 3;
int main()
{
return run_grouped_conv_fwd<NumDimSpatial,
InDataType,
WeiDataType,
OutDataType,
InLayout,
WeiLayout,
OutLayout,
3,
AComputeType,
BComputeType>(
{N, Di, Hi, Wi, G, C}, {G, K, Z, Y, X, C}, {N, Do, Ho, Wo, G, K})
? EXIT_SUCCESS
: EXIT_FAILURE;
}