remove commented code and enable all tests again

This commit is contained in:
Sami Remes
2025-10-22 11:09:21 +00:00
parent bb52cd9889
commit f179a8a97b
4 changed files with 38 additions and 55 deletions

View File

@@ -352,7 +352,7 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase<Problem_>
else
{
// FIXME: temporarily the tile distribution replicates all block's
// scales to all threads - need to figure out the index manually
// scales to all threads -> need to calculate the index manually
// here from nIter and warp id
const index_t n_idx_of_warp =
nIter * WarpGemm::kN * NWarp + get_warp_id() * WarpGemm::kN;
@@ -369,16 +369,7 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase<Problem_>
auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset];
float scale_reg_f = Base::cvt_scale_to_fp32(scale_reg);
// if(threadIdx.x % 64 == 0 && blockIdx.x == 0)
// {
// printf("warp_id: %d, mIter: %d, nIter: %d, kQScale: %d, reg_offset: %d, scale_reg_f: %f\n",
// get_warp_id(),
// mIter.value,
// nIter.value,
// kQScale.value,
// reg_offset,
// scale_reg_f);
// }
static_for<0, WarpGemm::kM / 2, 1>{}([&](auto c_row) {
c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] +=
(c_warp_tensor.get_thread_buffer()[c_row] * scale_reg_f);

View File

@@ -397,10 +397,6 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseBQuantGemmPipelineAgBgCrCompV
bq_copy_dram_window,
bq_dram_tile_window_step);
// if(threadIdx.x == 0 && blockIdx.x == 0)
// {
// printf("---- pipeline loop %d ----\n", i);
// }
block_gemm(
c_block_tile, bq_block_tile[currIdx], a_lds_gemm_window, b_lds_gemm_window);

View File

@@ -192,7 +192,7 @@ struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding
{
if constexpr(YPerQ == 1)
{
// YPerQ == 1 implementation
// YPerQ == 1 implementation - each row of B has independent scale
constexpr index_t X = XPerTile;
constexpr index_t XR = 2;
constexpr index_t Y0 = NIterPerWarp;
@@ -211,18 +211,10 @@ struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding
}
else
{
// YPerQ > 1 implementation
// YPerQ > 1 implementation - each group of YPerQ rows share the same scale
// TODO: do not repeat everything to all threads
// return make_static_tile_distribution(
// tile_distribution_encoding<sequence<MWarps, 2, NWarps>,
// tuple<sequence<YPerTile, 1>, sequence<XPerTile>>,
// tuple<sequence<0, 0>, sequence<0, 1>>,
// tuple<sequence<0, 2>, sequence<1, 1>>,
// sequence<1, 2>,
// sequence<0, 0>>{});
return make_static_tile_distribution(
tile_distribution_encoding<sequence<MWarps, NWarps, 64>,
tile_distribution_encoding<sequence<MWarps, NWarps, get_warp_size()>,
tuple<sequence<YPerTile>, sequence<XPerTile>>,
tuple<sequence<0, 0>, sequence<0>>,
tuple<sequence<0, 1>, sequence<2>>,

View File

@@ -33,22 +33,22 @@ using AQuantTypes = ::testing::Types<
std::tuple<RowMajor, ColumnMajor, RowMajor, PkInt4, FP8, FP8, Half, AQuantGrouped, GemmConfigBase, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, PkInt4, BF8, BF8, Half, AQuantGrouped, GemmConfigBase, GroupSize>,
// // PreshuffleQuant = false && TransposeC = true
// std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, FP8, float, Half, AQuantGrouped, GemmConfigTransposeC, GroupSize>,
// std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, BF8, float, Half, AQuantGrouped, GemmConfigTransposeC, GroupSize>,
// std::tuple<RowMajor, ColumnMajor, RowMajor, PkInt4, FP8, FP8, Half, AQuantGrouped, GemmConfigTransposeC, GroupSize>,
// std::tuple<RowMajor, ColumnMajor, RowMajor, PkInt4, BF8, BF8, Half, AQuantGrouped, GemmConfigTransposeC, GroupSize>,
// PreshuffleQuant = false && TransposeC = true
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, FP8, float, Half, AQuantGrouped, GemmConfigTransposeC, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, BF8, float, Half, AQuantGrouped, GemmConfigTransposeC, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, PkInt4, FP8, FP8, Half, AQuantGrouped, GemmConfigTransposeC, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, PkInt4, BF8, BF8, Half, AQuantGrouped, GemmConfigTransposeC, GroupSize>,
// // PreshuffleQuant = true && TransposeC = false
// std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, FP8, float, Half, AQuantGrouped, GemmConfigPreshuffleQuant, GroupSize>,
// std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, BF8, float, Half, AQuantGrouped, GemmConfigPreshuffleQuant, GroupSize>,
// std::tuple<RowMajor, ColumnMajor, RowMajor, PkInt4, FP8, FP8, Half, AQuantGrouped, GemmConfigPreshuffleQuant, GroupSize>,
// std::tuple<RowMajor, ColumnMajor, RowMajor, PkInt4, BF8, BF8, Half, AQuantGrouped, GemmConfigPreshuffleQuant, GroupSize>,
// PreshuffleQuant = true && TransposeC = false
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, FP8, float, Half, AQuantGrouped, GemmConfigPreshuffleQuant, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, BF8, float, Half, AQuantGrouped, GemmConfigPreshuffleQuant, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, PkInt4, FP8, FP8, Half, AQuantGrouped, GemmConfigPreshuffleQuant, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, PkInt4, BF8, BF8, Half, AQuantGrouped, GemmConfigPreshuffleQuant, GroupSize>,
// // PreshuffleQuant = true && TransposeC = true
// std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, FP8, float, Half, AQuantGrouped, GemmConfigPreshuffleQuantTransposeC, GroupSize>,
// std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, BF8, float, Half, AQuantGrouped, GemmConfigPreshuffleQuantTransposeC, GroupSize>,
// std::tuple<RowMajor, ColumnMajor, RowMajor, PkInt4, FP8, FP8, Half, AQuantGrouped, GemmConfigPreshuffleQuantTransposeC, GroupSize>,
// PreshuffleQuant = true && TransposeC = true
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, FP8, float, Half, AQuantGrouped, GemmConfigPreshuffleQuantTransposeC, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, BF8, float, Half, AQuantGrouped, GemmConfigPreshuffleQuantTransposeC, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, PkInt4, FP8, FP8, Half, AQuantGrouped, GemmConfigPreshuffleQuantTransposeC, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, PkInt4, BF8, BF8, Half, AQuantGrouped, GemmConfigPreshuffleQuantTransposeC, GroupSize>
>;
// clang-format on
@@ -56,30 +56,34 @@ using AQuantTypes = ::testing::Types<
// clang-format off
using BQuantTypes = ::testing::Types<
// 1d cases with grouping only on k axis
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize>,
// std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize>,
// std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize>,
// std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigBase, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigBase, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize64>,
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize64>,
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize64>,
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigBase, GroupSize64>,
// std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize64>,
// std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize64>,
// std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize64>,
// std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigBase, GroupSize64>
// 2d cases with grouping also on the n axis
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D>
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D>,
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D>,
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D>,
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D>
>;
// clang-format on
// clang-format off
using BPreshuffleBQuantTypes = ::testing::Types<
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize>,
// std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize>,
// std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize>,
// std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize>,
// std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize>,
// std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize>,
// std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize>
>;
// clang-format on