Add gemm weight preshuffle pk_int_t support (#2858)

* Factor out the three separate copies of load_interleaved_pk_type into a common utility class

* Add preprocessing with optional cache flushing and clearing of output for k_batch > 1 to the weight preshuffle GEMM example

* Remove a duplicate function

* Add support for B tensor type pk_int4_t for the weight preshuffle GEMM, with tests included

* I4 support introduced more failing test cases that mirror the existing ones for F8

* Simplify the check for which tests to skip (they all have F8 as A tensor type)

* Add a changelog entry

* add the test for v2 wp pipeline, polish the code, add the support of int4 for v2 wp pipeline

* have a workable version for atomic add

* Revert "have a workable version for atomic add"

This reverts commit 792377a590c26cfff9c8f545d9a9e8484a7422eb.

---------

Co-authored-by: ThomasNing <thomas.ning@amd.com>
This commit is contained in:
SamiAario-AMD
2025-09-19 07:26:10 +03:00
committed by GitHub
parent 30ab1d6a71
commit 47cd0d5cff
13 changed files with 183 additions and 177 deletions

View File

@@ -4,6 +4,7 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/elementwise.hpp"
@@ -13,7 +14,9 @@ namespace ck_tile {
// A is block window on shared memory
// B is block window on shared memory
// C is block distributed tensor
template <typename Problem_, typename Policy_ = BlockGemmASmemBSmemCRegV1DefaultPolicy>
template <typename Problem_,
typename Policy_ = BlockGemmASmemBSmemCRegV1DefaultPolicy,
index_t UnaryOpSize_ = 8>
struct BlockUniversalGemmAsBsCr
{
private:
@@ -91,6 +94,7 @@ struct BlockUniversalGemmAsBsCr
using ComputeDataType = remove_cvref_t<typename Traits::ComputeDataType>;
using CDataType = remove_cvref_t<typename Traits::CDataType>;
using Loader = remove_cvref_t<InterleavedPKTypeLoader<ComputeDataType, UnaryOpSize_>>;
using WarpGemm = remove_cvref_t<typename Traits::WarpGemm>;
static constexpr index_t KIterPerWarp = Traits::KIterPerWarp;
@@ -179,25 +183,6 @@ struct BlockUniversalGemmAsBsCr
return b_block_dstr_encode;
}
private:
template <typename WarpWindow, typename WarpTile>
CK_TILE_DEVICE static void load_interleaved_pk_type(WarpTile& warp_tile,
const WarpWindow& warp_window)
{
constexpr index_t UnaryOpSize = 8;
const element_wise::PassThroughPack8 elementwise_op{};
constexpr index_t thread_buffer_size = WarpTile::get_thread_buffer_size() / UnaryOpSize;
const auto in_dstr_tensors = load_tile(warp_window);
static_assert(WarpTile::get_thread_buffer_size() % UnaryOpSize == 0);
using ComputeVectorType = ComputeDataType __attribute__((ext_vector_type(UnaryOpSize)));
static_for<0, thread_buffer_size, 1>{}([&](auto i) {
elementwise_op(warp_tile.get_thread_buffer().template get_as<ComputeVectorType>()(i),
in_dstr_tensors.get_thread_buffer().template get_as<pk_int4x4_t>()[i]);
});
}
template <GemmPipelineScheduler Scheduler, typename GemmTraits>
struct BlockGemmImpl
{
@@ -239,7 +224,7 @@ struct BlockUniversalGemmAsBsCr
if constexpr(std::is_same_v<ADataType, pk_int4_t>)
{
load_interleaved_pk_type(a_warp_tile_, a_block_window);
Loader::load_interleaved_pk_type(a_warp_tile_, a_block_window);
}
else
{
@@ -247,7 +232,7 @@ struct BlockUniversalGemmAsBsCr
}
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
{
load_interleaved_pk_type(b_warp_tile_, b_block_window);
Loader::load_interleaved_pk_type(b_warp_tile_, b_block_window);
}
else
{
@@ -317,7 +302,7 @@ struct BlockUniversalGemmAsBsCr
{
if constexpr(std::is_same_v<ADataType, pk_int4_t>)
{
load_interleaved_pk_type(a_warp_tile_, a_block_window);
Loader::load_interleaved_pk_type(a_warp_tile_, a_block_window);
}
else if constexpr(ALoadTranspose)
{
@@ -329,7 +314,7 @@ struct BlockUniversalGemmAsBsCr
}
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
{
load_interleaved_pk_type(b_warp_tile_, b_block_window);
Loader::load_interleaved_pk_type(b_warp_tile_, b_block_window);
}
else if constexpr(BLoadTranspose)
{
@@ -468,7 +453,7 @@ struct BlockUniversalGemmAsBsCr
if constexpr(std::is_same_v<ADataType, pk_int4_t>)
{
load_interleaved_pk_type(a_warp_tile_, a_block_window);
Loader::load_interleaved_pk_type(a_warp_tile_, a_block_window);
}
else if constexpr(ALoadTranspose)
{
@@ -480,7 +465,7 @@ struct BlockUniversalGemmAsBsCr
}
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
{
load_interleaved_pk_type(b_warp_tile_, b_block_window);
Loader::load_interleaved_pk_type(b_warp_tile_, b_block_window);
}
else if constexpr(BLoadTranspose)
{