Merge branch 'develop' of github.com:ROCm/composable_kernel into ck_moe_bs_splitk_pr

This commit is contained in:
yadaish
2025-12-16 04:40:20 +00:00
65 changed files with 3693 additions and 1067 deletions

View File

@@ -72,7 +72,12 @@ inline bool is_xdl_supported()
is_gfx12_supported() || is_gfx11_supported();
}
template <typename ADataType, typename BDataType, index_t MPerXDL, index_t NPerXDL>
template <typename ADataType,
typename BDataType,
index_t MPerXDL64,
index_t NPerXDL64,
index_t MPerXDL32 = MPerXDL64,
index_t NPerXDL32 = NPerXDL64>
inline bool is_xdl_wmma_supported()
{
if(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
@@ -82,7 +87,7 @@ inline bool is_xdl_wmma_supported()
}
else if(is_gfx12_supported() || is_gfx11_supported())
{
if constexpr((MPerXDL != 16) || (NPerXDL != 16))
if constexpr((MPerXDL32 != 16) || (NPerXDL32 != 16))
{
return false;
}

View File

@@ -17,6 +17,7 @@
#endif
#endif
#include "ck/utility/get_id.hpp"
#include "ck/utility/sequence.hpp"
namespace ck {
namespace tensor_operation {
@@ -96,6 +97,57 @@ static constexpr auto GetNXdlPerWave2()
IsWave64>(); \
}
template <index_t BlockSize_,
index_t MPerBlock_,
index_t NPerBlock_,
index_t MPerXDL_,
index_t NPerXDL_,
index_t MXdlPerWave_,
index_t CShuffleMXdlPerWavePerShuffle_,
index_t CShuffleNXdlPerWavePerShuffle_,
bool IsWave64>
static constexpr auto GetWarpTileConfig()
{
constexpr auto MXdlPerWave64 = MXdlPerWave_;
constexpr auto MXdlPerWave32 = MXdlPerWave_ * MPerXDL_ / 16;
constexpr auto CShuffleMXdlPerWavePerShuffle32 = CShuffleMXdlPerWavePerShuffle_ * MPerXDL_ / 16;
constexpr auto NXdlPerWave =
IsWave64
? GetNXdlPerWave2<BlockSize_,
MPerBlock_,
NPerBlock_,
MPerXDL_,
NPerXDL_,
MXdlPerWave_,
true>()
: GetNXdlPerWave2<BlockSize_, MPerBlock_, NPerBlock_, 16, 16, MXdlPerWave32, false>();
if constexpr(IsWave64 == false && NXdlPerWave != 0)
{
constexpr auto CShuffleNXdlPerWavePerShuffle32 =
NXdlPerWave >= CShuffleNXdlPerWavePerShuffle_ * NPerXDL_ / 16
? CShuffleNXdlPerWavePerShuffle_ * NPerXDL_ / 16
: CShuffleNXdlPerWavePerShuffle_;
static_assert(CShuffleNXdlPerWavePerShuffle32 > 0);
return Sequence<16,
16,
MXdlPerWave32,
NXdlPerWave,
CShuffleMXdlPerWavePerShuffle32,
CShuffleNXdlPerWavePerShuffle32>{};
}
else
{
return Sequence<MPerXDL_,
NPerXDL_,
MXdlPerWave64,
NXdlPerWave,
CShuffleMXdlPerWavePerShuffle_,
CShuffleNXdlPerWavePerShuffle_>{};
}
}
#define INVOKER_RUN_IMPL \
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) \
{ \

View File

@@ -166,11 +166,27 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
{
using DeviceOp = DeviceContractionMultipleD_Xdl_CShuffle;
GET_NXDL_PER_WAVE_IMPL
static constexpr auto NXdlPerWave64 = GetNXdlPerWave<true>();
static constexpr auto NXdlPerWave32 = GetNXdlPerWave<false>();
static constexpr index_t NumDTensor = DsDataType::Size();
static constexpr auto WarpTileConfig64 = GetWarpTileConfig<BlockSize,
MPerBlock,
NPerBlock,
MPerXDL,
NPerXDL,
MXdlPerWave,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
true>();
static constexpr auto WarpTileConfig32 = GetWarpTileConfig<BlockSize,
MPerBlock,
NPerBlock,
MPerXDL,
NPerXDL,
MXdlPerWave,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
false>();
static constexpr auto NXdlPerWave64 = WarpTileConfig64.At(3);
static constexpr auto NXdlPerWave32 = WarpTileConfig32.At(3);
static constexpr index_t NumDTensor = DsDataType::Size();
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
@@ -321,7 +337,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N({}, {}));
// GridwiseGemm
template <index_t NXdlPerWave_>
template <typename WarpTileConfig>
using GridwiseGemmBase = GridwiseGemmMultipleD_xdl_cshuffle<
ADataType, // TODO: distinguish A/B datatype
BDataType,
@@ -340,10 +356,10 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
KPerBlock,
AK1,
BK1,
MPerXDL,
NPerXDL,
MXdlPerWave,
NXdlPerWave_,
WarpTileConfig::At(0),
WarpTileConfig::At(1),
WarpTileConfig::At(2),
WarpTileConfig::At(3),
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
@@ -360,13 +376,13 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
BBlockTransferDstScalarPerVector_BK1,
false,
BBlockLdsExtraN,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
WarpTileConfig::At(4),
WarpTileConfig::At(5),
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CDEBlockTransferScalarPerVector_NPerBlock,
LoopSched>;
using GridwiseGemm64 = GridwiseGemmBase<math::max(NXdlPerWave64, 1)>;
using GridwiseGemm32 = GridwiseGemmBase<NXdlPerWave32>;
using GridwiseGemm64 = GridwiseGemmBase<decltype(WarpTileConfig64)>;
using GridwiseGemm32 = GridwiseGemmBase<decltype(WarpTileConfig32)>;
// desc for blockwise copy
using AGridDesc_AK0_M_AK1 =
@@ -588,7 +604,12 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
static bool IsSupportedArgument(const Argument& arg)
{
if(!ck::is_xdl_wmma_supported<ComputeDataType, ComputeDataType, MPerXDL, NPerXDL>())
if(!ck::is_xdl_wmma_supported<ComputeDataType,
ComputeDataType,
MPerXDL,
NPerXDL,
WarpTileConfig32.At(0),
WarpTileConfig32.At(1)>())
{
return false;
}
@@ -783,6 +804,8 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
<< MPerBlock << ", "
<< NPerBlock << ", "
<< KPerBlock << ", "
<< MPerXDL << ", "
<< NPerXDL << ", "
<< AK1 << ", "
<< BK1 << ", "
<< ABlockTransferSrcVectorDim << ", "

View File

@@ -620,7 +620,44 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
bool isWave64 = get_warp_size() == 64;
for(std::size_t i = 0; i < arg.gemm_kernel_args_.size(); ++i)
{
const auto& a = arg.gemm_kernel_args_[i].karg_;
const auto& a = arg.gemm_kernel_args_[i].karg_;
// Validate stride requirements for SplitK (k_batch > 1)
// TODO: Enable splitK
if(a.k_batch > 1)
{
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
{
if(a.StrideC != a.N)
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "[" << __func__ << "] group id: " << i
<< " SplitK (k_batch=" << a.k_batch
<< ") requires contiguous output stride."
<< " For RowMajor layout: StrideC must equal N."
<< " Got StrideC=" << a.StrideC << ", N=" << a.N << std::endl;
}
return false;
}
}
else if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::ColumnMajor>)
{
if(a.StrideC != a.M)
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "[" << __func__ << "] group id: " << i
<< " SplitK (k_batch=" << a.k_batch
<< ") requires contiguous output stride."
<< " For ColumnMajor layout: StrideC must equal M."
<< " Got StrideC=" << a.StrideC << ", M=" << a.M << std::endl;
}
return false;
}
}
}
bool group_arg_valid = false;
if(isWave64)
{

View File

@@ -366,6 +366,26 @@ struct amdgcn_compiler_target_state
#else
static constexpr bool CK_TILE_ARCH_GFX1010 = false;
#endif
#if defined(__gfx1011__)
static constexpr bool CK_TILE_ARCH_GFX1011 = true;
#else
static constexpr bool CK_TILE_ARCH_GFX1011 = false;
#endif
#if defined(__gfx1012__)
static constexpr bool CK_TILE_ARCH_GFX1012 = true;
#else
static constexpr bool CK_TILE_ARCH_GFX1012 = false;
#endif
#if defined(__gfx1013__)
static constexpr bool CK_TILE_ARCH_GFX1013 = true;
#else
static constexpr bool CK_TILE_ARCH_GFX1013 = false;
#endif
#if defined(__gfx10_1_generic__)
static constexpr bool CK_TILE_ARCH_GFX10_1_GENERIC = true;
#else
static constexpr bool CK_TILE_ARCH_GFX10_1_GENERIC = false;
#endif // __gfx10_1_generic__
#if defined(__gfx1030__)
static constexpr bool CK_TILE_ARCH_GFX1030 = true;
@@ -504,6 +524,10 @@ CK_TILE_HOST_DEVICE static constexpr uint32_t count_values_of(T search, Ts... se
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX942, \
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX950, \
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1010, \
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1011, \
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1012, \
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1013, \
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX10_1_GENERIC, \
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1030, \
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1031, \
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1032, \

View File

@@ -68,7 +68,7 @@ auto shuffle_bq(const ck_tile::HostTensor<T>* t, int block_bq_k)
}
template <typename GemmConfig, typename T>
auto shuffle_b(const ck_tile::HostTensor<T>& t)
auto shuffle_b(const ck_tile::HostTensor<T>& t, const GemmConfig& gemmConfig)
{
assert(t.get_lengths().size() == 2);
int n_ = t.get_lengths()[1];
@@ -78,10 +78,10 @@ auto shuffle_b(const ck_tile::HostTensor<T>& t)
{
constexpr int divisor = 2;
constexpr int kABK1PerLane = 8;
constexpr int kABK0PerLane = GemmConfig::K_Warp_Tile / divisor / kABK1PerLane;
ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Warp_Tile,
GemmConfig::N_Warp_Tile,
k_ / GemmConfig::K_Warp_Tile,
int kABK0PerLane = gemmConfig.K_Warp_Tile / divisor / kABK1PerLane;
ck_tile::HostTensor<T> t_view({n_ / gemmConfig.N_Warp_Tile,
gemmConfig.N_Warp_Tile,
k_ / gemmConfig.K_Warp_Tile,
kABK0PerLane,
divisor,
kABK1PerLane});
@@ -98,18 +98,24 @@ auto shuffle_b(const ck_tile::HostTensor<T>& t)
else
{
assert(is_wave32() == false);
divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4;
divisor = gemmConfig.N_Warp_Tile == 32 ? 2 : 4;
}
ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Warp_Tile,
GemmConfig::N_Warp_Tile,
k_ / GemmConfig::K_Warp_Tile,
ck_tile::HostTensor<T> t_view({n_ / gemmConfig.N_Warp_Tile,
gemmConfig.N_Warp_Tile,
k_ / gemmConfig.K_Warp_Tile,
divisor,
GemmConfig::K_Warp_Tile / divisor});
gemmConfig.K_Warp_Tile / divisor});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
}
}
template <typename GemmConfig, typename T>
auto shuffle_b(const ck_tile::HostTensor<T>& t)
{
return shuffle_b(t, GemmConfig{});
}
template <typename GemmConfig, typename T>
auto bq_permuteN(const ck_tile::HostTensor<T>& t, index_t group_n)
{
@@ -129,22 +135,22 @@ auto bq_permuteN(const ck_tile::HostTensor<T>& t, index_t group_n)
}
template <typename GemmConfig, typename T>
auto shuffle_b_permuteN(const ck_tile::HostTensor<T>& t)
auto shuffle_b_permuteN(const ck_tile::HostTensor<T>& t, const GemmConfig& gemmConfig)
{
assert(t.get_lengths().size() == 2);
int n_ = t.get_lengths()[1];
int k_ = t.get_lengths()[0];
constexpr int NRepeat = GemmConfig::N_Tile / GemmConfig::N_Warp_Tile / GemmConfig::N_Warp;
int n_ = t.get_lengths()[1];
int k_ = t.get_lengths()[0];
int NRepeat = gemmConfig.N_Tile / gemmConfig.N_Warp_Tile / gemmConfig.N_Warp;
if(ck_tile::is_gfx12_supported())
{
constexpr int divisor = 2;
constexpr int kABK1PerLane = 8;
constexpr int kABK0PerLane = GemmConfig::K_Warp_Tile / divisor / kABK1PerLane;
ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Tile,
GemmConfig::N_Warp,
GemmConfig::N_Warp_Tile,
int kABK0PerLane = gemmConfig.K_Warp_Tile / divisor / kABK1PerLane;
ck_tile::HostTensor<T> t_view({n_ / gemmConfig.N_Tile,
gemmConfig.N_Warp,
gemmConfig.N_Warp_Tile,
NRepeat,
k_ / GemmConfig::K_Warp_Tile,
k_ / gemmConfig.K_Warp_Tile,
kABK0PerLane,
divisor,
kABK1PerLane});
@@ -161,17 +167,23 @@ auto shuffle_b_permuteN(const ck_tile::HostTensor<T>& t)
else
{
assert(is_wave32() == false);
divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4;
divisor = gemmConfig.N_Warp_Tile == 32 ? 2 : 4;
}
ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Tile,
GemmConfig::N_Warp,
GemmConfig::N_Warp_Tile,
ck_tile::HostTensor<T> t_view({n_ / gemmConfig.N_Tile,
gemmConfig.N_Warp,
gemmConfig.N_Warp_Tile,
NRepeat,
k_ / GemmConfig::K_Warp_Tile,
k_ / gemmConfig.K_Warp_Tile,
divisor,
GemmConfig::K_Warp_Tile / divisor});
gemmConfig.K_Warp_Tile / divisor});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 3, 1, 4, 5, 2, 6});
}
}
template <typename GemmConfig, typename T>
auto shuffle_b_permuteN(const ck_tile::HostTensor<T>& t)
{
return shuffle_b_permuteN(t, GemmConfig{});
}
} // namespace ck_tile

