Revert "feature:tf32:add initial conv3d fwd kernel support (#2763)" (#2848)

This reverts commit c51102144f.
This commit is contained in:
Illia Silin
2025-09-15 08:27:04 -07:00
committed by GitHub
parent c51102144f
commit 03b59f8c76
44 changed files with 175 additions and 1085 deletions

View File

@@ -708,9 +708,7 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
MXdlPerWave,
NXdlPerWave,
KPack,
LoopSched,
AComputeDataType,
BComputeDataType>();
LoopSched>();
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();

View File

@@ -107,10 +107,8 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
using BComputeDataType =
conditional_t<is_same_v<BComputeDataType_, ck::half_t>, ck::bhalf_t, BComputeDataType_>;
#else
using AComputeDataType =
conditional_t<is_same_v<AComputeDataType_, ck::tf32_t>, float, AComputeDataType_>;
using BComputeDataType =
conditional_t<is_same_v<BComputeDataType_, ck::tf32_t>, float, BComputeDataType_>;
using AComputeDataType = AComputeDataType_;
using BComputeDataType = BComputeDataType_;
#endif
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
@@ -661,27 +659,26 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
: false;
constexpr auto is_scale_mfma = false;
constexpr index_t KPack = math::max(lcm_AK1_BK1,
MfmaSelector<AComputeDataType_,
MfmaSelector<AComputeDataType,
MPerXdl,
NPerXdl,
BComputeDataType_,
BComputeDataType,
is_single_rate_mfma,
is_scale_mfma>::selected_mfma.k_per_blk);
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize,
AComputeDataType,
BComputeDataType,
AccDataType,
decltype(a_block_desc_ak0_m_ak1),
decltype(b_block_desc_bk0_n_bk1),
MPerXdl,
NPerXdl,
MXdlPerWave,
NXdlPerWave,
KPack,
LoopSched,
AComputeDataType_,
BComputeDataType_>();
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize,
AComputeDataType,
BComputeDataType,
AccDataType,
decltype(a_block_desc_ak0_m_ak1),
decltype(b_block_desc_bk0_n_bk1),
MPerXdl,
NPerXdl,
MXdlPerWave,
NXdlPerWave,
KPack,
LoopSched>();
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();

View File

@@ -144,7 +144,7 @@ template <typename ALayout,
index_t CDEShuffleBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched,
PipelineVersion PipelineVer = PipelineVersion::v4,
typename BComputeDataType_ = AComputeDataType_>
typename BComputeDataType = AComputeDataType_>
struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
{
static constexpr index_t NumDTensor = DsDataType::Size();
@@ -172,10 +172,7 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
using AComputeDataType =
conditional_t<is_same_v<AComputeDataType_, ck::half_t>, ck::bhalf_t, AComputeDataType_>;
#else
using AComputeDataType =
conditional_t<is_same_v<AComputeDataType_, ck::tf32_t>, float, AComputeDataType_>;
using BComputeDataType =
conditional_t<is_same_v<BComputeDataType_, ck::tf32_t>, float, BComputeDataType_>;
using AComputeDataType = AComputeDataType_;
#endif
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
@@ -576,6 +573,7 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
// This forces m/n_block_data_idx_on_grid into SGPR.
const index_t m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
const index_t n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
@@ -642,10 +640,10 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
constexpr auto is_scale_mfma = false;
constexpr index_t KPack = math::max(lcm_AK1_BK1,
MfmaSelector<AComputeDataType_,
MfmaSelector<AComputeDataType,
MPerXdl,
NPerXdl,
BComputeDataType_,
BComputeDataType,
is_single_rate_mfma,
is_scale_mfma>::selected_mfma.k_per_blk);
@@ -661,9 +659,7 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
MXdlPerWave,
NXdlPerWave,
KPack,
LoopSched,
AComputeDataType_,
BComputeDataType_>();
LoopSched>();
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();