Support b_scale: (#2350)

- extend pipeline v1 and v3
 - add instances
 - add tests
 - add example

Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>

[ROCm/composable_kernel commit: b01a27ff22]
This commit is contained in:
Enrico Degregori
2025-07-25 03:49:58 +02:00
committed by GitHub
parent c721559117
commit 2d68b3f9c0
24 changed files with 3744 additions and 1660 deletions

View File

@@ -242,6 +242,7 @@ add_subdirectory(gemm_add)
add_subdirectory(gemm_layernorm)
add_subdirectory(gemm_split_k)
add_subdirectory(gemm_universal)
add_subdirectory(gemm_b_scale)
add_subdirectory(gemm_universal_streamk)
add_subdirectory(gemm_reduce)
add_subdirectory(batched_gemm)

View File

@@ -0,0 +1,9 @@
add_gtest_executable(test_gemm_b_scale_xdl test_gemm_b_scale_xdl.cpp)
if(result EQUAL 0)
target_link_libraries(test_gemm_b_scale_xdl PRIVATE utility device_gemm_b_scale_instance)
endif()
add_gtest_executable(test_gemm_b_scale_wmma test_gemm_b_scale_wmma.cpp)
if(result EQUAL 0)
target_link_libraries(test_gemm_b_scale_wmma PRIVATE utility device_gemm_b_scale_instance)
endif()

View File

@@ -0,0 +1,43 @@
#pragma once
TYPED_TEST(TestGemmBScale_MK_NK, SmallM)
{
std::vector<int> Ms{1, 2, 3, 4, 5, 6};
constexpr int N = 256;
constexpr int K = 1024;
constexpr int StrideA = K;
constexpr int StrideB = K;
constexpr int StrideC = N;
for(int M : Ms)
this->Run(M, N, K, StrideA, StrideB, StrideC);
}
TYPED_TEST(TestGemmBScale_MK_NK, MidLargeM)
{
std::vector<int> Ms{127, 255, 312, 799, 1573};
constexpr int N = 512;
constexpr int K = 768;
constexpr int StrideA = K;
constexpr int StrideB = K;
constexpr int StrideC = N;
for(int M : Ms)
this->Run(M, N, K, StrideA, StrideB, StrideC);
}
TYPED_TEST(TestGemmBScale_MK_NK, Regular)
{
std::vector<int> Ms{512, 1024};
constexpr int N = 512;
constexpr int K = 1024;
constexpr int StrideA = K;
constexpr int StrideB = K;
constexpr int StrideC = N;
for(int M : Ms)
this->Run(M, N, K, StrideA, StrideB, StrideC);
}

View File

@@ -0,0 +1,97 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <string>
#include <sstream>
#include <tuple>
#include <vector>
#include <gtest/gtest.h>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "include/ck/utility/data_type.hpp"
#include "profiler/profile_gemm_b_scale_impl.hpp"
namespace ck {
namespace test {
template <typename Tuple>
class TestGemmBScale : public testing::Test
{
using Row = ck::tensor_layout::gemm::RowMajor;
using F32 = float;
protected:
using ALayout = std::tuple_element_t<0, Tuple>;
using BLayout = std::tuple_element_t<1, Tuple>;
using CLayout = Row;
using ADataType = std::tuple_element_t<2, Tuple>;
using BDataType = std::tuple_element_t<3, Tuple>;
using BScaleDataType = std::tuple_element_t<4, Tuple>;
using ComputeDataType = std::tuple_element_t<5, Tuple>;
using CDataType = std::tuple_element_t<6, Tuple>;
public:
static constexpr ck::index_t ScaleBlockK = 128; // all instances
static constexpr bool verify_ = true;
static constexpr int init_method_ = 2;
static constexpr bool log_ = false;
static constexpr bool bench_ = false; // measure kernel performance
std::vector<int> k_batches_;
void SetUp() override { k_batches_ = {1, 2}; }
void Run(const int M,
const int N,
const int K,
const int StrideA,
const int StrideB,
const int StrideC)
{
for(auto kb : k_batches_)
{
RunSingle(M, N, K, StrideA, StrideB, StrideC, kb);
}
}
void RunSingle(const int M,
const int N,
const int K,
const int StrideA,
const int StrideB,
const int StrideC,
int kbatch = 1,
int n_warmup = 1,
int n_iter = 10)
{
bool pass = ck::profiler::profile_gemm_b_scale_impl<ADataType,
BDataType,
BScaleDataType,
ComputeDataType,
F32,
CDataType,
ScaleBlockK,
ALayout,
BLayout,
CLayout>(verify_,
init_method_,
log_,
bench_,
M,
N,
K,
StrideA,
StrideB,
StrideC,
kbatch,
n_warmup,
n_iter);
EXPECT_TRUE(pass);
}
};
} // namespace test
} // namespace ck

View File

@@ -0,0 +1,45 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <tuple>
#include "gtest/gtest.h"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "test_gemm_b_scale_util.hpp"
using I4 = ck::pk_i4_t;
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
namespace {
template <typename X, typename Y>
struct tuple_concat;
template <typename... Xs, typename... Ys>
struct tuple_concat<std::tuple<Xs...>, std::tuple<Ys...>>
{
using type = std::tuple<Xs..., Ys...>;
};
} // namespace
template <typename Tuple>
class TestGemmBScale_MK_NK
: public ck::test::TestGemmBScale<typename tuple_concat<std::tuple<Row, Col>, Tuple>::type>
{
};
// clang-format off
using KernelTypes_MK_NK = ::testing::Types<
// ADataType, BDataType, BScaleDataType, ComputeDataType, CDataType
std::tuple< F16, I4, F16, F16, F16>
>;
// clang-format on
TYPED_TEST_SUITE(TestGemmBScale_MK_NK, KernelTypes_MK_NK);
#include "test_gemm_b_scale_ut_cases.inc"

View File

@@ -0,0 +1,45 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <tuple>
#include "gtest/gtest.h"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "test_gemm_b_scale_util.hpp"
using I4 = ck::pk_i4_t;
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
namespace {
template <typename X, typename Y>
struct tuple_concat;
template <typename... Xs, typename... Ys>
struct tuple_concat<std::tuple<Xs...>, std::tuple<Ys...>>
{
using type = std::tuple<Xs..., Ys...>;
};
} // namespace
template <typename Tuple>
class TestGemmBScale_MK_NK
: public ck::test::TestGemmBScale<typename tuple_concat<std::tuple<Row, Col>, Tuple>::type>
{
};
// clang-format off
using KernelTypes_MK_NK = ::testing::Types<
// ADataType, BDataType, BScaleDataType, ComputeDataType, CDataType
std::tuple< F16, I4, F16, F16, F16>
>;
// clang-format on
TYPED_TEST_SUITE(TestGemmBScale_MK_NK, KernelTypes_MK_NK);
#include "test_gemm_b_scale_ut_cases.inc"