mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[rocm-libraries] ROCm/rocm-libraries#4267 (commit 3c5d95e)
[CK_TILE] Extend support of mix precision microscaling BQuant (#4267) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Proposed changes Supported types combinations using BQuant=e8m0: - A=bf16 - B=bf16,bf8,fp4 Summary: - remove usage of `pk_fp4_raw_t`: consistent with other implementations and avoid taking into account of the packed size explicitly. In general, the raw type should not be used because CK Tile internally takes care of the PackedSize, so using the raw type adds unnecessary complexity to the implementation - handle microscaling by checking for `e8m0` type for BQuant (previous implementation was inconsistent) - add support for scaling instructions in `DequantPack8` - mx pipeline: - extend existing pipeline to support different B types - add support to scale and cast before writing to LDS or after reading from LDS (this can be defined in the `Problem` by the user) - block gemm: - mx pipeline is now using block gemm BQuant - block gemm BQuant can now load from LDS and apply scale and then call block gemm universal operator. This adds new functionalities and remove code duplication - warp gemm: - add case to support 128bit ds_read/write for both A and B when A=16bit and B=8bit - add examples and tests: note that some tests for bf16/fp4 already existed but were removed during previous tests refactoring. I added them again and other relevant tests for new types combinations ## Checklist Please put an `x` into the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask. - [ ] I have added tests relevant to the introduced functionality, and the unit tests are passing locally - [ ] I have added the test to REGRESSION_TESTS list defined at the top of CMakeLists.txt in tests/CMakeLists.txt, **IF** the test takes more than 30 seconds to run. - [ ] I have added inline documentation which enables the maintainers with understanding the motivation - [ ] I have removed the stale documentation which is no longer relevant after this pull request - [ ] (If this change is user-facing) I have added release notes which provide the end users with a brief summary of the improvement from this pull request - [ ] I have run `clang-format` on all changed files - [ ] Any dependent changes have been merged ## Discussion If this is a relatively large or complex change, feel free to start a discussion by explaining why you chose the solution you did and what alternatives you considered
This commit is contained in:
committed by
assistant-librarian[bot]
parent
3af1a0aafc
commit
4c626aeaa6
@@ -190,6 +190,47 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12")
|
||||
)
|
||||
target_compile_options(test_tile_gemm_quant_bquant_preshuffleQuant_prefill_2d PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
# BQuant microscale tests
|
||||
add_gtest_executable(test_tile_gemm_quant_bquant_microscale_rcr_1d_64
|
||||
test_gemm_quant_bquant_microscale_rcr_1d_64.cpp
|
||||
)
|
||||
target_compile_options(test_tile_gemm_quant_bquant_microscale_rcr_1d_64 PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
add_gtest_executable(test_tile_gemm_quant_bquant_microscale_crr_1d_64
|
||||
test_gemm_quant_bquant_microscale_crr_1d_64.cpp
|
||||
)
|
||||
target_compile_options(test_tile_gemm_quant_bquant_microscale_crr_1d_64 PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
add_gtest_executable(test_tile_gemm_quant_bquant_microscale_rrr_1d_64
|
||||
test_gemm_quant_bquant_microscale_rrr_1d_64.cpp
|
||||
)
|
||||
target_compile_options(test_tile_gemm_quant_bquant_microscale_rrr_1d_64 PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
add_gtest_executable(test_tile_gemm_quant_bquant_microscale_ccr_1d_64
|
||||
test_gemm_quant_bquant_microscale_ccr_1d_64.cpp
|
||||
)
|
||||
target_compile_options(test_tile_gemm_quant_bquant_microscale_ccr_1d_64 PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
add_gtest_executable(test_tile_gemm_quant_bquant_microscale_rcr_1d_128
|
||||
test_gemm_quant_bquant_microscale_rcr_1d_128.cpp
|
||||
)
|
||||
target_compile_options(test_tile_gemm_quant_bquant_microscale_rcr_1d_128 PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
add_gtest_executable(test_tile_gemm_quant_bquant_microscale_crr_1d_128
|
||||
test_gemm_quant_bquant_microscale_crr_1d_128.cpp
|
||||
)
|
||||
target_compile_options(test_tile_gemm_quant_bquant_microscale_crr_1d_128 PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
add_gtest_executable(test_tile_gemm_quant_bquant_microscale_rrr_1d_128
|
||||
test_gemm_quant_bquant_microscale_rrr_1d_128.cpp
|
||||
)
|
||||
target_compile_options(test_tile_gemm_quant_bquant_microscale_rrr_1d_128 PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
add_gtest_executable(test_tile_gemm_quant_bquant_microscale_ccr_1d_128
|
||||
test_gemm_quant_bquant_microscale_ccr_1d_128.cpp
|
||||
)
|
||||
target_compile_options(test_tile_gemm_quant_bquant_microscale_ccr_1d_128 PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
# RowColQuant tests
|
||||
add_gtest_executable(test_tile_gemm_quant_rowcol
|
||||
test_gemm_quant_rowcol.cpp
|
||||
|
||||
@@ -153,7 +153,7 @@ class TestCkTileGemmQuantBase : public ::testing::Test
|
||||
const float max_accumulated_value)
|
||||
{
|
||||
using ComputeType = std::conditional_t<
|
||||
std::is_same_v<BDataType_, ck_tile::pk_fp4_raw_t>,
|
||||
std::is_same_v<BDataType_, ck_tile::pk_fp4_t>,
|
||||
ADataType_,
|
||||
std::conditional_t<sizeof(ADataType_) < sizeof(BDataType_), ADataType_, BDataType_>>;
|
||||
// Calculate thresholds
|
||||
|
||||
@@ -25,9 +25,9 @@ using GroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
// clang-format off
|
||||
using BQuant1D128Types = ::testing::Types<
|
||||
// 1d cases with grouping only on k axis
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigBase, GroupSize>
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigBase, GroupSize>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
|
||||
@@ -0,0 +1,41 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <memory>
|
||||
|
||||
#include "test_gemm_quant_fixtures.hpp"
|
||||
|
||||
// Type aliases for readability
|
||||
using RowMajor = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using BF8 = ck_tile::bf8_t;
|
||||
using BF16 = ck_tile::bf16_t;
|
||||
using PkFP4 = ck_tile::pk_fp4_t;
|
||||
using E8M0 = ck_tile::e8m0_t;
|
||||
using BQuantGrouped = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::BQuantGrouped>;
|
||||
using GroupSize128 = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
|
||||
// Type combinations for BQuant tests - 1D GroupSize 128
|
||||
// Tuple format: <ALayout, BLayout, CLayout, BQLayout, ADataType, BDataType, QDataType, CDataType,
|
||||
// QuantType, GemmConfig, QuantGroupSize>
|
||||
// clang-format off
|
||||
using BQuant1D128Types = ::testing::Types<
|
||||
// CCR BQ: C
|
||||
std::tuple<ColumnMajor, ColumnMajor, RowMajor, ColumnMajor, BF16, BF16, E8M0, BF16, BQuantGrouped, GemmConfigMx, GroupSize128>,
|
||||
std::tuple<ColumnMajor, ColumnMajor, RowMajor, ColumnMajor, BF16, BF8, E8M0, BF16, BQuantGrouped, GemmConfigMx, GroupSize128>,
|
||||
std::tuple<ColumnMajor, ColumnMajor, RowMajor, ColumnMajor, BF16, PkFP4, E8M0, BF16, BQuantGrouped, GemmConfigMx, GroupSize128>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
// Test suite for BQuant 1D 128
|
||||
TYPED_TEST_SUITE(TestCkTileGemmBQuant, BQuant1D128Types);
|
||||
|
||||
// BQuant tests
|
||||
TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedTest)
|
||||
{
|
||||
this->run_test_with_validation(1024, 1024, 1024);
|
||||
}
|
||||
@@ -0,0 +1,45 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <memory>
|
||||
|
||||
#include "test_gemm_quant_fixtures.hpp"
|
||||
|
||||
// Type aliases for readability
|
||||
using RowMajor = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using FP8 = ck_tile::fp8_t;
|
||||
using BF8 = ck_tile::bf8_t;
|
||||
using FP16 = ck_tile::fp16_t;
|
||||
using BF16 = ck_tile::bf16_t;
|
||||
using Half = ck_tile::half_t;
|
||||
using PkInt4 = ck_tile::pk_int4_t;
|
||||
using PkFP4 = ck_tile::pk_fp4_t;
|
||||
using E8M0 = ck_tile::e8m0_t;
|
||||
using BQuantGrouped = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::BQuantGrouped>;
|
||||
using GroupSize64 = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 64>>;
|
||||
|
||||
// Type combinations for BQuant tests - 1D GroupSize 64
|
||||
// Tuple format: <ALayout, BLayout, CLayout, BQLayout, ADataType, BDataType, QDataType, CDataType,
|
||||
// QuantType, GemmConfig, QuantGroupSize>
|
||||
// clang-format off
|
||||
using BQuant1D64Types = ::testing::Types<
|
||||
// CCR BQ: C
|
||||
std::tuple<ColumnMajor, ColumnMajor, RowMajor, ColumnMajor, BF16, BF16, E8M0, BF16, BQuantGrouped, GemmConfigMx, GroupSize64>,
|
||||
std::tuple<ColumnMajor, ColumnMajor, RowMajor, ColumnMajor, BF16, BF8, E8M0, BF16, BQuantGrouped, GemmConfigMx, GroupSize64>,
|
||||
std::tuple<ColumnMajor, ColumnMajor, RowMajor, ColumnMajor, BF16, PkFP4, E8M0, BF16, BQuantGrouped, GemmConfigMx, GroupSize64>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
// Test suite for BQuant 1D 64
|
||||
TYPED_TEST_SUITE(TestCkTileGemmBQuant, BQuant1D64Types);
|
||||
|
||||
// BQuant tests
|
||||
TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedTest)
|
||||
{
|
||||
this->run_test_with_validation(1024, 1024, 1024);
|
||||
}
|
||||
@@ -0,0 +1,42 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <memory>
|
||||
|
||||
#include "test_gemm_quant_fixtures.hpp"
|
||||
|
||||
// Type aliases for readability
|
||||
using RowMajor = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using BF8 = ck_tile::bf8_t;
|
||||
using BF16 = ck_tile::bf16_t;
|
||||
using PkFP4 = ck_tile::pk_fp4_t;
|
||||
using E8M0 = ck_tile::e8m0_t;
|
||||
using BQuantGrouped = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::BQuantGrouped>;
|
||||
using GroupSize128 = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
|
||||
// Type combinations for BQuant tests - 1D GroupSize 128
|
||||
// Tuple format: <ALayout, BLayout, CLayout, BQLayout, ADataType, BDataType, QDataType, CDataType,
|
||||
// QuantType, GemmConfig, QuantGroupSize>
|
||||
// clang-format off
|
||||
using BQuant1D128Types = ::testing::Types<
|
||||
// CRR BQ: C
|
||||
std::tuple<ColumnMajor, RowMajor, RowMajor, ColumnMajor, BF16, BF8, E8M0, BF16, BQuantGrouped, GemmConfigMx, GroupSize128>,
|
||||
// CRR BQ: R
|
||||
std::tuple<ColumnMajor, RowMajor, RowMajor, RowMajor, BF16, BF16, E8M0, BF16, BQuantGrouped, GemmConfigMx, GroupSize128>,
|
||||
std::tuple<ColumnMajor, RowMajor, RowMajor, ColumnMajor, BF16, PkFP4, E8M0, BF16, BQuantGrouped, GemmConfigMxFP4, GroupSize128>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
// Test suite for BQuant 1D 128
|
||||
TYPED_TEST_SUITE(TestCkTileGemmBQuant, BQuant1D128Types);
|
||||
|
||||
// BQuant tests
|
||||
TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedTest)
|
||||
{
|
||||
this->run_test_with_validation(1024, 1024, 1024);
|
||||
}
|
||||
@@ -0,0 +1,42 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <memory>
|
||||
|
||||
#include "test_gemm_quant_fixtures.hpp"
|
||||
|
||||
// Type aliases for readability
|
||||
using RowMajor = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using BF8 = ck_tile::bf8_t;
|
||||
using BF16 = ck_tile::bf16_t;
|
||||
using PkFP4 = ck_tile::pk_fp4_t;
|
||||
using E8M0 = ck_tile::e8m0_t;
|
||||
using BQuantGrouped = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::BQuantGrouped>;
|
||||
using GroupSize64 = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 64>>;
|
||||
|
||||
// Type combinations for BQuant tests - 1D GroupSize 64
|
||||
// Tuple format: <ALayout, BLayout, CLayout, BQLayout, ADataType, BDataType, QDataType, CDataType,
|
||||
// QuantType, GemmConfig, QuantGroupSize>
|
||||
// clang-format off
|
||||
using BQuant1D64Types = ::testing::Types<
|
||||
// CRR BQ: C
|
||||
std::tuple<ColumnMajor, RowMajor, RowMajor, ColumnMajor, BF16, BF8, E8M0, BF16, BQuantGrouped, GemmConfigMx, GroupSize64>,
|
||||
// CRR BQ: R
|
||||
std::tuple<ColumnMajor, RowMajor, RowMajor, RowMajor, BF16, BF16, E8M0, BF16, BQuantGrouped, GemmConfigMx, GroupSize64>,
|
||||
std::tuple<ColumnMajor, RowMajor, RowMajor, ColumnMajor, BF16, PkFP4, E8M0, BF16, BQuantGrouped, GemmConfigMxFP4, GroupSize64>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
// Test suite for BQuant 1D 64
|
||||
TYPED_TEST_SUITE(TestCkTileGemmBQuant, BQuant1D64Types);
|
||||
|
||||
// BQuant tests
|
||||
TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedTest)
|
||||
{
|
||||
this->run_test_with_validation(1024, 1024, 1024);
|
||||
}
|
||||
@@ -0,0 +1,51 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <memory>
|
||||
|
||||
#include "test_gemm_quant_fixtures.hpp"
|
||||
|
||||
// Type aliases for readability
|
||||
using RowMajor = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using FP8 = ck_tile::fp8_t;
|
||||
using BF8 = ck_tile::bf8_t;
|
||||
using FP16 = ck_tile::fp16_t;
|
||||
using BF16 = ck_tile::bf16_t;
|
||||
using PkFP4 = ck_tile::pk_fp4_t;
|
||||
using E8M0 = ck_tile::e8m0_t;
|
||||
using BQuantGrouped = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::BQuantGrouped>;
|
||||
using GroupSize128 = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
|
||||
// Type combinations for BQuant tests - 1D GroupSize 128
|
||||
// Tuple format: <ALayout, BLayout, CLayout, BQLayout, ADataType, BDataType, QDataType, CDataType,
|
||||
// QuantType, GemmConfig, QuantGroupSize>
|
||||
// clang-format off
|
||||
using BQuant1D128Types = ::testing::Types<
|
||||
// RCR BQ: C
|
||||
std::tuple< RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF16, PkFP4, E8M0, BF16, BQuantGrouped, GemmConfigMx, GroupSize128>,
|
||||
std::tuple< RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF16, BF8, E8M0, BF16, BQuantGrouped, GemmConfigMx, GroupSize128>,
|
||||
std::tuple< RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF16, BF16, E8M0, BF16, BQuantGrouped, GemmConfigMx, GroupSize128>,
|
||||
std::tuple< RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP16, FP16, E8M0, FP16, BQuantGrouped, GemmConfigMx, GroupSize128>,
|
||||
std::tuple< RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF16, FP8, E8M0, BF16, BQuantGrouped, GemmConfigMx, GroupSize128>,
|
||||
std::tuple< RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP16, FP8, E8M0, FP16, BQuantGrouped, GemmConfigMx, GroupSize128>,
|
||||
std::tuple< RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP16, BF8, E8M0, FP16, BQuantGrouped, GemmConfigMx, GroupSize128>,
|
||||
std::tuple< RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP16, PkFP4, E8M0, FP16, BQuantGrouped, GemmConfigMx, GroupSize128>,
|
||||
// RCR BQ: R
|
||||
std::tuple< RowMajor, ColumnMajor, RowMajor, RowMajor, BF16, BF8, E8M0, BF16, BQuantGrouped, GemmConfigMx, GroupSize128>,
|
||||
std::tuple< RowMajor, ColumnMajor, RowMajor, RowMajor, BF16, BF16, E8M0, BF16, BQuantGrouped, GemmConfigMx, GroupSize128>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
// Test suite for BQuant 1D 128
|
||||
TYPED_TEST_SUITE(TestCkTileGemmBQuant, BQuant1D128Types);
|
||||
|
||||
// BQuant tests
|
||||
TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedTest)
|
||||
{
|
||||
this->run_test_with_validation(1024, 1024, 1024);
|
||||
}
|
||||
@@ -0,0 +1,51 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <memory>
|
||||
|
||||
#include "test_gemm_quant_fixtures.hpp"
|
||||
|
||||
// Type aliases for readability
|
||||
using RowMajor = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using FP8 = ck_tile::fp8_t;
|
||||
using BF8 = ck_tile::bf8_t;
|
||||
using FP16 = ck_tile::fp16_t;
|
||||
using BF16 = ck_tile::bf16_t;
|
||||
using PkFP4 = ck_tile::pk_fp4_t;
|
||||
using E8M0 = ck_tile::e8m0_t;
|
||||
using BQuantGrouped = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::BQuantGrouped>;
|
||||
using GroupSize64 = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 64>>;
|
||||
|
||||
// Type combinations for BQuant tests - 1D GroupSize 64
|
||||
// Tuple format: <ALayout, BLayout, CLayout, BQLayout, ADataType, BDataType, QDataType, CDataType,
|
||||
// QuantType, GemmConfig, QuantGroupSize>
|
||||
// clang-format off
|
||||
using BQuant1D64Types = ::testing::Types<
|
||||
// RCR BQ: C
|
||||
std::tuple< RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF16, PkFP4, E8M0, BF16, BQuantGrouped, GemmConfigMx, GroupSize64>,
|
||||
std::tuple< RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF16, BF8, E8M0, BF16, BQuantGrouped, GemmConfigMx, GroupSize64>,
|
||||
std::tuple< RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF16, BF16, E8M0, BF16, BQuantGrouped, GemmConfigMx, GroupSize64>,
|
||||
std::tuple< RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP16, FP16, E8M0, FP16, BQuantGrouped, GemmConfigMx, GroupSize64>,
|
||||
std::tuple< RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF16, FP8, E8M0, BF16, BQuantGrouped, GemmConfigMx, GroupSize64>,
|
||||
std::tuple< RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP16, FP8, E8M0, FP16, BQuantGrouped, GemmConfigMx, GroupSize64>,
|
||||
std::tuple< RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP16, BF8, E8M0, FP16, BQuantGrouped, GemmConfigMx, GroupSize64>,
|
||||
std::tuple< RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP16, PkFP4, E8M0, FP16, BQuantGrouped, GemmConfigMx, GroupSize64>,
|
||||
// RCR BQ: R
|
||||
std::tuple< RowMajor, ColumnMajor, RowMajor, RowMajor, BF16, BF8, E8M0, BF16, BQuantGrouped, GemmConfigMx, GroupSize64>,
|
||||
std::tuple< RowMajor, ColumnMajor, RowMajor, RowMajor, BF16, BF16, E8M0, BF16, BQuantGrouped, GemmConfigMx, GroupSize64>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
// Test suite for BQuant 1D 64
|
||||
TYPED_TEST_SUITE(TestCkTileGemmBQuant, BQuant1D64Types);
|
||||
|
||||
// BQuant tests
|
||||
TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedTest)
|
||||
{
|
||||
this->run_test_with_validation(1024, 1024, 1024);
|
||||
}
|
||||
@@ -0,0 +1,43 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <memory>
|
||||
|
||||
#include "test_gemm_quant_fixtures.hpp"
|
||||
|
||||
// Type aliases for readability
|
||||
using RowMajor = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using BF8 = ck_tile::bf8_t;
|
||||
using BF16 = ck_tile::bf16_t;
|
||||
using PkFP4 = ck_tile::pk_fp4_t;
|
||||
using E8M0 = ck_tile::e8m0_t;
|
||||
using BQuantGrouped = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::BQuantGrouped>;
|
||||
using GroupSize128 = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
|
||||
// Type combinations for BQuant tests - 1D GroupSize 128
|
||||
// Tuple format: <ALayout, BLayout, CLayout, BQLayout, ADataType, BDataType, QDataType, CDataType,
|
||||
// QuantType, GemmConfig, QuantGroupSize>
|
||||
// clang-format off
|
||||
using BQuant1D128Types = ::testing::Types<
|
||||
// RRR BQ: C
|
||||
std::tuple< RowMajor, RowMajor, RowMajor, ColumnMajor, BF16, BF8, E8M0, BF16, BQuantGrouped, GemmConfigMx, GroupSize128>,
|
||||
std::tuple< RowMajor, RowMajor, RowMajor, ColumnMajor, BF16, PkFP4, E8M0, BF16, BQuantGrouped, GemmConfigMxFP4, GroupSize128>,
|
||||
std::tuple< RowMajor, RowMajor, RowMajor, ColumnMajor, BF16, BF16, E8M0, BF16, BQuantGrouped, GemmConfigMx, GroupSize128>,
|
||||
// RRR BQ: R
|
||||
std::tuple< RowMajor, RowMajor, RowMajor, RowMajor, BF16, BF16, E8M0, BF16, BQuantGrouped, GemmConfigMx, GroupSize128>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
// Test suite for BQuant 1D 128
|
||||
TYPED_TEST_SUITE(TestCkTileGemmBQuant, BQuant1D128Types);
|
||||
|
||||
// BQuant tests
|
||||
TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedTest)
|
||||
{
|
||||
this->run_test_with_validation(1024, 1024, 1024);
|
||||
}
|
||||
@@ -0,0 +1,43 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <memory>
|
||||
|
||||
#include "test_gemm_quant_fixtures.hpp"
|
||||
|
||||
// Type aliases for readability
|
||||
using RowMajor = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using BF8 = ck_tile::bf8_t;
|
||||
using BF16 = ck_tile::bf16_t;
|
||||
using PkFP4 = ck_tile::pk_fp4_t;
|
||||
using E8M0 = ck_tile::e8m0_t;
|
||||
using BQuantGrouped = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::BQuantGrouped>;
|
||||
using GroupSize64 = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 64>>;
|
||||
|
||||
// Type combinations for BQuant tests - 1D GroupSize 64
|
||||
// Tuple format: <ALayout, BLayout, CLayout, BQLayout, ADataType, BDataType, QDataType, CDataType,
|
||||
// QuantType, GemmConfig, QuantGroupSize>
|
||||
// clang-format off
|
||||
using BQuant1D64Types = ::testing::Types<
|
||||
// RRR BQ: C
|
||||
std::tuple< RowMajor, RowMajor, RowMajor, ColumnMajor, BF16, BF8, E8M0, BF16, BQuantGrouped, GemmConfigMx, GroupSize64>,
|
||||
std::tuple< RowMajor, RowMajor, RowMajor, ColumnMajor, BF16, PkFP4, E8M0, BF16, BQuantGrouped, GemmConfigMxFP4, GroupSize64>,
|
||||
std::tuple< RowMajor, RowMajor, RowMajor, ColumnMajor, BF16, BF16, E8M0, BF16, BQuantGrouped, GemmConfigMx, GroupSize64>,
|
||||
// RRR BQ: R
|
||||
std::tuple< RowMajor, RowMajor, RowMajor, RowMajor, BF16, BF16, E8M0, BF16, BQuantGrouped, GemmConfigMx, GroupSize64>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
// Test suite for BQuant 1D 64
|
||||
TYPED_TEST_SUITE(TestCkTileGemmBQuant, BQuant1D64Types);
|
||||
|
||||
// BQuant tests
|
||||
TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedTest)
|
||||
{
|
||||
this->run_test_with_validation(1024, 1024, 1024);
|
||||
}
|
||||
@@ -102,13 +102,24 @@ struct GemmConfigDecodeInterwave : public GemmConfigBase
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave;
|
||||
};
|
||||
|
||||
struct GemmConfigMxFp4 : public GemmConfigBase
|
||||
struct GemmConfigMx : public GemmConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 128;
|
||||
};
|
||||
|
||||
// This configuration uses K_Warp_Tile = 64 on CDNA. In this way, on gfx950 we can use
|
||||
// LDS load transpose on matrix B (FP4) because the instruction requires each
|
||||
// lane to load 16 4bits elements
|
||||
struct GemmConfigMxFP4 : public GemmConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<true>();
|
||||
};
|
||||
|
||||
struct GemmConfigPreshuffleQuant : public GemmConfigBase
|
||||
{
|
||||
static constexpr bool APreshuffleQuant = true;
|
||||
@@ -666,8 +677,7 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGem
|
||||
ck_tile::index_t k_batch = 1)
|
||||
{
|
||||
const ck_tile::index_t stride_A = K;
|
||||
const ck_tile::index_t stride_B =
|
||||
std::is_same_v<BDataType, ck_tile::pk_fp4_raw_t> ? (K / 2) : K;
|
||||
const ck_tile::index_t stride_B = K;
|
||||
const ck_tile::index_t stride_C = N;
|
||||
|
||||
// BQuant uses block/grouped quantization for B matrix
|
||||
@@ -678,24 +688,36 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGem
|
||||
// Generate test data
|
||||
ck_tile::HostTensor<ADataType> a_m_k(
|
||||
ck_tile::host_tensor_descriptor(M, K, stride_A, this->is_row_major(ALayout{})));
|
||||
ck_tile::HostTensor<BDataType> b_k_n(ck_tile::host_tensor_descriptor(
|
||||
std::is_same_v<BDataType, ck_tile::pk_fp4_raw_t> ? K / 2 : K,
|
||||
N,
|
||||
stride_B,
|
||||
this->is_row_major(BLayout{})));
|
||||
ck_tile::HostTensor<BDataType> b_k_n(
|
||||
ck_tile::host_tensor_descriptor(K, N, stride_B, this->is_row_major(BLayout{})));
|
||||
ck_tile::HostTensor<QDataType> bq_bqk_bqn(
|
||||
ck_tile::host_tensor_descriptor(BQK, BQN, stride_BQ, this->is_row_major(BQLayout{})));
|
||||
|
||||
// Initialize data with random values
|
||||
ck_tile::FillUniformDistribution<ADataType>{-0.5f, 0.5f}(a_m_k);
|
||||
if constexpr(std::is_same_v<BDataType, ck_tile::pk_fp4_raw_t>)
|
||||
if constexpr(std::is_same_v<BDataType, ck_tile::pk_fp4_t>)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<BDataType>{-5.0f, 5.0f}(b_k_n);
|
||||
ck_tile::FillUniformDistribution<QDataType>{125.f, 130.f}(bq_bqk_bqn);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::FillUniformDistribution<BDataType>{0.f, 1.f}(b_k_n);
|
||||
}
|
||||
|
||||
if constexpr(std::is_same_v<QDataType, ck_tile::e8m0_t>)
|
||||
{
|
||||
auto gen_scales = [&](auto& scales, float range_min, float range_max) {
|
||||
// e8m0_t is basically an exponent of float32
|
||||
ck_tile::HostTensor<float> pow2(scales.get_lengths());
|
||||
ck_tile::FillUniformDistributionIntegerValue<float>{range_min, range_max}(pow2);
|
||||
scales.ForEach([&](auto& self, const auto& i) {
|
||||
self(i) = static_cast<QDataType>(std::exp2(pow2(i)));
|
||||
});
|
||||
};
|
||||
gen_scales(bq_bqk_bqn, -2, 2);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::FillUniformDistribution<QDataType>{-1.0f, 1.0f}(bq_bqk_bqn);
|
||||
}
|
||||
|
||||
@@ -780,14 +802,15 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGem
|
||||
c_m_n_host_ref.SetZero();
|
||||
|
||||
// Run reference BQuant implementation
|
||||
if constexpr(std::is_same_v<BDataType, ck_tile::pk_fp4_raw_t>)
|
||||
ck_tile::reference_mxfp4gemm_quant<ADataType,
|
||||
QDataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
QuantGroupSize,
|
||||
false>(a_m_k, bq_bqk_bqn, b_k_n, c_m_n_host_ref);
|
||||
if constexpr(std::is_same_v<QDataType, ck_tile::e8m0_t>)
|
||||
ck_tile::reference_mx_gemm_bquant<ADataType,
|
||||
QDataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
QuantGroupSize,
|
||||
BLayout,
|
||||
false>(a_m_k, bq_bqk_bqn, b_k_n, c_m_n_host_ref);
|
||||
else
|
||||
ck_tile::reference_gemm_quant<ADataType,
|
||||
QDataType,
|
||||
@@ -852,8 +875,11 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGem
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto b_cast_policy_v = std::is_same_v<ADataType, BDataType>
|
||||
? ck_tile::CastPolicy::BeforeLDSWrite
|
||||
: ck_tile::CastPolicy::AfterLDSRead;
|
||||
|
||||
using PipelineProblem =
|
||||
ck_tile::GemmBQuantPipelineProblem<ADataType,
|
||||
@@ -866,18 +892,19 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGem
|
||||
ComputeDataType,
|
||||
ck_tile::GemmPipelineScheduler::Intrawave,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>;
|
||||
tail_number_v,
|
||||
b_cast_policy_v>;
|
||||
|
||||
using GemmPipeline = std::conditional_t<
|
||||
PreshuffleB == false,
|
||||
std::conditional_t<std::is_same_v<BDataType, ck_tile::pk_fp4_raw_t>,
|
||||
ck_tile::MxFp4GemmPipelineAgBgCrCompV3<PipelineProblem>,
|
||||
std::conditional_t<std::is_same_v<QDataType, ck_tile::e8m0_t>,
|
||||
ck_tile::MicroscaleGemmPipelineAgBgCrCompV3<PipelineProblem>,
|
||||
ck_tile::BQuantGemmPipelineAgBgCrCompV3<PipelineProblem>>,
|
||||
ck_tile::WPQuantBPipelineAgBgCrV2<PipelineProblem>>;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
|
||||
ADataType,
|
||||
std::conditional_t<std::is_same_v<BDataType, ck_tile::pk_fp4_raw_t>,
|
||||
std::conditional_t<std::is_same_v<BDataType, ck_tile::pk_fp4_t>,
|
||||
ADataType,
|
||||
BDataType>,
|
||||
ck_tile::tuple<>,
|
||||
|
||||
Reference in New Issue
Block a user