mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-20 04:49:54 +00:00
[CK_Tile] Support for preshuffle weight(B) quant tensor for block scale gemm (#3165)
* formatted
* formatted
* formatting
* formatting
* formatting
* [CK TILE GEMM] Refactor block_scale_gemm examples
- Split cpp file to reduce building time
- Support multiple GemmConfig
* [CK TILE GEMM] Refactor block_scale_gemm examples
- Update Readme
* enable prefill shapes
* [CK TILE GEMM] Refactor block_scale_gemm examples
- Add support for rowcol and tensor GEMM operations
* [CK TILE GEMM] Refactor block_scale_gemm examples
- Update README
* adding preshuffle quant as new parameter and its associated new files
* remove debugging statements
* adding test
* enable preshuffle quant with permuteN
* updating readme and correcponding gemmconfigs
* updating cmake file
* fixing CI failures for grouped quant gemm
* addressing review comments
* fixing CI issue
* addressing reveiw comments
* formatting
* formatting
* fixing aquant operator overlaoding
* formatting
---------
Co-authored-by: Cong Ma <congma13@amd.com>
Co-authored-by: Thomas Ning <Thomas.Ning@amd.com>
[ROCm/composable_kernel commit: 8111572785]
This commit is contained in:
@@ -20,6 +20,49 @@ auto shuffle_aq(const ck_tile::HostTensor<T>* t, int block_aq_k)
|
||||
return ck_tile::reference_permute(t_view, {1, 0, 2});
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
auto shuffle_bq(const ck_tile::HostTensor<T>* t, int block_bq_k)
|
||||
{
|
||||
const auto& lengths = t->get_lengths();
|
||||
const size_t rank = lengths.size();
|
||||
|
||||
// Validate block_bq_k divisibility based on rank
|
||||
int bqk_dim = (rank == 5) ? lengths[4] : (rank == 2) ? lengths[0] : -1;
|
||||
|
||||
if(bqk_dim < 0)
|
||||
{
|
||||
throw std::runtime_error("shuffle_bq expects either rank-2 or rank-5 tensor, got rank " +
|
||||
std::to_string(rank));
|
||||
}
|
||||
|
||||
if(bqk_dim % block_bq_k != 0)
|
||||
{
|
||||
throw std::runtime_error("shuffle_bq needs bqk dimension to be a multiple of block_bq_k.");
|
||||
}
|
||||
|
||||
// For TilePermuteN
|
||||
if(rank == 5)
|
||||
{
|
||||
// Handle 5D tensor: [n, nrepeat, nwarp, n_warp_tile, bqk]
|
||||
ck_tile::HostTensor<T> t_view({static_cast<int>(lengths[0]),
|
||||
static_cast<int>(lengths[1]),
|
||||
static_cast<int>(lengths[2]),
|
||||
static_cast<int>(lengths[3]),
|
||||
bqk_dim / block_bq_k,
|
||||
block_bq_k});
|
||||
std::copy(t->begin(), t->end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {4, 0, 1, 2, 3, 5});
|
||||
}
|
||||
else // rank == 2
|
||||
{
|
||||
// Handle 2D tensor: [bqk, n]
|
||||
int n_ = lengths[1];
|
||||
ck_tile::HostTensor<T> t_view({n_, bqk_dim / block_bq_k, block_bq_k});
|
||||
std::copy(t->begin(), t->end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {1, 0, 2});
|
||||
}
|
||||
}
|
||||
|
||||
template <typename GemmConfig, typename T>
|
||||
auto shuffle_b(const ck_tile::HostTensor<T>& t)
|
||||
{
|
||||
@@ -64,7 +107,7 @@ auto shuffle_b(const ck_tile::HostTensor<T>& t)
|
||||
}
|
||||
|
||||
template <typename GemmConfig, typename T>
|
||||
auto shuffle_bq_permuteN(const ck_tile::HostTensor<T>& t)
|
||||
auto bq_permuteN(const ck_tile::HostTensor<T>& t)
|
||||
{
|
||||
assert(t.get_lengths().size() == 2);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user