[CK_TILE] Adding support for TiledPermuteN on preshuffle Block Scale Gemm (#3019)

* Adding support for TiledPermuteN

* Adding test

* resolving remod.py

---------

Co-authored-by: root <root@banff-cyxtera-s73-2.ctr.dcgpu>

[ROCm/composable_kernel commit: 0584399571]
This commit is contained in:
Khushbu Agarwal
2025-10-24 11:06:51 -07:00
committed by GitHub
parent a1681b077e
commit eef9513fd3
8 changed files with 161 additions and 98 deletions

View File

@@ -143,7 +143,11 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
GemmConfig::N_Warp_Tile,
GemmConfig::K_Warp_Tile,
transpose_c,
ck_tile::memory_operation_enum::set>>;
ck_tile::memory_operation_enum::set,
1,
false,
1,
GemmConfig::TiledMMAPermuteN>>;
using Kernel =
ck_tile::QuantGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue, QuantMode>;

View File

@@ -93,6 +93,7 @@ struct GemmConfigBase
static constexpr bool PreshuffleQuant = false;
static constexpr bool PreshuffleB = false;
static constexpr bool DoubleSmemBuffer = false;
static constexpr bool TiledMMAPermuteN = false;
};
template <typename PrecType>
@@ -164,6 +165,9 @@ struct GemmConfigPreshuffleB_Bquant_decode : public GemmConfigBase
static constexpr bool PreshuffleB = true;
static constexpr bool DoubleSmemBuffer = true;
static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp;
static constexpr bool TiledMMAPermuteN = N_Repeat % 2 == 0;
};
template <typename PrecType>
@@ -184,6 +188,9 @@ struct GemmConfigPreshuffleB_Bquant_prefill : public GemmConfigBase
static constexpr bool PreshuffleB = true;
static constexpr bool DoubleSmemBuffer = true;
static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp;
static constexpr bool TiledMMAPermuteN = N_Repeat % 2 == 0;
};
template <typename ADataType_,

View File

