Revert "feature:tf32:add initial conv3d fwd kernel support (#2763)" (#2848)

This reverts commit c51102144f.
This commit is contained in:
Illia Silin
2025-09-15 08:27:04 -07:00
committed by GitHub
parent c51102144f
commit 03b59f8c76
44 changed files with 175 additions and 1085 deletions

View File

@@ -119,9 +119,7 @@ struct DeviceGemm_Xdl_CShuffle_LdsDirectLoad : public DeviceGemm<ALayout,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CDEBlockTransferScalarPerVector_NPerBlock,
LoopSched,
PipelineVer,
ComputeDataType>;
PipelineVer>;
using GridwiseGemm64 = GridwiseGemmBase<math::max(NXdlPerWave64, 1)>;
using GridwiseGemm32 = GridwiseGemmBase<NXdlPerWave32>;
@@ -216,14 +214,6 @@ struct DeviceGemm_Xdl_CShuffle_LdsDirectLoad : public DeviceGemm<ALayout,
return false;
}
if constexpr(is_same_v<ComputeDataType, ck::tf32_t>)
{
if(!is_tf32_supported())
{
return false;
}
}
// Check vector load/store.
{
using Row = ck::tensor_layout::gemm::RowMajor;

View File

@@ -1003,20 +1003,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
void Print() const
{
std::cout << "AComputeDataType: " << get_type_name<AComputeDataType>()
<< "; BComputeDataType: " << get_type_name<BComputeDataType>()
<< "; EDataType: " << get_type_name<EDataType>() << std::endl;
std::cout << "A[M, K]: " << a_grid_desc_m_k_ << std::endl;
std::cout << "B[N, K]: " << b_grid_desc_n_k_ << std::endl;
static_for<0, NumDTensor, 1>{}(
[&](auto i) { std::cout << "Ds[M, N]: " << ds_grid_desc_m_n_[i] << std::endl; });
std::cout << "E[M, N]: " << e_grid_desc_m_n_ << std::endl;
std::cout << "a grid desc" << a_grid_desc_ak0_m_ak1_ << std::endl;
std::cout << "b grid desc" << b_grid_desc_bk0_n_bk1_ << std::endl;
std::cout << "e grid desc" << e_grid_desc_mblock_mperblock_nblock_nperblock_
<< std::endl;
}
// private:
@@ -1207,6 +1198,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
isMultiA,
isMultiB,
CTranspose>;
return launch_and_time_kernel(
stream_config,
kernel,
@@ -1289,6 +1281,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
float avg_time = 0.f;
if constexpr(NeedTransposeKernel)
{
const index_t a_grid_size =
@@ -1693,23 +1686,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
{
return false;
}
if constexpr(is_same_v<AComputeDataType, ck::tf32_t> ||
is_same_v<BComputeDataType, ck::tf32_t>)
{
if(!is_tf32_supported())
{
return false;
}
if constexpr(!is_same_v<AComputeDataType, BComputeDataType>)
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "ComputeDataType for A and B should be same while using TF32"
<< std::endl;
}
return false;
}
}
// check Gridwise GEMM
if(get_warp_size() == 64)
{
@@ -1789,28 +1766,6 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
}
}
if constexpr(is_same_v<AComputeDataType, ck::tf32_t> ||
is_same_v<BComputeDataType, ck::tf32_t>)
{
if(!(ck::get_device_name() == "gfx942"))
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "TF32 is enabled on gfx942 only" << std::endl;
}
return false;
}
if constexpr(!is_same_v<AComputeDataType, BComputeDataType>)
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "ComputeDataType for A and B should be same while using TF32"
<< std::endl;
}
return false;
}
}
return false;
}