Revert "feature:tf32:add initial conv3d fwd kernel support (#2763)" (#2848)

This reverts commit c51102144f.
This commit is contained in:
Illia Silin
2025-09-15 08:27:04 -07:00
committed by GitHub
parent c51102144f
commit 03b59f8c76
44 changed files with 175 additions and 1085 deletions

View File

@@ -1,5 +1,4 @@
add_example_executable(example_convnd_fwd_xdl_fp32 convnd_fwd_xdl_fp32.cpp)
add_example_executable(example_convnd_fwd_xdl_fp32_tf32 convnd_fwd_xdl_fp32_tf32.cpp)
add_example_executable(example_convnd_fwd_xdl_fp16 convnd_fwd_xdl_fp16.cpp)
add_example_executable(example_convnd_fwd_xdl_bf16 convnd_fwd_xdl_bf16.cpp)
add_example_executable(example_convnd_fwd_xdl_int8 convnd_fwd_xdl_int8.cpp)
@@ -20,4 +19,4 @@ foreach(gpu IN LISTS GPU_TARGETS)
add_example_executable(example_convnd_fwd_xdl_fp64 convnd_fwd_xdl_fp64.cpp)
set(target 1)
endif()
endforeach()
endforeach()

View File

@@ -27,14 +27,10 @@ void print_helper_msg()
<< ck::utils::conv::get_conv_param_parser_helper_msg() << std::endl;
}
template <typename DataType, typename GemmType = DataType>
template <typename DataType>
inline __host__ __device__ constexpr double get_rtol()
{
if constexpr(std::is_same_v<DataType, float> && std::is_same_v<GemmType, ck::tf32_t>)
{
return 5e-3;
}
else if constexpr(std::is_same_v<DataType, float>)
if constexpr(std::is_same_v<DataType, float>)
{
return 1e-3;
}
@@ -72,14 +68,10 @@ inline __host__ __device__ constexpr double get_rtol()
}
}
template <typename DataType, typename GemmType = DataType>
template <typename DataType>
inline __host__ __device__ constexpr double get_atol()
{
if constexpr(std::is_same_v<DataType, float> && std::is_same_v<GemmType, ck::tf32_t>)
{
return 1e-2;
}
else if constexpr(std::is_same_v<DataType, float>)
if constexpr(std::is_same_v<DataType, float>)
{
return 1e-3;
}
@@ -124,8 +116,7 @@ template <ck::index_t NDimSpatial,
typename InElementOp,
typename WeiElementOp,
typename OutElementOp,
typename DeviceConvNDFwdInstance,
typename ComputeDataType = OutDataType>
typename DeviceConvNDFwdInstance>
bool run_grouped_conv_fwd(bool do_verification,
int init_method,
bool time_kernel,
@@ -237,11 +228,7 @@ bool run_grouped_conv_fwd(bool do_verification,
OutDataType,
InElementOp,
WeiElementOp,
OutElementOp,
0,
0,
0,
ComputeDataType>();
OutElementOp>();
auto ref_invoker = ref_conv.MakeInvoker();
auto ref_argument = ref_conv.MakeArgument(in,
@@ -262,8 +249,8 @@ bool run_grouped_conv_fwd(bool do_verification,
return ck::utils::check_err(out_device,
out_host,
"Error: incorrect results!",
get_rtol<OutDataType, ComputeDataType>(),
get_atol<OutDataType, ComputeDataType>());
get_rtol<OutDataType>(),
get_atol<OutDataType>());
}
return true;

View File

@@ -1,89 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "convnd_fwd_common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp"
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
#define EXAMPLE_WITH_COMPUTE_DATATYPE
using InDataType = float;
using WeiDataType = float;
using AccDataType = float;
using CShuffleDataType = float;
using OutDataType = float;
using ComputeDataType = ck::tf32_t;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using InElementOp = ck::tensor_operation::element_wise::PassThrough;
using WeiElementOp = ck::tensor_operation::element_wise::PassThrough;
using OutElementOp = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto ConvSpec =
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
template <ck::index_t NDimSpatial, typename InLayout, typename WeiLayout, typename OutLayout>
using DeviceGroupedConvNDFwdInstance =
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<
NDimSpatial,
InLayout, // ALayout
WeiLayout, // BLayout
ck::Tuple<>, // DsLayout
OutLayout, // ELayout
InDataType, // ADataType
WeiDataType, // BDataType
AccDataType, // AccDataType
CShuffleDataType, // CShuffleDataType
ck::Tuple<>, // DsDataType
OutDataType, // EDataType
InElementOp, // AElementwiseOperation
WeiElementOp, // BElementwiseOperation
OutElementOp, // CDEElementwiseOperation
ConvSpec, // ConvForwardSpecialization
GemmSpec, // GemmSpecialization
1, // NumGemmKPrefetchStage
256, // BlockSize
128, // MPerBlock
192, // NPerBlock
16, // KPerBlock
4, // AK1
4, // BK1
32, // MPerXdl
32, // NPerXdl
2, // MXdlPerWave
3, // NXdlPerWave
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
4, // ABlockTransferSrcScalarPerVector
4, // ABlockTransferDstScalarPerVector_AK1
1, // ABlockLdsExtraM
S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
4, // BBlockTransferSrcScalarPerVector
4, // BBlockTransferDstScalarPerVector_BK1
1, // BBlockLdsExtraN
1, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle
S<1, 16, 1, 16>, // CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
4, // CDEBlockTransferScalarPerVector_NPerBlock
ComputeDataType, // AComputeDataType
ComputeDataType, // BComputeDataType
ck::LoopScheduler::Default, // LoopScheduler
1 // NumGroupsToMerge
>;
#include "run_convnd_fwd_example.inc"
int main(int argc, char* argv[]) { return run_convnd_fwd_example(argc, argv) ? 0 : 1; }
#undef EXAMPLE_WITH_COMPUTE_DATATYPE

View File

@@ -7,8 +7,6 @@
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
#define EXAMPLE_WITH_COMPUTE_DATATYPE
using InDataType = ck::f8_t;
using WeiDataType = ck::f8_t;
using AccDataType = float;
@@ -89,5 +87,3 @@ int main(int argc, char* argv[])
}
return run_convnd_fwd_example(argc, argv) ? 0 : 1;
}
#undef EXAMPLE_WITH_COMPUTE_DATATYPE

View File

@@ -3,11 +3,6 @@
#pragma once
// use macro to minimize code change
#ifndef EXAMPLE_WITH_COMPUTE_DATATYPE
using ComputeDataType = AccDataType;
#endif
bool run_convnd_fwd_example(int argc, char* argv[])
{
print_helper_msg();
@@ -70,17 +65,17 @@ bool run_convnd_fwd_example(int argc, char* argv[])
InElementOp,
WeiElementOp,
OutElementOp,
DeviceGroupedConvNDFwdInstance<ndim_spatial_value, InLayout, WeiLayout, OutLayout>,
ComputeDataType>(do_verification,
init_method,
time_kernel,
conv_param,
in_g_n_c_wis_desc,
wei_g_k_c_xs_desc,
out_g_n_k_wos_desc,
in_element_op,
wei_element_op,
out_element_op);
DeviceGroupedConvNDFwdInstance<ndim_spatial_value, InLayout, WeiLayout, OutLayout>>(
do_verification,
init_method,
time_kernel,
conv_param,
in_g_n_c_wis_desc,
wei_g_k_c_xs_desc,
out_g_n_k_wos_desc,
in_element_op,
wei_element_op,
out_element_op);
};
namespace ctc = ck::tensor_layout::convolution;