TF32 POC in Conv3d on MI30x platform #2763 (second attempt) (#2852)

* Revert "Revert "feature:tf32:add initial conv3d fwd kernel support (#2763)" (#2848)"

This reverts commit 03b59f8c76.

* fix compile error on gf12x

* only run tf32 example on gfx942

* only build tf32 instance on gfx942

* ckProfiler:only support tf32 in gfx942

* delete unuseful messages
This commit is contained in:
yinglu
2025-09-18 05:50:15 +08:00
committed by GitHub
parent 7c934b72ab
commit dd7af118d7
45 changed files with 1147 additions and 181 deletions

View File

@@ -708,7 +708,9 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
MXdlPerWave,
NXdlPerWave,
KPack,
LoopSched>();
LoopSched,
AComputeDataType,
BComputeDataType>();
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();

View File

@@ -107,8 +107,10 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
using BComputeDataType =
conditional_t<is_same_v<BComputeDataType_, ck::half_t>, ck::bhalf_t, BComputeDataType_>;
#else
using AComputeDataType = AComputeDataType_;
using BComputeDataType = BComputeDataType_;
using AComputeDataType =
conditional_t<is_same_v<AComputeDataType_, ck::tf32_t>, float, AComputeDataType_>;
using BComputeDataType =
conditional_t<is_same_v<BComputeDataType_, ck::tf32_t>, float, BComputeDataType_>;
#endif
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
@@ -659,26 +661,27 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
: false;
constexpr auto is_scale_mfma = false;
constexpr index_t KPack = math::max(lcm_AK1_BK1,
MfmaSelector<AComputeDataType,
MfmaSelector<AComputeDataType_,
MPerXdl,
NPerXdl,
BComputeDataType,
BComputeDataType_,
is_single_rate_mfma,
is_scale_mfma>::selected_mfma.k_per_blk);
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize,
AComputeDataType,
BComputeDataType,
AccDataType,
decltype(a_block_desc_ak0_m_ak1),
decltype(b_block_desc_bk0_n_bk1),
MPerXdl,
NPerXdl,
MXdlPerWave,
NXdlPerWave,
KPack,
LoopSched>();
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize,
AComputeDataType,
BComputeDataType,
AccDataType,
decltype(a_block_desc_ak0_m_ak1),
decltype(b_block_desc_bk0_n_bk1),
MPerXdl,
NPerXdl,
MXdlPerWave,
NXdlPerWave,
KPack,
LoopSched,
AComputeDataType_,
BComputeDataType_>();
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();

View File

@@ -144,7 +144,7 @@ template <typename ALayout,
index_t CDEShuffleBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched,
PipelineVersion PipelineVer = PipelineVersion::v4,
typename BComputeDataType = AComputeDataType_>
typename BComputeDataType_ = AComputeDataType_>
struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
{
static constexpr index_t NumDTensor = DsDataType::Size();
@@ -172,7 +172,10 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
using AComputeDataType =
conditional_t<is_same_v<AComputeDataType_, ck::half_t>, ck::bhalf_t, AComputeDataType_>;
#else
using AComputeDataType = AComputeDataType_;
using AComputeDataType =
conditional_t<is_same_v<AComputeDataType_, ck::tf32_t>, float, AComputeDataType_>;
using BComputeDataType =
conditional_t<is_same_v<BComputeDataType_, ck::tf32_t>, float, BComputeDataType_>;
#endif
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
@@ -573,7 +576,6 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
// This forces m/n_block_data_idx_on_grid into SGPR.
const index_t m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
const index_t n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
@@ -640,10 +642,10 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
constexpr auto is_scale_mfma = false;
constexpr index_t KPack = math::max(lcm_AK1_BK1,
MfmaSelector<AComputeDataType,
MfmaSelector<AComputeDataType_,
MPerXdl,
NPerXdl,
BComputeDataType,
BComputeDataType_,
is_single_rate_mfma,
is_scale_mfma>::selected_mfma.k_per_blk);
@@ -659,7 +661,9 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
MXdlPerWave,
NXdlPerWave,
KPack,
LoopSched>();
LoopSched,
AComputeDataType_,
BComputeDataType_>();
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();