diff --git a/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant.cpp b/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant.cpp index 6a1a28884a..ef46adff6c 100644 --- a/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant.cpp +++ b/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant.cpp @@ -58,4 +58,4 @@ using KernelTypes = ::testing::Types< TYPED_TEST_SUITE(TestCkTileGroupedGemmQuant, KernelTypes); -#include "test_grouped_gemm_quant_ut_cases.inc" +#include "test_grouped_gemm_quant_ut_cases.inc" \ No newline at end of file diff --git a/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_abquant.cpp b/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_abquant.cpp deleted file mode 100644 index 1a2d8d0b35..0000000000 --- a/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_abquant.cpp +++ /dev/null @@ -1,39 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include - -#include "gtest/gtest.h" - -#include "ck_tile/host.hpp" -#include "test_grouped_gemm_util_quant.hpp" - -using F16 = ck_tile::half_t; -using F32 = float; -using FP8 = ck_tile::fp8_t; -using BF8 = ck_tile::bf8_t; -using Row = ck_tile::tensor_layout::gemm::RowMajor; -using Col = ck_tile::tensor_layout::gemm::ColumnMajor; -using True = ck_tile::bool_constant; -using False = ck_tile::bool_constant; -using ABQuant = std::integral_constant; - -// clang-format off -using KernelTypes_ABQuant = ::testing::Types< - // ALayout, BLayout, CLayout, ADataType, AQDataType, BDataType, BQDataType, AccDataType, CDataType, QuantType, PreshuffleB, Persistent, TransposeC - std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, ABQuant, False, True, False>, - std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, ABQuant, False, True, False>, - std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, ABQuant, False, True, False>, - std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, ABQuant, False, True, False>, - - std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, ABQuant, False, False, False>, - std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, ABQuant, False, False, False> - >; -// clang-format on - -TYPED_TEST_SUITE(TestCkTileGroupedGemmQuant_ABQuant, KernelTypes_ABQuant); - -#define TEST_CLASS_NAME TestCkTileGroupedGemmQuant_ABQuant -#include "test_grouped_gemm_quant_ut_cases.inc" -#undef TEST_CLASS_NAME - diff --git a/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_util_quant.hpp b/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_util_quant.hpp index 2a6bbdc634..07c9b5a0f4 100644 --- a/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_util_quant.hpp +++ b/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_util_quant.hpp @@ -494,16 +494,6 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test "K must be divisible by QuantGroupSize::kK for BQuantGrouped mode"); } } - else if constexpr(QuantType == ck_tile::QuantType::ABQuantGrouped) - { - AQK = K / QuantGroupSize::kK; // Group quantization: AQK = K / QuantGroupSize - BQK = K / QuantGroupSize::kK; // Group quantization: BQK = K / QuantGroupSize - if(K % QuantGroupSize::kK != 0) - { - throw std::runtime_error( - "K must be divisible by QuantGroupSize::kK for ABQuantGrouped mode"); - } - } stride_As[i] = ck_tile::get_default_stride(M, K, stride_As[i], is_row_major(ALayout{})); stride_Bs[i] = ck_tile::get_default_stride(K, N, stride_Bs[i], is_row_major(BLayout{})); @@ -532,13 +522,6 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test stride_BQs[i] = ck_tile::get_default_stride(BQK, N, stride_BQs[i], is_row_major(BQLayout())); } - else if constexpr(QuantType == ck_tile::QuantType::ABQuantGrouped) - { - stride_AQs[i] = - ck_tile::get_default_stride(M, AQK, stride_AQs[i], is_row_major(AQLayout())); - stride_BQs[i] = - ck_tile::get_default_stride(BQK, N, stride_BQs[i], is_row_major(BQLayout())); - } a_m_k_tensors.push_back(ck_tile::HostTensor( ck_tile::host_tensor_descriptor(M, K, stride_As[i], is_row_major(ALayout{})))); @@ -582,15 +565,6 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test ck_tile::HostTensor(ck_tile::host_tensor_descriptor( BQK, N, stride_BQs[i], is_row_major(BQLayout())))); } - else if constexpr(QuantType == ck_tile::QuantType::ABQuantGrouped) - { - aq_tensors.push_back( - ck_tile::HostTensor(ck_tile::host_tensor_descriptor( - M, AQK, stride_AQs[i], is_row_major(AQLayout{})))); - bq_tensors.push_back( - ck_tile::HostTensor(ck_tile::host_tensor_descriptor( - BQK, N, stride_BQs[i], is_row_major(BQLayout())))); - } std::cout << "gemm[" << i << "]" << " a_m_k: " << a_m_k_tensors[i].mDesc << " b_k_n: " << b_k_n_tensors[i].mDesc @@ -776,18 +750,6 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test false>( a_m_k_tensors[i], bq_tensors[i], b_k_n_tensors[i], c_m_n_host_ref); } - else if constexpr(QuantType == ck_tile::QuantType::ABQuantGrouped) - { - ck_tile::reference_gemm_abquant( - a_m_k_tensors[i], aq_tensors[i], b_k_n_tensors[i], bq_tensors[i], c_m_n_host_ref); - } const float max_accumulated_value = *std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end()); @@ -819,7 +781,4 @@ template using TestCkTileGroupedGemmQuant_AQuant = TestCkTileGroupedGemmQuant; template -using TestCkTileGroupedGemmQuant_BQuant = TestCkTileGroupedGemmQuant; - -template -using TestCkTileGroupedGemmQuant_ABQuant = TestCkTileGroupedGemmQuant; +using TestCkTileGroupedGemmQuant_BQuant = TestCkTileGroupedGemmQuant; \ No newline at end of file