mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[CK_Tile] Enable PreshuffleB for 2d block scale Gemm (#3298)
* formatted * formatted * formatting * formatting * formatting * [CK TILE GEMM] Refactor block_scale_gemm examples - Split cpp file to reduce building time - Support multiple GemmConfig * [CK TILE GEMM] Refactor block_scale_gemm examples - Update Readme * enable prefill shapes * [CK TILE GEMM] Refactor block_scale_gemm examples - Add support for rowcol and tensor GEMM operations * [CK TILE GEMM] Refactor block_scale_gemm examples - Update README * adding preshuffle quant as new parameter and its associated new files * remove debugging statements * adding test * enable preshuffle quant with permuteN * updating readme and correcponding gemmconfigs * updating cmake file * fixing CI failures for grouped quant gemm * debugging permuteN * debugging * debugging PermuteN * initial commit * resolving merge conflicts * adding test cases * fixing bq tensor calculation --------- Co-authored-by: Cong Ma <congma13@amd.com> Co-authored-by: Thomas Ning <Thomas.Ning@amd.com>
This commit is contained in:
@@ -14,6 +14,7 @@ auto shuffle_aq(const ck_tile::HostTensor<T>* t, int block_aq_k)
|
||||
}
|
||||
int m_ = t->get_lengths()[0];
|
||||
int aqk_ = t->get_lengths()[1];
|
||||
|
||||
if(aqk_ % block_aq_k != 0)
|
||||
{
|
||||
throw std::runtime_error("shuffle_aq needs a aqk of multiple times of block_aq_k.");
|
||||
@@ -110,7 +111,7 @@ auto shuffle_b(const ck_tile::HostTensor<T>& t)
|
||||
}
|
||||
|
||||
template <typename GemmConfig, typename T>
|
||||
auto bq_permuteN(const ck_tile::HostTensor<T>& t)
|
||||
auto bq_permuteN(const ck_tile::HostTensor<T>& t, index_t group_n)
|
||||
{
|
||||
assert(t.get_lengths().size() == 2);
|
||||
|
||||
@@ -118,8 +119,11 @@ auto bq_permuteN(const ck_tile::HostTensor<T>& t)
|
||||
int bqk_ = t.get_lengths()[0];
|
||||
constexpr int NRepeat = GemmConfig::N_Tile / GemmConfig::N_Warp_Tile / GemmConfig::N_Warp;
|
||||
|
||||
ck_tile::HostTensor<T> t_view(
|
||||
{n_ / GemmConfig::N_Tile, GemmConfig::N_Warp, GemmConfig::N_Warp_Tile, NRepeat, bqk_});
|
||||
ck_tile::HostTensor<T> t_view({n_ / (GemmConfig::N_Tile / group_n),
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::N_Warp_Tile / group_n,
|
||||
NRepeat,
|
||||
bqk_});
|
||||
std::copy(t.begin(), t.end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {0, 3, 1, 2, 4});
|
||||
}
|
||||
|
||||
@@ -28,7 +28,6 @@ struct BlockGemmWeightPreshuffleBQuantARegBRegCReg
|
||||
using QuantGroupSize = remove_cvref_t<typename Problem::QuantGroupSize>;
|
||||
|
||||
static_assert(QuantGroupSize::kM == 1, "only N/K blocks for BQuant preshuffle kernel!");
|
||||
static_assert(QuantGroupSize::kN == 1, "no block for N supported yet!");
|
||||
|
||||
static constexpr auto I0 = number<0>();
|
||||
static constexpr auto I1 = number<1>();
|
||||
@@ -205,7 +204,17 @@ struct BlockGemmWeightPreshuffleBQuantARegBRegCReg
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr index_t reg_offset = nIter * KPerBlockBQ + kQScale;
|
||||
index_t reg_offset = [&]() {
|
||||
if constexpr(QuantGroupSize::kN >= (NWarp * WG::kN))
|
||||
{
|
||||
return (nIter * NWarp * WG::kN) / QuantGroupSize::kN * KPerBlockBQ +
|
||||
kQScale;
|
||||
}
|
||||
else
|
||||
{
|
||||
return nIter * KPerBlockBQ + kQScale;
|
||||
}
|
||||
}();
|
||||
auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset];
|
||||
float scale_reg_f = cvt_scale_to_fp32(scale_reg);
|
||||
|
||||
|
||||
@@ -747,7 +747,6 @@ struct QuantGemmKernel
|
||||
(splitk_batch_offset.splitted_k /
|
||||
GemmPipeline::BlockGemmShape::WarpTile::at(number<2>{}));
|
||||
index_t kFlatN = kargs.N * kargs.K / kFlatK;
|
||||
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
b_ptr,
|
||||
make_tuple(kFlatN, kFlatK),
|
||||
|
||||
Reference in New Issue
Block a user