mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[ck][gfx12] support contraction on gfx12 (#3421)
* support contraction on gfx12 * increase tolerance for gfx11 in example contraction the precsion of gfx11 wmma is less than others.
This commit is contained in:
@@ -233,7 +233,20 @@ int run_contraction_bilinear_example(int argc, char* argv[])
|
||||
}
|
||||
}
|
||||
|
||||
return ck::utils::check_err(e_ms_ns_device_result, e_ms_ns_host_result) ? 0 : 1;
|
||||
if(ck::is_gfx11_supported())
|
||||
{
|
||||
return ck::utils::check_err(e_ms_ns_device_result,
|
||||
e_ms_ns_host_result,
|
||||
"Error: Incorrect results!",
|
||||
1e-4,
|
||||
1e-4)
|
||||
? 0
|
||||
: 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
return ck::utils::check_err(e_ms_ns_device_result, e_ms_ns_host_result) ? 0 : 1;
|
||||
}
|
||||
}
|
||||
|
||||
return 0;
|
||||
|
||||
@@ -216,7 +216,20 @@ int run_contraction_scale_example(int argc, char* argv[])
|
||||
}
|
||||
}
|
||||
|
||||
return ck::utils::check_err(e_ms_ns_device_result, e_ms_ns_host_result) ? 0 : 1;
|
||||
if(ck::is_gfx11_supported())
|
||||
{
|
||||
return ck::utils::check_err(e_ms_ns_device_result,
|
||||
e_ms_ns_host_result,
|
||||
"Error: Incorrect results!",
|
||||
1e-4,
|
||||
1e-4)
|
||||
? 0
|
||||
: 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
return ck::utils::check_err(e_ms_ns_device_result, e_ms_ns_host_result) ? 0 : 1;
|
||||
}
|
||||
}
|
||||
|
||||
return 0;
|
||||
|
||||
@@ -72,7 +72,12 @@ inline bool is_xdl_supported()
|
||||
is_gfx12_supported() || is_gfx11_supported();
|
||||
}
|
||||
|
||||
template <typename ADataType, typename BDataType, index_t MPerXDL, index_t NPerXDL>
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
index_t MPerXDL64,
|
||||
index_t NPerXDL64,
|
||||
index_t MPerXDL32 = MPerXDL64,
|
||||
index_t NPerXDL32 = NPerXDL64>
|
||||
inline bool is_xdl_wmma_supported()
|
||||
{
|
||||
if(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
|
||||
@@ -82,7 +87,7 @@ inline bool is_xdl_wmma_supported()
|
||||
}
|
||||
else if(is_gfx12_supported() || is_gfx11_supported())
|
||||
{
|
||||
if constexpr((MPerXDL != 16) || (NPerXDL != 16))
|
||||
if constexpr((MPerXDL32 != 16) || (NPerXDL32 != 16))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
#endif
|
||||
#endif
|
||||
#include "ck/utility/get_id.hpp"
|
||||
#include "ck/utility/sequence.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -96,6 +97,57 @@ static constexpr auto GetNXdlPerWave2()
|
||||
IsWave64>(); \
|
||||
}
|
||||
|
||||
template <index_t BlockSize_,
|
||||
index_t MPerBlock_,
|
||||
index_t NPerBlock_,
|
||||
index_t MPerXDL_,
|
||||
index_t NPerXDL_,
|
||||
index_t MXdlPerWave_,
|
||||
index_t CShuffleMXdlPerWavePerShuffle_,
|
||||
index_t CShuffleNXdlPerWavePerShuffle_,
|
||||
bool IsWave64>
|
||||
static constexpr auto GetWarpTileConfig()
|
||||
{
|
||||
constexpr auto MXdlPerWave64 = MXdlPerWave_;
|
||||
constexpr auto MXdlPerWave32 = MXdlPerWave_ * MPerXDL_ / 16;
|
||||
constexpr auto CShuffleMXdlPerWavePerShuffle32 = CShuffleMXdlPerWavePerShuffle_ * MPerXDL_ / 16;
|
||||
|
||||
constexpr auto NXdlPerWave =
|
||||
IsWave64
|
||||
? GetNXdlPerWave2<BlockSize_,
|
||||
MPerBlock_,
|
||||
NPerBlock_,
|
||||
MPerXDL_,
|
||||
NPerXDL_,
|
||||
MXdlPerWave_,
|
||||
true>()
|
||||
: GetNXdlPerWave2<BlockSize_, MPerBlock_, NPerBlock_, 16, 16, MXdlPerWave32, false>();
|
||||
|
||||
if constexpr(IsWave64 == false && NXdlPerWave != 0)
|
||||
{
|
||||
constexpr auto CShuffleNXdlPerWavePerShuffle32 =
|
||||
NXdlPerWave >= CShuffleNXdlPerWavePerShuffle_ * NPerXDL_ / 16
|
||||
? CShuffleNXdlPerWavePerShuffle_ * NPerXDL_ / 16
|
||||
: CShuffleNXdlPerWavePerShuffle_;
|
||||
static_assert(CShuffleNXdlPerWavePerShuffle32 > 0);
|
||||
return Sequence<16,
|
||||
16,
|
||||
MXdlPerWave32,
|
||||
NXdlPerWave,
|
||||
CShuffleMXdlPerWavePerShuffle32,
|
||||
CShuffleNXdlPerWavePerShuffle32>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
return Sequence<MPerXDL_,
|
||||
NPerXDL_,
|
||||
MXdlPerWave64,
|
||||
NXdlPerWave,
|
||||
CShuffleMXdlPerWavePerShuffle_,
|
||||
CShuffleNXdlPerWavePerShuffle_>{};
|
||||
}
|
||||
}
|
||||
|
||||
#define INVOKER_RUN_IMPL \
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) \
|
||||
{ \
|
||||
|
||||
@@ -166,11 +166,27 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
|
||||
{
|
||||
using DeviceOp = DeviceContractionMultipleD_Xdl_CShuffle;
|
||||
|
||||
GET_NXDL_PER_WAVE_IMPL
|
||||
static constexpr auto NXdlPerWave64 = GetNXdlPerWave<true>();
|
||||
static constexpr auto NXdlPerWave32 = GetNXdlPerWave<false>();
|
||||
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
static constexpr auto WarpTileConfig64 = GetWarpTileConfig<BlockSize,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MXdlPerWave,
|
||||
CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
true>();
|
||||
static constexpr auto WarpTileConfig32 = GetWarpTileConfig<BlockSize,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MXdlPerWave,
|
||||
CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
false>();
|
||||
static constexpr auto NXdlPerWave64 = WarpTileConfig64.At(3);
|
||||
static constexpr auto NXdlPerWave32 = WarpTileConfig32.At(3);
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
@@ -321,7 +337,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
|
||||
using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N({}, {}));
|
||||
|
||||
// GridwiseGemm
|
||||
template <index_t NXdlPerWave_>
|
||||
template <typename WarpTileConfig>
|
||||
using GridwiseGemmBase = GridwiseGemmMultipleD_xdl_cshuffle<
|
||||
ADataType, // TODO: distinguish A/B datatype
|
||||
BDataType,
|
||||
@@ -340,10 +356,10 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
|
||||
KPerBlock,
|
||||
AK1,
|
||||
BK1,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MXdlPerWave,
|
||||
NXdlPerWave_,
|
||||
WarpTileConfig::At(0),
|
||||
WarpTileConfig::At(1),
|
||||
WarpTileConfig::At(2),
|
||||
WarpTileConfig::At(3),
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
@@ -360,13 +376,13 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
|
||||
BBlockTransferDstScalarPerVector_BK1,
|
||||
false,
|
||||
BBlockLdsExtraN,
|
||||
CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
WarpTileConfig::At(4),
|
||||
WarpTileConfig::At(5),
|
||||
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CDEBlockTransferScalarPerVector_NPerBlock,
|
||||
LoopSched>;
|
||||
using GridwiseGemm64 = GridwiseGemmBase<math::max(NXdlPerWave64, 1)>;
|
||||
using GridwiseGemm32 = GridwiseGemmBase<NXdlPerWave32>;
|
||||
using GridwiseGemm64 = GridwiseGemmBase<decltype(WarpTileConfig64)>;
|
||||
using GridwiseGemm32 = GridwiseGemmBase<decltype(WarpTileConfig32)>;
|
||||
|
||||
// desc for blockwise copy
|
||||
using AGridDesc_AK0_M_AK1 =
|
||||
@@ -588,7 +604,12 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(!ck::is_xdl_wmma_supported<ComputeDataType, ComputeDataType, MPerXDL, NPerXDL>())
|
||||
if(!ck::is_xdl_wmma_supported<ComputeDataType,
|
||||
ComputeDataType,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
WarpTileConfig32.At(0),
|
||||
WarpTileConfig32.At(1)>())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
@@ -783,6 +804,8 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
|
||||
<< MPerBlock << ", "
|
||||
<< NPerBlock << ", "
|
||||
<< KPerBlock << ", "
|
||||
<< MPerXDL << ", "
|
||||
<< NPerXDL << ", "
|
||||
<< AK1 << ", "
|
||||
<< BK1 << ", "
|
||||
<< ABlockTransferSrcVectorDim << ", "
|
||||
|
||||
Reference in New Issue
Block a user