mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
Add BF16 example.
This commit is contained in:
@@ -8,6 +8,9 @@ add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_con
|
||||
add_example_executable(example_grouped_conv_fwd_xdl_fp16 grouped_conv_fwd_xdl_fp16.cpp)
|
||||
add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_xdl_fp16)
|
||||
|
||||
add_example_executable(example_grouped_conv_fwd_xdl_bf16 grouped_conv_fwd_xdl_bf16.cpp)
|
||||
add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_xdl_bf16)
|
||||
|
||||
add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_fp32 grouped_conv_fwd_bias_relu_add_xdl_fp32.cpp)
|
||||
add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_fp32)
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
|
||||
@@ -0,0 +1,24 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
// kernel data types
|
||||
using InKernelDataType = BF16;
|
||||
using WeiKernelDataType = BF16;
|
||||
using AccDataType = FP32;
|
||||
using CShuffleDataType = BF16;
|
||||
using OutKernelDataType = BF16;
|
||||
|
||||
// tensor data types
|
||||
using InUserDataType = InKernelDataType;
|
||||
using WeiUserDataType = WeiKernelDataType;
|
||||
using OutUserDataType = OutKernelDataType;
|
||||
|
||||
using InElementOp = PassThrough;
|
||||
using WeiElementOp = PassThrough;
|
||||
using OutElementOp = PassThrough;
|
||||
|
||||
#include "run_grouped_conv_fwd_example.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_grouped_conv_fwd_example(argc, argv); }
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
template <ck::index_t NDimSpatial>
|
||||
using DeviceConvFwdInstance =
|
||||
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<
|
||||
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<
|
||||
NDimSpatial,
|
||||
InputLayout<NDimSpatial>,
|
||||
WeightLayout<NDimSpatial>,
|
||||
@@ -20,35 +20,36 @@ using DeviceConvFwdInstance =
|
||||
OutElementOp,
|
||||
ConvSpec, // ConvForwardSpecialization
|
||||
GemmSpec, // GemmSpecialization
|
||||
1, //
|
||||
256, // BlockSize
|
||||
128, // MPerBlock
|
||||
256, // NPerBlock
|
||||
16, // KPerBlock
|
||||
4, // AK1
|
||||
4, // BK1
|
||||
256, // MPerBlock
|
||||
128, // NPerBlock
|
||||
32, // KPerBlock
|
||||
8, // AK1
|
||||
8, // BK1
|
||||
16, // MPerXdl
|
||||
16, // NPerXdl
|
||||
4, // MXdlPerWave
|
||||
8, // NXdlPerWave
|
||||
8, // MXdlPerWave
|
||||
4, // NXdlPerWave
|
||||
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
|
||||
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
|
||||
2, // ABlockTransferSrcVectorDim
|
||||
4, // ABlockTransferSrcScalarPerVector
|
||||
4, // ABlockTransferDstScalarPerVector_AK1
|
||||
8, // ABlockTransferSrcScalarPerVector
|
||||
8, // ABlockTransferDstScalarPerVector_AK1
|
||||
1, // ABlockLdsExtraM
|
||||
S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
|
||||
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
|
||||
2, // BBlockTransferSrcVectorDim
|
||||
4, // BBlockTransferSrcScalarPerVector
|
||||
4, // BBlockTransferDstScalarPerVector_BK1
|
||||
8, // BBlockTransferSrcScalarPerVector
|
||||
8, // BBlockTransferDstScalarPerVector_BK1
|
||||
1, // BBlockLdsExtraN
|
||||
1,
|
||||
1,
|
||||
S<1, 32, 1, 8>,
|
||||
4>;
|
||||
4,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave,
|
||||
ck::BlockGemmPipelineVersion::v3>;
|
||||
|
||||
template <ck::index_t NDimSpatial>
|
||||
using HostConvFwdInstance = ck::tensor_operation::host::ReferenceConvFwd<NDimSpatial,
|
||||
|
||||
Reference in New Issue
Block a user