[CK TILE] GEMM with packed i4 (#1885)

* [CK TILE] GEMM with packed i4

* Fixes

* fixes

* fixes

* fixes
This commit is contained in:
Bartłomiej Kocot
2025-02-20 09:59:49 +01:00
committed by GitHub
parent 824e2c1737
commit 4d9973ec8e
32 changed files with 882 additions and 305 deletions

View File

@@ -1,11 +1,12 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/elementwise.hpp"
namespace ck_tile {
@@ -20,12 +21,13 @@ struct BlockUniversalGemmAsBsCr
template <typename PipelineProblem_, typename GemmPolicy_>
struct GemmTraits_
{
using Problem = remove_cvref_t<PipelineProblem_>;
using Policy = remove_cvref_t<GemmPolicy_>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
using Problem = remove_cvref_t<PipelineProblem_>;
using Policy = remove_cvref_t<GemmPolicy_>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr auto Scheduler = Problem::Scheduler;
@@ -71,10 +73,10 @@ struct BlockUniversalGemmAsBsCr
using BWarpTileDistr = remove_cvref_t<decltype(make_static_tile_distribution(
typename WarpGemm::BWarpDstrEncoding{}))>;
using AWarpTile =
remove_cvref_t<decltype(make_static_distributed_tensor<ADataType>(AWarpTileDistr{}))>;
using BWarpTile =
remove_cvref_t<decltype(make_static_distributed_tensor<BDataType>(BWarpTileDistr{}))>;
using AWarpTile = remove_cvref_t<decltype(make_static_distributed_tensor<ComputeDataType>(
AWarpTileDistr{}))>;
using BWarpTile = remove_cvref_t<decltype(make_static_distributed_tensor<ComputeDataType>(
BWarpTileDistr{}))>;
// TODO: Should we have two policies? Interwave & Intrawave ??
static constexpr index_t InterWaveSchedulingMacClusters = 1;
@@ -90,9 +92,10 @@ struct BlockUniversalGemmAsBsCr
public:
using Traits = GemmTraits_<Problem_, Policy_>;
using ADataType = remove_cvref_t<typename Traits::ADataType>;
using BDataType = remove_cvref_t<typename Traits::BDataType>;
using CDataType = remove_cvref_t<typename Traits::CDataType>;
using ADataType = remove_cvref_t<typename Traits::ADataType>;
using BDataType = remove_cvref_t<typename Traits::BDataType>;
using ComputeDataType = remove_cvref_t<typename Traits::ComputeDataType>;
using CDataType = remove_cvref_t<typename Traits::CDataType>;
using WarpGemm = remove_cvref_t<typename Traits::WarpGemm>;
@@ -105,10 +108,34 @@ struct BlockUniversalGemmAsBsCr
static constexpr auto Scheduler = Traits::Scheduler;
static constexpr index_t APackedSize =
ck_tile::numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
static constexpr index_t BPackedSize =
ck_tile::numeric_traits<remove_cvref_t<BDataType>>::PackedSize;
using I0 = number<0>;
using I1 = number<1>;
private:
template <typename WarpWindow, typename WarpTile>
CK_TILE_DEVICE static void load_interleaved_pk_type(const WarpWindow& warp_window,
WarpTile& warp_tile)
{
constexpr index_t UnaryOpSize = 8;
const element_wise::PassThroughPack8 elementwise_op{};
constexpr index_t thread_buffer_size =
Traits::AWarpTile::get_thread_buffer_size() / UnaryOpSize;
const auto in_dstr_tensors = load_tile(warp_window);
static_assert(Traits::AWarpTile::get_thread_buffer_size() % UnaryOpSize == 0);
using ComputeVectorType = ComputeDataType __attribute__((ext_vector_type(UnaryOpSize)));
static_for<0, thread_buffer_size, 1>{}([&](auto i) {
elementwise_op(warp_tile.get_thread_buffer().template get_as<ComputeVectorType>()(i),
in_dstr_tensors.get_thread_buffer().template get_as<pk_int4x4_t>()[i]);
});
}
template <GemmPipelineScheduler Scheduler, typename GemmTraits>
struct BlockGemmImpl
{
@@ -208,6 +235,8 @@ struct BlockUniversalGemmAsBsCr
});
using CWarpDstr = typename WarpGemm::CWarpDstr;
using AWarpTensor = typename WarpGemm::AWarpTensor;
using BWarpTensor = typename WarpGemm::BWarpTensor;
using CWarpTensor = typename WarpGemm::CWarpTensor;
constexpr auto c_warp_y_lengths =
@@ -217,10 +246,26 @@ struct BlockUniversalGemmAsBsCr
// hot loop:
static_for<0, GemmTraits::KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
const auto a_warp_tile = load_tile(a_warp_windows(mIter)(kIter));
AWarpTensor a_warp_tile;
if constexpr(std::is_same_v<ADataType, pk_int4_t>)
{
load_interleaved_pk_type(a_warp_windows(mIter)(kIter), a_warp_tile);
}
else
{
a_warp_tile = load_tile(a_warp_windows(mIter)(kIter));
}
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
const auto b_warp_tile = load_tile(b_warp_windows(nIter)(kIter));
BWarpTensor b_warp_tile;
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
{
load_interleaved_pk_type(b_warp_windows(nIter)(kIter), b_warp_tile);
}
else
{
b_warp_tile = load_tile(b_warp_windows(nIter)(kIter));
}
// read C warp tensor from C block tensor-
CWarpTensor c_warp_tensor;
@@ -342,11 +387,27 @@ struct BlockUniversalGemmAsBsCr
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// read A warp tensor from A block window
load_tile(a_warp_tiles_(mIter)(kIter), a_warp_windows(mIter)(kIter));
if constexpr(std::is_same_v<ADataType, pk_int4_t>)
{
load_interleaved_pk_type(a_warp_windows(mIter)(kIter),
a_warp_tiles_(mIter)(kIter));
}
else
{
a_warp_tiles_(mIter)(kIter) = load_tile(a_warp_windows(mIter)(kIter));
}
});
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read B warp tensor from B Block window
load_tile(b_warp_tiles_(nIter)(kIter), b_warp_windows(nIter)(kIter));
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
{
load_interleaved_pk_type(b_warp_windows(nIter)(kIter),
b_warp_tiles_(nIter)(kIter));
}
else
{
b_warp_tiles_(nIter)(kIter) = load_tile(b_warp_windows(nIter)(kIter));
}
});
});
}
@@ -504,12 +565,27 @@ struct BlockUniversalGemmAsBsCr
// TODO check if a_warp_tiles has same desc as a_warp_window
static_for<0, KInnerLoopIter, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// read A warp tensor from A block window
load_tile(a_warp_tiles_(mIter)(kIter), a_warp_windows(mIter)(kIter));
if constexpr(std::is_same_v<ADataType, pk_int4_t>)
{
load_interleaved_pk_type(a_warp_windows(mIter)(kIter),
a_warp_tiles_(mIter)(kIter));
}
else
{
a_warp_tiles_(mIter)(kIter) = load_tile(a_warp_windows(mIter)(kIter));
}
});
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read B warp tensor from B Block window
load_tile(b_warp_tiles_(nIter)(kIter), b_warp_windows(nIter)(kIter));
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
{
load_interleaved_pk_type(b_warp_windows(nIter)(kIter),
b_warp_tiles_(nIter)(kIter));
}
else
{
b_warp_tiles_(nIter)(kIter) = load_tile(b_warp_windows(nIter)(kIter));
}
});
});
}