fix compile error

This commit is contained in:
KenSCLin
2025-12-12 13:29:11 +00:00
parent 112b5ecf6b
commit d40fb754bf

View File

@@ -21,10 +21,24 @@
template <ck_tile::QuantType QT>
struct QuantTypeTraits;
template <typename TTuple, size_t Index, typename DefaultType, typename Enable = void>
struct SafeTupleElement {
using type = DefaultType;
};
template <typename TTuple, size_t Index, typename DefaultType>
using SafeTupleElement_t = std::conditional_t<(Index < std::tuple_size_v<TTuple>),
std::tuple_element_t<Index, TTuple>,
DefaultType>;
struct SafeTupleElement<
TTuple,
Index,
DefaultType,
std::enable_if_t<(Index < std::tuple_size_v<TTuple>)>>
{
using type = std::tuple_element_t<Index, TTuple>;
};
template <typename TTuple, size_t Index, typename DefaultType>
using SafeTupleElement_t = typename SafeTupleElement<TTuple, Index, DefaultType>::type;
// Base class for common quant gemm functionality
template <typename Tuple, typename Derived>
class TestCkTileGemmQuantBase : public ::testing::Test
@@ -43,7 +57,7 @@ class TestCkTileGemmQuantBase : public ::testing::Test
using QuantGroupSize = std::tuple_element_t<10, Tuple>;
using AQuantGroupSize = QuantGroupSize;
using BQuantGroupSize = SafeTupleElement_t<Tuple, 11, QuantGroupSize>;
using BQLayout = SafeTupleElement_t<Tuple, 12, AQLayout>;
using BQLayout = SafeTupleElement_t<Tuple, 12, AQLayout>;
using AccDataType = float; // accumulate always in float
// Get the quant-type specific data types from traits
@@ -93,9 +107,6 @@ class TestCkTileGemmQuantBase : public ::testing::Test
using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenGemmShape>;
// Re-use the AQLayout for BQLayout
// using BQLayout = AQLayout;
using CodegenGemmTraits = ck_tile::TileGemmQuantTraits<kPadM,
kPadN,
kPadK,