mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 19:57:40 +00:00
Fix handling of n dim blocks in tile windows etc
This commit is contained in:
@@ -356,15 +356,7 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase<Problem_>
|
||||
// here from nIter and warp id
|
||||
const index_t n_idx_of_warp =
|
||||
nIter * WarpGemm::kN * NWarp + get_warp_id() * WarpGemm::kN;
|
||||
const index_t row_index =
|
||||
n_idx_of_warp / Traits::QuantGroupSize::kN;
|
||||
if(threadIdx.x == 0)
|
||||
{
|
||||
printf("n_idx_of_warp: %d, row_index: %d, kQScale: %d\n",
|
||||
n_idx_of_warp,
|
||||
row_index,
|
||||
kQScale.value);
|
||||
}
|
||||
const index_t row_index = n_idx_of_warp / Traits::QuantGroupSize::kN;
|
||||
return row_index * Traits::BQPerBlock + kQScale;
|
||||
}
|
||||
}();
|
||||
@@ -377,6 +369,16 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase<Problem_>
|
||||
|
||||
auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset];
|
||||
float scale_reg_f = Base::cvt_scale_to_fp32(scale_reg);
|
||||
// if(threadIdx.x % 64 == 0 && blockIdx.x == 0)
|
||||
// {
|
||||
// printf("warp_id: %d, mIter: %d, nIter: %d, kQScale: %d, reg_offset: %d, scale_reg_f: %f\n",
|
||||
// get_warp_id(),
|
||||
// mIter.value,
|
||||
// nIter.value,
|
||||
// kQScale.value,
|
||||
// reg_offset,
|
||||
// scale_reg_f);
|
||||
// }
|
||||
static_for<0, WarpGemm::kM / 2, 1>{}([&](auto c_row) {
|
||||
c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] +=
|
||||
(c_warp_tensor.get_thread_buffer()[c_row] * scale_reg_f);
|
||||
|
||||
@@ -684,9 +684,10 @@ struct QuantGemmKernel
|
||||
else if constexpr(kQuantType == QuantType::BQuantGrouped)
|
||||
{
|
||||
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
|
||||
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::QuantGroupSize>;
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
bq_ptr,
|
||||
make_tuple(kargs.N, kargs.QK_B),
|
||||
make_tuple(integer_divide_ceil(kargs.N, QuantGroupSize::kN), kargs.QK_B),
|
||||
make_tuple(kargs.stride_BQ, 1),
|
||||
number<GemmPipeline::GetVectorSizeBQ()>{},
|
||||
number<1>{});
|
||||
@@ -907,9 +908,9 @@ struct QuantGemmKernel
|
||||
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::QuantGroupSize>;
|
||||
return make_tile_window(
|
||||
bq_pad_view,
|
||||
make_tuple(number<TilePartitioner::NPerBlock>{},
|
||||
make_tuple(number<TilePartitioner::NPerBlock / QuantGroupSize::kN>{},
|
||||
number<TilePartitioner::KPerBlock / QuantGroupSize::kK>{}),
|
||||
{i_n, 0});
|
||||
{i_n / QuantGroupSize::kN, 0});
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
@@ -120,6 +120,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseBQuantGemmPipelineAgBgCrCompV
|
||||
static constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
|
||||
static constexpr index_t NPerBlockBQ = BlockGemmShape::kN / QuantGroupSize::kN;
|
||||
static constexpr index_t KPerBlockBQ = BlockGemmShape::kK / QuantGroupSize::kK;
|
||||
|
||||
static constexpr index_t GetVectorSizeA() { return Policy::template GetVectorSizeA<Problem>(); }
|
||||
@@ -258,7 +259,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseBQuantGemmPipelineAgBgCrCompV
|
||||
constexpr bool is_b_row_major = std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
|
||||
|
||||
static_assert(is_bq_col_major, "Bq must be col major (row major not supported yet)");
|
||||
static_assert(NPerBlock == BQDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
|
||||
static_assert(NPerBlockBQ == BQDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
|
||||
KPerBlockBQ == BQDramBlockWindowTmp{}.get_window_lengths()[I1{}],
|
||||
"Bq block window has incorrect lengths for defined BqLayout!");
|
||||
|
||||
@@ -396,6 +397,10 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseBQuantGemmPipelineAgBgCrCompV
|
||||
bq_copy_dram_window,
|
||||
bq_dram_tile_window_step);
|
||||
|
||||
// if(threadIdx.x == 0 && blockIdx.x == 0)
|
||||
// {
|
||||
// printf("---- pipeline loop %d ----\n", i);
|
||||
// }
|
||||
block_gemm(
|
||||
c_block_tile, bq_block_tile[currIdx], a_lds_gemm_window, b_lds_gemm_window);
|
||||
|
||||
|
||||
@@ -222,10 +222,10 @@ struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding
|
||||
// sequence<1, 2>,
|
||||
// sequence<0, 0>>{});
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tile_distribution_encoding<sequence<MWarps, NWarps, 64>,
|
||||
tuple<sequence<YPerTile>, sequence<XPerTile>>,
|
||||
tuple<sequence<0>, sequence<0>>,
|
||||
tuple<sequence<0>, sequence<0>>,
|
||||
tuple<sequence<0, 0>, sequence<0>>,
|
||||
tuple<sequence<0, 1>, sequence<2>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{});
|
||||
}
|
||||
|
||||
@@ -29,7 +29,7 @@ struct GemmConfigBase
|
||||
|
||||
// Default GEMM tile sizes for tests
|
||||
static constexpr ck_tile::index_t M_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Tile = 64;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 256;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
@@ -399,7 +399,11 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGem
|
||||
// Initialize data with random values
|
||||
ck_tile::FillUniformDistribution<ADataType>{-0.5f, 0.5f}(a_m_k);
|
||||
ck_tile::FillUniformDistribution<BDataType>{0.f, 1.f}(b_k_n);
|
||||
ck_tile::FillUniformDistribution<QDataType>{0.001f, 0.01f}(bq_bqk_bqn);
|
||||
// ck_tile::FillUniformDistribution<QDataType>{0.001f, 0.01f}(bq_bqk_bqn);
|
||||
for (size_t i = 0; i < bq_bqk_bqn.size(); ++i)
|
||||
{
|
||||
bq_bqk_bqn.mData[i] = static_cast<QDataType>(0.0001f + 0.0001f * static_cast<float>(i));
|
||||
}
|
||||
// Allocate device memory
|
||||
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size() * sizeof(ADataType));
|
||||
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size() * sizeof(BDataType));
|
||||
@@ -441,7 +445,7 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGem
|
||||
N,
|
||||
K, // M, N, K
|
||||
0, // QK_A (not used for BQuant)
|
||||
BQK, // QK_B
|
||||
BQK, // QK_B - TODO: we can remove BQK and BQN from args later?
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C,
|
||||
|
||||
Reference in New Issue
Block a user