[CK_Tile] Support for preshuffle weight(B) quant tensor for block scale gemm (#3165)

* formatted

* formatted

* formatting

* formatting

* formatting

* [CK TILE GEMM] Refactor block_scale_gemm examples

- Split cpp file to reduce building time
- Support multiple GemmConfig

* [CK TILE GEMM] Refactor block_scale_gemm examples

- Update Readme

* enable prefill shapes

* [CK TILE GEMM] Refactor block_scale_gemm examples

- Add support for rowcol and tensor GEMM operations

* [CK TILE GEMM] Refactor block_scale_gemm examples

- Update README

* adding preshuffle quant as new parameter and its associated new files

* remove debugging statements

* adding test

* enable preshuffle quant with permuteN

* updating readme and correcponding gemmconfigs

* updating cmake file

* fixing CI failures for grouped quant gemm

* addressing review comments

* fixing CI issue

* addressing reveiw comments

* formatting

* formatting

* fixing aquant operator overlaoding

* formatting

---------

Co-authored-by: Cong Ma <congma13@amd.com>
Co-authored-by: Thomas Ning <Thomas.Ning@amd.com>
This commit is contained in:
Khushbu Agarwal
2025-11-24 07:48:42 -08:00
committed by GitHub
parent e857e26bf6
commit 8111572785
31 changed files with 855 additions and 247 deletions

View File

@@ -20,6 +20,49 @@ auto shuffle_aq(const ck_tile::HostTensor<T>* t, int block_aq_k)
return ck_tile::reference_permute(t_view, {1, 0, 2});
}
template <typename T>
auto shuffle_bq(const ck_tile::HostTensor<T>* t, int block_bq_k)
{
const auto& lengths = t->get_lengths();
const size_t rank = lengths.size();
// Validate block_bq_k divisibility based on rank
int bqk_dim = (rank == 5) ? lengths[4] : (rank == 2) ? lengths[0] : -1;
if(bqk_dim < 0)
{
throw std::runtime_error("shuffle_bq expects either rank-2 or rank-5 tensor, got rank " +
std::to_string(rank));
}
if(bqk_dim % block_bq_k != 0)
{
throw std::runtime_error("shuffle_bq needs bqk dimension to be a multiple of block_bq_k.");
}
// For TilePermuteN
if(rank == 5)
{
// Handle 5D tensor: [n, nrepeat, nwarp, n_warp_tile, bqk]
ck_tile::HostTensor<T> t_view({static_cast<int>(lengths[0]),
static_cast<int>(lengths[1]),
static_cast<int>(lengths[2]),
static_cast<int>(lengths[3]),
bqk_dim / block_bq_k,
block_bq_k});
std::copy(t->begin(), t->end(), t_view.begin());
return ck_tile::reference_permute(t_view, {4, 0, 1, 2, 3, 5});
}
else // rank == 2
{
// Handle 2D tensor: [bqk, n]
int n_ = lengths[1];
ck_tile::HostTensor<T> t_view({n_, bqk_dim / block_bq_k, block_bq_k});
std::copy(t->begin(), t->end(), t_view.begin());
return ck_tile::reference_permute(t_view, {1, 0, 2});
}
}
template <typename GemmConfig, typename T>
auto shuffle_b(const ck_tile::HostTensor<T>& t)
{
@@ -64,7 +107,7 @@ auto shuffle_b(const ck_tile::HostTensor<T>& t)
}
template <typename GemmConfig, typename T>
auto shuffle_bq_permuteN(const ck_tile::HostTensor<T>& t)
auto bq_permuteN(const ck_tile::HostTensor<T>& t)
{
assert(t.get_lengths().size() == 2);