mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 13:41:24 +00:00
* Revert "Revert "feature:tf32:add initial conv3d fwd kernel support (#2763)" (#2848)"
This reverts commit 03b59f8c76.
* fix compile error on gf12x
* only run tf32 example on gfx942
* only build tf32 instance on gfx942
* ckProfiler:only support tf32 in gfx942
* delete unuseful messages
This commit is contained in:
@@ -49,6 +49,11 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
|
||||
|
||||
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
|
||||
|
||||
using ElementDataTypeA =
|
||||
conditional_t<is_same_v<ComputeTypeA, ck::tf32_t>, float, ComputeTypeA>;
|
||||
using ElementDataTypeB =
|
||||
conditional_t<is_same_v<ComputeTypeB, ck::tf32_t>, float, ComputeTypeB>;
|
||||
|
||||
static constexpr index_t MPerBlock = AK0MK1BlockDesc{}.GetLength(I1);
|
||||
static constexpr index_t NPerBlock = BK0NK1BlockDesc{}.GetLength(I1);
|
||||
static constexpr index_t KPerBlock =
|
||||
@@ -64,7 +69,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
|
||||
static constexpr index_t WaveSize = BlockSize / MWaves / NWaves;
|
||||
|
||||
static constexpr auto xdlops_gemm =
|
||||
XdlopsGemm<ComputeTypeA, MPerXDL, NPerXDL, KPack, ComputeTypeB>{};
|
||||
XdlopsGemm<ComputeTypeA, MPerXDL, NPerXDL, KPack, ComputeTypeB, false, false>{};
|
||||
|
||||
static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops;
|
||||
|
||||
@@ -172,6 +177,11 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
|
||||
|
||||
static_assert(MPerBlock % (MPerXDL * MRepeat) == 0 && NPerBlock % (NPerXDL * NRepeat) == 0,
|
||||
"wrong!");
|
||||
if constexpr(is_same_v<ComputeTypeA, ck::tf32_t> || is_same_v<ComputeTypeB, ck::tf32_t>)
|
||||
{
|
||||
static_assert(is_same_v<ComputeTypeA, ComputeTypeB>,
|
||||
"ComputeTypeA and ComputeTypeB must be same when one of them is tf32");
|
||||
}
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
|
||||
@@ -297,9 +307,9 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
|
||||
const BBlockBuffer& b_block_buf,
|
||||
CThreadBuffer& c_thread_buf) const
|
||||
{
|
||||
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
|
||||
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ElementDataTypeA>(
|
||||
a_thread_desc_.GetElementSpaceSize());
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ElementDataTypeB>(
|
||||
b_thread_desc_.GetElementSpaceSize());
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
@@ -321,20 +331,20 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
|
||||
b_thread_buf);
|
||||
|
||||
static_for<0, KPerThread, KPack>{}([&](auto k) {
|
||||
vector_type<ComputeTypeA, KPack> a_thread_vec;
|
||||
vector_type<ComputeTypeB, KPack> b_thread_vec;
|
||||
vector_type<ElementDataTypeA, KPack> a_thread_vec;
|
||||
vector_type<ElementDataTypeB, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto i) {
|
||||
a_thread_vec.template AsType<ComputeTypeA>()(i) = a_thread_buf
|
||||
a_thread_vec.template AsType<ElementDataTypeA>()(i) = a_thread_buf
|
||||
[Number<a_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, k + i))>{}];
|
||||
b_thread_vec.template AsType<ComputeTypeB>()(i) = b_thread_buf
|
||||
b_thread_vec.template AsType<ElementDataTypeB>()(i) = b_thread_buf
|
||||
[Number<b_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, k + i))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type_a =
|
||||
typename vector_type<ComputeTypeA, xdlops_gemm.K1PerXdlops>::type;
|
||||
typename vector_type<ElementDataTypeA, xdlops_gemm.K1PerXdlops>::type;
|
||||
using mfma_input_type_b =
|
||||
typename vector_type<ComputeTypeB, xdlops_gemm.K1PerXdlops>::type;
|
||||
typename vector_type<ElementDataTypeB, xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
@@ -361,7 +371,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
|
||||
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, xdlops_gemm.GetRegSizePerXdlops()));
|
||||
|
||||
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatA,
|
||||
ComputeTypeA,
|
||||
ElementDataTypeA,
|
||||
decltype(a_block_desc_m0_m1_m2_k),
|
||||
decltype(a_thread_desc_),
|
||||
Sequence<1, 1, 1, KPerThread>,
|
||||
@@ -371,7 +381,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
|
||||
A_K1>;
|
||||
|
||||
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatB,
|
||||
ComputeTypeB,
|
||||
ElementDataTypeB,
|
||||
decltype(b_block_desc_n0_n1_n2_k),
|
||||
decltype(b_thread_desc_),
|
||||
Sequence<1, 1, 1, KPerThread>,
|
||||
@@ -445,6 +455,11 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
|
||||
using Base::KPerThread;
|
||||
using Base::xdlops_gemm;
|
||||
|
||||
using ElementDataTypeA =
|
||||
conditional_t<is_same_v<ComputeTypeA, ck::tf32_t>, float, ComputeTypeA>;
|
||||
using ElementDataTypeB =
|
||||
conditional_t<is_same_v<ComputeTypeB, ck::tf32_t>, float, ComputeTypeB>;
|
||||
|
||||
static constexpr index_t KPerInnerLoop = math::max(KPerThread / NumMacClusters, KPack);
|
||||
|
||||
// 2-wave optimized blockwise gemm
|
||||
@@ -453,9 +468,9 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
|
||||
const BBlockBuffer& b_block_buf,
|
||||
CThreadBuffer& c_thread_buf) const
|
||||
{
|
||||
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
|
||||
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ElementDataTypeA>(
|
||||
a_thread_desc_.GetElementSpaceSize());
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ElementDataTypeB>(
|
||||
b_thread_desc_.GetElementSpaceSize());
|
||||
|
||||
static_for<0, KPerThread, KPerInnerLoop>{}([&](auto k) {
|
||||
@@ -499,22 +514,22 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
|
||||
static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) {
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
vector_type<ComputeTypeA, KPack> a_thread_vec;
|
||||
vector_type<ComputeTypeB, KPack> b_thread_vec;
|
||||
vector_type<ElementDataTypeA, KPack> a_thread_vec;
|
||||
vector_type<ElementDataTypeB, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto i) {
|
||||
a_thread_vec.template AsType<ComputeTypeA>()(i) =
|
||||
a_thread_vec.template AsType<ElementDataTypeA>()(i) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, 0, 0, k_ + i))>{}];
|
||||
b_thread_vec.template AsType<ComputeTypeB>()(i) =
|
||||
b_thread_vec.template AsType<ElementDataTypeB>()(i) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, 0, 0, k_ + i))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type_a =
|
||||
typename vector_type<ComputeTypeA, xdlops_gemm.K1PerXdlops>::type;
|
||||
typename vector_type<ElementDataTypeA, xdlops_gemm.K1PerXdlops>::type;
|
||||
using mfma_input_type_b =
|
||||
typename vector_type<ComputeTypeB, xdlops_gemm.K1PerXdlops>::type;
|
||||
typename vector_type<ElementDataTypeB, xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
@@ -563,7 +578,7 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
|
||||
make_tuple(Number<NRepeat>{}, I1, I1, Number<KPerInnerLoop>{}));
|
||||
|
||||
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatA,
|
||||
ComputeTypeA,
|
||||
ElementDataTypeA,
|
||||
decltype(a_block_desc_m0_m1_m2_k),
|
||||
decltype(a_thread_desc_),
|
||||
Sequence<1, 1, 1, KPerInnerLoop>,
|
||||
@@ -573,7 +588,7 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
|
||||
A_K1>;
|
||||
|
||||
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatB,
|
||||
ComputeTypeB,
|
||||
ElementDataTypeB,
|
||||
decltype(b_block_desc_n0_n1_n2_k),
|
||||
decltype(b_thread_desc_),
|
||||
Sequence<1, 1, 1, KPerInnerLoop>,
|
||||
@@ -622,19 +637,21 @@ constexpr auto BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector()
|
||||
}
|
||||
else if constexpr(LoopSched == LoopScheduler::Interwave)
|
||||
{
|
||||
return BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
|
||||
FloatA,
|
||||
FloatB,
|
||||
FloatAcc,
|
||||
AK0MK1BlockDesc,
|
||||
BK0NK1BlockDesc,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack,
|
||||
ComputeTypeA,
|
||||
ComputeTypeB>{};
|
||||
return BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<
|
||||
BlockSize,
|
||||
FloatA,
|
||||
FloatB,
|
||||
FloatAcc,
|
||||
AK0MK1BlockDesc,
|
||||
BK0NK1BlockDesc,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack,
|
||||
ComputeTypeA,
|
||||
ComputeTypeB,
|
||||
CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING_MAC_CLUSTERS>{};
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -119,7 +119,9 @@ struct DeviceGemm_Xdl_CShuffle_LdsDirectLoad : public DeviceGemm<ALayout,
|
||||
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CDEBlockTransferScalarPerVector_NPerBlock,
|
||||
LoopSched,
|
||||
PipelineVer>;
|
||||
PipelineVer,
|
||||
ComputeDataType>;
|
||||
|
||||
using GridwiseGemm64 = GridwiseGemmBase<math::max(NXdlPerWave64, 1)>;
|
||||
using GridwiseGemm32 = GridwiseGemmBase<NXdlPerWave32>;
|
||||
|
||||
@@ -214,6 +216,14 @@ struct DeviceGemm_Xdl_CShuffle_LdsDirectLoad : public DeviceGemm<ALayout,
|
||||
return false;
|
||||
}
|
||||
|
||||
if constexpr(is_same_v<ComputeDataType, ck::tf32_t>)
|
||||
{
|
||||
if(!is_tf32_supported())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Check vector load/store.
|
||||
{
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
|
||||
@@ -1003,11 +1003,20 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
|
||||
void Print() const
|
||||
{
|
||||
std::cout << "AComputeDataType: " << get_type_name<AComputeDataType>()
|
||||
<< "; BComputeDataType: " << get_type_name<BComputeDataType>()
|
||||
<< "; EDataType: " << get_type_name<EDataType>() << std::endl;
|
||||
|
||||
std::cout << "A[M, K]: " << a_grid_desc_m_k_ << std::endl;
|
||||
std::cout << "B[N, K]: " << b_grid_desc_n_k_ << std::endl;
|
||||
static_for<0, NumDTensor, 1>{}(
|
||||
[&](auto i) { std::cout << "Ds[M, N]: " << ds_grid_desc_m_n_[i] << std::endl; });
|
||||
std::cout << "E[M, N]: " << e_grid_desc_m_n_ << std::endl;
|
||||
|
||||
std::cout << "a grid desc" << a_grid_desc_ak0_m_ak1_ << std::endl;
|
||||
std::cout << "b grid desc" << b_grid_desc_bk0_n_bk1_ << std::endl;
|
||||
std::cout << "e grid desc" << e_grid_desc_mblock_mperblock_nblock_nperblock_
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
// private:
|
||||
@@ -1198,7 +1207,6 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
isMultiA,
|
||||
isMultiB,
|
||||
CTranspose>;
|
||||
|
||||
return launch_and_time_kernel(
|
||||
stream_config,
|
||||
kernel,
|
||||
@@ -1281,7 +1289,6 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
float avg_time = 0.f;
|
||||
|
||||
if constexpr(NeedTransposeKernel)
|
||||
{
|
||||
const index_t a_grid_size =
|
||||
@@ -1686,7 +1693,23 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if constexpr(is_same_v<AComputeDataType, ck::tf32_t> ||
|
||||
is_same_v<BComputeDataType, ck::tf32_t>)
|
||||
{
|
||||
if(!is_tf32_supported())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
if constexpr(!is_same_v<AComputeDataType, BComputeDataType>)
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "ComputeDataType for A and B should be same while using TF32"
|
||||
<< std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
// check Gridwise GEMM
|
||||
if(get_warp_size() == 64)
|
||||
{
|
||||
@@ -1766,6 +1789,28 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(is_same_v<AComputeDataType, ck::tf32_t> ||
|
||||
is_same_v<BComputeDataType, ck::tf32_t>)
|
||||
|
||||
{
|
||||
if(!(ck::get_device_name() == "gfx942"))
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "TF32 is enabled on gfx942 only" << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
if constexpr(!is_same_v<AComputeDataType, BComputeDataType>)
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "ComputeDataType for A and B should be same while using TF32"
|
||||
<< std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
@@ -708,7 +708,9 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
|
||||
MXdlPerWave,
|
||||
NXdlPerWave,
|
||||
KPack,
|
||||
LoopSched>();
|
||||
LoopSched,
|
||||
AComputeDataType,
|
||||
BComputeDataType>();
|
||||
|
||||
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
|
||||
|
||||
|
||||
@@ -107,8 +107,10 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
|
||||
using BComputeDataType =
|
||||
conditional_t<is_same_v<BComputeDataType_, ck::half_t>, ck::bhalf_t, BComputeDataType_>;
|
||||
#else
|
||||
using AComputeDataType = AComputeDataType_;
|
||||
using BComputeDataType = BComputeDataType_;
|
||||
using AComputeDataType =
|
||||
conditional_t<is_same_v<AComputeDataType_, ck::tf32_t>, float, AComputeDataType_>;
|
||||
using BComputeDataType =
|
||||
conditional_t<is_same_v<BComputeDataType_, ck::tf32_t>, float, BComputeDataType_>;
|
||||
#endif
|
||||
|
||||
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
|
||||
@@ -659,26 +661,27 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
|
||||
: false;
|
||||
constexpr auto is_scale_mfma = false;
|
||||
constexpr index_t KPack = math::max(lcm_AK1_BK1,
|
||||
MfmaSelector<AComputeDataType,
|
||||
MfmaSelector<AComputeDataType_,
|
||||
MPerXdl,
|
||||
NPerXdl,
|
||||
BComputeDataType,
|
||||
BComputeDataType_,
|
||||
is_single_rate_mfma,
|
||||
is_scale_mfma>::selected_mfma.k_per_blk);
|
||||
|
||||
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
|
||||
BlockSize,
|
||||
AComputeDataType,
|
||||
BComputeDataType,
|
||||
AccDataType,
|
||||
decltype(a_block_desc_ak0_m_ak1),
|
||||
decltype(b_block_desc_bk0_n_bk1),
|
||||
MPerXdl,
|
||||
NPerXdl,
|
||||
MXdlPerWave,
|
||||
NXdlPerWave,
|
||||
KPack,
|
||||
LoopSched>();
|
||||
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
|
||||
BlockSize,
|
||||
AComputeDataType,
|
||||
BComputeDataType,
|
||||
AccDataType,
|
||||
decltype(a_block_desc_ak0_m_ak1),
|
||||
decltype(b_block_desc_bk0_n_bk1),
|
||||
MPerXdl,
|
||||
NPerXdl,
|
||||
MXdlPerWave,
|
||||
NXdlPerWave,
|
||||
KPack,
|
||||
LoopSched,
|
||||
AComputeDataType_,
|
||||
BComputeDataType_>();
|
||||
|
||||
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
|
||||
|
||||
|
||||
@@ -144,7 +144,7 @@ template <typename ALayout,
|
||||
index_t CDEShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
LoopScheduler LoopSched,
|
||||
PipelineVersion PipelineVer = PipelineVersion::v4,
|
||||
typename BComputeDataType = AComputeDataType_>
|
||||
typename BComputeDataType_ = AComputeDataType_>
|
||||
struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
|
||||
{
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
@@ -172,7 +172,10 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
|
||||
using AComputeDataType =
|
||||
conditional_t<is_same_v<AComputeDataType_, ck::half_t>, ck::bhalf_t, AComputeDataType_>;
|
||||
#else
|
||||
using AComputeDataType = AComputeDataType_;
|
||||
using AComputeDataType =
|
||||
conditional_t<is_same_v<AComputeDataType_, ck::tf32_t>, float, AComputeDataType_>;
|
||||
using BComputeDataType =
|
||||
conditional_t<is_same_v<BComputeDataType_, ck::tf32_t>, float, BComputeDataType_>;
|
||||
#endif
|
||||
|
||||
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
|
||||
@@ -573,7 +576,6 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
|
||||
// This forces m/n_block_data_idx_on_grid into SGPR.
|
||||
const index_t m_block_data_idx_on_grid =
|
||||
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
|
||||
|
||||
const index_t n_block_data_idx_on_grid =
|
||||
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
|
||||
|
||||
@@ -640,10 +642,10 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
|
||||
constexpr auto is_scale_mfma = false;
|
||||
|
||||
constexpr index_t KPack = math::max(lcm_AK1_BK1,
|
||||
MfmaSelector<AComputeDataType,
|
||||
MfmaSelector<AComputeDataType_,
|
||||
MPerXdl,
|
||||
NPerXdl,
|
||||
BComputeDataType,
|
||||
BComputeDataType_,
|
||||
is_single_rate_mfma,
|
||||
is_scale_mfma>::selected_mfma.k_per_blk);
|
||||
|
||||
@@ -659,7 +661,9 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
|
||||
MXdlPerWave,
|
||||
NXdlPerWave,
|
||||
KPack,
|
||||
LoopSched>();
|
||||
LoopSched,
|
||||
AComputeDataType_,
|
||||
BComputeDataType_>();
|
||||
|
||||
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
|
||||
|
||||
|
||||
@@ -41,11 +41,11 @@ static constexpr bool scale_mfma_hw_support()
|
||||
|
||||
enum struct MfmaInstr
|
||||
{
|
||||
mfma_f32_32x32x1xf32 = 0,
|
||||
mfma_f32_16x16x1xf32,
|
||||
mfma_f32_4x4x1xf32,
|
||||
mfma_f32_32x32x2xf32,
|
||||
mfma_f32_16x16x4xf32,
|
||||
mfma_f32_32x32x1f32 = 0,
|
||||
mfma_f32_16x16x1f32,
|
||||
mfma_f32_4x4x1f32,
|
||||
mfma_f32_32x32x2f32,
|
||||
mfma_f32_16x16x4f32,
|
||||
mfma_f32_32x32x4f16,
|
||||
mfma_f32_16x16x4f16,
|
||||
mfma_f32_4x4x4f16,
|
||||
@@ -78,6 +78,8 @@ enum struct MfmaInstr
|
||||
mfma_f32_16x16x128f8f6f4,
|
||||
mfma_scale_f32_32x32x64f8f6f4,
|
||||
mfma_scale_f32_16x16x128f8f6f4,
|
||||
mfma_f32_16x16x8xf32, // tf32
|
||||
mfma_f32_32x32x4xf32,
|
||||
// gfx11
|
||||
wmma_f32_16x16x16_f16,
|
||||
wmma_f32_16x16x16_bf16,
|
||||
@@ -98,7 +100,7 @@ template <MfmaInstr instr>
|
||||
struct mfma_type;
|
||||
|
||||
template <>
|
||||
struct mfma_type<MfmaInstr::mfma_f32_32x32x1xf32>
|
||||
struct mfma_type<MfmaInstr::mfma_f32_32x32x1f32>
|
||||
{
|
||||
static constexpr index_t group_size = 4;
|
||||
static constexpr index_t num_groups_per_blk = 4;
|
||||
@@ -120,7 +122,7 @@ struct mfma_type<MfmaInstr::mfma_f32_32x32x1xf32>
|
||||
};
|
||||
|
||||
template <>
|
||||
struct mfma_type<MfmaInstr::mfma_f32_32x32x2xf32>
|
||||
struct mfma_type<MfmaInstr::mfma_f32_32x32x2f32>
|
||||
{
|
||||
static constexpr index_t group_size = 4;
|
||||
static constexpr index_t num_groups_per_blk = 4;
|
||||
@@ -142,7 +144,7 @@ struct mfma_type<MfmaInstr::mfma_f32_32x32x2xf32>
|
||||
};
|
||||
|
||||
template <>
|
||||
struct mfma_type<MfmaInstr::mfma_f32_16x16x4xf32>
|
||||
struct mfma_type<MfmaInstr::mfma_f32_16x16x4f32>
|
||||
{
|
||||
static constexpr index_t group_size = 4;
|
||||
static constexpr index_t num_groups_per_blk = 1;
|
||||
@@ -164,7 +166,7 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x4xf32>
|
||||
};
|
||||
|
||||
template <>
|
||||
struct mfma_type<MfmaInstr::mfma_f32_16x16x1xf32>
|
||||
struct mfma_type<MfmaInstr::mfma_f32_16x16x1f32>
|
||||
{
|
||||
static constexpr index_t group_size = 4;
|
||||
static constexpr index_t num_groups_per_blk = 1;
|
||||
@@ -187,7 +189,7 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x1xf32>
|
||||
|
||||
// treat 4x4x1 as a single-blk 4x64 mfma
|
||||
template <>
|
||||
struct mfma_type<MfmaInstr::mfma_f32_4x4x1xf32>
|
||||
struct mfma_type<MfmaInstr::mfma_f32_4x4x1f32>
|
||||
{
|
||||
static constexpr index_t group_size = 4;
|
||||
static constexpr index_t num_groups_per_blk = 1;
|
||||
@@ -947,6 +949,70 @@ struct mfma_type<MfmaInstr::mfma_scale_f32_16x16x128f8f6f4>
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* num_threads_per_blk == n_per_blk
|
||||
* num_regs_per_blk * num_input_blks == m_per_blk
|
||||
* num_regs_per_blk * wave_size == m_per_blk * n_per_blk
|
||||
*
|
||||
* group_size * num_groups_per_blk == num_regs_per_blk
|
||||
*
|
||||
* num_regs_per_blk is output(CD) register size which is determined by the instruction.
|
||||
* k_per_blk(K1PerXdlops) is input(AB) register size which is determined by the instruction.
|
||||
* group_size is corresponding to CD rows mapping. see: GetBeginOfThreadBlk()
|
||||
*
|
||||
* is_k_reduction = (k_per_blk == KPerXdlops) ? false: true.
|
||||
*
|
||||
* if (is_k_reduction){
|
||||
* num_output_blks == 1;
|
||||
* } else {
|
||||
* num_input_blks == num_output_blks;
|
||||
* }
|
||||
*/
|
||||
template <>
|
||||
struct mfma_type<MfmaInstr::mfma_f32_16x16x8xf32>
|
||||
{
|
||||
static constexpr index_t wave_size = 64; // fixed
|
||||
static constexpr index_t m_per_blk = 16; // from the instruction
|
||||
static constexpr index_t n_per_blk = 16; // from the instruction
|
||||
static constexpr index_t num_threads_per_blk = n_per_blk; // 16
|
||||
static constexpr index_t num_regs_per_blk = m_per_blk * n_per_blk / wave_size; // 4
|
||||
static constexpr index_t num_input_blks = m_per_blk / num_regs_per_blk; // 4
|
||||
static constexpr index_t group_size = 4;
|
||||
static constexpr index_t num_groups_per_blk = 1;
|
||||
static constexpr index_t num_output_blks = 1;
|
||||
static constexpr index_t k_per_blk = 2; // k_per_blk(K1PerXdlops) should be 2.
|
||||
static constexpr bool is_k_reduction = true;
|
||||
|
||||
// AB register size : 2, register size: 4
|
||||
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
|
||||
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
|
||||
{
|
||||
intrin_mfma_f32_16x16x8xf32<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct mfma_type<MfmaInstr::mfma_f32_32x32x4xf32>
|
||||
{
|
||||
static constexpr index_t wave_size = 64; // fixed
|
||||
static constexpr index_t m_per_blk = 32; // from the instruction
|
||||
static constexpr index_t n_per_blk = 32; // from the instruction
|
||||
static constexpr index_t num_threads_per_blk = n_per_blk; // 32
|
||||
static constexpr index_t num_regs_per_blk = m_per_blk * n_per_blk / wave_size; // 16
|
||||
static constexpr index_t num_input_blks = m_per_blk / num_regs_per_blk; // 2
|
||||
static constexpr index_t group_size = 4; // corresponding to CD rows mapping
|
||||
static constexpr index_t num_groups_per_blk = 4;
|
||||
static constexpr index_t num_output_blks = 1;
|
||||
static constexpr index_t k_per_blk = 2;
|
||||
static constexpr bool is_k_reduction = true;
|
||||
// AB register size: 2, CD register size: 16
|
||||
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
|
||||
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
|
||||
{
|
||||
intrin_mfma_f32_32x32x4xf32<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
|
||||
}
|
||||
};
|
||||
|
||||
// gfx11
|
||||
struct mfma_type_gfx11_base
|
||||
{
|
||||
@@ -1116,6 +1182,20 @@ struct mfma_type<MfmaInstr::wmma_unsupport_16x16_gfx12> : public mfma_type_gfx12
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* @class MfmaSelector
|
||||
* @brief Selects the appropriate MFMA instruction type and configuration for given data types
|
||||
* and tile sizes on AMD GPUs.
|
||||
*
|
||||
* @tparam base_type The base data type for the matrix operation (e.g., float, half_t).
|
||||
* @tparam MPerXdlops The number of rows per XDLops tile.
|
||||
* @tparam NPerXdlops The number of columns per XDLops tile.
|
||||
* @tparam additional_type (Optional) Additional data type for mixed-precision or special cases.
|
||||
* Defaults to base_type.
|
||||
* @tparam is_single_rate_mfma (Optional) Whether to use single-rate MFMA instructions.
|
||||
* Defaults to false.
|
||||
* @tparam is_scale_mfma (Optional) Whether to use scale MFMA instructions. Defaults to false.
|
||||
*/
|
||||
template <typename base_type,
|
||||
index_t MPerXdlops,
|
||||
index_t NPerXdlops,
|
||||
@@ -1147,37 +1227,37 @@ struct MfmaSelector
|
||||
template <>
|
||||
constexpr auto GetMfma<float, 64, 64>()
|
||||
{
|
||||
return MfmaInstr::mfma_f32_32x32x1xf32;
|
||||
return MfmaInstr::mfma_f32_32x32x1f32;
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr auto GetMfma<float, 32, 64>()
|
||||
{
|
||||
return MfmaInstr::mfma_f32_32x32x1xf32;
|
||||
return MfmaInstr::mfma_f32_32x32x1f32;
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr auto GetMfma<float, 16, 64>()
|
||||
{
|
||||
return MfmaInstr::mfma_f32_16x16x1xf32;
|
||||
return MfmaInstr::mfma_f32_16x16x1f32;
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr auto GetMfma<float, 8, 64>()
|
||||
{
|
||||
return MfmaInstr::mfma_f32_4x4x1xf32;
|
||||
return MfmaInstr::mfma_f32_4x4x1f32;
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr auto GetMfma<float, 4, 64>()
|
||||
{
|
||||
return MfmaInstr::mfma_f32_4x4x1xf32;
|
||||
return MfmaInstr::mfma_f32_4x4x1f32;
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr auto GetMfma<float, 32, 32>()
|
||||
{
|
||||
return MfmaInstr::mfma_f32_32x32x2xf32;
|
||||
return MfmaInstr::mfma_f32_32x32x2f32;
|
||||
}
|
||||
|
||||
template <>
|
||||
@@ -1188,10 +1268,22 @@ struct MfmaSelector
|
||||
#elif defined(__gfx11__)
|
||||
return MfmaInstr::wmma_unsupport_16x16_gfx11;
|
||||
#else
|
||||
return MfmaInstr::mfma_f32_16x16x4xf32;
|
||||
return MfmaInstr::mfma_f32_16x16x4f32;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr auto GetMfma<tf32_t, 32, 32>()
|
||||
{
|
||||
return MfmaInstr::mfma_f32_32x32x4xf32;
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr auto GetMfma<tf32_t, 16, 16>()
|
||||
{
|
||||
return MfmaInstr::mfma_f32_16x16x8xf32;
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr auto GetMfma<half_t, 64, 64>()
|
||||
{
|
||||
@@ -1896,7 +1988,7 @@ struct XdlopsGemm
|
||||
|
||||
__device__ __host__ static constexpr index_t GetRegSizePerXdlops()
|
||||
{
|
||||
return MPerXdlops * NPerXdlops / mfma_instr.wave_size;
|
||||
return mfma_instr.num_regs_per_blk;
|
||||
}
|
||||
|
||||
__device__ static constexpr index_t GetWaveSize() { return mfma_instr.wave_size; }
|
||||
@@ -1906,12 +1998,12 @@ struct XdlopsGemm
|
||||
{
|
||||
static_assert(
|
||||
is_same<base_type, double>::value || is_same<base_type, float>::value ||
|
||||
is_same<base_type, half_t>::value || is_same<base_type, bhalf_t>::value ||
|
||||
is_same<base_type, int8_t>::value || is_same<base_type, f8_t>::value ||
|
||||
is_same<base_type, bf8_t>::value ||
|
||||
is_same<base_type, tf32_t>::value || is_same<base_type, half_t>::value ||
|
||||
is_same<base_type, bhalf_t>::value || is_same<base_type, int8_t>::value ||
|
||||
is_same<base_type, f8_t>::value || is_same<base_type, bf8_t>::value ||
|
||||
(is_same<base_type, f8_t>::value && is_same<additional_type, bf8_t>::value) ||
|
||||
(is_same<base_type, bf8_t>::value && is_same<additional_type, f8_t>::value),
|
||||
"base base_type must be double, float, half, bfloat16, int8_t, f8_t or bf8_t!");
|
||||
"base_type must be double, float, tf32_t, half, bfloat16, int8_t, f8_t or bf8_t!");
|
||||
|
||||
static_for<0, KPack / mfma_instr.k_per_blk, 1>{}([&](auto k) {
|
||||
if constexpr(!TransposeC)
|
||||
|
||||
Reference in New Issue
Block a user