Add BF16 example.

This commit is contained in:
Ville Pietilä
2026-02-09 06:36:20 -05:00
parent 60d1ec34a9
commit b722492a30
4 changed files with 43 additions and 14 deletions

View File

@@ -8,6 +8,9 @@ add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_con
add_example_executable(example_grouped_conv_fwd_xdl_fp16 grouped_conv_fwd_xdl_fp16.cpp)
add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_xdl_fp16)
add_example_executable(example_grouped_conv_fwd_xdl_bf16 grouped_conv_fwd_xdl_bf16.cpp)
add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_xdl_bf16)
add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_fp32 grouped_conv_fwd_bias_relu_add_xdl_fp32.cpp)
add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_fp32)

View File

@@ -13,6 +13,7 @@
#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"

View File

@@ -0,0 +1,24 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "common.hpp"
// kernel data types
using InKernelDataType = BF16;
using WeiKernelDataType = BF16;
using AccDataType = FP32;
using CShuffleDataType = BF16;
using OutKernelDataType = BF16;
// tensor data types
using InUserDataType = InKernelDataType;
using WeiUserDataType = WeiKernelDataType;
using OutUserDataType = OutKernelDataType;
using InElementOp = PassThrough;
using WeiElementOp = PassThrough;
using OutElementOp = PassThrough;
#include "run_grouped_conv_fwd_example.inc"
int main(int argc, char* argv[]) { return !run_grouped_conv_fwd_example(argc, argv); }

View File

@@ -3,7 +3,7 @@
template <ck::index_t NDimSpatial>
using DeviceConvFwdInstance =
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<
NDimSpatial,
InputLayout<NDimSpatial>,
WeightLayout<NDimSpatial>,
@@ -20,35 +20,36 @@ using DeviceConvFwdInstance =
OutElementOp,
ConvSpec, // ConvForwardSpecialization
GemmSpec, // GemmSpecialization
1, //
256, // BlockSize
128, // MPerBlock
256, // NPerBlock
16, // KPerBlock
4, // AK1
4, // BK1
256, // MPerBlock
128, // NPerBlock
32, // KPerBlock
8, // AK1
8, // BK1
16, // MPerXdl
16, // NPerXdl
4, // MXdlPerWave
8, // NXdlPerWave
8, // MXdlPerWave
4, // NXdlPerWave
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
4, // ABlockTransferSrcScalarPerVector
4, // ABlockTransferDstScalarPerVector_AK1
8, // ABlockTransferSrcScalarPerVector
8, // ABlockTransferDstScalarPerVector_AK1
1, // ABlockLdsExtraM
S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
4, // BBlockTransferSrcScalarPerVector
4, // BBlockTransferDstScalarPerVector_BK1
8, // BBlockTransferSrcScalarPerVector
8, // BBlockTransferDstScalarPerVector_BK1
1, // BBlockLdsExtraN
1,
1,
S<1, 32, 1, 8>,
4>;
4,
ck::BlockGemmPipelineScheduler::Intrawave,
ck::BlockGemmPipelineVersion::v3>;
template <ck::index_t NDimSpatial>
using HostConvFwdInstance = ck::tensor_operation::host::ReferenceConvFwd<NDimSpatial,