mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[CK_Tile] Support for preshuffle weight(B) quant tensor for block scale gemm (#3165)
* formatted * formatted * formatting * formatting * formatting * [CK TILE GEMM] Refactor block_scale_gemm examples - Split cpp file to reduce building time - Support multiple GemmConfig * [CK TILE GEMM] Refactor block_scale_gemm examples - Update Readme * enable prefill shapes * [CK TILE GEMM] Refactor block_scale_gemm examples - Add support for rowcol and tensor GEMM operations * [CK TILE GEMM] Refactor block_scale_gemm examples - Update README * adding preshuffle quant as new parameter and its associated new files * remove debugging statements * adding test * enable preshuffle quant with permuteN * updating readme and correcponding gemmconfigs * updating cmake file * fixing CI failures for grouped quant gemm * addressing review comments * fixing CI issue * addressing reveiw comments * formatting * formatting * fixing aquant operator overlaoding * formatting --------- Co-authored-by: Cong Ma <congma13@amd.com> Co-authored-by: Thomas Ning <Thomas.Ning@amd.com>
This commit is contained in:
@@ -20,6 +20,49 @@ auto shuffle_aq(const ck_tile::HostTensor<T>* t, int block_aq_k)
|
||||
return ck_tile::reference_permute(t_view, {1, 0, 2});
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
auto shuffle_bq(const ck_tile::HostTensor<T>* t, int block_bq_k)
|
||||
{
|
||||
const auto& lengths = t->get_lengths();
|
||||
const size_t rank = lengths.size();
|
||||
|
||||
// Validate block_bq_k divisibility based on rank
|
||||
int bqk_dim = (rank == 5) ? lengths[4] : (rank == 2) ? lengths[0] : -1;
|
||||
|
||||
if(bqk_dim < 0)
|
||||
{
|
||||
throw std::runtime_error("shuffle_bq expects either rank-2 or rank-5 tensor, got rank " +
|
||||
std::to_string(rank));
|
||||
}
|
||||
|
||||
if(bqk_dim % block_bq_k != 0)
|
||||
{
|
||||
throw std::runtime_error("shuffle_bq needs bqk dimension to be a multiple of block_bq_k.");
|
||||
}
|
||||
|
||||
// For TilePermuteN
|
||||
if(rank == 5)
|
||||
{
|
||||
// Handle 5D tensor: [n, nrepeat, nwarp, n_warp_tile, bqk]
|
||||
ck_tile::HostTensor<T> t_view({static_cast<int>(lengths[0]),
|
||||
static_cast<int>(lengths[1]),
|
||||
static_cast<int>(lengths[2]),
|
||||
static_cast<int>(lengths[3]),
|
||||
bqk_dim / block_bq_k,
|
||||
block_bq_k});
|
||||
std::copy(t->begin(), t->end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {4, 0, 1, 2, 3, 5});
|
||||
}
|
||||
else // rank == 2
|
||||
{
|
||||
// Handle 2D tensor: [bqk, n]
|
||||
int n_ = lengths[1];
|
||||
ck_tile::HostTensor<T> t_view({n_, bqk_dim / block_bq_k, block_bq_k});
|
||||
std::copy(t->begin(), t->end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {1, 0, 2});
|
||||
}
|
||||
}
|
||||
|
||||
template <typename GemmConfig, typename T>
|
||||
auto shuffle_b(const ck_tile::HostTensor<T>& t)
|
||||
{
|
||||
@@ -64,7 +107,7 @@ auto shuffle_b(const ck_tile::HostTensor<T>& t)
|
||||
}
|
||||
|
||||
template <typename GemmConfig, typename T>
|
||||
auto shuffle_bq_permuteN(const ck_tile::HostTensor<T>& t)
|
||||
auto bq_permuteN(const ck_tile::HostTensor<T>& t)
|
||||
{
|
||||
assert(t.get_lengths().size() == 2);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user