View File

@@ -35,7 +35,8 @@ template <typename AsDataType_,
bool FixedVectorSize_ = false,
index_t VectorSizeC_ = 1,
bool TiledMMAPermuteN_ = false,
index_t BlockedXDLN_PerWarp_ = 1> // The number of continuous xdl_output per warp
index_t BlockedXDLN_PerWarp_ = 1, // The number of continuous xdl_output per warp
bool DoubleSmemBuffer_ = false>
struct CShuffleEpilogueProblem
{
using AsDataType = remove_cvref_t<AsDataType_>;
@@ -59,6 +60,7 @@ struct CShuffleEpilogueProblem
static constexpr bool FixedVectorSize = FixedVectorSize_;
static constexpr index_t VectorSizeC = VectorSizeC_;
static constexpr index_t BlockedXDLN_PerWarp = BlockedXDLN_PerWarp_;
static constexpr bool DoubleSmemBuffer = DoubleSmemBuffer_;
static constexpr bool TiledMMAPermuteN = TiledMMAPermuteN_;
static constexpr index_t kNumWaveGroups = kNumWaveGroups_;
static constexpr index_t NumDTensor = DsDataType::size();
@@ -118,6 +120,7 @@ struct CShuffleEpilogue
static constexpr bool FixedVectorSize = Problem::FixedVectorSize;
static constexpr bool TiledMMAPermuteN = Problem::TiledMMAPermuteN;
static constexpr index_t BlockedXDLN_PerWarp = Problem::BlockedXDLN_PerWarp;
static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer;
static constexpr index_t VectorSizeC = Problem::VectorSizeC;
static constexpr index_t MPerIteration = MPerXdl * MWave;
static constexpr index_t NPerIteration = NPerXdl * NWave;
@@ -204,6 +207,26 @@ struct CShuffleEpilogue
}
return max_vector_size / sizeof(DiDataType);
}
/**
* @brief Shuffle tile configuration parameters check and aligment
*
* @details Return tuple(1, 1) if shuffle_tile values are too large for SMEM.
*/
template <index_t m_shuffle_tile, index_t n_shuffle_tile>
CK_TILE_HOST_DEVICE static constexpr auto AlignShuffleTileWithSmem()
{
constexpr index_t m_val = MPerXdl * MWave * m_shuffle_tile;
constexpr index_t n_val = NPerXdl * NWave * n_shuffle_tile;
constexpr auto shuffle_tile =
m_val * n_val * sizeof(ODataType) > get_smem_capacity() || DoubleSmemBuffer
? std::make_tuple(1, 1)
: std::make_tuple(m_shuffle_tile, n_shuffle_tile);
return shuffle_tile;
}
/**
* @brief Shuffle tile configuration parameters
*
@@ -214,20 +237,23 @@ struct CShuffleEpilogue
*/
static constexpr auto shuffle_tile_tuple = [] {
constexpr index_t elem_per_thread = MPerXdl * NPerXdl / get_warp_size();
if constexpr(elem_per_thread >= GetVectorSizeC())
if constexpr(elem_per_thread <= GetVectorSizeC())
{
return std::make_tuple(1, 1);
}
else
{
constexpr index_t num_xdl_shuffles = GetVectorSizeC() / elem_per_thread;
constexpr index_t num_xdl_shuffles = elem_per_thread / GetVectorSizeC();
static_assert(elem_per_thread % GetVectorSizeC() == 0);
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
{
static_assert((kMPerBlock % (MPerXdl * MWave) == 0) &&
(kMPerBlock % num_xdl_shuffles == 0),
"kMPerBlock must be divisible by MPerXdl*MWave and "
"num_xdl_shuffles for CShuffleEpilogue");
return std::make_tuple(min(num_xdl_shuffles, kMPerBlock / (MPerXdl * MWave)), 1);
return AlignShuffleTileWithSmem<min(num_xdl_shuffles,
kMPerBlock / (MPerXdl * MWave)),
1>();
}
else
{
@@ -235,7 +261,9 @@ struct CShuffleEpilogue
(kNPerBlock % num_xdl_shuffles == 0),
"kNPerBlock must be divisible by NPerXdl*NWave and "
"num_xdl_shuffles for CShuffleEpilogue");
return std::make_tuple(1, min(num_xdl_shuffles, kNPerBlock / (NPerXdl * NWave)));
return AlignShuffleTileWithSmem<1,
min(num_xdl_shuffles,
kNPerBlock / (NPerXdl * NWave))>();
}
}
}();

