mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-25 07:14:37 +00:00
Merge commit '7e93eed8787afd175d3a045303096a4a98638f4b' into develop
This commit is contained in:
@@ -72,7 +72,12 @@ inline bool is_xdl_supported()
|
||||
is_gfx12_supported() || is_gfx11_supported();
|
||||
}
|
||||
|
||||
template <typename ADataType, typename BDataType, index_t MPerXDL, index_t NPerXDL>
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
index_t MPerXDL64,
|
||||
index_t NPerXDL64,
|
||||
index_t MPerXDL32 = MPerXDL64,
|
||||
index_t NPerXDL32 = NPerXDL64>
|
||||
inline bool is_xdl_wmma_supported()
|
||||
{
|
||||
if(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
|
||||
@@ -82,7 +87,7 @@ inline bool is_xdl_wmma_supported()
|
||||
}
|
||||
else if(is_gfx12_supported() || is_gfx11_supported())
|
||||
{
|
||||
if constexpr((MPerXDL != 16) || (NPerXDL != 16))
|
||||
if constexpr((MPerXDL32 != 16) || (NPerXDL32 != 16))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
#endif
|
||||
#endif
|
||||
#include "ck/utility/get_id.hpp"
|
||||
#include "ck/utility/sequence.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -96,6 +97,57 @@ static constexpr auto GetNXdlPerWave2()
|
||||
IsWave64>(); \
|
||||
}
|
||||
|
||||
template <index_t BlockSize_,
|
||||
index_t MPerBlock_,
|
||||
index_t NPerBlock_,
|
||||
index_t MPerXDL_,
|
||||
index_t NPerXDL_,
|
||||
index_t MXdlPerWave_,
|
||||
index_t CShuffleMXdlPerWavePerShuffle_,
|
||||
index_t CShuffleNXdlPerWavePerShuffle_,
|
||||
bool IsWave64>
|
||||
static constexpr auto GetWarpTileConfig()
|
||||
{
|
||||
constexpr auto MXdlPerWave64 = MXdlPerWave_;
|
||||
constexpr auto MXdlPerWave32 = MXdlPerWave_ * MPerXDL_ / 16;
|
||||
constexpr auto CShuffleMXdlPerWavePerShuffle32 = CShuffleMXdlPerWavePerShuffle_ * MPerXDL_ / 16;
|
||||
|
||||
constexpr auto NXdlPerWave =
|
||||
IsWave64
|
||||
? GetNXdlPerWave2<BlockSize_,
|
||||
MPerBlock_,
|
||||
NPerBlock_,
|
||||
MPerXDL_,
|
||||
NPerXDL_,
|
||||
MXdlPerWave_,
|
||||
true>()
|
||||
: GetNXdlPerWave2<BlockSize_, MPerBlock_, NPerBlock_, 16, 16, MXdlPerWave32, false>();
|
||||
|
||||
if constexpr(IsWave64 == false && NXdlPerWave != 0)
|
||||
{
|
||||
constexpr auto CShuffleNXdlPerWavePerShuffle32 =
|
||||
NXdlPerWave >= CShuffleNXdlPerWavePerShuffle_ * NPerXDL_ / 16
|
||||
? CShuffleNXdlPerWavePerShuffle_ * NPerXDL_ / 16
|
||||
: CShuffleNXdlPerWavePerShuffle_;
|
||||
static_assert(CShuffleNXdlPerWavePerShuffle32 > 0);
|
||||
return Sequence<16,
|
||||
16,
|
||||
MXdlPerWave32,
|
||||
NXdlPerWave,
|
||||
CShuffleMXdlPerWavePerShuffle32,
|
||||
CShuffleNXdlPerWavePerShuffle32>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
return Sequence<MPerXDL_,
|
||||
NPerXDL_,
|
||||
MXdlPerWave64,
|
||||
NXdlPerWave,
|
||||
CShuffleMXdlPerWavePerShuffle_,
|
||||
CShuffleNXdlPerWavePerShuffle_>{};
|
||||
}
|
||||
}
|
||||
|
||||
#define INVOKER_RUN_IMPL \
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) \
|
||||
{ \
|
||||
|
||||
@@ -166,11 +166,27 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
|
||||
{
|
||||
using DeviceOp = DeviceContractionMultipleD_Xdl_CShuffle;
|
||||
|
||||
GET_NXDL_PER_WAVE_IMPL
|
||||
static constexpr auto NXdlPerWave64 = GetNXdlPerWave<true>();
|
||||
static constexpr auto NXdlPerWave32 = GetNXdlPerWave<false>();
|
||||
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
static constexpr auto WarpTileConfig64 = GetWarpTileConfig<BlockSize,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MXdlPerWave,
|
||||
CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
true>();
|
||||
static constexpr auto WarpTileConfig32 = GetWarpTileConfig<BlockSize,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MXdlPerWave,
|
||||
CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
false>();
|
||||
static constexpr auto NXdlPerWave64 = WarpTileConfig64.At(3);
|
||||
static constexpr auto NXdlPerWave32 = WarpTileConfig32.At(3);
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
@@ -321,7 +337,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
|
||||
using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N({}, {}));
|
||||
|
||||
// GridwiseGemm
|
||||
template <index_t NXdlPerWave_>
|
||||
template <typename WarpTileConfig>
|
||||
using GridwiseGemmBase = GridwiseGemmMultipleD_xdl_cshuffle<
|
||||
ADataType, // TODO: distinguish A/B datatype
|
||||
BDataType,
|
||||
@@ -340,10 +356,10 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
|
||||
KPerBlock,
|
||||
AK1,
|
||||
BK1,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MXdlPerWave,
|
||||
NXdlPerWave_,
|
||||
WarpTileConfig::At(0),
|
||||
WarpTileConfig::At(1),
|
||||
WarpTileConfig::At(2),
|
||||
WarpTileConfig::At(3),
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
@@ -360,13 +376,13 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
|
||||
BBlockTransferDstScalarPerVector_BK1,
|
||||
false,
|
||||
BBlockLdsExtraN,
|
||||
CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
WarpTileConfig::At(4),
|
||||
WarpTileConfig::At(5),
|
||||
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CDEBlockTransferScalarPerVector_NPerBlock,
|
||||
LoopSched>;
|
||||
using GridwiseGemm64 = GridwiseGemmBase<math::max(NXdlPerWave64, 1)>;
|
||||
using GridwiseGemm32 = GridwiseGemmBase<NXdlPerWave32>;
|
||||
using GridwiseGemm64 = GridwiseGemmBase<decltype(WarpTileConfig64)>;
|
||||
using GridwiseGemm32 = GridwiseGemmBase<decltype(WarpTileConfig32)>;
|
||||
|
||||
// desc for blockwise copy
|
||||
using AGridDesc_AK0_M_AK1 =
|
||||
@@ -588,7 +604,12 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(!ck::is_xdl_wmma_supported<ComputeDataType, ComputeDataType, MPerXDL, NPerXDL>())
|
||||
if(!ck::is_xdl_wmma_supported<ComputeDataType,
|
||||
ComputeDataType,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
WarpTileConfig32.At(0),
|
||||
WarpTileConfig32.At(1)>())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
@@ -783,6 +804,8 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
|
||||
<< MPerBlock << ", "
|
||||
<< NPerBlock << ", "
|
||||
<< KPerBlock << ", "
|
||||
<< MPerXDL << ", "
|
||||
<< NPerXDL << ", "
|
||||
<< AK1 << ", "
|
||||
<< BK1 << ", "
|
||||
<< ABlockTransferSrcVectorDim << ", "
|
||||
|
||||
@@ -68,7 +68,7 @@ auto shuffle_bq(const ck_tile::HostTensor<T>* t, int block_bq_k)
|
||||
}
|
||||
|
||||
template <typename GemmConfig, typename T>
|
||||
auto shuffle_b(const ck_tile::HostTensor<T>& t)
|
||||
auto shuffle_b(const ck_tile::HostTensor<T>& t, const GemmConfig& gemmConfig)
|
||||
{
|
||||
assert(t.get_lengths().size() == 2);
|
||||
int n_ = t.get_lengths()[1];
|
||||
@@ -78,10 +78,10 @@ auto shuffle_b(const ck_tile::HostTensor<T>& t)
|
||||
{
|
||||
constexpr int divisor = 2;
|
||||
constexpr int kABK1PerLane = 8;
|
||||
constexpr int kABK0PerLane = GemmConfig::K_Warp_Tile / divisor / kABK1PerLane;
|
||||
ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
k_ / GemmConfig::K_Warp_Tile,
|
||||
int kABK0PerLane = gemmConfig.K_Warp_Tile / divisor / kABK1PerLane;
|
||||
ck_tile::HostTensor<T> t_view({n_ / gemmConfig.N_Warp_Tile,
|
||||
gemmConfig.N_Warp_Tile,
|
||||
k_ / gemmConfig.K_Warp_Tile,
|
||||
kABK0PerLane,
|
||||
divisor,
|
||||
kABK1PerLane});
|
||||
@@ -98,18 +98,24 @@ auto shuffle_b(const ck_tile::HostTensor<T>& t)
|
||||
else
|
||||
{
|
||||
assert(is_wave32() == false);
|
||||
divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4;
|
||||
divisor = gemmConfig.N_Warp_Tile == 32 ? 2 : 4;
|
||||
}
|
||||
ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
k_ / GemmConfig::K_Warp_Tile,
|
||||
ck_tile::HostTensor<T> t_view({n_ / gemmConfig.N_Warp_Tile,
|
||||
gemmConfig.N_Warp_Tile,
|
||||
k_ / gemmConfig.K_Warp_Tile,
|
||||
divisor,
|
||||
GemmConfig::K_Warp_Tile / divisor});
|
||||
gemmConfig.K_Warp_Tile / divisor});
|
||||
std::copy(t.begin(), t.end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
|
||||
}
|
||||
}
|
||||
|
||||
template <typename GemmConfig, typename T>
|
||||
auto shuffle_b(const ck_tile::HostTensor<T>& t)
|
||||
{
|
||||
return shuffle_b(t, GemmConfig{});
|
||||
}
|
||||
|
||||
template <typename GemmConfig, typename T>
|
||||
auto bq_permuteN(const ck_tile::HostTensor<T>& t, index_t group_n)
|
||||
{
|
||||
@@ -129,22 +135,22 @@ auto bq_permuteN(const ck_tile::HostTensor<T>& t, index_t group_n)
|
||||
}
|
||||
|
||||
template <typename GemmConfig, typename T>
|
||||
auto shuffle_b_permuteN(const ck_tile::HostTensor<T>& t)
|
||||
auto shuffle_b_permuteN(const ck_tile::HostTensor<T>& t, const GemmConfig& gemmConfig)
|
||||
{
|
||||
assert(t.get_lengths().size() == 2);
|
||||
int n_ = t.get_lengths()[1];
|
||||
int k_ = t.get_lengths()[0];
|
||||
constexpr int NRepeat = GemmConfig::N_Tile / GemmConfig::N_Warp_Tile / GemmConfig::N_Warp;
|
||||
int n_ = t.get_lengths()[1];
|
||||
int k_ = t.get_lengths()[0];
|
||||
int NRepeat = gemmConfig.N_Tile / gemmConfig.N_Warp_Tile / gemmConfig.N_Warp;
|
||||
if(ck_tile::is_gfx12_supported())
|
||||
{
|
||||
constexpr int divisor = 2;
|
||||
constexpr int kABK1PerLane = 8;
|
||||
constexpr int kABK0PerLane = GemmConfig::K_Warp_Tile / divisor / kABK1PerLane;
|
||||
ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Tile,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
int kABK0PerLane = gemmConfig.K_Warp_Tile / divisor / kABK1PerLane;
|
||||
ck_tile::HostTensor<T> t_view({n_ / gemmConfig.N_Tile,
|
||||
gemmConfig.N_Warp,
|
||||
gemmConfig.N_Warp_Tile,
|
||||
NRepeat,
|
||||
k_ / GemmConfig::K_Warp_Tile,
|
||||
k_ / gemmConfig.K_Warp_Tile,
|
||||
kABK0PerLane,
|
||||
divisor,
|
||||
kABK1PerLane});
|
||||
@@ -161,17 +167,23 @@ auto shuffle_b_permuteN(const ck_tile::HostTensor<T>& t)
|
||||
else
|
||||
{
|
||||
assert(is_wave32() == false);
|
||||
divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4;
|
||||
divisor = gemmConfig.N_Warp_Tile == 32 ? 2 : 4;
|
||||
}
|
||||
ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Tile,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
ck_tile::HostTensor<T> t_view({n_ / gemmConfig.N_Tile,
|
||||
gemmConfig.N_Warp,
|
||||
gemmConfig.N_Warp_Tile,
|
||||
NRepeat,
|
||||
k_ / GemmConfig::K_Warp_Tile,
|
||||
k_ / gemmConfig.K_Warp_Tile,
|
||||
divisor,
|
||||
GemmConfig::K_Warp_Tile / divisor});
|
||||
gemmConfig.K_Warp_Tile / divisor});
|
||||
std::copy(t.begin(), t.end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {0, 3, 1, 4, 5, 2, 6});
|
||||
}
|
||||
}
|
||||
|
||||
template <typename GemmConfig, typename T>
|
||||
auto shuffle_b_permuteN(const ck_tile::HostTensor<T>& t)
|
||||
{
|
||||
return shuffle_b_permuteN(t, GemmConfig{});
|
||||
}
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -43,4 +43,26 @@ struct TileGemmShape
|
||||
}
|
||||
};
|
||||
|
||||
template <typename PrecType, index_t M_Warp_Tile, bool IsFlatMM = false>
|
||||
constexpr index_t get_k_warp_tile()
|
||||
{
|
||||
#if CK_TILE_USE_WMMA
|
||||
return 16;
|
||||
#else
|
||||
#if defined(CK_GFX950_SUPPORT)
|
||||
constexpr bool is_8bit_float =
|
||||
std::is_same_v<PrecType, fp8_t> || std::is_same_v<PrecType, bf8_t>;
|
||||
if constexpr(M_Warp_Tile == 32)
|
||||
return is_8bit_float ? 64 : 16;
|
||||
else
|
||||
return is_8bit_float ? 128 : 32;
|
||||
#else
|
||||
if constexpr(M_Warp_Tile == 32)
|
||||
return (sizeof(PrecType) == 2 || IsFlatMM == false) ? 16 : 32;
|
||||
else
|
||||
return (sizeof(PrecType) == 2 || IsFlatMM == false) ? 32 : 64;
|
||||
#endif
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
Reference in New Issue
Block a user