mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
[CK TILE] GEMM with packed i4 (#1885)
* [CK TILE] GEMM with packed i4 * Fixes * fixes * fixes * fixes
This commit is contained in:
@@ -1,11 +1,12 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
|
||||
#include "ck_tile/ops/elementwise.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -20,12 +21,13 @@ struct BlockUniversalGemmAsBsCr
|
||||
template <typename PipelineProblem_, typename GemmPolicy_>
|
||||
struct GemmTraits_
|
||||
{
|
||||
using Problem = remove_cvref_t<PipelineProblem_>;
|
||||
using Policy = remove_cvref_t<GemmPolicy_>;
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
using Problem = remove_cvref_t<PipelineProblem_>;
|
||||
using Policy = remove_cvref_t<GemmPolicy_>;
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
static constexpr auto Scheduler = Problem::Scheduler;
|
||||
@@ -71,10 +73,10 @@ struct BlockUniversalGemmAsBsCr
|
||||
using BWarpTileDistr = remove_cvref_t<decltype(make_static_tile_distribution(
|
||||
typename WarpGemm::BWarpDstrEncoding{}))>;
|
||||
|
||||
using AWarpTile =
|
||||
remove_cvref_t<decltype(make_static_distributed_tensor<ADataType>(AWarpTileDistr{}))>;
|
||||
using BWarpTile =
|
||||
remove_cvref_t<decltype(make_static_distributed_tensor<BDataType>(BWarpTileDistr{}))>;
|
||||
using AWarpTile = remove_cvref_t<decltype(make_static_distributed_tensor<ComputeDataType>(
|
||||
AWarpTileDistr{}))>;
|
||||
using BWarpTile = remove_cvref_t<decltype(make_static_distributed_tensor<ComputeDataType>(
|
||||
BWarpTileDistr{}))>;
|
||||
|
||||
// TODO: Should we have two policies? Interwave & Intrawave ??
|
||||
static constexpr index_t InterWaveSchedulingMacClusters = 1;
|
||||
@@ -90,9 +92,10 @@ struct BlockUniversalGemmAsBsCr
|
||||
public:
|
||||
using Traits = GemmTraits_<Problem_, Policy_>;
|
||||
|
||||
using ADataType = remove_cvref_t<typename Traits::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Traits::BDataType>;
|
||||
using CDataType = remove_cvref_t<typename Traits::CDataType>;
|
||||
using ADataType = remove_cvref_t<typename Traits::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Traits::BDataType>;
|
||||
using ComputeDataType = remove_cvref_t<typename Traits::ComputeDataType>;
|
||||
using CDataType = remove_cvref_t<typename Traits::CDataType>;
|
||||
|
||||
using WarpGemm = remove_cvref_t<typename Traits::WarpGemm>;
|
||||
|
||||
@@ -105,10 +108,34 @@ struct BlockUniversalGemmAsBsCr
|
||||
|
||||
static constexpr auto Scheduler = Traits::Scheduler;
|
||||
|
||||
static constexpr index_t APackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
|
||||
static constexpr index_t BPackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<BDataType>>::PackedSize;
|
||||
|
||||
using I0 = number<0>;
|
||||
using I1 = number<1>;
|
||||
|
||||
private:
|
||||
template <typename WarpWindow, typename WarpTile>
|
||||
CK_TILE_DEVICE static void load_interleaved_pk_type(const WarpWindow& warp_window,
|
||||
WarpTile& warp_tile)
|
||||
{
|
||||
constexpr index_t UnaryOpSize = 8;
|
||||
const element_wise::PassThroughPack8 elementwise_op{};
|
||||
constexpr index_t thread_buffer_size =
|
||||
Traits::AWarpTile::get_thread_buffer_size() / UnaryOpSize;
|
||||
const auto in_dstr_tensors = load_tile(warp_window);
|
||||
|
||||
static_assert(Traits::AWarpTile::get_thread_buffer_size() % UnaryOpSize == 0);
|
||||
|
||||
using ComputeVectorType = ComputeDataType __attribute__((ext_vector_type(UnaryOpSize)));
|
||||
static_for<0, thread_buffer_size, 1>{}([&](auto i) {
|
||||
elementwise_op(warp_tile.get_thread_buffer().template get_as<ComputeVectorType>()(i),
|
||||
in_dstr_tensors.get_thread_buffer().template get_as<pk_int4x4_t>()[i]);
|
||||
});
|
||||
}
|
||||
|
||||
template <GemmPipelineScheduler Scheduler, typename GemmTraits>
|
||||
struct BlockGemmImpl
|
||||
{
|
||||
@@ -208,6 +235,8 @@ struct BlockUniversalGemmAsBsCr
|
||||
});
|
||||
|
||||
using CWarpDstr = typename WarpGemm::CWarpDstr;
|
||||
using AWarpTensor = typename WarpGemm::AWarpTensor;
|
||||
using BWarpTensor = typename WarpGemm::BWarpTensor;
|
||||
using CWarpTensor = typename WarpGemm::CWarpTensor;
|
||||
|
||||
constexpr auto c_warp_y_lengths =
|
||||
@@ -217,10 +246,26 @@ struct BlockUniversalGemmAsBsCr
|
||||
// hot loop:
|
||||
static_for<0, GemmTraits::KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
const auto a_warp_tile = load_tile(a_warp_windows(mIter)(kIter));
|
||||
AWarpTensor a_warp_tile;
|
||||
if constexpr(std::is_same_v<ADataType, pk_int4_t>)
|
||||
{
|
||||
load_interleaved_pk_type(a_warp_windows(mIter)(kIter), a_warp_tile);
|
||||
}
|
||||
else
|
||||
{
|
||||
a_warp_tile = load_tile(a_warp_windows(mIter)(kIter));
|
||||
}
|
||||
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
const auto b_warp_tile = load_tile(b_warp_windows(nIter)(kIter));
|
||||
BWarpTensor b_warp_tile;
|
||||
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
|
||||
{
|
||||
load_interleaved_pk_type(b_warp_windows(nIter)(kIter), b_warp_tile);
|
||||
}
|
||||
else
|
||||
{
|
||||
b_warp_tile = load_tile(b_warp_windows(nIter)(kIter));
|
||||
}
|
||||
|
||||
// read C warp tensor from C block tensor-
|
||||
CWarpTensor c_warp_tensor;
|
||||
@@ -342,11 +387,27 @@ struct BlockUniversalGemmAsBsCr
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
// read A warp tensor from A block window
|
||||
load_tile(a_warp_tiles_(mIter)(kIter), a_warp_windows(mIter)(kIter));
|
||||
if constexpr(std::is_same_v<ADataType, pk_int4_t>)
|
||||
{
|
||||
load_interleaved_pk_type(a_warp_windows(mIter)(kIter),
|
||||
a_warp_tiles_(mIter)(kIter));
|
||||
}
|
||||
else
|
||||
{
|
||||
a_warp_tiles_(mIter)(kIter) = load_tile(a_warp_windows(mIter)(kIter));
|
||||
}
|
||||
});
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read B warp tensor from B Block window
|
||||
load_tile(b_warp_tiles_(nIter)(kIter), b_warp_windows(nIter)(kIter));
|
||||
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
|
||||
{
|
||||
load_interleaved_pk_type(b_warp_windows(nIter)(kIter),
|
||||
b_warp_tiles_(nIter)(kIter));
|
||||
}
|
||||
else
|
||||
{
|
||||
b_warp_tiles_(nIter)(kIter) = load_tile(b_warp_windows(nIter)(kIter));
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
@@ -504,12 +565,27 @@ struct BlockUniversalGemmAsBsCr
|
||||
// TODO check if a_warp_tiles has same desc as a_warp_window
|
||||
static_for<0, KInnerLoopIter, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
// read A warp tensor from A block window
|
||||
load_tile(a_warp_tiles_(mIter)(kIter), a_warp_windows(mIter)(kIter));
|
||||
if constexpr(std::is_same_v<ADataType, pk_int4_t>)
|
||||
{
|
||||
load_interleaved_pk_type(a_warp_windows(mIter)(kIter),
|
||||
a_warp_tiles_(mIter)(kIter));
|
||||
}
|
||||
else
|
||||
{
|
||||
a_warp_tiles_(mIter)(kIter) = load_tile(a_warp_windows(mIter)(kIter));
|
||||
}
|
||||
});
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read B warp tensor from B Block window
|
||||
load_tile(b_warp_tiles_(nIter)(kIter), b_warp_windows(nIter)(kIter));
|
||||
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
|
||||
{
|
||||
load_interleaved_pk_type(b_warp_windows(nIter)(kIter),
|
||||
b_warp_tiles_(nIter)(kIter));
|
||||
}
|
||||
else
|
||||
{
|
||||
b_warp_tiles_(nIter)(kIter) = load_tile(b_warp_windows(nIter)(kIter));
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user