View File

@@ -232,7 +232,7 @@ struct BatchedGemmKernel
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
{
__shared__ char smem_ptr1[GetSmemSize()];
__shared__ char smem_ptr1[GemmPipeline::GetSmemSize()];
UniversalGemmKernel::RunGemm2LDS({a_ptr},
{b_ptr},
{/*ds_ptr*/},

View File

@@ -310,7 +310,7 @@ struct GroupedGemmKernel
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
{
__shared__ char smem_ptr_1[GetSmemSize()];
__shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()];
RunGemmWithPipelineSelection2LDS(a_ptr,
b_ptr,
c_ptr,

View File

@@ -1084,7 +1084,7 @@ struct UniversalGemmKernel
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
{
__shared__ char smem_ptr_1[GetSmemSize()];
__shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()];
if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
is_any_of<EDataType, fp16_t, bf16_t>::value))
@@ -1169,7 +1169,7 @@ struct UniversalGemmKernel
// Run the GEMM
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
{
__shared__ char smem_ptr_1[GetSmemSize()];
__shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()];
if constexpr(!(EpiloguePipeline::MemoryOperation ==
memory_operation_enum::atomic_add &&
EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&

View File

@@ -43,4 +43,26 @@ struct TileGemmShape
}
};
template <typename PrecType, index_t M_Warp_Tile, bool IsFlatMM = false>
constexpr index_t get_k_warp_tile()
{
#if CK_TILE_USE_WMMA
return 16;
#else
#if defined(CK_GFX950_SUPPORT)
constexpr bool is_8bit_float =
std::is_same_v<PrecType, fp8_t> || std::is_same_v<PrecType, bf8_t>;
if constexpr(M_Warp_Tile == 32)
return is_8bit_float ? 64 : 16;
else
return is_8bit_float ? 128 : 32;
#else
if constexpr(M_Warp_Tile == 32)
return (sizeof(PrecType) == 2 || IsFlatMM == false) ? 16 : 32;
else
return (sizeof(PrecType) == 2 || IsFlatMM == false) ? 32 : 64;
#endif
#endif
}
} // namespace ck_tile

