mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
feat(block_scale_gemm): Support RRR-R, CRR-R and CCR-C layout for aquant quant mode (#3193)
* [CK TILE GEMM] Refactor block_scale_gemm examples - Split cpp file to reduce building time - Support multiple GemmConfig * [CK TILE GEMM] Refactor block_scale_gemm examples - Update Readme * feat(gemm_quant): add RRR and CRR layout support for aquant gemm * test(gemm_quant): add unit tests for RRR and CRR layout support for aquant gemm * fix: compilation error on gfx950 by omitting support for the gpu in example and unit tests * fix: test cases compilation failure due to PR# 2095 * fix: make condition to filter out tests for gfx950 more explicit * need to support the gfx950 * fix: add layout suppot for gfx950 * Extend pk_int4_t support for block_scale_gemm aquant CR and RR layout (#3277) * WIP: add support for pk_int4_t for aquant mode layouts RR and CR * test(block_scale_gemm): add unit tests for CRR and RRR layout when data type is int4 && aquant * fix: compile time error for gfx950 * fix: minor bug where is_a_load_tr_v() was mising * feat(block_scale_gemm): Add layout Col-Col-Row-Col (ABC-Aquant) for tensors in aquant (#3318) * feat(block_scale_gemm): Add layout Col-Col-Row-Col (ABC-Aquant) for tensors in aquant * test: add unit tests for new layout support CCRC for aquant block scale gemm * docs: update changelog with new layout support info * Update CHANGELOG.md Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * refactor: break test instances into multiple cpp files to reduce build time (#3319) * feat(block_scale_gemm): Add layout Col-Col-Row-Col (ABC-Aquant) for tensors in aquant * test: add unit tests for new layout support CCRC for aquant block scale gemm * refactor: break test instances into multiple cpp files to reduce build time * chore: rename file for better code readability * fix: merge conflict resolution * fix: remove memory pipeline because new layout is not compatible * build: resolve build errors for gfx950 by modifying is_a_load_tr() & is_b_load_tr() * refactor: address review comments * solve the conflict --------- Co-authored-by: Cong Ma <congma13@amd.com> Co-authored-by: ThomasNing <thomas.ning@amd.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -9,11 +9,37 @@ endif()
|
||||
list(APPEND TEST_GEMM_COMPILE_OPTIONS -mllvm -enable-noalias-to-md-conversion=0)
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12")
|
||||
# Typed Test Suite for GEMM Quantization
|
||||
add_gtest_executable(test_tile_gemm_quant_typed
|
||||
test_gemm_quant_typed.cpp
|
||||
# Typed Test Suite for GEMM Quantization - split into multiple files to reduce compile time
|
||||
|
||||
# AQuant tests
|
||||
add_gtest_executable(test_tile_gemm_quant_aquant
|
||||
test_gemm_quant_aquant.cpp
|
||||
)
|
||||
target_compile_options(test_tile_gemm_quant_typed PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
|
||||
target_compile_options(test_tile_gemm_quant_aquant PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
# BQuant tests (without PreshuffleB)
|
||||
add_gtest_executable(test_tile_gemm_quant_bquant
|
||||
test_gemm_quant_bquant.cpp
|
||||
)
|
||||
target_compile_options(test_tile_gemm_quant_bquant PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
# BQuant tests (with PreshuffleB)
|
||||
add_gtest_executable(test_tile_gemm_quant_bquant_preshuffle
|
||||
test_gemm_quant_bquant_preshuffle.cpp
|
||||
)
|
||||
target_compile_options(test_tile_gemm_quant_bquant_preshuffle PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
# RowColQuant tests
|
||||
add_gtest_executable(test_tile_gemm_quant_rowcol
|
||||
test_gemm_quant_rowcol.cpp
|
||||
)
|
||||
target_compile_options(test_tile_gemm_quant_rowcol PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
# TensorQuant tests
|
||||
add_gtest_executable(test_tile_gemm_quant_tensor
|
||||
test_gemm_quant_tensor.cpp
|
||||
)
|
||||
target_compile_options(test_tile_gemm_quant_tensor PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
|
||||
else()
|
||||
message(DEBUG "Skipping ck_tile quant gemm tests for current target")
|
||||
endif()
|
||||
|
||||
95
test/ck_tile/gemm_block_scale/test_gemm_quant_aquant.cpp
Normal file
95
test/ck_tile/gemm_block_scale/test_gemm_quant_aquant.cpp
Normal file
@@ -0,0 +1,95 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <memory>
|
||||
|
||||
#include "test_gemm_quant_fixtures.hpp"
|
||||
|
||||
// Type aliases for readability
|
||||
using RowMajor = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using FP8 = ck_tile::fp8_t;
|
||||
using BF8 = ck_tile::bf8_t;
|
||||
using Half = ck_tile::half_t;
|
||||
using PkInt4 = ck_tile::pk_int4_t;
|
||||
using AQuantGrouped = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::AQuantGrouped>;
|
||||
using BQuantGrouped = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::BQuantGrouped>;
|
||||
using RowColQuant = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::RowColQuant>;
|
||||
using TensorQuant = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::TensorQuant>;
|
||||
using GroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
using GroupSize64 = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 64>>;
|
||||
|
||||
// 2d block sizes for BQuant
|
||||
using GroupSize2D8N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
|
||||
using GroupSize2D16N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 16, 128>>;
|
||||
using GroupSize2D32N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 32, 128>>;
|
||||
using GroupSize2D64N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
|
||||
using GroupSize2D128N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
|
||||
// Type combinations for AQuant tests
|
||||
// Tuple format: <ALayout, BLayout, CLayout, AQLayout, ADataType, BDataType, QDataType, CDataType,
|
||||
// QuantType, GemmConfig, QuantGroupSize>
|
||||
// clang-format off
|
||||
using AQuantTypes = ::testing::Types<
|
||||
// PreshuffleQuant = false && TransposeC = false (RCR layout with RowMajor AQ)
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, AQuantGrouped, GemmConfigBase, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, BF8, BF8, float, Half, AQuantGrouped, GemmConfigBase, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, PkInt4, FP8, FP8, Half, AQuantGrouped, GemmConfigBase, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, PkInt4, BF8, BF8, Half, AQuantGrouped, GemmConfigBase, GroupSize>,
|
||||
|
||||
// RRR layout (RowMajor A, RowMajor B, RowMajor C with RowMajor AQ)
|
||||
std::tuple<RowMajor, RowMajor, RowMajor, RowMajor, FP8, FP8, float, Half, AQuantGrouped, GemmConfigBase, GroupSize>,
|
||||
std::tuple<RowMajor, RowMajor, RowMajor, RowMajor, BF8, BF8, float, Half, AQuantGrouped, GemmConfigBase, GroupSize>,
|
||||
std::tuple<RowMajor, RowMajor, RowMajor, RowMajor, PkInt4, FP8, FP8, Half, AQuantGrouped, GemmConfigBase, GroupSize>,
|
||||
std::tuple<RowMajor, RowMajor, RowMajor, RowMajor, PkInt4, BF8, BF8, Half, AQuantGrouped, GemmConfigBase, GroupSize>,
|
||||
|
||||
// CRR layout (ColumnMajor A, RowMajor B, RowMajor C with RowMajor AQ)
|
||||
std::tuple<ColumnMajor, RowMajor, RowMajor, RowMajor, FP8, FP8, float, Half, AQuantGrouped, GemmConfigBase, GroupSize>,
|
||||
std::tuple<ColumnMajor, RowMajor, RowMajor, RowMajor, BF8, BF8, float, Half, AQuantGrouped, GemmConfigBase, GroupSize>,
|
||||
std::tuple<ColumnMajor, RowMajor, RowMajor, RowMajor, PkInt4, FP8, FP8, Half, AQuantGrouped, GemmConfigBase, GroupSize>,
|
||||
std::tuple<ColumnMajor, RowMajor, RowMajor, RowMajor, PkInt4, BF8, BF8, Half, AQuantGrouped, GemmConfigBase, GroupSize>,
|
||||
|
||||
// CCR layout (ColumnMajor A, ColumnMajor B, RowMajor C with ColumnMajor AQ) - NEW layout support
|
||||
std::tuple<ColumnMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, AQuantGrouped, GemmConfigBase, GroupSize>,
|
||||
std::tuple<ColumnMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, BF8, float, Half, AQuantGrouped, GemmConfigBase, GroupSize>,
|
||||
std::tuple<ColumnMajor, ColumnMajor, RowMajor, ColumnMajor, PkInt4, FP8, FP8, Half, AQuantGrouped, GemmConfigBase, GroupSize>,
|
||||
std::tuple<ColumnMajor, ColumnMajor, RowMajor, ColumnMajor, PkInt4, BF8, BF8, Half, AQuantGrouped, GemmConfigBase, GroupSize>,
|
||||
|
||||
// RCR layout - with the Prefill BlockTile Config.
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, AQuantGrouped, GemmConfigPrefill, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, BF8, BF8, float, Half, AQuantGrouped, GemmConfigPrefill, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, PkInt4, FP8, FP8, Half, AQuantGrouped, GemmConfigPrefill, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, PkInt4, BF8, BF8, Half, AQuantGrouped, GemmConfigPrefill, GroupSize>,
|
||||
|
||||
// PreshuffleQuant = false && TransposeC = true (with RowMajor AQ)
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, AQuantGrouped, GemmConfigTransposeC, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, BF8, BF8, float, Half, AQuantGrouped, GemmConfigTransposeC, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, PkInt4, FP8, FP8, Half, AQuantGrouped, GemmConfigTransposeC, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, PkInt4, BF8, BF8, Half, AQuantGrouped, GemmConfigTransposeC, GroupSize>,
|
||||
|
||||
// PreshuffleQuant = true && TransposeC = false (with RowMajor AQ - PreshuffleQuant only supports RowMajor)
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, AQuantGrouped, GemmConfigPreshuffleQuant, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, BF8, BF8, float, Half, AQuantGrouped, GemmConfigPreshuffleQuant, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, PkInt4, FP8, FP8, Half, AQuantGrouped, GemmConfigPreshuffleQuant, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, PkInt4, BF8, BF8, Half, AQuantGrouped, GemmConfigPreshuffleQuant, GroupSize>,
|
||||
|
||||
// PreshuffleQuant = true && TransposeC = true (with RowMajor AQ - PreshuffleQuant only supports RowMajor)
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, AQuantGrouped, GemmConfigPreshuffleQuantTransposeC, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, BF8, BF8, float, Half, AQuantGrouped, GemmConfigPreshuffleQuantTransposeC, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, PkInt4, FP8, FP8, Half, AQuantGrouped, GemmConfigPreshuffleQuantTransposeC, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, PkInt4, BF8, BF8, Half, AQuantGrouped, GemmConfigPreshuffleQuantTransposeC, GroupSize>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
// Test suite for AQuant
|
||||
TYPED_TEST_SUITE(TestCkTileGemmAQuant, AQuantTypes);
|
||||
|
||||
// AQuant tests
|
||||
TYPED_TEST(TestCkTileGemmAQuant, AQuantGroupedTest)
|
||||
{
|
||||
this->run_test_with_validation(1024, 1024, 1024);
|
||||
}
|
||||
@@ -29,13 +29,14 @@ class TestCkTileGemmQuantBase : public ::testing::Test
|
||||
using ALayout = std::tuple_element_t<0, Tuple>;
|
||||
using BLayout = std::tuple_element_t<1, Tuple>;
|
||||
using CLayout = std::tuple_element_t<2, Tuple>;
|
||||
using ADataType = std::tuple_element_t<3, Tuple>;
|
||||
using BDataType = std::tuple_element_t<4, Tuple>;
|
||||
using QDataType = std::tuple_element_t<5, Tuple>;
|
||||
using CDataType = std::tuple_element_t<6, Tuple>;
|
||||
static constexpr auto QuantType = std::tuple_element_t<7, Tuple>::value;
|
||||
using GemmConfig = std::tuple_element_t<8, Tuple>;
|
||||
using QuantGroupSize = std::tuple_element_t<9, Tuple>;
|
||||
using AQLayout = std::tuple_element_t<3, Tuple>;
|
||||
using ADataType = std::tuple_element_t<4, Tuple>;
|
||||
using BDataType = std::tuple_element_t<5, Tuple>;
|
||||
using QDataType = std::tuple_element_t<6, Tuple>;
|
||||
using CDataType = std::tuple_element_t<7, Tuple>;
|
||||
static constexpr auto QuantType = std::tuple_element_t<8, Tuple>::value;
|
||||
using GemmConfig = std::tuple_element_t<9, Tuple>;
|
||||
using QuantGroupSize = std::tuple_element_t<10, Tuple>;
|
||||
using AccDataType = float; // accumulate always in float
|
||||
|
||||
// Get the quant-type specific data types from traits
|
||||
@@ -85,6 +86,9 @@ class TestCkTileGemmQuantBase : public ::testing::Test
|
||||
|
||||
using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenGemmShape>;
|
||||
|
||||
// BQLayout is always ColumnMajor for BQuant
|
||||
using BQLayout = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using CodegenGemmTraits = ck_tile::TileGemmQuantTraits<kPadM,
|
||||
kPadN,
|
||||
kPadK,
|
||||
@@ -94,8 +98,8 @@ class TestCkTileGemmQuantBase : public ::testing::Test
|
||||
BLayout,
|
||||
CLayout,
|
||||
QuantType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
AQLayout,
|
||||
BQLayout,
|
||||
GemmConfig::TransposeC,
|
||||
DoubleSmemBuffer,
|
||||
false,
|
||||
|
||||
77
test/ck_tile/gemm_block_scale/test_gemm_quant_bquant.cpp
Normal file
77
test/ck_tile/gemm_block_scale/test_gemm_quant_bquant.cpp
Normal file
@@ -0,0 +1,77 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <memory>
|
||||
|
||||
#include "test_gemm_quant_fixtures.hpp"
|
||||
|
||||
// Type aliases for readability
|
||||
using RowMajor = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using FP8 = ck_tile::fp8_t;
|
||||
using BF8 = ck_tile::bf8_t;
|
||||
using Half = ck_tile::half_t;
|
||||
using PkInt4 = ck_tile::pk_int4_t;
|
||||
using BQuantGrouped = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::BQuantGrouped>;
|
||||
using GroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
using GroupSize64 = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 64>>;
|
||||
|
||||
// 2d block sizes for BQuant
|
||||
using GroupSize2D8N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
|
||||
using GroupSize2D16N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 16, 128>>;
|
||||
using GroupSize2D32N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 32, 128>>;
|
||||
using GroupSize2D64N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
|
||||
using GroupSize2D128N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
|
||||
// Type combinations for BQuant tests (without PreshuffleB)
|
||||
// Tuple format: <ALayout, BLayout, CLayout, AQLayout, ADataType, BDataType, QDataType, CDataType,
|
||||
// QuantType, GemmConfig, QuantGroupSize>
|
||||
// clang-format off
|
||||
using BQuantTypes = ::testing::Types<
|
||||
// 1d cases with grouping only on k axis (AQLayout is always RowMajor for BQuant)
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigBase, GroupSize>,
|
||||
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize64>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize64>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize64>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigBase, GroupSize64>,
|
||||
|
||||
// 2d cases with grouping also on the n axis
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D8N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D8N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D8N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D8N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D16N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D16N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D16N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D16N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D32N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D32N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D32N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D32N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D64N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D64N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D64N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D64N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D128N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D128N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D128N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D128N>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
// Test suite for BQuant (without PreshuffleB)
|
||||
TYPED_TEST_SUITE(TestCkTileGemmBQuant, BQuantTypes);
|
||||
|
||||
// BQuant tests
|
||||
TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedTest)
|
||||
{
|
||||
this->run_test_with_validation(1024, 1024, 1024);
|
||||
}
|
||||
@@ -0,0 +1,51 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <memory>
|
||||
|
||||
#include "test_gemm_quant_fixtures.hpp"
|
||||
|
||||
// Type aliases for readability
|
||||
using RowMajor = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using FP8 = ck_tile::fp8_t;
|
||||
using BF8 = ck_tile::bf8_t;
|
||||
using Half = ck_tile::half_t;
|
||||
using PkInt4 = ck_tile::pk_int4_t;
|
||||
using BQuantGrouped = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::BQuantGrouped>;
|
||||
using GroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
|
||||
// Type combinations for BQuant tests with PreshuffleB
|
||||
// Tuple format: <ALayout, BLayout, CLayout, AQLayout, ADataType, BDataType, QDataType, CDataType,
|
||||
// QuantType, GemmConfig, QuantGroupSize>
|
||||
// clang-format off
|
||||
using BPreshuffleBQuantTypes = ::testing::Types<
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize>,
|
||||
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize>,
|
||||
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigPreshuffleBPrefillTiledPermuteN, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigPreshuffleBPrefillTiledPermuteN, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigPreshuffleBPrefillTiledPermuteN, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigPreshuffleBPrefillTiledPermuteN, GroupSize>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
// Test suite for BQuant with PreshuffleB
|
||||
TYPED_TEST_SUITE(TestCkTileGemmPreshuffleBBQuant, BPreshuffleBQuantTypes);
|
||||
|
||||
// BQuant PreshuffleB tests
|
||||
TYPED_TEST(TestCkTileGemmPreshuffleBBQuant, BQuantPreshuffleTest)
|
||||
{
|
||||
this->run_test_with_validation(1024, 1024, 1024);
|
||||
}
|
||||
@@ -135,6 +135,7 @@ class TestCkTileGemmAQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGem
|
||||
using typename Base::AccDataType;
|
||||
using typename Base::ADataType;
|
||||
using typename Base::ALayout;
|
||||
using typename Base::AQLayout;
|
||||
using typename Base::BDataType;
|
||||
using typename Base::BLayout;
|
||||
using typename Base::CDataType;
|
||||
@@ -152,20 +153,25 @@ class TestCkTileGemmAQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGem
|
||||
// AQuant-specific data generation
|
||||
void run_test_with_validation(ck_tile::index_t M, ck_tile::index_t N, ck_tile::index_t K)
|
||||
{
|
||||
const ck_tile::index_t stride_A = K;
|
||||
const ck_tile::index_t stride_B = K;
|
||||
const ck_tile::index_t stride_C = N;
|
||||
const ck_tile::index_t stride_A =
|
||||
ck_tile::get_default_stride(M, K, 0, this->is_row_major(ALayout{}));
|
||||
const ck_tile::index_t stride_B =
|
||||
ck_tile::get_default_stride(K, N, 0, this->is_row_major(BLayout{}));
|
||||
const ck_tile::index_t stride_C =
|
||||
ck_tile::get_default_stride(M, N, 0, this->is_row_major(CLayout{}));
|
||||
|
||||
// AQuant uses grouped quantization for A matrix
|
||||
const ck_tile::index_t AQK = ck_tile::integer_divide_ceil(K, QuantGroupSize::kK);
|
||||
// AQLayout is parameterized in the test tuple (can be RowMajor or ColumnMajor for AQuant)
|
||||
const ck_tile::index_t stride_AQ =
|
||||
ck_tile::get_default_stride(M, AQK, 0, this->is_row_major(ALayout{}));
|
||||
ck_tile::get_default_stride(M, AQK, 0, this->is_row_major(AQLayout{}));
|
||||
|
||||
// Generate test data
|
||||
ck_tile::HostTensor<ADataType> a_m_k(
|
||||
ck_tile::host_tensor_descriptor(M, K, stride_A, this->is_row_major(ALayout{})));
|
||||
// AQLayout is independently specified for each test case
|
||||
ck_tile::HostTensor<QDataType> aq_m_aqk(
|
||||
ck_tile::host_tensor_descriptor(M, AQK, stride_AQ, this->is_row_major(ALayout{})));
|
||||
ck_tile::host_tensor_descriptor(M, AQK, stride_AQ, this->is_row_major(AQLayout{})));
|
||||
ck_tile::HostTensor<BDataType> b_k_n(
|
||||
ck_tile::host_tensor_descriptor(K, N, stride_B, this->is_row_major(BLayout{})));
|
||||
|
||||
@@ -407,8 +413,9 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGem
|
||||
ck_tile::host_tensor_descriptor(M, K, stride_A, this->is_row_major(ALayout{})));
|
||||
ck_tile::HostTensor<BDataType> b_k_n(
|
||||
ck_tile::host_tensor_descriptor(K, N, stride_B, this->is_row_major(BLayout{})));
|
||||
// BQ is always ColumnMajor
|
||||
ck_tile::HostTensor<QDataType> bq_bqk_bqn(
|
||||
ck_tile::host_tensor_descriptor(BQK, BQN, stride_BQ, this->is_row_major(BLayout{})));
|
||||
ck_tile::host_tensor_descriptor(BQK, BQN, stride_BQ, ck_tile::bool_constant<false>{}));
|
||||
|
||||
// Initialize data with random values
|
||||
ck_tile::FillUniformDistribution<ADataType>{-0.5f, 0.5f}(a_m_k);
|
||||
|
||||
38
test/ck_tile/gemm_block_scale/test_gemm_quant_rowcol.cpp
Normal file
38
test/ck_tile/gemm_block_scale/test_gemm_quant_rowcol.cpp
Normal file
@@ -0,0 +1,38 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <memory>
|
||||
|
||||
#include "test_gemm_quant_fixtures.hpp"
|
||||
|
||||
// Type aliases for readability
|
||||
using RowMajor = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using FP8 = ck_tile::fp8_t;
|
||||
using BF8 = ck_tile::bf8_t;
|
||||
using Half = ck_tile::half_t;
|
||||
using RowColQuant = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::RowColQuant>;
|
||||
using GroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
|
||||
// Type combinations for RowColQuant tests
|
||||
// Tuple format: <ALayout, BLayout, CLayout, AQLayout, ADataType, BDataType, QDataType, CDataType,
|
||||
// QuantType, GemmConfig, QuantGroupSize>
|
||||
// clang-format off
|
||||
using RowColQuantTypes = ::testing::Types<
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, RowColQuant, GemmConfigBase, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, BF8, BF8, float, Half, RowColQuant, GemmConfigBase, GroupSize>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
// Test suite for RowColQuant
|
||||
TYPED_TEST_SUITE(TestCkTileGemmRowColQuant, RowColQuantTypes);
|
||||
|
||||
// RowColQuant tests
|
||||
TYPED_TEST(TestCkTileGemmRowColQuant, RowColQuantTest)
|
||||
{
|
||||
this->run_test_with_validation(1024, 1024, 1024);
|
||||
}
|
||||
38
test/ck_tile/gemm_block_scale/test_gemm_quant_tensor.cpp
Normal file
38
test/ck_tile/gemm_block_scale/test_gemm_quant_tensor.cpp
Normal file
@@ -0,0 +1,38 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <memory>
|
||||
|
||||
#include "test_gemm_quant_fixtures.hpp"
|
||||
|
||||
// Type aliases for readability
|
||||
using RowMajor = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using FP8 = ck_tile::fp8_t;
|
||||
using BF8 = ck_tile::bf8_t;
|
||||
using Half = ck_tile::half_t;
|
||||
using TensorQuant = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::TensorQuant>;
|
||||
using GroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
|
||||
// Type combinations for TensorQuant tests
|
||||
// Tuple format: <ALayout, BLayout, CLayout, AQLayout, ADataType, BDataType, QDataType, CDataType,
|
||||
// QuantType, GemmConfig, QuantGroupSize>
|
||||
// clang-format off
|
||||
using TensorQuantTypes = ::testing::Types<
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, TensorQuant, GemmConfigBase, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, BF8, BF8, float, Half, TensorQuant, GemmConfigBase, GroupSize>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
// Test suite for TensorQuant
|
||||
TYPED_TEST_SUITE(TestCkTileGemmTensorQuant, TensorQuantTypes);
|
||||
|
||||
// TensorQuant tests
|
||||
TYPED_TEST(TestCkTileGemmTensorQuant, TensorQuantTest)
|
||||
{
|
||||
this->run_test_with_validation(1024, 1024, 1024);
|
||||
}
|
||||
@@ -1,146 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <memory>
|
||||
|
||||
#include "test_gemm_quant_fixtures.hpp"
|
||||
|
||||
// Type aliases for readability
|
||||
using RowMajor = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using FP8 = ck_tile::fp8_t;
|
||||
using BF8 = ck_tile::bf8_t;
|
||||
using Half = ck_tile::half_t;
|
||||
using PkInt4 = ck_tile::pk_int4_t;
|
||||
using AQuantGrouped = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::AQuantGrouped>;
|
||||
using BQuantGrouped = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::BQuantGrouped>;
|
||||
using RowColQuant = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::RowColQuant>;
|
||||
using TensorQuant = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::TensorQuant>;
|
||||
using GroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
using GroupSize64 = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 64>>;
|
||||
|
||||
// 2d block sizes for BQuant
|
||||
using GroupSize2D8N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
|
||||
using GroupSize2D16N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 16, 128>>;
|
||||
using GroupSize2D32N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 32, 128>>;
|
||||
using GroupSize2D64N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
|
||||
using GroupSize2D128N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
|
||||
// Type combinations for each quantization type
|
||||
// clang-format off
|
||||
using AQuantTypes = ::testing::Types<
|
||||
// PreshuffleQuant = false && TransposeC = false
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, FP8, float, Half, AQuantGrouped, GemmConfigBase, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, BF8, float, Half, AQuantGrouped, GemmConfigBase, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, PkInt4, FP8, FP8, Half, AQuantGrouped, GemmConfigBase, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, PkInt4, BF8, BF8, Half, AQuantGrouped, GemmConfigBase, GroupSize>,
|
||||
|
||||
// PreshuffleQuant = false && TransposeC = false && Prefill
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, FP8, float, Half, AQuantGrouped, GemmConfigPrefill, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, BF8, float, Half, AQuantGrouped, GemmConfigPrefill, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, PkInt4, FP8, FP8, Half, AQuantGrouped, GemmConfigPrefill, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, PkInt4, BF8, BF8, Half, AQuantGrouped, GemmConfigPrefill, GroupSize>,
|
||||
|
||||
// PreshuffleQuant = false && TransposeC = true
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, FP8, float, Half, AQuantGrouped, GemmConfigTransposeC, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, BF8, float, Half, AQuantGrouped, GemmConfigTransposeC, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, PkInt4, FP8, FP8, Half, AQuantGrouped, GemmConfigTransposeC, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, PkInt4, BF8, BF8, Half, AQuantGrouped, GemmConfigTransposeC, GroupSize>,
|
||||
|
||||
// PreshuffleQuant = true && TransposeC = false
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, FP8, float, Half, AQuantGrouped, GemmConfigPreshuffleQuant, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, BF8, float, Half, AQuantGrouped, GemmConfigPreshuffleQuant, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, PkInt4, FP8, FP8, Half, AQuantGrouped, GemmConfigPreshuffleQuant, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, PkInt4, BF8, BF8, Half, AQuantGrouped, GemmConfigPreshuffleQuant, GroupSize>,
|
||||
|
||||
// PreshuffleQuant = true && TransposeC = true
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, FP8, float, Half, AQuantGrouped, GemmConfigPreshuffleQuantTransposeC, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, BF8, float, Half, AQuantGrouped, GemmConfigPreshuffleQuantTransposeC, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, PkInt4, FP8, FP8, Half, AQuantGrouped, GemmConfigPreshuffleQuantTransposeC, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, PkInt4, BF8, BF8, Half, AQuantGrouped, GemmConfigPreshuffleQuantTransposeC, GroupSize>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
// clang-format off
|
||||
using BQuantTypes = ::testing::Types<
|
||||
// 1d cases with grouping only on k axis
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigBase, GroupSize>,
|
||||
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize64>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize64>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize64>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigBase, GroupSize64>,
|
||||
|
||||
// 2d cases with grouping also on the n axis
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D8N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D8N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D8N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D8N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D16N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D16N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D16N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D16N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D32N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D32N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D32N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D32N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D64N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D64N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D64N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D64N>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
// clang-format off
|
||||
using BPreshuffleBQuantTypes = ::testing::Types<
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize>,
|
||||
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize>,
|
||||
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigPreshuffleBPrefillTiledPermuteN, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigPreshuffleBPrefillTiledPermuteN, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigPreshuffleBPrefillTiledPermuteN, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigPreshuffleBPrefillTiledPermuteN, GroupSize>,
|
||||
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigPreshuffleBPreshuffleQuantDecode, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigPreshuffleBPreshuffleQuantDecode, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigPreshuffleBPreshuffleQuantDecode, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigPreshuffleBPreshuffleQuantDecode, GroupSize>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
// clang-format off
|
||||
using RowColQuantTypes = ::testing::Types<
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, FP8, float, Half, RowColQuant, GemmConfigBase, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, BF8, float, Half, RowColQuant, GemmConfigBase, GroupSize>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
// clang-format off
|
||||
using TensorQuantTypes = ::testing::Types<
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, FP8, float, Half, TensorQuant, GemmConfigBase, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, BF8, float, Half, TensorQuant, GemmConfigBase, GroupSize>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
// Test suites for each quantization type
|
||||
TYPED_TEST_SUITE(TestCkTileGemmAQuant, AQuantTypes);
|
||||
TYPED_TEST_SUITE(TestCkTileGemmBQuant, BQuantTypes);
|
||||
TYPED_TEST_SUITE(TestCkTileGemmPreshuffleBBQuant, BPreshuffleBQuantTypes);
|
||||
TYPED_TEST_SUITE(TestCkTileGemmRowColQuant, RowColQuantTypes);
|
||||
TYPED_TEST_SUITE(TestCkTileGemmTensorQuant, TensorQuantTypes);
|
||||
|
||||
#include "test_gemm_quant_ut_cases.inc"
|
||||
@@ -1,33 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
// AQuant tests
|
||||
TYPED_TEST(TestCkTileGemmAQuant, AQuantGroupedTest)
|
||||
{
|
||||
this->run_test_with_validation(1024, 1024, 1024);
|
||||
}
|
||||
|
||||
// BQuant tests
|
||||
TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedTest)
|
||||
{
|
||||
this->run_test_with_validation(1024, 1024, 1024);
|
||||
}
|
||||
|
||||
// BQuant tests
|
||||
TYPED_TEST(TestCkTileGemmPreshuffleBBQuant, BQuantGroupedTest)
|
||||
{
|
||||
this->run_test_with_validation(1024, 1024, 1024);
|
||||
}
|
||||
// RowColQuant tests
|
||||
TYPED_TEST(TestCkTileGemmRowColQuant, RowColQuantTest)
|
||||
{
|
||||
this->run_test_with_validation(1024, 1024, 1024);
|
||||
}
|
||||
|
||||
// TensorQuant tests
|
||||
TYPED_TEST(TestCkTileGemmTensorQuant, TensorQuantTest)
|
||||
{
|
||||
this->run_test_with_validation(1024, 1024, 1024);
|
||||
}
|
||||
Reference in New Issue
Block a user