mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 05:01:25 +00:00
This reverts commit c51102144f.
This commit is contained in:
@@ -708,9 +708,7 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
|
||||
MXdlPerWave,
|
||||
NXdlPerWave,
|
||||
KPack,
|
||||
LoopSched,
|
||||
AComputeDataType,
|
||||
BComputeDataType>();
|
||||
LoopSched>();
|
||||
|
||||
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
|
||||
|
||||
|
||||
@@ -107,10 +107,8 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
|
||||
using BComputeDataType =
|
||||
conditional_t<is_same_v<BComputeDataType_, ck::half_t>, ck::bhalf_t, BComputeDataType_>;
|
||||
#else
|
||||
using AComputeDataType =
|
||||
conditional_t<is_same_v<AComputeDataType_, ck::tf32_t>, float, AComputeDataType_>;
|
||||
using BComputeDataType =
|
||||
conditional_t<is_same_v<BComputeDataType_, ck::tf32_t>, float, BComputeDataType_>;
|
||||
using AComputeDataType = AComputeDataType_;
|
||||
using BComputeDataType = BComputeDataType_;
|
||||
#endif
|
||||
|
||||
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
|
||||
@@ -661,27 +659,26 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
|
||||
: false;
|
||||
constexpr auto is_scale_mfma = false;
|
||||
constexpr index_t KPack = math::max(lcm_AK1_BK1,
|
||||
MfmaSelector<AComputeDataType_,
|
||||
MfmaSelector<AComputeDataType,
|
||||
MPerXdl,
|
||||
NPerXdl,
|
||||
BComputeDataType_,
|
||||
BComputeDataType,
|
||||
is_single_rate_mfma,
|
||||
is_scale_mfma>::selected_mfma.k_per_blk);
|
||||
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
|
||||
BlockSize,
|
||||
AComputeDataType,
|
||||
BComputeDataType,
|
||||
AccDataType,
|
||||
decltype(a_block_desc_ak0_m_ak1),
|
||||
decltype(b_block_desc_bk0_n_bk1),
|
||||
MPerXdl,
|
||||
NPerXdl,
|
||||
MXdlPerWave,
|
||||
NXdlPerWave,
|
||||
KPack,
|
||||
LoopSched,
|
||||
AComputeDataType_,
|
||||
BComputeDataType_>();
|
||||
|
||||
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
|
||||
BlockSize,
|
||||
AComputeDataType,
|
||||
BComputeDataType,
|
||||
AccDataType,
|
||||
decltype(a_block_desc_ak0_m_ak1),
|
||||
decltype(b_block_desc_bk0_n_bk1),
|
||||
MPerXdl,
|
||||
NPerXdl,
|
||||
MXdlPerWave,
|
||||
NXdlPerWave,
|
||||
KPack,
|
||||
LoopSched>();
|
||||
|
||||
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
|
||||
|
||||
|
||||
@@ -144,7 +144,7 @@ template <typename ALayout,
|
||||
index_t CDEShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
LoopScheduler LoopSched,
|
||||
PipelineVersion PipelineVer = PipelineVersion::v4,
|
||||
typename BComputeDataType_ = AComputeDataType_>
|
||||
typename BComputeDataType = AComputeDataType_>
|
||||
struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
|
||||
{
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
@@ -172,10 +172,7 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
|
||||
using AComputeDataType =
|
||||
conditional_t<is_same_v<AComputeDataType_, ck::half_t>, ck::bhalf_t, AComputeDataType_>;
|
||||
#else
|
||||
using AComputeDataType =
|
||||
conditional_t<is_same_v<AComputeDataType_, ck::tf32_t>, float, AComputeDataType_>;
|
||||
using BComputeDataType =
|
||||
conditional_t<is_same_v<BComputeDataType_, ck::tf32_t>, float, BComputeDataType_>;
|
||||
using AComputeDataType = AComputeDataType_;
|
||||
#endif
|
||||
|
||||
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
|
||||
@@ -576,6 +573,7 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
|
||||
// This forces m/n_block_data_idx_on_grid into SGPR.
|
||||
const index_t m_block_data_idx_on_grid =
|
||||
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
|
||||
|
||||
const index_t n_block_data_idx_on_grid =
|
||||
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
|
||||
|
||||
@@ -642,10 +640,10 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
|
||||
constexpr auto is_scale_mfma = false;
|
||||
|
||||
constexpr index_t KPack = math::max(lcm_AK1_BK1,
|
||||
MfmaSelector<AComputeDataType_,
|
||||
MfmaSelector<AComputeDataType,
|
||||
MPerXdl,
|
||||
NPerXdl,
|
||||
BComputeDataType_,
|
||||
BComputeDataType,
|
||||
is_single_rate_mfma,
|
||||
is_scale_mfma>::selected_mfma.k_per_blk);
|
||||
|
||||
@@ -661,9 +659,7 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
|
||||
MXdlPerWave,
|
||||
NXdlPerWave,
|
||||
KPack,
|
||||
LoopSched,
|
||||
AComputeDataType_,
|
||||
BComputeDataType_>();
|
||||
LoopSched>();
|
||||
|
||||
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
|
||||
|
||||
|
||||
Reference in New Issue
Block a user