View File

@@ -61,6 +61,7 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase<Problem_>
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using BQDataType = remove_cvref_t<typename Problem::BQDataType>;
using BLayout = remove_cvref_t<typename Problem::BLayout>;
using BQLayout = remove_cvref_t<typename Problem::BQLayout>;
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
@@ -156,9 +157,11 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase<Problem_>
using CDataType = remove_cvref_t<typename Traits::CDataType>;
// BDataType gets converted from PkInt4 during loading
using OverrideBDataType =
std::conditional_t<std::is_same_v<BDataType, pk_int4_t>, ADataType, BDataType>;
using OverrideBDataType = std::conditional_t<
std::is_same_v<BDataType, pk_int4_t> &&
std::is_same_v<typename Traits::BLayout, tensor_layout::gemm::RowMajor>,
ADataType,
BDataType>;
using Base = BlockGemmBQuantBase<Problem_>;
using WarpGemm = remove_cvref_t<typename Traits::WarpGemm>;

View File

@@ -1324,7 +1324,7 @@ struct QuantGemmKernel
assert(kargs.k_batch == 1);
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
{
__shared__ char smem_ptr_1[GetSmemSize()];
__shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()];
RunGemm2LDS(a_ptr,
b_ptr,

View File

@@ -325,7 +325,7 @@ struct QuantGroupedGemmKernel
kQuantType == QuantType::BQuantGrouped)
{
__shared__ char smem_ptr_1[GetSmemSize()];
__shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()];
RunGemmWithPipelineSelection2LDS(a_ptr,
b_ptr,
aq_ptr,

View File

@@ -33,9 +33,17 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
using QuantGroupSize = remove_cvref_t<typename Problem::QuantGroupSize>;
using ALayout = remove_cvref_t<typename Problem::ALayout>;
using BQLayout = remove_cvref_t<typename Problem::BQLayout>;
using BLayout = remove_cvref_t<typename Problem::BLayout>;
using CLayout = remove_cvref_t<typename Problem::CLayout>;
// BDataType gets converted from PkInt4 during loading
using OverrideBDataType =
std::conditional_t<std::is_same_v<BDataType, pk_int4_t>, ADataType, BDataType>;
std::conditional_t<std::is_same_v<BDataType, pk_int4_t> &&
std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>,
ADataType,
BDataType>;
static_assert(QuantGroupSize::kM == 1, "only N/K blocks for BQuant kernel!");
using I0 = number<0>;
@@ -50,11 +58,6 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
static constexpr index_t BQPackedSize =
ck_tile::numeric_traits<remove_cvref_t<BQDataType>>::PackedSize;
using ALayout = remove_cvref_t<typename Problem::ALayout>;
using BQLayout = remove_cvref_t<typename Problem::BQLayout>;
using BLayout = remove_cvref_t<typename Problem::BLayout>;
using CLayout = remove_cvref_t<typename Problem::CLayout>;
using BlockGemm = remove_cvref_t<decltype(Policy::template GetBlockGemm<Problem>())>;
static constexpr index_t BlockSize = Problem::kBlockSize;
@@ -184,6 +187,23 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
load_int4_tile<SrcDataType, DestDataType, UnaryOpSize>(b_block_tile, b_dram_window);
}
template <typename BBlockTile_, typename BDramWindow, typename BDramTileWindowStep>
CK_TILE_DEVICE void
BGlobalPrefetch(BBlockTile_& b_block_tile,
BDramWindow& b_copy_dram_window,
const BDramTileWindowStep& b_dram_tile_window_step) const
{
if constexpr(!std::is_same_v<BDataType, OverrideBDataType>)
{
LoadAndConvertBTile(b_block_tile, b_copy_dram_window);
move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
}
else
{
Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step);
}
}
template <bool HasHotLoop,
TailNumber TailNum,
typename ADramBlockWindowTmp,
@@ -262,7 +282,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
using ABlockTile =
decltype(make_static_distributed_tensor<ADataType>(ABlockTileDistr{}));
using BBlockTile =
decltype(make_static_distributed_tensor<ADataType>(BBlockTileDistr{}));
decltype(make_static_distributed_tensor<OverrideBDataType>(BBlockTileDistr{}));
using BQBlockTile =
decltype(make_static_distributed_tensor<BQDataType>(BQBlockTileDistr{}));
@@ -289,8 +309,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
// DRAM prefetch (global read 0)
Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step);
// B tile gets converted to A datatype during loading
LoadAndConvertBTile(b_block_tile, b_copy_dram_window);
move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
BGlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step);
Base::GlobalPrefetch(
bq_block_tile[currIdx], bq_copy_dram_window, bq_dram_tile_window_step);
@@ -311,7 +330,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
if constexpr(is_b_row_major && !is_b_load_tr_v())
{
// B datatype is converted to A datatype during loading
auto b_shuffle_tmp = make_static_distributed_tensor<ADataType>(
auto b_shuffle_tmp = make_static_distributed_tensor<OverrideBDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
transpose_tile2d(b_shuffle_tmp, b_block_tile);
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
@@ -322,8 +341,8 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
}
Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step);
LoadAndConvertBTile(b_block_tile, b_copy_dram_window);
move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
// B tile gets converted to A datatype during loading
BGlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step);
block_sync_lds();
@@ -366,8 +385,8 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
}
Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step);
LoadAndConvertBTile(b_block_tile, b_copy_dram_window);
move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
// B tile gets converted to A datatype during loading
BGlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step);
Base::GlobalPrefetch(bq_block_tile[(currIdx + 1) % 2],
bq_copy_dram_window,
bq_dram_tile_window_step);

View File

@@ -1048,7 +1048,7 @@ struct GroupedConvolutionBackwardDataKernel
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
{
__shared__ char smem_ptr_1[GetSmemSize()];
__shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()];
if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
is_any_of<OutDataType, fp16_t, bf16_t>::value))

View File

@@ -1005,7 +1005,7 @@ struct GroupedConvolutionBackwardWeightKernel
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
{
__shared__ char smem_ptr_1[GetSmemSize()];
__shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()];
if constexpr(!(EpiloguePipeline::MemoryOperation ==
memory_operation_enum::atomic_add &&
GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&

View File

@@ -1184,7 +1184,7 @@ struct GroupedConvolutionForwardKernel
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
{
__shared__ char smem_ptr_1[GetSmemSize()];
__shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()];
if constexpr(!(EpiloguePipeline::MemoryOperation ==
memory_operation_enum::atomic_add &&
GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&