Merge commit '47cd0d5cff77658adc1c9f184c012ec3496e8214' into develop

This commit is contained in:
assistant-librarian[bot]
2025-09-19 05:12:36 +00:00
parent 142a7e067a
commit 042cd4e556
13 changed files with 183 additions and 177 deletions

View File

@@ -29,7 +29,8 @@ TYPED_TEST(TestCkTileBatchedGemm, Basic)
{256, 256, 64, 8},
{256, 256, 64, 16}};
if(ck_tile::get_device_name() != "gfx950") {
if(ck_tile::get_device_name() != "gfx950")
{
gemmParams.emplace_back(256, 256, 128, 2);
}

View File

@@ -2,6 +2,8 @@
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/host/permute_pk_int4.hpp"
template <typename Layout>
static constexpr inline auto is_row_major(Layout layout_)
{
@@ -91,61 +93,6 @@ void permute_tensor_b(Tensor& tensor)
}
}
template <typename Tensor>
void permute_vectors_i4x4_b(Tensor& tensor)
{
const ck_tile::index_t K = tensor.get_length(0);
const ck_tile::index_t N = tensor.get_length(1);
// vector pk_i4x4 permute
for(int i = 0; i < N; i++)
{
for(int j = 0; j < K; j += 8)
{
int8_t input[8];
for(int k = 0; k < 4; k++)
{
int8_t i4x2 = tensor(j + k * 2, i).data;
input[k * 2 + 0] = (i4x2 >> 4) & 0xf;
input[k * 2 + 1] = (i4x2 >> 0) & 0xf;
}
// permute 01234567->20643175
{
int8_t hi = input[2];
int8_t lo = input[0];
int8_t i4x2 = (hi << 4) | lo;
tensor(j + 0, i) = i4x2;
}
{
int8_t hi = input[6];
int8_t lo = input[4];
int8_t i4x2 = (hi << 4) | lo;
tensor(j + 2, i) = i4x2;
}
{
int8_t hi = input[3];
int8_t lo = input[1];
int8_t i4x2 = (hi << 4) | lo;
tensor(j + 4, i) = i4x2;
}
{
int8_t hi = input[7];
int8_t lo = input[5];
int8_t i4x2 = (hi << 4) | lo;
tensor(j + 6, i) = i4x2;
}
}
}
}
template <typename GemmConfig,
typename ADataType,
typename BDataType,

View File

@@ -13,6 +13,7 @@ using F16 = ck_tile::half_t;
using F32 = float;
using F8 = ck_tile::fp8_t;
using BF16 = ck_tile::bf16_t;
using I4 = ck_tile::pk_int4_t;
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
@@ -20,20 +21,24 @@ using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
using Default = ck_tile::integral_constant<ck_tile::GemmPipelineScheduler,
ck_tile::GemmPipelineScheduler::Default>;
using WeightPreshuffle =
ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::WeightPreshuffle>;
// Adding alias for the F8 parameters to facilitate skipping tests.
// This alias can be removed once test failures are fixed.
using F8Types = std::tuple<Row, Col, Row, F8, F8, F32, F16, Default, WeightPreshuffle>;
using WeightPreshuffleV1 =
ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::WeightPreshuffleV1>;
using WeightPreshuffleV2 =
ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::WeightPreshuffleV2>;
// clang-format off
using KernelTypesWeightPreshuffle = ::testing::Types<
std::tuple< Row, Col, Row, F16, F16, F32, F16, Default, WeightPreshuffle>,
std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, Default, WeightPreshuffle>
#if !CK_TILE_USE_WMMA || CK_TILE_USE_OCP_FP8
, F8Types
std::tuple< Row, Col, Row, F16, F16, F32, F16, Default, WeightPreshuffleV1>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, Default, WeightPreshuffleV2>,
std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, Default, WeightPreshuffleV2>,
std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, Default, WeightPreshuffleV1>
#if !CK_TILE_USE_WMMA || CK_TILE_USE_OCP_FP8
,
std::tuple< Row, Col, Row, F8, F8, F32, F16, Default, WeightPreshuffleV1>,
std::tuple< Row, Col, Row, F8, F8, F32, F16, Default, WeightPreshuffleV2>,
std::tuple< Row, Col, Row, F8, I4, F32, F16, Default, WeightPreshuffleV2>,
std::tuple< Row, Col, Row, F8, I4, F32, F16, Default, WeightPreshuffleV1>
#endif
>;

View File

