Refactor type conversions out of MakeBLdsBlockDescriptor, WIP!

This commit is contained in:
Sami Aario
2025-12-18 09:14:11 +00:00
parent 1b610f4aaf
commit 7fef648bca

View File

@@ -302,15 +302,12 @@ struct UniversalGemmBasePolicy
* @tparam Problem Gemm pipeline problem.
* @return B tensor LDS block descriptor.
*/
template <typename Problem>
template <typename Problem,
typename OverrideBDataType = remove_cvref_t<typename Problem::BDataType>>
CK_TILE_DEVICE static constexpr auto MakeBLdsBlockDescriptor()
{
using BLayout = remove_cvref_t<typename Problem::BLayout>;
using BDataType =
std::conditional_t<std::is_same_v<typename Problem::BDataType, pk_fp4_raw_t>,
typename Problem::ADataType,
typename Problem::BDataType>;
using BLayout = remove_cvref_t<typename Problem::BLayout>;
using BDataType = OverrideBDataType;
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;