diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp index 1022572a92..169a39f2eb 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp @@ -21,10 +21,24 @@ template struct QuantTypeTraits; +template +struct SafeTupleElement { + using type = DefaultType; +}; + template -using SafeTupleElement_t = std::conditional_t<(Index < std::tuple_size_v), - std::tuple_element_t, - DefaultType>; +struct SafeTupleElement< + TTuple, + Index, + DefaultType, + std::enable_if_t<(Index < std::tuple_size_v)>> + { + using type = std::tuple_element_t; + }; + +template +using SafeTupleElement_t = typename SafeTupleElement::type; + // Base class for common quant gemm functionality template class TestCkTileGemmQuantBase : public ::testing::Test @@ -43,7 +57,7 @@ class TestCkTileGemmQuantBase : public ::testing::Test using QuantGroupSize = std::tuple_element_t<10, Tuple>; using AQuantGroupSize = QuantGroupSize; using BQuantGroupSize = SafeTupleElement_t; - using BQLayout = SafeTupleElement_t; + using BQLayout = SafeTupleElement_t; using AccDataType = float; // accumulate always in float // Get the quant-type specific data types from traits @@ -93,9 +107,6 @@ class TestCkTileGemmQuantBase : public ::testing::Test using TilePartitioner = ck_tile::GemmTile1DPartitioner; - // Re-use the AQLayout for BQLayout - // using BQLayout = AQLayout; - using CodegenGemmTraits = ck_tile::TileGemmQuantTraits