@@ -5,40 +5,7 @@
#include <random>
#include <stdexcept>
#include "ck_tile/host/permute_pk_int4.hpp"
template <typename T>
auto shuffle_aq(const ck_tile::HostTensor<T>* t, int block_aq_k)
{
if(t->get_lengths().size() != 2)
{
throw std::runtime_error("Host tensor is not rank 2 tensor.");
}
int m_ = t->get_lengths()[0];
int aqk_ = t->get_lengths()[1];
if(aqk_ % block_aq_k != 0)
{
throw std::runtime_error("shuffle_aq needs a aqk of multiple times of block_aq_k.");
}
ck_tile::HostTensor<T> t_view({m_, aqk_ / block_aq_k, block_aq_k});
std::copy(t->begin(), t->end(), t_view.begin());
return ck_tile::reference_permute(t_view, {1, 0, 2});
}
template <typename GemmConfig, typename T>
auto shuffle_b(const ck_tile::HostTensor<T>& t)
{
assert(t.get_lengths().size() == 2);
int n_ = t.get_lengths()[1];
int k_ = t.get_lengths()[0];
constexpr int divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4;
ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Warp_Tile,
GemmConfig::N_Warp_Tile,
k_ / GemmConfig::K_Warp_Tile,
divisor,
GemmConfig::K_Warp_Tile / divisor});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
}
#include "ck_tile/host/shuffle_utils.hpp"
template <typename GemmConfig,
typename TypeConfig,
@@ -390,7 +357,7 @@ int run_gemm_example_with_layouts(int argc,
if constexpr(GemmConfig::PreshuffleQuant)
{
ck_tile::HostTensor<AQDataType> aq_shuffle_host =
shuffle_aq(aq_tensor_ptr.get(), GemmConfig::K_Tile / QuantGroupSize);
ck_tile::shuffle_aq(aq_tensor_ptr.get(), GemmConfig::K_Tile / QuantGroupSize);
aq_dev_buf_ptr->ToDevice(aq_shuffle_host.data());
}
else
@@ -412,25 +379,26 @@ int run_gemm_example_with_layouts(int argc,
}
ck_tile::HostTensor<BDataType> b_k_n_dev = b_k_n;
if constexpr(GemmConfig::PreshuffleB)
{
if constexpr(GemmConfig::TiledMMAPermuteN)
{
printf("PreshuffleB with TiledMMAPermuteN\n");
b_k_n_dev = ck_tile::shuffle_b_permuteN<GemmConfig>(b_k_n);
}
else
{
printf("PreshuffleB without TiledMMAPermuteN\n");
b_k_n_dev = ck_tile::shuffle_b<GemmConfig>(b_k_n);
}
}
if constexpr(std::is_same_v<BDataType, ck_tile::pk_int4_t>)
{
if constexpr(GemmConfig::PreshuffleB)
{
b_k_n_dev = shuffle_b<GemmConfig>(b_k_n);
}
ck_tile::permute_vectors_i4x4_b(b_k_n_dev);
b_k_n_dev_buf.ToDevice(b_k_n_dev.data());
}
else
{
if constexpr(GemmConfig::PreshuffleB)
{
b_k_n_dev = shuffle_b<GemmConfig>(b_k_n);
}
b_k_n_dev_buf.ToDevice(b_k_n_dev.data());
}
b_k_n_dev_buf.ToDevice(b_k_n_dev.data());
c_m_n_dev_buf.SetZero();
c_m_n_dev_result.SetZero();
@@ -438,7 +406,15 @@ int run_gemm_example_with_layouts(int argc,
QuantMode == ck_tile::QuantType::RowColQuant ||
QuantMode == ck_tile::QuantType::TensorQuant)
{
bq_dev_buf_ptr->ToDevice(bq_tensor_ptr->data());
if constexpr(GemmConfig::PreshuffleB && GemmConfig::TiledMMAPermuteN)
{
printf("Preshuffle BQ with TiledMMAPermuteN \n");
ck_tile::HostTensor<BQDataType> bq_shuffle_host =
ck_tile::shuffle_bq_permuteN<GemmConfig>(*bq_tensor_ptr);
bq_dev_buf_ptr->ToDevice(bq_shuffle_host.data());
}
else
bq_dev_buf_ptr->ToDevice(bq_tensor_ptr->data());
}
invoke_gemm<GemmConfig,

View File

@@ -46,6 +46,7 @@
#include "ck_tile/host/reference/reference_topk.hpp"
#include "ck_tile/host/reference/reference_transpose.hpp"
#include "ck_tile/host/rotating_buffers.hpp"
#include "ck_tile/host/shuffle_utils.hpp"
#include "ck_tile/host/stream_config.hpp"
#include "ck_tile/host/stream_utils.hpp"
#include "ck_tile/host/timer.hpp"

View File

@@ -0,0 +1,75 @@
#pragma once
#include <stdexcept>
namespace ck_tile {
template <typename T>
auto shuffle_aq(const ck_tile::HostTensor<T>* t, int block_aq_k)
{
if(t->get_lengths().size() != 2)
{
throw std::runtime_error("Host tensor is not rank 2 tensor.");
}
int m_ = t->get_lengths()[0];
int aqk_ = t->get_lengths()[1];
if(aqk_ % block_aq_k != 0)
{
throw std::runtime_error("shuffle_aq needs a aqk of multiple times of block_aq_k.");
}
ck_tile::HostTensor<T> t_view({m_, aqk_ / block_aq_k, block_aq_k});
std::copy(t->begin(), t->end(), t_view.begin());
return ck_tile::reference_permute(t_view, {1, 0, 2});
}
template <typename GemmConfig, typename T>
auto shuffle_b(const ck_tile::HostTensor<T>& t)
{
assert(t.get_lengths().size() == 2);
int n_ = t.get_lengths()[1];
int k_ = t.get_lengths()[0];
constexpr int divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4;
ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Warp_Tile,
GemmConfig::N_Warp_Tile,
k_ / GemmConfig::K_Warp_Tile,
divisor,
GemmConfig::K_Warp_Tile / divisor});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
}
template <typename GemmConfig, typename T>
auto shuffle_bq_permuteN(const ck_tile::HostTensor<T>& t)
{
assert(t.get_lengths().size() == 2);
int n_ = t.get_lengths()[1];
int bqk_ = t.get_lengths()[0];
constexpr int NRepeat = GemmConfig::N_Tile / GemmConfig::N_Warp_Tile / GemmConfig::N_Warp;
ck_tile::HostTensor<T> t_view(
{n_ / GemmConfig::N_Tile, GemmConfig::N_Warp, GemmConfig::N_Warp_Tile, NRepeat, bqk_});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 3, 1, 2, 4});
}
template <typename GemmConfig, typename T>
auto shuffle_b_permuteN(const ck_tile::HostTensor<T>& t)
{
assert(t.get_lengths().size() == 2);
int n_ = t.get_lengths()[1];
int k_ = t.get_lengths()[0];
constexpr int divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4;
constexpr int NRepeat = GemmConfig::N_Tile / GemmConfig::N_Warp_Tile / GemmConfig::N_Warp;
ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Tile,
GemmConfig::N_Warp,
GemmConfig::N_Warp_Tile,
NRepeat,
k_ / GemmConfig::K_Warp_Tile,
divisor,
GemmConfig::K_Warp_Tile / divisor});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 3, 1, 4, 5, 2, 6});
}
} // namespace ck_tile

