mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-20 21:09:08 +00:00
[CK_TILE]: PreshuffleB + PreshuffleBQuant for ABQuant pipeline (#4268)
## Proposed changes Implement BQuantPreshuffle option for the ABQuant PreshuffleB pipeline. ## Checklist Please put an `x` into the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask. - [X] I have added tests relevant to the introduced functionality, and the unit tests are passing locally - [X] I have added the test to REGRESSION_TESTS list defined at the top of CMakeLists.txt in tests/CMakeLists.txt, **IF** the test takes more than 30 seconds to run. - [X] I have added inline documentation which enables the maintainers with understanding the motivation - [X] I have removed the stale documentation which is no longer relevant after this pull request - [ ] (If this change is user-facing) I have added release notes which provide the end users with a brief summary of the improvement from this pull request - [X] I have run `clang-format` on all changed files - [X] Any dependent changes have been merged --- 🔁 Imported from [ROCm/composable_kernel#3687](https://github.com/ROCm/composable_kernel/pull/3687) 🧑💻 Originally authored by @ErwinTerpstra --------- Co-authored-by: Erwin Terpstra <erwin.terpstra@streamhpc.com> Co-authored-by: systems-assistant[bot] <systems-assistant[bot]@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
1ac61a54c9
commit
f6bb48458d
@@ -76,6 +76,10 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12")
|
||||
)
|
||||
target_compile_options(test_tile_gemm_quant_abquant_preshuffle PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
add_gtest_executable(test_tile_gemm_quant_abquant_preshuffle_preshuffleQuant
|
||||
test_gemm_quant_abquant_preshuffle_preshuffleQuant.cpp
|
||||
)
|
||||
target_compile_options(test_tile_gemm_quant_abquant_preshuffle_preshuffleQuant PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
add_gtest_executable(test_tile_gemm_quant_abquant_a4w4_base
|
||||
test_gemm_quant_abquant_a4w4_base.cpp
|
||||
|
||||
@@ -0,0 +1,43 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <memory>
|
||||
|
||||
#include "test_gemm_quant_fixtures.hpp"
|
||||
|
||||
// Type aliases for readability
|
||||
using RowMajor = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using FP8 = ck_tile::fp8_t;
|
||||
using BF8 = ck_tile::bf8_t;
|
||||
using Half = ck_tile::half_t;
|
||||
using PkInt4 = ck_tile::pk_int4_t;
|
||||
using ABQuantGrouped =
|
||||
std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::ABQuantGrouped>;
|
||||
using GroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
|
||||
// 2d block sizes for BQuant
|
||||
using GroupSize2D128N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
|
||||
// Type combinations for ABQuant tests
|
||||
// Tuple format: <ALayout, BLayout, CLayout, AQLayout, ADataType, BDataType, QDataType, CDataType,
|
||||
// QuantType, GemmConfig, AQuantGroupSize, BQuantGroupSize, BQLayout>
|
||||
// clang-format off
|
||||
using ABQuantPreshuffleQuantTypes = ::testing::Types<
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, ABQuantGrouped, GemmConfigPreshuffleBPreshuffleQuantPrefill, GroupSize, GroupSize, ColumnMajor>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, ABQuantGrouped, GemmConfigPreshuffleBPreshuffleQuantPrefill, GroupSize, GroupSize2D128N, ColumnMajor>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
// Test suite for ABQuant
|
||||
TYPED_TEST_SUITE(TestCkTileGemmABQuant, ABQuantPreshuffleQuantTypes);
|
||||
|
||||
// AQuant tests
|
||||
TYPED_TEST(TestCkTileGemmABQuant, ABQuantGroupedTest)
|
||||
{
|
||||
this->run_test_with_validation(1024, 1024, 1024);
|
||||
}
|
||||
@@ -159,6 +159,11 @@ struct GemmConfigPreshuffleBPrefillTiledPermuteN : public GemmConfigPreshuffleBP
|
||||
static constexpr bool TiledMMAPermuteN = N_Repeat % 2 == 0;
|
||||
};
|
||||
|
||||
struct GemmConfigPreshuffleBPreshuffleQuantPrefill : public GemmConfigPreshuffleBPrefill
|
||||
{
|
||||
static constexpr bool BPreshuffleQuant = true;
|
||||
};
|
||||
|
||||
struct GemmConfigPreshuffleBPreshuffleQuantDecode : public GemmConfigPreshuffleBDecode
|
||||
{
|
||||
static constexpr bool BPreshuffleQuant = true;
|
||||
|
||||
Reference in New Issue
Block a user