remove test code for now

This commit is contained in:
kyle-256
2025-12-16 09:11:15 +00:00
parent e82236292b
commit 7d897cea19
3 changed files with 2 additions and 82 deletions

View File

@@ -58,4 +58,4 @@ using KernelTypes = ::testing::Types<
TYPED_TEST_SUITE(TestCkTileGroupedGemmQuant, KernelTypes);
#include "test_grouped_gemm_quant_ut_cases.inc"
#include "test_grouped_gemm_quant_ut_cases.inc"

View File

@@ -1,39 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <tuple>
#include "gtest/gtest.h"
#include "ck_tile/host.hpp"
#include "test_grouped_gemm_util_quant.hpp"
using F16 = ck_tile::half_t;
using F32 = float;
using FP8 = ck_tile::fp8_t;
using BF8 = ck_tile::bf8_t;
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
using True = ck_tile::bool_constant<true>;
using False = ck_tile::bool_constant<false>;
using ABQuant = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::ABQuantGrouped>;
// clang-format off
using KernelTypes_ABQuant = ::testing::Types<
// ALayout, BLayout, CLayout, ADataType, AQDataType, BDataType, BQDataType, AccDataType, CDataType, QuantType, PreshuffleB, Persistent, TransposeC
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, ABQuant, False, True, False>,
std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, ABQuant, False, True, False>,
std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, ABQuant, False, True, False>,
std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, ABQuant, False, True, False>,
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, ABQuant, False, False, False>,
std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, ABQuant, False, False, False>
>;
// clang-format on
TYPED_TEST_SUITE(TestCkTileGroupedGemmQuant_ABQuant, KernelTypes_ABQuant);
#define TEST_CLASS_NAME TestCkTileGroupedGemmQuant_ABQuant
#include "test_grouped_gemm_quant_ut_cases.inc"
#undef TEST_CLASS_NAME

View File

@@ -494,16 +494,6 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test
"K must be divisible by QuantGroupSize::kK for BQuantGrouped mode");
}
}
else if constexpr(QuantType == ck_tile::QuantType::ABQuantGrouped)
{
AQK = K / QuantGroupSize::kK; // Group quantization: AQK = K / QuantGroupSize
BQK = K / QuantGroupSize::kK; // Group quantization: BQK = K / QuantGroupSize
if(K % QuantGroupSize::kK != 0)
{
throw std::runtime_error(
"K must be divisible by QuantGroupSize::kK for ABQuantGrouped mode");
}
}
stride_As[i] = ck_tile::get_default_stride(M, K, stride_As[i], is_row_major(ALayout{}));
stride_Bs[i] = ck_tile::get_default_stride(K, N, stride_Bs[i], is_row_major(BLayout{}));
@@ -532,13 +522,6 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test
stride_BQs[i] =
ck_tile::get_default_stride(BQK, N, stride_BQs[i], is_row_major(BQLayout()));
}
else if constexpr(QuantType == ck_tile::QuantType::ABQuantGrouped)
{
stride_AQs[i] =
ck_tile::get_default_stride(M, AQK, stride_AQs[i], is_row_major(AQLayout()));
stride_BQs[i] =
ck_tile::get_default_stride(BQK, N, stride_BQs[i], is_row_major(BQLayout()));
}
a_m_k_tensors.push_back(ck_tile::HostTensor<ADataType>(
ck_tile::host_tensor_descriptor(M, K, stride_As[i], is_row_major(ALayout{}))));
@@ -582,15 +565,6 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test
ck_tile::HostTensor<BQDataType>(ck_tile::host_tensor_descriptor(
BQK, N, stride_BQs[i], is_row_major(BQLayout()))));
}
else if constexpr(QuantType == ck_tile::QuantType::ABQuantGrouped)
{
aq_tensors.push_back(
ck_tile::HostTensor<AQDataType>(ck_tile::host_tensor_descriptor(
M, AQK, stride_AQs[i], is_row_major(AQLayout{}))));
bq_tensors.push_back(
ck_tile::HostTensor<BQDataType>(ck_tile::host_tensor_descriptor(
BQK, N, stride_BQs[i], is_row_major(BQLayout()))));
}
std::cout << "gemm[" << i << "]" << " a_m_k: " << a_m_k_tensors[i].mDesc
<< " b_k_n: " << b_k_n_tensors[i].mDesc
@@ -776,18 +750,6 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test
false>(
a_m_k_tensors[i], bq_tensors[i], b_k_n_tensors[i], c_m_n_host_ref);
}
else if constexpr(QuantType == ck_tile::QuantType::ABQuantGrouped)
{
ck_tile::reference_gemm_abquant<ADataType,
AQDataType,
BDataType,
BQDataType,
AccDataType,
CDataType,
QuantGroupSize,
QuantGroupSize>(
a_m_k_tensors[i], aq_tensors[i], b_k_n_tensors[i], bq_tensors[i], c_m_n_host_ref);
}
const float max_accumulated_value =
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
@@ -819,7 +781,4 @@ template <typename Tuple>
using TestCkTileGroupedGemmQuant_AQuant = TestCkTileGroupedGemmQuant<Tuple>;
template <typename Tuple>
using TestCkTileGroupedGemmQuant_BQuant = TestCkTileGroupedGemmQuant<Tuple>;
template <typename Tuple>
using TestCkTileGroupedGemmQuant_ABQuant = TestCkTileGroupedGemmQuant<Tuple>;
using TestCkTileGroupedGemmQuant_BQuant = TestCkTileGroupedGemmQuant<Tuple>;