@@ -20,7 +20,7 @@ TYPED_TEST(TEST_SUITE_NAME, GemmPreshuffle)
TYPED_TEST(TEST_SUITE_NAME, GemmPreshuffle_128x128x128)
{
if constexpr(std::is_same_v<TypeParam, F8Types>)
if constexpr(std::is_same_v<std::tuple_element_t<3, TypeParam>, F8>)
{
GTEST_SKIP() << "Skipping this test due to failures with F8";
}
@@ -48,7 +48,7 @@ TYPED_TEST(TEST_SUITE_NAME, GemmPreshuffle_128x128x4096)
TYPED_TEST(TEST_SUITE_NAME, GemmPreshuffle_128x2048x128)
{
if constexpr(std::is_same_v<TypeParam, F8Types>)
if constexpr(std::is_same_v<std::tuple_element_t<3, TypeParam>, F8>)
{
GTEST_SKIP() << "Skipping this test due to failures with F8";
}
@@ -77,7 +77,7 @@ TYPED_TEST(TEST_SUITE_NAME, GemmPreshuffle_128x2048x4096)
TYPED_TEST(TEST_SUITE_NAME, GemmPreshuffle_1024x128x128)
{
if constexpr(std::is_same_v<TypeParam, F8Types>)
if constexpr(std::is_same_v<std::tuple_element_t<3, TypeParam>, F8>)
{
GTEST_SKIP() << "Skipping this test due to failures with F8";
}
@@ -106,7 +106,7 @@ TYPED_TEST(TEST_SUITE_NAME, GemmPreshuffle_1024x128x4096)
TYPED_TEST(TEST_SUITE_NAME, GemmPreshuffle_1024x2048x128)
{
if constexpr(std::is_same_v<TypeParam, F8Types>)
if constexpr(std::is_same_v<std::tuple_element_t<3, TypeParam>, F8>)
{
GTEST_SKIP() << "Skipping this test due to failures with F8";
}

View File

@@ -8,6 +8,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/host/permute_pk_int4.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
@@ -34,20 +35,31 @@ auto calculate_rtol_atol(const ck_tile::index_t K,
enum struct GemmPipelineType
{
WeightPreshuffle
WeightPreshuffleV1,
WeightPreshuffleV2
};
template <GemmPipelineType PT, typename Problem>
struct GemmPipelineTypeSelector;
template <typename Problem>
struct GemmPipelineTypeSelector<GemmPipelineType::WeightPreshuffle, Problem>
struct GemmPipelineTypeSelector<GemmPipelineType::WeightPreshuffleV1, Problem>
{
using base_pipeline = ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV1<Problem>;
using pipeline = ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV1<Problem>;
static constexpr auto GetName() { return "GemmPipelineAgBgCrWeightPreshuffle"; }
static constexpr auto GetName() { return "GemmPipelineAgBgCrWeightPreshuffleV1"; }
};
template <typename Problem>
struct GemmPipelineTypeSelector<GemmPipelineType::WeightPreshuffleV2, Problem>
{
using base_pipeline = ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2<Problem>;
using pipeline = ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV2<Problem>;
static constexpr auto GetName() { return "GemmPipelineAgBgCrWeightPreshuffleV2"; }
};
template <typename Datatype>
struct config
{
@@ -122,7 +134,8 @@ class TestCkTileGemmPipeline : public ::testing::Test
constexpr bool kPadK = PadK;
constexpr bool preshuffle = Preshuffle;
constexpr bool DoubleSmemBuffer = false;
constexpr bool DoubleSmemBuffer =
(PipelineType == GemmPipelineType::WeightPreshuffleV2) ? true : false;
// TODO: For now - but this should also be a test parameter
constexpr bool TransposeC = false;
@@ -391,10 +404,19 @@ class TestCkTileGemmPipeline : public ::testing::Test
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes());
ck_tile::HostTensor<BDataType> b_shuffle_host = shuffle_b<GemmConfig>(b_k_n);
a_m_k_dev_buf.ToDevice(a_m_k.data());
b_k_n_dev_buf.ToDevice(b_shuffle_host.data());
ck_tile::HostTensor<BDataType> b_shuffle_host = shuffle_b<GemmConfig>(b_k_n);
if constexpr(std::is_same_v<BDataType, ck_tile::pk_int4_t>)
{
// Permute vector pk_i4x4 data for device implementation
ck_tile::HostTensor<BDataType> b_shuffle_host_dev = b_shuffle_host;
ck_tile::permute_vectors_i4x4_b(b_shuffle_host_dev);
b_k_n_dev_buf.ToDevice(b_shuffle_host_dev.data());
}
else
{
b_k_n_dev_buf.ToDevice(b_shuffle_host.data());
}
c_m_n_dev_buf.SetZero();
c_m_n_dev_result.SetZero();