Merge commit '6b1bceca7baea62941793e562d6ff58c571d9191' into develop

This commit is contained in:
assistant-librarian[bot]
2025-12-05 18:14:37 +00:00
parent e4b2f98d0d
commit b2019db495
7 changed files with 257 additions and 36 deletions

View File

@@ -28,7 +28,6 @@ struct BlockGemmWeightPreshuffleBQuantARegBRegCReg
using QuantGroupSize = remove_cvref_t<typename Problem::QuantGroupSize>;
static_assert(QuantGroupSize::kM == 1, "only N/K blocks for BQuant preshuffle kernel!");
static_assert(QuantGroupSize::kN == 1, "no block for N supported yet!");
static constexpr auto I0 = number<0>();
static constexpr auto I1 = number<1>();
@@ -205,7 +204,17 @@ struct BlockGemmWeightPreshuffleBQuantARegBRegCReg
}
else
{
constexpr index_t reg_offset = nIter * KPerBlockBQ + kQScale;
index_t reg_offset = [&]() {
if constexpr(QuantGroupSize::kN >= (NWarp * WG::kN))
{
return (nIter * NWarp * WG::kN) / QuantGroupSize::kN * KPerBlockBQ +
kQScale;
}
else
{
return nIter * KPerBlockBQ + kQScale;
}
}();
auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset];
float scale_reg_f = cvt_scale_to_fp32(scale_reg);

View File

@@ -747,7 +747,6 @@ struct QuantGemmKernel
(splitk_batch_offset.splitted_k /
GemmPipeline::BlockGemmShape::WarpTile::at(number<2>{}));
index_t kFlatN = kargs.N * kargs.K / kFlatK;
return make_naive_tensor_view<address_space_enum::global>(
b_ptr,
make_tuple(kFlatN, kFlatK),