View File

@@ -55,6 +55,7 @@ class TestCkTileGemmQuantBase : public ::testing::Test
static constexpr ck_tile::index_t K_Warp_Tile = GemmConfig::K_Warp_Tile;
static constexpr bool PreshuffleQuant = GemmConfig::PreshuffleQuant;
static constexpr bool PreshuffleB = GemmConfig::PreshuffleB;
static constexpr bool TiledMMAPermuteN = GemmConfig::TiledMMAPermuteN;
static constexpr bool DoubleSmemBuffer = GemmConfig::DoubleSmemBuffer;
public:
@@ -132,19 +133,6 @@ class TestCkTileGemmQuantBase : public ::testing::Test
// Use higher threshold
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
}
template <typename T>
auto shuffle_b(const ck_tile::HostTensor<T>& t)
{
assert(t.get_lengths().size() == 2);
int n_ = t.get_lengths()[1];
int k_ = t.get_lengths()[0];
constexpr int divisor = N_Warp_Tile == 32 ? 2 : 4;
ck_tile::HostTensor<T> t_view(
{n_ / N_Warp_Tile, N_Warp_Tile, k_ / K_Warp_Tile, divisor, K_Warp_Tile / divisor});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
}
};
// Define generic QuantTypeTraits template (will be specialized)

View File

