mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 05:01:25 +00:00
* Revert "Revert "feature:tf32:add initial conv3d fwd kernel support (#2763)" (#2848)"
This reverts commit 03b59f8c76.
* fix compile error on gf12x
* only run tf32 example on gfx942
* only build tf32 instance on gfx942
* ckProfiler:only support tf32 in gfx942
* delete unuseful messages
This commit is contained in:
@@ -708,7 +708,9 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
|
||||
MXdlPerWave,
|
||||
NXdlPerWave,
|
||||
KPack,
|
||||
LoopSched>();
|
||||
LoopSched,
|
||||
AComputeDataType,
|
||||
BComputeDataType>();
|
||||
|
||||
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
|
||||
|
||||
|
||||
@@ -107,8 +107,10 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
|
||||
using BComputeDataType =
|
||||
conditional_t<is_same_v<BComputeDataType_, ck::half_t>, ck::bhalf_t, BComputeDataType_>;
|
||||
#else
|
||||
using AComputeDataType = AComputeDataType_;
|
||||
using BComputeDataType = BComputeDataType_;
|
||||
using AComputeDataType =
|
||||
conditional_t<is_same_v<AComputeDataType_, ck::tf32_t>, float, AComputeDataType_>;
|
||||
using BComputeDataType =
|
||||
conditional_t<is_same_v<BComputeDataType_, ck::tf32_t>, float, BComputeDataType_>;
|
||||
#endif
|
||||
|
||||
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
|
||||
@@ -659,26 +661,27 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
|
||||
: false;
|
||||
constexpr auto is_scale_mfma = false;
|
||||
constexpr index_t KPack = math::max(lcm_AK1_BK1,
|
||||
MfmaSelector<AComputeDataType,
|
||||
MfmaSelector<AComputeDataType_,
|
||||
MPerXdl,
|
||||
NPerXdl,
|
||||
BComputeDataType,
|
||||
BComputeDataType_,
|
||||
is_single_rate_mfma,
|
||||
is_scale_mfma>::selected_mfma.k_per_blk);
|
||||
|
||||
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
|
||||
BlockSize,
|
||||
AComputeDataType,
|
||||
BComputeDataType,
|
||||
AccDataType,
|
||||
decltype(a_block_desc_ak0_m_ak1),
|
||||
decltype(b_block_desc_bk0_n_bk1),
|
||||
MPerXdl,
|
||||
NPerXdl,
|
||||
MXdlPerWave,
|
||||
NXdlPerWave,
|
||||
KPack,
|
||||
LoopSched>();
|
||||
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
|
||||
BlockSize,
|
||||
AComputeDataType,
|
||||
BComputeDataType,
|
||||
AccDataType,
|
||||
decltype(a_block_desc_ak0_m_ak1),
|
||||
decltype(b_block_desc_bk0_n_bk1),
|
||||
MPerXdl,
|
||||
NPerXdl,
|
||||
MXdlPerWave,
|
||||
NXdlPerWave,
|
||||
KPack,
|
||||
LoopSched,
|
||||
AComputeDataType_,
|
||||
BComputeDataType_>();
|
||||
|
||||
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
|
||||
|
||||
|
||||
@@ -144,7 +144,7 @@ template <typename ALayout,
|
||||
index_t CDEShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
LoopScheduler LoopSched,
|
||||
PipelineVersion PipelineVer = PipelineVersion::v4,
|
||||
typename BComputeDataType = AComputeDataType_>
|
||||
typename BComputeDataType_ = AComputeDataType_>
|
||||
struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
|
||||
{
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
@@ -172,7 +172,10 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
|
||||
using AComputeDataType =
|
||||
conditional_t<is_same_v<AComputeDataType_, ck::half_t>, ck::bhalf_t, AComputeDataType_>;
|
||||
#else
|
||||
using AComputeDataType = AComputeDataType_;
|
||||
using AComputeDataType =
|
||||
conditional_t<is_same_v<AComputeDataType_, ck::tf32_t>, float, AComputeDataType_>;
|
||||
using BComputeDataType =
|
||||
conditional_t<is_same_v<BComputeDataType_, ck::tf32_t>, float, BComputeDataType_>;
|
||||
#endif
|
||||
|
||||
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
|
||||
@@ -573,7 +576,6 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
|
||||
// This forces m/n_block_data_idx_on_grid into SGPR.
|
||||
const index_t m_block_data_idx_on_grid =
|
||||
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
|
||||
|
||||
const index_t n_block_data_idx_on_grid =
|
||||
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
|
||||
|
||||
@@ -640,10 +642,10 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
|
||||
constexpr auto is_scale_mfma = false;
|
||||
|
||||
constexpr index_t KPack = math::max(lcm_AK1_BK1,
|
||||
MfmaSelector<AComputeDataType,
|
||||
MfmaSelector<AComputeDataType_,
|
||||
MPerXdl,
|
||||
NPerXdl,
|
||||
BComputeDataType,
|
||||
BComputeDataType_,
|
||||
is_single_rate_mfma,
|
||||
is_scale_mfma>::selected_mfma.k_per_blk);
|
||||
|
||||
@@ -659,7 +661,9 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
|
||||
MXdlPerWave,
|
||||
NXdlPerWave,
|
||||
KPack,
|
||||
LoopSched>();
|
||||
LoopSched,
|
||||
AComputeDataType_,
|
||||
BComputeDataType_>();
|
||||
|
||||
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
|
||||
|
||||
|
||||
Reference in New Issue
Block a user