Fix handling of n dim blocks in tile windows etc

This commit is contained in:
Sami Remes
2025-10-21 15:51:23 +00:00
parent 36b88c665c
commit bb52cd9889
5 changed files with 31 additions and 19 deletions

View File

@@ -356,15 +356,7 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase<Problem_>
// here from nIter and warp id
const index_t n_idx_of_warp =
nIter * WarpGemm::kN * NWarp + get_warp_id() * WarpGemm::kN;
const index_t row_index =
n_idx_of_warp / Traits::QuantGroupSize::kN;
if(threadIdx.x == 0)
{
printf("n_idx_of_warp: %d, row_index: %d, kQScale: %d\n",
n_idx_of_warp,
row_index,
kQScale.value);
}
const index_t row_index = n_idx_of_warp / Traits::QuantGroupSize::kN;
return row_index * Traits::BQPerBlock + kQScale;
}
}();
@@ -377,6 +369,16 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase<Problem_>
auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset];
float scale_reg_f = Base::cvt_scale_to_fp32(scale_reg);
// if(threadIdx.x % 64 == 0 && blockIdx.x == 0)
// {
// printf("warp_id: %d, mIter: %d, nIter: %d, kQScale: %d, reg_offset: %d, scale_reg_f: %f\n",
// get_warp_id(),
// mIter.value,
// nIter.value,
// kQScale.value,
// reg_offset,
// scale_reg_f);
// }
static_for<0, WarpGemm::kM / 2, 1>{}([&](auto c_row) {
c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] +=
(c_warp_tensor.get_thread_buffer()[c_row] * scale_reg_f);

View File

@@ -684,9 +684,10 @@ struct QuantGemmKernel
else if constexpr(kQuantType == QuantType::BQuantGrouped)
{
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::QuantGroupSize>;
return make_naive_tensor_view<address_space_enum::global>(
bq_ptr,
make_tuple(kargs.N, kargs.QK_B),
make_tuple(integer_divide_ceil(kargs.N, QuantGroupSize::kN), kargs.QK_B),
make_tuple(kargs.stride_BQ, 1),
number<GemmPipeline::GetVectorSizeBQ()>{},
number<1>{});
@@ -907,9 +908,9 @@ struct QuantGemmKernel
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::QuantGroupSize>;
return make_tile_window(
bq_pad_view,
make_tuple(number<TilePartitioner::NPerBlock>{},
make_tuple(number<TilePartitioner::NPerBlock / QuantGroupSize::kN>{},
number<TilePartitioner::KPerBlock / QuantGroupSize::kK>{}),
{i_n, 0});
{i_n / QuantGroupSize::kN, 0});
}
else
{

View File

@@ -120,6 +120,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseBQuantGemmPipelineAgBgCrCompV
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr index_t NPerBlockBQ = BlockGemmShape::kN / QuantGroupSize::kN;
static constexpr index_t KPerBlockBQ = BlockGemmShape::kK / QuantGroupSize::kK;
static constexpr index_t GetVectorSizeA() { return Policy::template GetVectorSizeA<Problem>(); }
@@ -258,7 +259,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseBQuantGemmPipelineAgBgCrCompV
constexpr bool is_b_row_major = std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
static_assert(is_bq_col_major, "Bq must be col major (row major not supported yet)");
static_assert(NPerBlock == BQDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
static_assert(NPerBlockBQ == BQDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
KPerBlockBQ == BQDramBlockWindowTmp{}.get_window_lengths()[I1{}],
"Bq block window has incorrect lengths for defined BqLayout!");
@@ -396,6 +397,10 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseBQuantGemmPipelineAgBgCrCompV
bq_copy_dram_window,
bq_dram_tile_window_step);
// if(threadIdx.x == 0 && blockIdx.x == 0)
// {
// printf("---- pipeline loop %d ----\n", i);
// }
block_gemm(
c_block_tile, bq_block_tile[currIdx], a_lds_gemm_window, b_lds_gemm_window);

View File

@@ -222,10 +222,10 @@ struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding
// sequence<1, 2>,
// sequence<0, 0>>{});
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tile_distribution_encoding<sequence<MWarps, NWarps, 64>,
tuple<sequence<YPerTile>, sequence<XPerTile>>,
tuple<sequence<0>, sequence<0>>,
tuple<sequence<0>, sequence<0>>,
tuple<sequence<0, 0>, sequence<0>>,
tuple<sequence<0, 1>, sequence<2>>,
sequence<1, 2>,
sequence<0, 0>>{});
}

View File

@@ -29,7 +29,7 @@ struct GemmConfigBase
// Default GEMM tile sizes for tests
static constexpr ck_tile::index_t M_Tile = 16;
static constexpr ck_tile::index_t N_Tile = 64;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 256;
static constexpr ck_tile::index_t M_Warp = 1;
@@ -399,7 +399,11 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGem
// Initialize data with random values
ck_tile::FillUniformDistribution<ADataType>{-0.5f, 0.5f}(a_m_k);
ck_tile::FillUniformDistribution<BDataType>{0.f, 1.f}(b_k_n);
ck_tile::FillUniformDistribution<QDataType>{0.001f, 0.01f}(bq_bqk_bqn);
// ck_tile::FillUniformDistribution<QDataType>{0.001f, 0.01f}(bq_bqk_bqn);
for (size_t i = 0; i < bq_bqk_bqn.size(); ++i)
{
bq_bqk_bqn.mData[i] = static_cast<QDataType>(0.0001f + 0.0001f * static_cast<float>(i));
}
// Allocate device memory
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size() * sizeof(ADataType));
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size() * sizeof(BDataType));
@@ -441,7 +445,7 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGem
N,
K, // M, N, K
0, // QK_A (not used for BQuant)
BQK, // QK_B
BQK, // QK_B - TODO: we can remove BQK and BQN from args later?
stride_A,
stride_B,
stride_C,