@@ -5,6 +5,7 @@
#include "test_gemm_quant_base.hpp"
#include "ck_tile/host/permute_pk_int4.hpp"
#include "ck_tile/host/shuffle_utils.hpp"
struct GemmConfigBase
{
@@ -26,6 +27,7 @@ struct GemmConfigBase
static constexpr bool PreshuffleQuant = false;
static constexpr bool PreshuffleB = false;
static constexpr bool DoubleSmemBuffer = false;
static constexpr bool TiledMMAPermuteN = false;
// Default GEMM tile sizes for tests
static constexpr ck_tile::index_t M_Tile = 16;
@@ -95,6 +97,12 @@ struct GemmConfigPreshuffleBPrefill : public GemmConfigBase
static constexpr ck_tile::index_t K_Warp_Tile = 64;
};
struct GemmConfigPreshuffleBPrefillTiledPermuteN : public GemmConfigPreshuffleBPrefill
{
static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp;
static constexpr bool TiledMMAPermuteN = N_Repeat % 2 == 0;
};
template <typename Tuple>
class TestCkTileGemmAQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGemmAQuant<Tuple>>
{
@@ -119,24 +127,6 @@ class TestCkTileGemmAQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGem
void SetUpQuantTypeSpecific() {}
void TearDownQuantTypeSpecific() {}
template <typename T>
auto shuffle_aq(const ck_tile::HostTensor<T>* t, int block_aq_k)
{
if(t->get_lengths().size() != 2)
{
throw std::runtime_error("Host tensor is not rank 2 tensor.");
}
int m_ = t->get_lengths()[0];
int aqk_ = t->get_lengths()[1];
if(aqk_ % block_aq_k != 0)
{
throw std::runtime_error("shuffle_aq needs a aqk of multiple times of block_aq_k.");
}
ck_tile::HostTensor<T> t_view({m_, aqk_ / block_aq_k, block_aq_k});
std::copy(t->begin(), t->end(), t_view.begin());
return ck_tile::reference_permute(t_view, {1, 0, 2});
}
// AQuant-specific data generation
void run_test_with_validation(ck_tile::index_t M, ck_tile::index_t N, ck_tile::index_t K)
{
@@ -191,7 +181,7 @@ class TestCkTileGemmAQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGem
if constexpr(Base::GemmConfig::PreshuffleQuant)
{
ck_tile::HostTensor<QDataType> aq_shuffle_host =
shuffle_aq(&aq_m_aqk, Base::GemmConfig::K_Tile / QuantGroupSize);
ck_tile::shuffle_aq(&aq_m_aqk, Base::GemmConfig::K_Tile / QuantGroupSize);
aq_m_aqk_dev_buf.ToDevice(aq_shuffle_host.data());
}
else
@@ -367,11 +357,13 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGem
using typename Base::CDataType;
using typename Base::CLayout;
using typename Base::ComputeDataType;
using typename Base::GemmConfig;
using typename Base::QDataType;
static constexpr auto QuantType = Base::QuantType;
static constexpr uint32_t QuantGroupSize = Base::QuantGroupSize;
static constexpr auto PreshuffleB = Base::PreshuffleB;
static constexpr auto TiledMMAPermuteN = Base::TiledMMAPermuteN;
protected:
void SetUpQuantTypeSpecific() {}
@@ -409,24 +401,35 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGem
// Copy to device
a_m_k_dev_buf.ToDevice(a_m_k.data());
ck_tile::HostTensor<BDataType> b_k_n_dev = b_k_n;
if constexpr(PreshuffleB)
{
if constexpr(TiledMMAPermuteN)
{
printf("PreshuffleB with TiledMMAPermuteN\n");
b_k_n_dev = ck_tile::shuffle_b_permuteN<GemmConfig>(b_k_n);
}
else
{
printf("PreshuffleB without TiledMMAPermuteN\n");
b_k_n_dev = ck_tile::shuffle_b<GemmConfig>(b_k_n);
}
}
if constexpr(std::is_same_v<BDataType, ck_tile::pk_int4_t>)
{
if constexpr(PreshuffleB)
{
b_k_n_dev = this->shuffle_b(b_k_n);
}
ck_tile::permute_vectors_i4x4_b(b_k_n_dev);
b_k_n_dev_buf.ToDevice(b_k_n_dev.data());
}
b_k_n_dev_buf.ToDevice(b_k_n_dev.data());
if constexpr(PreshuffleB && TiledMMAPermuteN)
{
printf("Preshuffle BQ with TiledMMAPermuteN \n");
ck_tile::HostTensor<QDataType> bq_shuffle_host =
ck_tile::shuffle_bq_permuteN<GemmConfig>(bq_bqk_n);
bq_bqk_n_dev_buf.ToDevice(bq_shuffle_host.data());
}
else
{
if constexpr(PreshuffleB)
{
b_k_n_dev = this->shuffle_b(b_k_n);
}
b_k_n_dev_buf.ToDevice(b_k_n_dev.data());
}
bq_bqk_n_dev_buf.ToDevice(bq_bqk_n.data());
bq_bqk_n_dev_buf.ToDevice(bq_bqk_n.data());
// Create args for kernel execution
ck_tile::QuantGemmHostArgs args{
@@ -559,7 +562,11 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGem
Base::N_Warp_Tile,
Base::K_Warp_Tile,
false, // transpose_c
ck_tile::memory_operation_enum::set>>;
ck_tile::memory_operation_enum::set,
1,
false,
1,
TiledMMAPermuteN>>;
using Kernel = ck_tile::QuantGemmKernel<TilePartitioner,
GemmPipeline,

View File

@@ -70,7 +70,12 @@ using BPreshuffleBQuantTypes = ::testing::Types<
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize>
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigPreshuffleBPrefillTiledPermuteN, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigPreshuffleBPrefillTiledPermuteN, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigPreshuffleBPrefillTiledPermuteN, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigPreshuffleBPrefillTiledPermuteN, GroupSize>
>;
// clang-format off