From d4dbf931193914f0a77a0d4085f4d83f3dcfefdf Mon Sep 17 00:00:00 2001 From: lym Date: Mon, 15 Sep 2025 21:03:00 +0800 Subject: [PATCH] feature:tf32:add initial conv3d fwd kernel support (#2763) [ROCm/composable_kernel commit: c51102144f481be65b7aa803e830ce1f684b2f02] --- example/01_gemm/CMakeLists.txt | 3 + example/01_gemm/common.hpp | 16 ++- .../gemm_xdl_lds_direct_load_fp32_tf32.cpp | 85 +++++++++++ example/01_gemm/run_gemm_example.inc | 13 +- example/09_convnd_fwd/CMakeLists.txt | 3 +- example/09_convnd_fwd/convnd_fwd_common.hpp | 29 ++-- .../convnd_fwd_xdl_fp32_tf32.cpp | 89 ++++++++++++ example/09_convnd_fwd/convnd_fwd_xdl_fp8.cpp | 4 + .../09_convnd_fwd/run_convnd_fwd_example.inc | 27 ++-- include/ck/host_utility/device_prop.hpp | 2 + include/ck/library/utility/check_err.hpp | 2 +- .../gpu/block/blockwise_gemm_xdlops.hpp | 85 ++++++----- ...vice_gemm_xdl_cshuffle_lds_direct_load.hpp | 12 +- ...ped_conv_fwd_multiple_abd_xdl_cshuffle.hpp | 51 ++++++- ...ridwise_gemm_multiple_abd_xdl_cshuffle.hpp | 4 +- .../gridwise_gemm_multiple_d_xdl_cshuffle.hpp | 39 ++--- ...ultiple_d_xdl_cshuffle_lds_direct_load.hpp | 16 ++- .../tensor_operation/gpu/warp/xdlops_gemm.hpp | 136 +++++++++++++++--- include/ck/utility/amd_xdlops.hpp | 41 ++++++ include/ck/utility/data_type.hpp | 35 +++++ include/ck/utility/type_convert.hpp | 13 ++ .../cpu/reference_conv_fwd.hpp | 7 +- .../cpu/reference_gemm.hpp | 40 ++++-- .../gpu/reference_gemm.hpp | 19 ++- .../device_operation_instance_factory.hpp | 1 + ...ouped_conv_fwd_xdl_dynamic_op_instance.hpp | 1 + .../device_grouped_conv_fwd_xdl_instance.hpp | 43 +++++- .../gpu/grouped_convolution_forward.hpp | 6 + ...grouped_convolution_forward_bias_clamp.hpp | 8 ++ ...ped_convolution_forward_bias_clamp_xdl.inc | 16 +++ .../gpu/grouped_convolution_forward_clamp.hpp | 8 ++ .../grouped_convolution_forward_clamp_xdl.inc | 16 +++ ...grouped_convolution_forward_dynamic_op.hpp | 14 +- .../gpu/grouped_convolution_forward_xdl.inc | 16 +++ .../gpu/grouped_conv3d_fwd/CMakeLists.txt | 3 +- ...ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp | 56 ++++++++ .../CMakeLists.txt | 59 ++++---- ..._ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.in | 81 +++++++++++ .../CMakeLists.txt | 4 +- ...dhwgc_gkzyxc_ndhwgk_fp32_tf32_instance.cpp | 60 ++++++++ .../grouped_conv3d_fwd_clamp/CMakeLists.txt | 4 +- ...dhwgc_gkzyxc_ndhwgk_fp32_tf32_instance.cpp | 60 ++++++++ library/src/utility/host_tensor.cpp | 2 +- profiler/src/profile_grouped_conv_fwd.cpp | 31 ++-- 44 files changed, 1085 insertions(+), 175 deletions(-) create mode 100644 example/01_gemm/gemm_xdl_lds_direct_load_fp32_tf32.cpp create mode 100644 example/09_convnd_fwd/convnd_fwd_xdl_fp32_tf32.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/xdl/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.in create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/xdl/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/xdl/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_instance.cpp diff --git a/example/01_gemm/CMakeLists.txt b/example/01_gemm/CMakeLists.txt index 61f3ba5351..bae9fb9e24 100644 --- a/example/01_gemm/CMakeLists.txt +++ b/example/01_gemm/CMakeLists.txt @@ -87,6 +87,9 @@ foreach(gpu IN LISTS GPU_TARGETS) add_example_executable(example_gemm_xdl_lds_direct_load_fp32 gemm_xdl_lds_direct_load_fp32.cpp) add_example_dependencies(example_gemm_xdl example_gemm_xdl_lds_direct_load_fp32) + add_example_executable(example_gemm_xdl_lds_direct_load_fp32_tf32 gemm_xdl_lds_direct_load_fp32_tf32.cpp) + add_example_dependencies(example_gemm_xdl example_gemm_xdl_lds_direct_load_fp32_tf32) + add_example_executable(example_gemm_xdl_lds_direct_load_fp16 gemm_xdl_lds_direct_load_fp16.cpp) add_example_dependencies(example_gemm_xdl example_gemm_xdl_lds_direct_load_fp16) diff --git a/example/01_gemm/common.hpp b/example/01_gemm/common.hpp index 434f549443..e482953e46 100644 --- a/example/01_gemm/common.hpp +++ b/example/01_gemm/common.hpp @@ -310,10 +310,14 @@ bool parse_cmd_args(int argc, return true; } -template +template inline __host__ __device__ constexpr double get_rtol() { - if constexpr(std::is_same_v) + if constexpr(std::is_same_v && std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) { return 1e-3; } @@ -351,10 +355,14 @@ inline __host__ __device__ constexpr double get_rtol() } } -template +template inline __host__ __device__ constexpr double get_atol() { - if constexpr(std::is_same_v) + if constexpr(std::is_same_v && std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) { return 1e-3; } diff --git a/example/01_gemm/gemm_xdl_lds_direct_load_fp32_tf32.cpp b/example/01_gemm/gemm_xdl_lds_direct_load_fp32_tf32.cpp new file mode 100644 index 0000000000..9b92fad779 --- /dev/null +++ b/example/01_gemm/gemm_xdl_lds_direct_load_fp32_tf32.cpp @@ -0,0 +1,85 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "common.hpp" + +#define USING_DIRECT_LOADS 1 +#if USING_DIRECT_LOADS +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_lds_direct_load.hpp" +#else +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp" +#endif + +#define EXAMPLE_WITH_COMPUTE_DATATYPE + +using F32 = float; + +using ADataType = F32; +using BDataType = F32; +using AccDataType = F32; +using CShuffleDataType = F32; +using CDataType = F32; +using ComputeDataType = ck::tf32_t; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +#if USING_DIRECT_LOADS +// clang-format off +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle_LdsDirectLoad +// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| +// ######| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockLds| +// ######| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler | pipeline ver | gemm type | +// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| +// ######| XDL| XDL| Per| Per| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraM| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +// ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| | | PerVector| | Lengths_K0_N_K1| | | PerVector| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 128, 128, 32, + 8, 8, 32, 32, 2, 2, S<4, 8, 8>, S<1, 0, 2>, 2, 1, 1, S<4, 8, 8>, S<1, 0, 2>, 2, 1, 1, + 1, 1, S<1, 8, 1, 8>, 4, ck::LoopScheduler::Default, ck::PipelineVersion::v4, ComputeDataType>; +// clang-format on +#else +// clang-format off +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle +// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +// ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 8, 1, 8>, 4>; +// clang-format on +#endif +using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + +using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm; + +#include "run_gemm_example.inc" + +int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } + +#undef EXAMPLE_WITH_COMPUTE_DATATYPE diff --git a/example/01_gemm/run_gemm_example.inc b/example/01_gemm/run_gemm_example.inc index 3e018aad1e..08e2b8c15f 100644 --- a/example/01_gemm/run_gemm_example.inc +++ b/example/01_gemm/run_gemm_example.inc @@ -4,6 +4,11 @@ #pragma once #include "ck/library/utility/validation_common.hpp" +// use macro to minimize code change +#ifndef EXAMPLE_WITH_COMPUTE_DATATYPE +using ComputeDataType = AccDataType; +#endif + template bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) { @@ -218,8 +223,8 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) pass &= ck::utils::check_err(c_m_n_device_result, c_m_n_host_result, "Error: Incorrect results!", - get_rtol(), - get_atol()); + get_rtol(), + get_atol()); #endif } @@ -249,8 +254,8 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) pass &= ck::utils::check_err(c_m_n_device_result, c_m_n_device_ref_result, "Error: Incorrect results!", - get_rtol(), - get_atol()); + get_rtol(), + get_atol()); } return pass == true; diff --git a/example/09_convnd_fwd/CMakeLists.txt b/example/09_convnd_fwd/CMakeLists.txt index 91c072aef7..72a1cb2afb 100644 --- a/example/09_convnd_fwd/CMakeLists.txt +++ b/example/09_convnd_fwd/CMakeLists.txt @@ -1,4 +1,5 @@ add_example_executable(example_convnd_fwd_xdl_fp32 convnd_fwd_xdl_fp32.cpp) +add_example_executable(example_convnd_fwd_xdl_fp32_tf32 convnd_fwd_xdl_fp32_tf32.cpp) add_example_executable(example_convnd_fwd_xdl_fp16 convnd_fwd_xdl_fp16.cpp) add_example_executable(example_convnd_fwd_xdl_bf16 convnd_fwd_xdl_bf16.cpp) add_example_executable(example_convnd_fwd_xdl_int8 convnd_fwd_xdl_int8.cpp) @@ -19,4 +20,4 @@ foreach(gpu IN LISTS GPU_TARGETS) add_example_executable(example_convnd_fwd_xdl_fp64 convnd_fwd_xdl_fp64.cpp) set(target 1) endif() -endforeach() \ No newline at end of file +endforeach() diff --git a/example/09_convnd_fwd/convnd_fwd_common.hpp b/example/09_convnd_fwd/convnd_fwd_common.hpp index b0fd6a382a..d82b56ec00 100644 --- a/example/09_convnd_fwd/convnd_fwd_common.hpp +++ b/example/09_convnd_fwd/convnd_fwd_common.hpp @@ -27,10 +27,14 @@ void print_helper_msg() << ck::utils::conv::get_conv_param_parser_helper_msg() << std::endl; } -template +template inline __host__ __device__ constexpr double get_rtol() { - if constexpr(std::is_same_v) + if constexpr(std::is_same_v && std::is_same_v) + { + return 5e-3; + } + else if constexpr(std::is_same_v) { return 1e-3; } @@ -68,10 +72,14 @@ inline __host__ __device__ constexpr double get_rtol() } } -template +template inline __host__ __device__ constexpr double get_atol() { - if constexpr(std::is_same_v) + if constexpr(std::is_same_v && std::is_same_v) + { + return 1e-2; + } + else if constexpr(std::is_same_v) { return 1e-3; } @@ -116,7 +124,8 @@ template + typename DeviceConvNDFwdInstance, + typename ComputeDataType = OutDataType> bool run_grouped_conv_fwd(bool do_verification, int init_method, bool time_kernel, @@ -228,7 +237,11 @@ bool run_grouped_conv_fwd(bool do_verification, OutDataType, InElementOp, WeiElementOp, - OutElementOp>(); + OutElementOp, + 0, + 0, + 0, + ComputeDataType>(); auto ref_invoker = ref_conv.MakeInvoker(); auto ref_argument = ref_conv.MakeArgument(in, @@ -249,8 +262,8 @@ bool run_grouped_conv_fwd(bool do_verification, return ck::utils::check_err(out_device, out_host, "Error: incorrect results!", - get_rtol(), - get_atol()); + get_rtol(), + get_atol()); } return true; diff --git a/example/09_convnd_fwd/convnd_fwd_xdl_fp32_tf32.cpp b/example/09_convnd_fwd/convnd_fwd_xdl_fp32_tf32.cpp new file mode 100644 index 0000000000..348da7e1ef --- /dev/null +++ b/example/09_convnd_fwd/convnd_fwd_xdl_fp32_tf32.cpp @@ -0,0 +1,89 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp" + +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" + +#define EXAMPLE_WITH_COMPUTE_DATATYPE + +using InDataType = float; +using WeiDataType = float; +using AccDataType = float; +using CShuffleDataType = float; +using OutDataType = float; +using ComputeDataType = ck::tf32_t; + +template +using S = ck::Sequence; + +using InElementOp = ck::tensor_operation::element_wise::PassThrough; +using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; +using OutElementOp = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvSpec = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +template +using DeviceGroupedConvNDFwdInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< + NDimSpatial, + InLayout, // ALayout + WeiLayout, // BLayout + ck::Tuple<>, // DsLayout + OutLayout, // ELayout + InDataType, // ADataType + WeiDataType, // BDataType + AccDataType, // AccDataType + CShuffleDataType, // CShuffleDataType + ck::Tuple<>, // DsDataType + OutDataType, // EDataType + InElementOp, // AElementwiseOperation + WeiElementOp, // BElementwiseOperation + OutElementOp, // CDEElementwiseOperation + ConvSpec, // ConvForwardSpecialization + GemmSpec, // GemmSpecialization + 1, // NumGemmKPrefetchStage + 256, // BlockSize + 128, // MPerBlock + 192, // NPerBlock + 16, // KPerBlock + 4, // AK1 + 4, // BK1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 3, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 4, // ABlockTransferSrcScalarPerVector + 4, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 4, // BBlockTransferSrcScalarPerVector + 4, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + S<1, 16, 1, 16>, // CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 4, // CDEBlockTransferScalarPerVector_NPerBlock + ComputeDataType, // AComputeDataType + ComputeDataType, // BComputeDataType + ck::LoopScheduler::Default, // LoopScheduler + 1 // NumGroupsToMerge + >; + +#include "run_convnd_fwd_example.inc" + +int main(int argc, char* argv[]) { return run_convnd_fwd_example(argc, argv) ? 0 : 1; } + +#undef EXAMPLE_WITH_COMPUTE_DATATYPE diff --git a/example/09_convnd_fwd/convnd_fwd_xdl_fp8.cpp b/example/09_convnd_fwd/convnd_fwd_xdl_fp8.cpp index fde0f51bc7..c635d01d8f 100644 --- a/example/09_convnd_fwd/convnd_fwd_xdl_fp8.cpp +++ b/example/09_convnd_fwd/convnd_fwd_xdl_fp8.cpp @@ -7,6 +7,8 @@ #include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" +#define EXAMPLE_WITH_COMPUTE_DATATYPE + using InDataType = ck::f8_t; using WeiDataType = ck::f8_t; using AccDataType = float; @@ -87,3 +89,5 @@ int main(int argc, char* argv[]) } return run_convnd_fwd_example(argc, argv) ? 0 : 1; } + +#undef EXAMPLE_WITH_COMPUTE_DATATYPE diff --git a/example/09_convnd_fwd/run_convnd_fwd_example.inc b/example/09_convnd_fwd/run_convnd_fwd_example.inc index 49852ff667..016a189d4b 100644 --- a/example/09_convnd_fwd/run_convnd_fwd_example.inc +++ b/example/09_convnd_fwd/run_convnd_fwd_example.inc @@ -3,6 +3,11 @@ #pragma once +// use macro to minimize code change +#ifndef EXAMPLE_WITH_COMPUTE_DATATYPE +using ComputeDataType = AccDataType; +#endif + bool run_convnd_fwd_example(int argc, char* argv[]) { print_helper_msg(); @@ -65,17 +70,17 @@ bool run_convnd_fwd_example(int argc, char* argv[]) InElementOp, WeiElementOp, OutElementOp, - DeviceGroupedConvNDFwdInstance>( - do_verification, - init_method, - time_kernel, - conv_param, - in_g_n_c_wis_desc, - wei_g_k_c_xs_desc, - out_g_n_k_wos_desc, - in_element_op, - wei_element_op, - out_element_op); + DeviceGroupedConvNDFwdInstance, + ComputeDataType>(do_verification, + init_method, + time_kernel, + conv_param, + in_g_n_c_wis_desc, + wei_g_k_c_xs_desc, + out_g_n_k_wos_desc, + in_element_op, + wei_element_op, + out_element_op); }; namespace ctc = ck::tensor_layout::convolution; diff --git a/include/ck/host_utility/device_prop.hpp b/include/ck/host_utility/device_prop.hpp index 6b04b21e4f..919f6f91c7 100644 --- a/include/ck/host_utility/device_prop.hpp +++ b/include/ck/host_utility/device_prop.hpp @@ -134,5 +134,7 @@ inline bool is_wmma_supported() return is_gfx103_supported() || is_gfx11_supported() || is_gfx12_supported(); } +inline bool is_tf32_supported() { return (ck::get_device_name() == "gfx942") ? true : false; } + } // namespace ck #endif diff --git a/include/ck/library/utility/check_err.hpp b/include/ck/library/utility/check_err.hpp index d33ecaeef8..185166f7ec 100644 --- a/include/ck/library/utility/check_err.hpp +++ b/include/ck/library/utility/check_err.hpp @@ -180,13 +180,13 @@ check_err(const Range& out, if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r)) { max_err = err > max_err ? err : max_err; - err_count++; if(err_count < 5) { std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i << "] != ref[" << i << "]: " << o << " != " << r << std::endl; } res = false; + err_count++; } } if(!res) diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp index e848ca35b5..55015dd30f 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp @@ -49,6 +49,11 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 using ThisThreadBlock = ThisThreadBlock; + using ElementDataTypeA = + conditional_t, float, ComputeTypeA>; + using ElementDataTypeB = + conditional_t, float, ComputeTypeB>; + static constexpr index_t MPerBlock = AK0MK1BlockDesc{}.GetLength(I1); static constexpr index_t NPerBlock = BK0NK1BlockDesc{}.GetLength(I1); static constexpr index_t KPerBlock = @@ -64,7 +69,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 static constexpr index_t WaveSize = BlockSize / MWaves / NWaves; static constexpr auto xdlops_gemm = - XdlopsGemm{}; + XdlopsGemm{}; static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops; @@ -172,6 +177,11 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 static_assert(MPerBlock % (MPerXDL * MRepeat) == 0 && NPerBlock % (NPerXDL * NRepeat) == 0, "wrong!"); + if constexpr(is_same_v || is_same_v) + { + static_assert(is_same_v, + "ComputeTypeA and ComputeTypeB must be same when one of them is tf32"); + } } __host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2() @@ -297,9 +307,9 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 const BBlockBuffer& b_block_buf, CThreadBuffer& c_thread_buf) const { - auto a_thread_buf = make_static_buffer( + auto a_thread_buf = make_static_buffer( a_thread_desc_.GetElementSpaceSize()); - auto b_thread_buf = make_static_buffer( + auto b_thread_buf = make_static_buffer( b_thread_desc_.GetElementSpaceSize()); static_for<0, MRepeat, 1>{}([&](auto m0) { @@ -321,20 +331,20 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 b_thread_buf); static_for<0, KPerThread, KPack>{}([&](auto k) { - vector_type a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; static_for<0, KPack, 1>{}([&](auto i) { - a_thread_vec.template AsType()(i) = a_thread_buf + a_thread_vec.template AsType()(i) = a_thread_buf [Number{}]; - b_thread_vec.template AsType()(i) = b_thread_buf + b_thread_vec.template AsType()(i) = b_thread_buf [Number{}]; }); using mfma_input_type_a = - typename vector_type::type; + typename vector_type::type; using mfma_input_type_b = - typename vector_type::type; + typename vector_type::type; constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); @@ -361,7 +371,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 make_tuple(Number{}, Number{}, xdlops_gemm.GetRegSizePerXdlops())); using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, @@ -371,7 +381,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 A_K1>; using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, @@ -445,6 +455,11 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 using Base::KPerThread; using Base::xdlops_gemm; + using ElementDataTypeA = + conditional_t, float, ComputeTypeA>; + using ElementDataTypeB = + conditional_t, float, ComputeTypeB>; + static constexpr index_t KPerInnerLoop = math::max(KPerThread / NumMacClusters, KPack); // 2-wave optimized blockwise gemm @@ -453,9 +468,9 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 const BBlockBuffer& b_block_buf, CThreadBuffer& c_thread_buf) const { - auto a_thread_buf = make_static_buffer( + auto a_thread_buf = make_static_buffer( a_thread_desc_.GetElementSpaceSize()); - auto b_thread_buf = make_static_buffer( + auto b_thread_buf = make_static_buffer( b_thread_desc_.GetElementSpaceSize()); static_for<0, KPerThread, KPerInnerLoop>{}([&](auto k) { @@ -499,22 +514,22 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) { static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; static_for<0, KPack, 1>{}([&](auto i) { - a_thread_vec.template AsType()(i) = + a_thread_vec.template AsType()(i) = a_thread_buf[Number{}]; - b_thread_vec.template AsType()(i) = + b_thread_vec.template AsType()(i) = b_thread_buf[Number{}]; }); using mfma_input_type_a = - typename vector_type::type; + typename vector_type::type; using mfma_input_type_b = - typename vector_type::type; + typename vector_type::type; constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); @@ -563,7 +578,7 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 make_tuple(Number{}, I1, I1, Number{})); using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, @@ -573,7 +588,7 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 A_K1>; using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, @@ -622,19 +637,21 @@ constexpr auto BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector() } else if constexpr(LoopSched == LoopScheduler::Interwave) { - return BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1{}; + return BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1< + BlockSize, + FloatA, + FloatB, + FloatAcc, + AK0MK1BlockDesc, + BK0NK1BlockDesc, + MPerXDL, + NPerXDL, + MRepeat, + NRepeat, + KPack, + ComputeTypeA, + ComputeTypeB, + CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING_MAC_CLUSTERS>{}; } }; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_lds_direct_load.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_lds_direct_load.hpp index 8daaafaed1..23b0faec67 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_lds_direct_load.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_lds_direct_load.hpp @@ -119,7 +119,9 @@ struct DeviceGemm_Xdl_CShuffle_LdsDirectLoad : public DeviceGemm; + PipelineVer, + ComputeDataType>; + using GridwiseGemm64 = GridwiseGemmBase; using GridwiseGemm32 = GridwiseGemmBase; @@ -214,6 +216,14 @@ struct DeviceGemm_Xdl_CShuffle_LdsDirectLoad : public DeviceGemm) + { + if(!is_tf32_supported()) + { + return false; + } + } + // Check vector load/store. { using Row = ck::tensor_layout::gemm::RowMajor; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp index 1412c960c7..cc8561a09f 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp @@ -1003,11 +1003,20 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle void Print() const { + std::cout << "AComputeDataType: " << get_type_name() + << "; BComputeDataType: " << get_type_name() + << "; EDataType: " << get_type_name() << std::endl; + std::cout << "A[M, K]: " << a_grid_desc_m_k_ << std::endl; std::cout << "B[N, K]: " << b_grid_desc_n_k_ << std::endl; static_for<0, NumDTensor, 1>{}( [&](auto i) { std::cout << "Ds[M, N]: " << ds_grid_desc_m_n_[i] << std::endl; }); std::cout << "E[M, N]: " << e_grid_desc_m_n_ << std::endl; + + std::cout << "a grid desc" << a_grid_desc_ak0_m_ak1_ << std::endl; + std::cout << "b grid desc" << b_grid_desc_bk0_n_bk1_ << std::endl; + std::cout << "e grid desc" << e_grid_desc_mblock_mperblock_nblock_nperblock_ + << std::endl; } // private: @@ -1198,7 +1207,6 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle isMultiA, isMultiB, CTranspose>; - return launch_and_time_kernel( stream_config, kernel, @@ -1281,7 +1289,6 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { float avg_time = 0.f; - if constexpr(NeedTransposeKernel) { const index_t a_grid_size = @@ -1686,7 +1693,23 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle { return false; } - + if constexpr(is_same_v || + is_same_v) + { + if(!is_tf32_supported()) + { + return false; + } + if constexpr(!is_same_v) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "ComputeDataType for A and B should be same while using TF32" + << std::endl; + } + return false; + } + } // check Gridwise GEMM if(get_warp_size() == 64) { @@ -1766,6 +1789,28 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle } } + if constexpr(is_same_v || + is_same_v) + + { + if(!(ck::get_device_name() == "gfx942")) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "TF32 is enabled on gfx942 only" << std::endl; + } + return false; + } + if constexpr(!is_same_v) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "ComputeDataType for A and B should be same while using TF32" + << std::endl; + } + return false; + } + } return false; } diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp index c198711dbb..cbad6a5673 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp @@ -708,7 +708,9 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle MXdlPerWave, NXdlPerWave, KPack, - LoopSched>(); + LoopSched, + AComputeDataType, + BComputeDataType>(); auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp index 59d7f357ec..a97e4503a8 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp @@ -107,8 +107,10 @@ struct GridwiseGemmMultipleD_xdl_cshuffle using BComputeDataType = conditional_t, ck::bhalf_t, BComputeDataType_>; #else - using AComputeDataType = AComputeDataType_; - using BComputeDataType = BComputeDataType_; + using AComputeDataType = + conditional_t, float, AComputeDataType_>; + using BComputeDataType = + conditional_t, float, BComputeDataType_>; #endif __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() @@ -659,26 +661,27 @@ struct GridwiseGemmMultipleD_xdl_cshuffle : false; constexpr auto is_scale_mfma = false; constexpr index_t KPack = math::max(lcm_AK1_BK1, - MfmaSelector::selected_mfma.k_per_blk); - - auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< - BlockSize, - AComputeDataType, - BComputeDataType, - AccDataType, - decltype(a_block_desc_ak0_m_ak1), - decltype(b_block_desc_bk0_n_bk1), - MPerXdl, - NPerXdl, - MXdlPerWave, - NXdlPerWave, - KPack, - LoopSched>(); + auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< + BlockSize, + AComputeDataType, + BComputeDataType, + AccDataType, + decltype(a_block_desc_ak0_m_ak1), + decltype(b_block_desc_bk0_n_bk1), + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + KPack, + LoopSched, + AComputeDataType_, + BComputeDataType_>(); auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp index 095b1c5d63..1e72e78349 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp @@ -144,7 +144,7 @@ template + typename BComputeDataType_ = AComputeDataType_> struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad { static constexpr index_t NumDTensor = DsDataType::Size(); @@ -172,7 +172,10 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad using AComputeDataType = conditional_t, ck::bhalf_t, AComputeDataType_>; #else - using AComputeDataType = AComputeDataType_; + using AComputeDataType = + conditional_t, float, AComputeDataType_>; + using BComputeDataType = + conditional_t, float, BComputeDataType_>; #endif __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() @@ -573,7 +576,6 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad // This forces m/n_block_data_idx_on_grid into SGPR. const index_t m_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); - const index_t n_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); @@ -640,10 +642,10 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad constexpr auto is_scale_mfma = false; constexpr index_t KPack = math::max(lcm_AK1_BK1, - MfmaSelector::selected_mfma.k_per_blk); @@ -659,7 +661,9 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad MXdlPerWave, NXdlPerWave, KPack, - LoopSched>(); + LoopSched, + AComputeDataType_, + BComputeDataType_>(); auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); diff --git a/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp b/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp index deea6ae9cc..a97d9589cf 100644 --- a/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp +++ b/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp @@ -41,11 +41,11 @@ static constexpr bool scale_mfma_hw_support() enum struct MfmaInstr { - mfma_f32_32x32x1xf32 = 0, - mfma_f32_16x16x1xf32, - mfma_f32_4x4x1xf32, - mfma_f32_32x32x2xf32, - mfma_f32_16x16x4xf32, + mfma_f32_32x32x1f32 = 0, + mfma_f32_16x16x1f32, + mfma_f32_4x4x1f32, + mfma_f32_32x32x2f32, + mfma_f32_16x16x4f32, mfma_f32_32x32x4f16, mfma_f32_16x16x4f16, mfma_f32_4x4x4f16, @@ -78,6 +78,8 @@ enum struct MfmaInstr mfma_f32_16x16x128f8f6f4, mfma_scale_f32_32x32x64f8f6f4, mfma_scale_f32_16x16x128f8f6f4, + mfma_f32_16x16x8xf32, // tf32 + mfma_f32_32x32x4xf32, // gfx11 wmma_f32_16x16x16_f16, wmma_f32_16x16x16_bf16, @@ -98,7 +100,7 @@ template struct mfma_type; template <> -struct mfma_type +struct mfma_type { static constexpr index_t group_size = 4; static constexpr index_t num_groups_per_blk = 4; @@ -120,7 +122,7 @@ struct mfma_type }; template <> -struct mfma_type +struct mfma_type { static constexpr index_t group_size = 4; static constexpr index_t num_groups_per_blk = 4; @@ -142,7 +144,7 @@ struct mfma_type }; template <> -struct mfma_type +struct mfma_type { static constexpr index_t group_size = 4; static constexpr index_t num_groups_per_blk = 1; @@ -164,7 +166,7 @@ struct mfma_type }; template <> -struct mfma_type +struct mfma_type { static constexpr index_t group_size = 4; static constexpr index_t num_groups_per_blk = 1; @@ -187,7 +189,7 @@ struct mfma_type // treat 4x4x1 as a single-blk 4x64 mfma template <> -struct mfma_type +struct mfma_type { static constexpr index_t group_size = 4; static constexpr index_t num_groups_per_blk = 1; @@ -947,6 +949,70 @@ struct mfma_type } }; +/** + * num_threads_per_blk == n_per_blk + * num_regs_per_blk * num_input_blks == m_per_blk + * num_regs_per_blk * wave_size == m_per_blk * n_per_blk + * + * group_size * num_groups_per_blk == num_regs_per_blk + * + * num_regs_per_blk is output(CD) register size which is determined by the instruction. + * k_per_blk(K1PerXdlops) is input(AB) register size which is determined by the instruction. + * group_size is corresponding to CD rows mapping. see: GetBeginOfThreadBlk() + * + * is_k_reduction = (k_per_blk == KPerXdlops) ? false: true. + * + * if (is_k_reduction){ + * num_output_blks == 1; + * } else { + * num_input_blks == num_output_blks; + * } + */ +template <> +struct mfma_type +{ + static constexpr index_t wave_size = 64; // fixed + static constexpr index_t m_per_blk = 16; // from the instruction + static constexpr index_t n_per_blk = 16; // from the instruction + static constexpr index_t num_threads_per_blk = n_per_blk; // 16 + static constexpr index_t num_regs_per_blk = m_per_blk * n_per_blk / wave_size; // 4 + static constexpr index_t num_input_blks = m_per_blk / num_regs_per_blk; // 4 + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 1; + static constexpr index_t num_output_blks = 1; + static constexpr index_t k_per_blk = 2; // k_per_blk(K1PerXdlops) should be 2. + static constexpr bool is_k_reduction = true; + + // AB register size : 2, register size: 4 + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_16x16x8xf32::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_type +{ + static constexpr index_t wave_size = 64; // fixed + static constexpr index_t m_per_blk = 32; // from the instruction + static constexpr index_t n_per_blk = 32; // from the instruction + static constexpr index_t num_threads_per_blk = n_per_blk; // 32 + static constexpr index_t num_regs_per_blk = m_per_blk * n_per_blk / wave_size; // 16 + static constexpr index_t num_input_blks = m_per_blk / num_regs_per_blk; // 2 + static constexpr index_t group_size = 4; // corresponding to CD rows mapping + static constexpr index_t num_groups_per_blk = 4; + static constexpr index_t num_output_blks = 1; + static constexpr index_t k_per_blk = 2; + static constexpr bool is_k_reduction = true; + // AB register size: 2, CD register size: 16 + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_32x32x4xf32::Run(a, b, reg_c); + } +}; + // gfx11 struct mfma_type_gfx11_base { @@ -1116,6 +1182,20 @@ struct mfma_type : public mfma_type_gfx12 } }; +/** + * @class MfmaSelector + * @brief Selects the appropriate MFMA instruction type and configuration for given data types + * and tile sizes on AMD GPUs. + * + * @tparam base_type The base data type for the matrix operation (e.g., float, half_t). + * @tparam MPerXdlops The number of rows per XDLops tile. + * @tparam NPerXdlops The number of columns per XDLops tile. + * @tparam additional_type (Optional) Additional data type for mixed-precision or special cases. + * Defaults to base_type. + * @tparam is_single_rate_mfma (Optional) Whether to use single-rate MFMA instructions. + * Defaults to false. + * @tparam is_scale_mfma (Optional) Whether to use scale MFMA instructions. Defaults to false. + */ template constexpr auto GetMfma() { - return MfmaInstr::mfma_f32_32x32x1xf32; + return MfmaInstr::mfma_f32_32x32x1f32; } template <> constexpr auto GetMfma() { - return MfmaInstr::mfma_f32_32x32x1xf32; + return MfmaInstr::mfma_f32_32x32x1f32; } template <> constexpr auto GetMfma() { - return MfmaInstr::mfma_f32_16x16x1xf32; + return MfmaInstr::mfma_f32_16x16x1f32; } template <> constexpr auto GetMfma() { - return MfmaInstr::mfma_f32_4x4x1xf32; + return MfmaInstr::mfma_f32_4x4x1f32; } template <> constexpr auto GetMfma() { - return MfmaInstr::mfma_f32_4x4x1xf32; + return MfmaInstr::mfma_f32_4x4x1f32; } template <> constexpr auto GetMfma() { - return MfmaInstr::mfma_f32_32x32x2xf32; + return MfmaInstr::mfma_f32_32x32x2f32; } template <> @@ -1188,10 +1268,22 @@ struct MfmaSelector #elif defined(__gfx11__) return MfmaInstr::wmma_unsupport_16x16_gfx11; #else - return MfmaInstr::mfma_f32_16x16x4xf32; + return MfmaInstr::mfma_f32_16x16x4f32; #endif } + template <> + constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_32x32x4xf32; + } + + template <> + constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_16x16x8xf32; + } + template <> constexpr auto GetMfma() { @@ -1896,7 +1988,7 @@ struct XdlopsGemm __device__ __host__ static constexpr index_t GetRegSizePerXdlops() { - return MPerXdlops * NPerXdlops / mfma_instr.wave_size; + return mfma_instr.num_regs_per_blk; } __device__ static constexpr index_t GetWaveSize() { return mfma_instr.wave_size; } @@ -1906,12 +1998,12 @@ struct XdlopsGemm { static_assert( is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value || (is_same::value && is_same::value) || (is_same::value && is_same::value), - "base base_type must be double, float, half, bfloat16, int8_t, f8_t or bf8_t!"); + "base_type must be double, float, tf32_t, half, bfloat16, int8_t, f8_t or bf8_t!"); static_for<0, KPack / mfma_instr.k_per_blk, 1>{}([&](auto k) { if constexpr(!TransposeC) diff --git a/include/ck/utility/amd_xdlops.hpp b/include/ck/utility/amd_xdlops.hpp index 02a7a72b8c..be3a5cea42 100644 --- a/include/ck/utility/amd_xdlops.hpp +++ b/include/ck/utility/amd_xdlops.hpp @@ -1636,4 +1636,45 @@ struct intrin_mfma_f32_16x16x32bf8f8<16, 16> } }; +/******************* tf32 *************************************/ +template +struct intrin_mfma_f32_16x16x8xf32; + +template <> +struct intrin_mfma_f32_16x16x8xf32<16, 16> +{ + template + __device__ static void Run(const float2_t& reg_a, const float2_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx94__) + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x8_xf32( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif + } +}; + +template +struct intrin_mfma_f32_32x32x4xf32; + +template <> +struct intrin_mfma_f32_32x32x4xf32<32, 32> +{ + template + __device__ static void Run(const float2_t& reg_a, const float2_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx94__) + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x4_xf32( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif + } +}; + } // namespace ck diff --git a/include/ck/utility/data_type.hpp b/include/ck/utility/data_type.hpp index 5fbe30d21b..48b352986e 100644 --- a/include/ck/utility/data_type.hpp +++ b/include/ck/utility/data_type.hpp @@ -26,6 +26,7 @@ using byte = unsigned char; using std::byte; #endif +using tf32_t = _BitInt(19); // 1 sign bit, 8 exponent bits, 10 mantissa bits using bhalf_t = ushort; using half_t = _Float16; using int4_t = _BitInt(4); @@ -461,4 +462,38 @@ using int64_t = long long; using int64_t = long; #endif +template +inline const char* get_type_name() +{ + if constexpr(is_same_v) + return "fp16"; + else if constexpr(is_same_v) + return "bf16"; + else if constexpr(is_same_v) + return "tf32"; + else if constexpr(is_same_v) + return "int4"; + else if constexpr(is_same_v) + return "f4"; + else if constexpr(is_same_v) + return "f6"; + else if constexpr(is_same_v) + return "bf6"; + else if constexpr(is_same_v) + return "f8"; + else if constexpr(is_same_v) + return "bf8"; + else if constexpr(is_same_v) + return "e8m0"; + else if constexpr(is_same_v) + return "fp32"; +#if defined(__HIPCC_RTC__) || defined(CK_CODE_GEN_RTC) + else + return "unknown"; +#else + else + return typeid(T).name(); +#endif +} + } // namespace ck diff --git a/include/ck/utility/type_convert.hpp b/include/ck/utility/type_convert.hpp index 8e53728ef6..290a6c8dd6 100644 --- a/include/ck/utility/type_convert.hpp +++ b/include/ck/utility/type_convert.hpp @@ -187,6 +187,19 @@ inline __host__ __device__ constexpr bf8_ocp_t type_convert(int return bf8_ocp_t{type_convert(x)}; } +template , bool> = false> +inline __host__ __device__ constexpr float type_convert(float x) +{ + union + { + float fp32; + uint32_t int32; + } u = {x}; + + u.int32 = u.int32 & 0xffffe000; + return u.fp32; +} + // Convert X to Y template __host__ __device__ constexpr Y type_convert_sp(X x) diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp index 3884902bbf..7ae12e3551 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp @@ -59,6 +59,7 @@ template = 1 && NDimSpatial <= 3, bool>::type = false> struct ReferenceConvFwd : public device::BaseOperator { @@ -327,8 +328,10 @@ struct ReferenceConvFwd : public device::BaseOperator z, y, x); - v_acc += ck::type_convert(v_in) * - ck::type_convert(v_wei); + v_acc += ck::type_convert( + ck::type_convert(v_in)) * + ck::type_convert( + ck::type_convert(v_wei)); } } } diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp index ed07e53e6d..8b9b973b2d 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp @@ -25,6 +25,12 @@ template struct ReferenceGemm : public device::BaseOperator { + + using ElementDataTypeA = + ck::conditional_t, float, ComputeTypeA>; + using ElementDataTypeB = + ck::conditional_t, float, ComputeTypeB>; + // Argument struct Argument : public device::BaseArgument { @@ -63,8 +69,8 @@ struct ReferenceGemm : public device::BaseOperator const int K = arg.a_m_k_.mDesc.GetLengths()[1]; AccDataType v_acc{0}; - ComputeTypeA v_a{0}; - ComputeTypeB v_b{0}; + ElementDataTypeA v_a{0}; + ElementDataTypeB v_b{0}; for(int k = 0; k < K; ++k) { @@ -77,16 +83,16 @@ struct ReferenceGemm : public device::BaseOperator else i4 = (i4x2 >> 4) & 0xf; i4 = i4 - 8; - v_a = type_convert(i4); + v_a = type_convert(i4); } else if constexpr(is_same_v) { // TODO: add support for ColMajor layout as well if(k % 2 == 1) - v_a = type_convert( + v_a = type_convert( f4_t(arg.a_m_k_(m, k).template unpack<>(Number<1>{}))); else - v_a = type_convert( + v_a = type_convert( f4_t(arg.a_m_k_(m, k).template unpack<>(Number<0>{}))); } else if constexpr(is_same_v || @@ -94,7 +100,7 @@ struct ReferenceGemm : public device::BaseOperator is_same_v || is_same_v) { - v_a = type_convert( + v_a = type_convert( arg.a_m_k_(m, k).unpack(k % ADataType::packed_size)); } else @@ -111,16 +117,16 @@ struct ReferenceGemm : public device::BaseOperator else i4 = (i4x2 >> 4) & 0xf; i4 = i4 - 8; - v_b = type_convert(i4); + v_b = type_convert(i4); } else if constexpr(is_same_v) { // TODO: add support for RowMajor layout as well if(k % 2 == 1) - v_b = type_convert( + v_b = type_convert( f4_t(arg.b_k_n_(k, n).template unpack<>(Number<1>{}))); else - v_b = type_convert( + v_b = type_convert( f4_t(arg.b_k_n_(k, n).template unpack<>(Number<0>{}))); } else if constexpr(is_same_v || @@ -128,7 +134,7 @@ struct ReferenceGemm : public device::BaseOperator is_same_v || is_same_v) { - v_b = type_convert( + v_b = type_convert( arg.b_k_n_(k, n).unpack(k % BDataType::packed_size)); } else @@ -136,8 +142,18 @@ struct ReferenceGemm : public device::BaseOperator arg.b_element_op_(v_b, arg.b_k_n_(k, n)); } - v_acc += - ck::type_convert(v_a) * ck::type_convert(v_b); + if constexpr(is_same_v && + is_same_v) + { // only for tf32 now + v_acc += + ck::type_convert(ck::type_convert(v_a)) * + ck::type_convert(ck::type_convert(v_b)); + } + else + { + v_acc += + ck::type_convert(v_a) * ck::type_convert(v_b); + } } CDataType v_c{0}; diff --git a/library/include/ck/library/reference_tensor_operation/gpu/reference_gemm.hpp b/library/include/ck/library/reference_tensor_operation/gpu/reference_gemm.hpp index 28274a5154..cf30bc7dda 100644 --- a/library/include/ck/library/reference_tensor_operation/gpu/reference_gemm.hpp +++ b/library/include/ck/library/reference_tensor_operation/gpu/reference_gemm.hpp @@ -38,6 +38,10 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) const CDEElementwiseOperation c_element_op) { using RowMajor = ck::tensor_layout::gemm::RowMajor; + using ElementDataTypeA = + ck::conditional_t, float, ComputeTypeA>; + using ElementDataTypeB = + ck::conditional_t, float, ComputeTypeB>; const int row_idx = blockIdx.x * blockDim.x + threadIdx.x; const int col_idx = blockIdx.y * blockDim.y + threadIdx.y; @@ -46,8 +50,8 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) { AccDataType v_acc{0}; - ComputeTypeA v_a{0}; - ComputeTypeB v_b{0}; + ElementDataTypeA v_a{0}; + ElementDataTypeB v_b{0}; CDataType v_c{0}; for(int k_idx = 0; k_idx < k; ++k_idx) @@ -76,7 +80,16 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) // apply b_element_op b_element_op(v_b, p_b_grid[element_idx_b]); // multiply and accumulate - v_acc += type_convert(v_a) * type_convert(v_b); + if constexpr(is_same_v && + is_same_v) + { // only for tf32 now + v_acc += ck::type_convert(ck::type_convert(v_a)) * + ck::type_convert(ck::type_convert(v_b)); + } + else + { + v_acc += type_convert(v_a) * type_convert(v_b); + } } // apply c_element_op c_element_op(v_c, v_acc); diff --git a/library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp b/library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp index 7164f345cd..9aeca39718 100644 --- a/library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp +++ b/library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp @@ -16,6 +16,7 @@ namespace instance { // aliasing, for commonly used data type using F64 = double; using F32 = float; +using TF32 = ck::tf32_t; using F16 = ck::half_t; using BF16 = ck::bhalf_t; using I8 = int8_t; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_dynamic_op_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_dynamic_op_instance.hpp index 82c01a634b..568f0e0dc4 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_dynamic_op_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_dynamic_op_instance.hpp @@ -16,6 +16,7 @@ namespace instance { using BF16 = ck::bhalf_t; using F16 = ck::half_t; using F32 = float; +using TF32 = ck::tf32_t; template using S = ck::Sequence; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp index 768fcbada0..52c389d020 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp @@ -24,6 +24,7 @@ using BF8 = ck::bf8_t; using BF16 = ck::bhalf_t; using F16 = ck::half_t; using F32 = float; +using TF32 = ck::tf32_t; template using S = ck::Sequence; @@ -199,7 +200,7 @@ using device_grouped_conv_fwd_xdl_f16_nchw_instances = std::tuple< DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 8, 1, 8>, 1>, DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 1>, DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 8, 1, 8>, 1>, - // 32x32 instance + // 32x32 instance DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>, DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>, DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>, @@ -284,7 +285,45 @@ using device_grouped_conv_fwd_xdl_f32_instances = std::tuple< DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4> + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4> + // clang-format on + >; + +template , + typename OutElementOp = PassThrough> +using device_grouped_conv_fwd_xdl_f32_tf32_instances = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| AComputeType| BComputeType| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| DATATYPE | DATATYPE | + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl | + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // generic instance + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 8, 1, 8>, 1, TF32, TF32>, + // instances for small conv.K and conv.C + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 1, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, TF32, TF32>, + + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, TF32, TF32> // clang-format on >; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp index 545826650c..5a26abecc2 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp @@ -443,6 +443,12 @@ struct DeviceOperationInstanceFactory && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(op_ptrs); + } #endif #ifdef CK_ENABLE_FP8 diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_clamp.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_clamp.hpp index 43411b0031..11e827878c 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_clamp.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_clamp.hpp @@ -215,6 +215,14 @@ struct DeviceOperationInstanceFactory && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( + op_ptrs); + } #endif } #endif // CK_USE_XDL diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_clamp_xdl.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_clamp_xdl.inc index aaaacb0d18..045d1623cf 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_clamp_xdl.inc +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_clamp_xdl.inc @@ -578,6 +578,22 @@ void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_insta PassThrough, AddClamp>>>& instances); +void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( + std::vector, + NDHWGK, + F32, + F32, + Tuple, + F32, + PassThrough, + PassThrough, + AddClamp, + TF32, + TF32>>>& instances); + void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_16x16_instances( std::vector && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( + op_ptrs); + } #endif } #endif // CK_USE_XDL diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_clamp_xdl.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_clamp_xdl.inc index d5a8a5344a..b0061b966d 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_clamp_xdl.inc +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_clamp_xdl.inc @@ -578,6 +578,22 @@ void add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances( PassThrough, Clamp>>>& instances); +void add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( + std::vector, + NDHWGK, + F32, + F32, + Tuple<>, + F32, + PassThrough, + PassThrough, + Clamp, + TF32, + TF32>>>& instances); + void add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_16x16_instances( std::vector>>& instances); + #endif #ifdef CK_ENABLE_INT8 @@ -159,7 +160,8 @@ template + typename AComputeType, + typename BComputeType = AComputeType> struct DeviceOperationInstanceFactory> + AComputeType, + BComputeType>> { using DeviceOp = DeviceGroupedConvFwdMultipleABD; + AComputeType, + BComputeType>; static auto GetInstances() { @@ -207,7 +211,7 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v && is_same_v) + is_same_v && is_same_v) { add_device_grouped_conv3d_fwd_xdl_dynamic_op_ndhwgc_gkzyxc_ndhwgk_f16_instances( op_ptrs); @@ -244,7 +248,7 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v && is_same_v) + is_same_v && is_same_v) { add_device_grouped_conv2d_fwd_xdl_dynamic_op_nhwgc_gkyxc_nhwgk_f16_instances( op_ptrs); diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl.inc index a3f2515099..af6041bbc5 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl.inc +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl.inc @@ -559,6 +559,22 @@ void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances( PassThrough, PassThrough>>>& instances); +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( + std::vector>>& instances); + void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_16x16_instances( std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_tf32_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwdDefault>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_tf32_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1P0>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_tf32_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/CMakeLists.txt index bda9149227..6a776b4943 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/CMakeLists.txt @@ -2,7 +2,7 @@ set(GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP) include(ShardInstantiation) - + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances @@ -11,7 +11,7 @@ generate_sharded_instantiations( SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl ) - + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances @@ -20,7 +20,7 @@ generate_sharded_instantiations( SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl ) - + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances @@ -29,7 +29,16 @@ generate_sharded_instantiations( SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl ) - + +set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) +generate_sharded_instantiations( + INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances + TEMPLATE_FILE xdl/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.in + NUM_SHARDS 16 + SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP + OUTPUT_DIR ${GENERATED_DIR}/xdl +) + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16_instances @@ -38,7 +47,7 @@ generate_sharded_instantiations( SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl ) - + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_16x16_instances @@ -47,7 +56,7 @@ generate_sharded_instantiations( SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl ) - + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_16x16_instances @@ -58,7 +67,7 @@ generate_sharded_instantiations( ) # large tensor # NDHWGC, GKZYXC, NDHWGK - + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instances @@ -67,7 +76,7 @@ generate_sharded_instantiations( SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/large_tensor ) - + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instances @@ -76,7 +85,7 @@ generate_sharded_instantiations( SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/large_tensor ) - + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_instances @@ -87,7 +96,7 @@ generate_sharded_instantiations( ) # merged groups # NDHWGC, GKZYXC, NDHWGK - + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instances @@ -96,7 +105,7 @@ generate_sharded_instantiations( SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/merged_groups ) - + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f16_instances @@ -105,7 +114,7 @@ generate_sharded_instantiations( SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/merged_groups ) - + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_instances @@ -116,7 +125,7 @@ generate_sharded_instantiations( ) #mem # NDHWGC, GKZYXC, NDHWGK - + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instances @@ -125,7 +134,7 @@ generate_sharded_instantiations( SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/mem ) - + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_intra_instances @@ -134,7 +143,7 @@ generate_sharded_instantiations( SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/mem ) - + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_intra_instances @@ -144,7 +153,7 @@ generate_sharded_instantiations( OUTPUT_DIR ${GENERATED_DIR}/xdl/mem ) # NDHWGC, GKZYXC, NDHWGK - + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instances @@ -153,7 +162,7 @@ generate_sharded_instantiations( SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/mem ) - + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_inter_instances @@ -162,7 +171,7 @@ generate_sharded_instantiations( SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/mem ) - + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_inter_instances @@ -173,7 +182,7 @@ generate_sharded_instantiations( ) #comp # NDHWGC, GKZYXC, NDHWGK - + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instances @@ -182,7 +191,7 @@ generate_sharded_instantiations( SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/comp ) - + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instances @@ -191,7 +200,7 @@ generate_sharded_instantiations( SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/comp ) - + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_instances @@ -200,7 +209,7 @@ generate_sharded_instantiations( SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/comp ) - + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_2x_instances @@ -209,7 +218,7 @@ generate_sharded_instantiations( SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/comp ) - + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_2x_instances @@ -218,7 +227,7 @@ generate_sharded_instantiations( SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/comp ) - + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_part2_instances @@ -227,7 +236,7 @@ generate_sharded_instantiations( SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/comp ) - + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_part2_instances diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/xdl/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.in b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/xdl/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.in new file mode 100644 index 0000000000..d7f3c87b83 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/xdl/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.in @@ -0,0 +1,81 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" +#include "ck/utility/filter_tuple.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances = + std::vector, + NDHWGK, + F32, + F32, + Tuple, + F32, + PassThrough, + PassThrough, + BiasNormalizeInInferClamp, + TF32, + TF32>>>; + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +template +void add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances_shard( + device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances& instances) +{ + add_device_operation_instances( + instances, + ck::util::filter_tuple_by_modulo_t, + NDHWGK, + ConvFwdDefault, + Tuple, + BiasNormalizeInInferClamp>, + Shards, + ShardIndex>{}); + + add_device_operation_instances( + instances, + ck::util::filter_tuple_by_modulo_t, + NDHWGK, + ConvFwd1x1P0, + Tuple, + BiasNormalizeInInferClamp>, + Shards, + ShardIndex>{}); + + add_device_operation_instances( + instances, + ck::util::filter_tuple_by_modulo_t, + NDHWGK, + ConvFwd1x1S1P0, + Tuple, + BiasNormalizeInInferClamp>, + Shards, + ShardIndex>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/CMakeLists.txt index 3bd6916cf0..bcc7020ca9 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/CMakeLists.txt @@ -23,6 +23,8 @@ set(GROUPED_CONV3D_FWD xdl/mem/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_mem_inter_instance.cpp xdl/mem/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_mem_intra_instance.cpp xdl/comp/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_comp_instance.cpp -) + + xdl/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_instance.cpp + ) add_instance_library(device_grouped_conv3d_fwd_bias_clamp_instance ${GROUPED_CONV3D_FWD}) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/xdl/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/xdl/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_instance.cpp new file mode 100644 index 0000000000..328838bff2 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/xdl/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_instance.cpp @@ -0,0 +1,60 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( + std::vector, + NDHWGK, + F32, + F32, + Tuple, + F32, + PassThrough, + PassThrough, + AddClamp, + TF32, + TF32>>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_tf32_instances<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwdDefault, + Tuple, + AddClamp>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_tf32_instances<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwd1x1P0, + Tuple, + AddClamp>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_tf32_instances<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwd1x1S1P0, + Tuple, + AddClamp>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/CMakeLists.txt index 234533244e..059d22f8d2 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/CMakeLists.txt @@ -23,6 +23,8 @@ set(GROUPED_CONV3D_FWD xdl/mem/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_mem_inter_instance.cpp xdl/mem/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_mem_intra_instance.cpp xdl/comp/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_comp_instance.cpp -) + + xdl/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_instance.cpp + ) add_instance_library(device_grouped_conv3d_fwd_clamp_instance ${GROUPED_CONV3D_FWD}) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/xdl/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/xdl/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_instance.cpp new file mode 100644 index 0000000000..a1bf6562c2 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/xdl/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_instance.cpp @@ -0,0 +1,60 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( + std::vector, + NDHWGK, + F32, + F32, + Tuple<>, + F32, + PassThrough, + PassThrough, + Clamp, + TF32, + TF32>>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_tf32_instances<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwdDefault, + Tuple<>, + Clamp>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_tf32_instances<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwd1x1P0, + Tuple<>, + Clamp>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_tf32_instances<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwd1x1S1P0, + Tuple<>, + Clamp>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/utility/host_tensor.cpp b/library/src/utility/host_tensor.cpp index 7211552641..02bd562e43 100644 --- a/library/src/utility/host_tensor.cpp +++ b/library/src/utility/host_tensor.cpp @@ -53,7 +53,7 @@ std::ostream& operator<<(std::ostream& os, const HostTensorDescriptor& desc) os << "strides {"; LogRange(os, desc.GetStrides(), ", "); - os << "}"; + os << "} "; return os; } diff --git a/profiler/src/profile_grouped_conv_fwd.cpp b/profiler/src/profile_grouped_conv_fwd.cpp index a7714b4c73..4ddc7ed077 100644 --- a/profiler/src/profile_grouped_conv_fwd.cpp +++ b/profiler/src/profile_grouped_conv_fwd.cpp @@ -21,14 +21,15 @@ enum struct ConvLayout enum struct ConvDataType { - F32_F32_F32, // 0 - F16_F16_F16, // 1 - BF16_BF16_BF16, // 2 - INT8_INT8_INT8, // 3 - F8_F8_F8, // 4 - BF8_BF8_F8, // 5 - F8_BF8_F8, // 6 - BF8_F8_F8, // 7 + F32_F32_F32, // 0 + F16_F16_F16, // 1 + BF16_BF16_BF16, // 2 + INT8_INT8_INT8, // 3 + F8_F8_F8, // 4 + BF8_BF8_F8, // 5 + F8_BF8_F8, // 6 + BF8_F8_F8, // 7 + F32_F32_F32_TF32, // 8 }; enum struct IndexType @@ -53,6 +54,7 @@ static void print_helper_msg() << " 5: Input bf8, Weight bf8, Output fp8\n" << " 6: Input fp8, Weight bf8, Output fp8\n" << " 7: Input bf8, Weight fp8, Output fp8)\n" + << " 8: Input fp32, Weight fp32, Output fp32, Compute tf32\n" << "arg3: tensor layout (0: Input[G, N, Hi, Wi, C], Weight[G, K, Y, X, C], Output[G, N, Ho, Wo, K]\n" << " 1: Input[N, Hi, Wi, G, C], Weight[G, K, Y, X, C], Output[N, Ho, Wo, G, K]\n" << " 2: Input[N, G, C, Hi, Wi], Weight[G, K, Y, X, C], Output[N, " @@ -103,6 +105,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) using INT8 = int8_t; using F8 = ck::f8_t; using BF8 = ck::bf8_t; + using TF32 = ck::tf32_t; // using GNWC = ck::tensor_layout::convolution::GNWC; @@ -261,6 +264,10 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) return profile( I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, INT8{}, INT8{}, INT8{}, INT8{}, INT8{}); } + else if(data_type == ConvDataType::F32_F32_F32_TF32) + { + return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); + } } // NHWGC_GKYXC_NHWGK else if(num_dim_spatial == 1 && layout == ConvLayout::NHWGC_GKYXC_NHWGK) @@ -367,6 +374,10 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) { return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, BF8{}, F8{}, F8{}, BF8{}, F8{}); } + else if(data_type == ConvDataType::F32_F32_F32_TF32) + { + return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); + } } // NGCDHW_GKCZYX_NGKDHW else if(num_dim_spatial == 3 && layout == ConvLayout::NGCHW_GKCYX_NGKHW) @@ -384,6 +395,10 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) return profile( I3, NGCDHW{}, GKCZYX{}, NGKDHW{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{}); } + else if(data_type == ConvDataType::F32_F32_F32_TF32) + { + return profile(I3, NGCDHW{}, GKCZYX{}, NGKDHW{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); + } } std::cout << "this data_type & layout is not implemented" << std::endl;