Merge commit 'cd606f72c1fb3a99d596ad0f79521b46152764cb' into develop

This commit is contained in:
assistant-librarian[bot]
2025-06-18 01:43:45 +00:00
parent c8aa6b376e
commit fc2e176c8d
7 changed files with 315 additions and 158 deletions

View File

@@ -71,9 +71,11 @@ struct Default2DEpilogue
// TODO: this function assume store out vector size is the same as OAccTile last dimension size
// how do we fix this ?
template <typename ODramWindowTmp, typename OAccTile>
CK_TILE_DEVICE auto
operator()(ODramWindowTmp& o_dram_window_tmp, const OAccTile& o_acc_tile, void* = nullptr)
template <typename ODramWindowTmp, typename OAccTile, typename DsDramWindows>
CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp,
const OAccTile& o_acc_tile,
const DsDramWindows& /* unused */,
void* = nullptr)
{
// TODO: this is ugly
@@ -114,6 +116,8 @@ struct DefaultGemm2DEpilogue : public Default2DEpilogue<Problem_, Policy_>
// Used for weight-only quantization kernel, B would be dequantized to the same data type as A
using BTypeToUse =
std::conditional_t<std::is_same_v<BDataType, pk_int4_t>, ADataType, BDataType>;
using DsDataType = ck_tile::tuple<>;
using DsLayout = ck_tile::tuple<>;
using CLayout = remove_cvref_t<typename Problem::CLayout>;
static constexpr index_t kMPerXdl = Problem::kMPerXdl;
static constexpr index_t kNPerXdl = Problem::kNPerXdl;
@@ -181,6 +185,8 @@ struct DefaultGemm2DEpilogue : public Default2DEpilogue<Problem_, Policy_>
static_assert(false, "Unsupported CLayout!");
}
}
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeD() { return 1; }
};
} // namespace ck_tile

View File

@@ -42,8 +42,7 @@ struct GemmPipelineProblemBase
static constexpr bool kPadN = Traits::kPadN;
static constexpr bool kPadK = Traits::kPadK;
static constexpr bool DoubleSmemBuffer = Traits::DoubleSmemBuffer;
static constexpr bool DoubleSmemBuffer = Traits::DoubleSmemBuffer;
static constexpr auto Scheduler = GemmPipelineScheduler::Default;
static constexpr index_t VectorLoadSize = Traits::_VectorSize;

View File

@@ -12,7 +12,8 @@ template <bool kPadM_,
bool kPadK_,
typename ALayout_,
typename BLayout_,
typename CLayout_>
typename CLayout_,
index_t NumWaveGroups_ = 1>
struct TileGemmTraits
{
static constexpr bool kPadM = kPadM_;
@@ -28,7 +29,7 @@ struct TileGemmTraits
static constexpr bool TransposeC = false;
static constexpr bool UseStructuredSparsity = false;
static constexpr index_t NumWaveGroups = 1;
static constexpr index_t NumWaveGroups = NumWaveGroups_;
};
template <bool kPadM_,