debugging

This commit is contained in:
khuagarw
2025-12-06 08:57:22 +00:00
parent 48744f2d0d
commit 3ea3ca7b36
7 changed files with 139 additions and 75 deletions

View File

@@ -677,18 +677,19 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
{
ck_tile::HostTensor<BQDataType> bq_shuffle_host =
ck_tile::shuffle_bq(bq_tensor_ptr.get(), GemmConfig::K_Tile / QuantGroupSize::kK);
printf("Preshuffle BQ tensor\n");
for(int i = 0; i < static_cast<int>(bq_shuffle_host.get_lengths()[0]); i++)
{
for(int j = 0; j < static_cast<int>(bq_shuffle_host.get_lengths()[1]); j++)
{
for(int k = 0; k < static_cast<int>(bq_shuffle_host.get_lengths()[2]); k++)
{
printf(
"bq_shuffle_host[%d][%d][%d]: %f\n", i, j, k, bq_shuffle_host(i, j, k));
}
}
}
// printf("Preshuffle BQ tensor\n");
// for(int i = 0; i < static_cast<int>(bq_shuffle_host.get_lengths()[0]); i++)
// {
// for(int j = 0; j < static_cast<int>(bq_shuffle_host.get_lengths()[1]); j++)
// {
// for(int k = 0; k < static_cast<int>(bq_shuffle_host.get_lengths()[2]); k++)
// {
// printf(
// "bq_shuffle_host[%d][%d][%d]: %f\n", i, j, k, bq_shuffle_host(i, j,
// k));
// }
// }
// }
bq_dev_buf_ptr->ToDevice(bq_shuffle_host.data());
}
else

View File

@@ -63,22 +63,22 @@ auto shuffle_bq(const ck_tile::HostTensor<T>* t, int block_bq_k)
int n_ = lengths[1];
ck_tile::HostTensor<T> t_view({n_, bqk_dim / block_bq_k, block_bq_k});
std::copy(t->begin(), t->end(), t_view.begin());
printf("I am inside shuffle_bq\n");
printf("t_view.get_lengths(): %lu, %lu, %lu\n",
t_view.get_lengths()[0],
t_view.get_lengths()[1],
t_view.get_lengths()[2]);
for(int i = 0; i < static_cast<int>(t_view.get_lengths()[0]); i++)
{
for(int j = 0; j < static_cast<int>(t_view.get_lengths()[1]); j++)
{
for(int k = 0; k < static_cast<int>(t_view.get_lengths()[2]); k++)
{
printf("t_view[%d][%d][%d]: %f\n", i, j, k, t_view(i, j, k));
}
}
}
printf("I am inside shuffle_bq\n");
// printf("I am inside shuffle_bq\n");
// printf("t_view.get_lengths(): %lu, %lu, %lu\n",
// t_view.get_lengths()[0],
// t_view.get_lengths()[1],
// t_view.get_lengths()[2]);
// for(int i = 0; i < static_cast<int>(t_view.get_lengths()[0]); i++)
// {
// for(int j = 0; j < static_cast<int>(t_view.get_lengths()[1]); j++)
// {
// for(int k = 0; k < static_cast<int>(t_view.get_lengths()[2]); k++)
// {
// printf("t_view[%d][%d][%d]: %f\n", i, j, k, t_view(i, j, k));
// }
// }
// }
// printf("I am inside shuffle_bq\n");
return ck_tile::reference_permute(t_view, {1, 0, 2});
}
}

View File

@@ -195,6 +195,18 @@ struct BlockGemmWeightPreshuffleBQuantARegBRegCReg
pull_from_lane << 2, __builtin_bit_cast(int, scale_reg_dword));
float scale_reg_f = cvt_scale_to_fp32(gathered_scale_reg);
printf("get_block_id(): %d, get_thread_id(): %d, nIter: %d, lane_id(): %d, "
"kQScale: %d, pull_from_lane: %d, scale_reg: %f, "
"gathered_scale_reg: %d, scale_reg_f: %f\n",
get_block_id(),
get_thread_id(),
nIter,
__lane_id(),
static_cast<int>(kQScale),
pull_from_lane,
scale_reg,
gathered_scale_reg,
scale_reg_f);
static_for<0, WG::kM * WG::kN / warp_size, 1>{}([&](auto c_row) {
auto& c_ref = c_block_tensor.get_thread_buffer()[tbuf_offset + c_row];

View File

@@ -355,6 +355,23 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase<Problem_>
float scale_reg_f = Base::cvt_scale_to_fp32(gathered_scale_reg);
// if(get_block_id() ==0 && get_thread_id() == 0){
printf("get_block_id(): %d, get_thread_id(): %d, nIter: %d, lane_id(): "
"%u, KQPerBLock: %d, "
"kQScale: %d, pull_from_lane: %u, scale_reg: %f, "
"gathered_scale_reg: %d, scale_reg_f: %f\n",
get_block_id(),
get_thread_id(),
static_cast<int>(nIter),
__lane_id(),
Traits::KQPerBlock,
static_cast<int>(kQScale),
pull_from_lane,
scale_reg,
gathered_scale_reg,
scale_reg_f);
//}
static_for<0, WarpGemm::kM * WarpGemm::kN / warp_size, 1>{}(
[&](auto c_row) {
c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] +=

View File

@@ -280,12 +280,13 @@ struct QuantGemmKernel
// Helper: Create Pre-shuffled Quantization Tensor Descriptor
// ===================================================================
template <index_t KPerBlockBQ,
index_t NPerBlockBQ,
index_t NPerBlock,
index_t WarpTileN,
index_t GetVectorSizeBQ,
typename BQDataType_>
CK_TILE_DEVICE static auto
MakePreshuffledQuantTensorView(const BQDataType_* bq_ptr, index_t N, index_t QK_B)
MakePreshuffledQuantTensorView(const BQDataType_* bq_ptr, index_t N, index_t QN_B, index_t QK_B)
{
// Step 1: Calculate base BQ tensor dimensions
// ----------------------------------------------------------
@@ -316,7 +317,7 @@ struct QuantGemmKernel
// ----------------------------------------------------------
// Pad the X dimension to be a multiple of block_tile_size to ensure
// each thread block can process complete tiles without edge cases
const auto block_tile_size = NPerBlock * KPerBlockBQ; // 64 * 2 =128
const auto block_tile_size = NPerBlockBQ * KPerBlockBQ; // 64 * 2 = 128 || 8 * 2 = 16
if(get_block_id() == 0 && get_thread_id() == 0)
{
@@ -327,7 +328,7 @@ struct QuantGemmKernel
bq_desc,
make_tuple(
make_pass_through_transform(bq_y),
make_right_pad_transform(bq_x, get_padding_size(bq_x, block_tile_size))), // 2, 128
make_right_pad_transform(bq_x, get_padding_size(bq_x, block_tile_size))), // 2, 16
make_tuple(sequence<0>{}, sequence<1>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
@@ -336,22 +337,28 @@ struct QuantGemmKernel
// Split the X dimension into [wave_tile_count_x, wave_tile_size]
// This separates the work into tiles that can be processed by
// individual warps/waves
const auto pad_bq_x = bq_pad0_desc.get_lengths()[I1];
const auto wave_tile_size = WarpTileN * KPerBlockBQ;
const auto wave_tile_count_x = ck_tile::integer_divide_ceil(pad_bq_x, wave_tile_size);
const auto pad_bq_x = bq_pad0_desc.get_lengths()[I1]; // 128 || 16
const auto wave_tile_size = (WarpTileN / QN_B) * KPerBlockBQ; // 16 * 2= 32 || 16/8 x 2 = 4
const auto wave_tile_count_x =
ck_tile::integer_divide_ceil(pad_bq_x, wave_tile_size); // 128/32 = 4 || 16/4 = 4
if(get_block_id() == 0 && get_thread_id() == 0)
{
printf("pad_bq_x:%d, wave_tile_size: %d, wave_tile_count_x: %d\n",
printf("pad_bq_x:%d, WarpTileN:%d, NPerBlockPQ: %d, KPerBlockBQ: %d, wave_tile_size: "
"%d, wave_tile_count_x: %d\n",
pad_bq_x,
WarpTileN,
NPerBlockBQ,
KPerBlockBQ,
wave_tile_size,
wave_tile_count_x);
}
const auto bq_unmerge_pad0_desc = transform_tensor_descriptor(
bq_pad0_desc,
make_tuple(make_pass_through_transform(bq_y),
make_unmerge_transform(make_tuple(wave_tile_count_x, wave_tile_size))),
make_tuple(
make_pass_through_transform(bq_y),
make_unmerge_transform(make_tuple(wave_tile_count_x, wave_tile_size))), // 2, 4, 4
make_tuple(sequence<0>{}, sequence<1>{}),
make_tuple(sequence<0>{}, sequence<1, 2>{}));
@@ -361,10 +368,11 @@ struct QuantGemmKernel
// This ensures coalesced memory accesses within each warp
const auto bq_pad1_desc = transform_tensor_descriptor(
bq_unmerge_pad0_desc,
make_tuple(make_pass_through_transform(bq_y),
make_pass_through_transform(wave_tile_count_x),
make_right_pad_transform(wave_tile_size,
get_padding_size(wave_tile_size, get_warp_size()))),
make_tuple(
make_pass_through_transform(bq_y), // 2
make_pass_through_transform(wave_tile_count_x), // 4
make_right_pad_transform(wave_tile_size,
get_padding_size(wave_tile_size, get_warp_size()))), // 64
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}));
@@ -381,8 +389,8 @@ struct QuantGemmKernel
}
const auto bq_merge_pad1_desc = transform_tensor_descriptor(
bq_pad1_desc,
make_tuple(make_merge_transform(make_tuple(bq_y, wave_tile_count_x)),
make_pass_through_transform(pad_wave_size)),
make_tuple(make_merge_transform(make_tuple(bq_y, wave_tile_count_x)), // 8
make_pass_through_transform(pad_wave_size)), // 64
make_tuple(sequence<0, 1>{}, sequence<2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
@@ -817,11 +825,13 @@ struct QuantGemmKernel
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::QuantGroupSize>;
return MakePreshuffledQuantTensorView<
GemmPipeline::KPerBlockBQ,
GemmPipeline::NPerBlockBQ,
GemmPipeline::NPerBlock,
TilePartitioner::BlockGemmShape::WarpTile::at(I1),
GemmPipeline::GetVectorSizeBQ()>(
bq_ptr,
ck_tile::integer_divide_ceil(kargs.N, QuantGroupSize::kN),
QuantGroupSize::kN,
kargs.QK_B);
}
else
@@ -1170,7 +1180,7 @@ struct QuantGemmKernel
// bq_block_window.template print_tile_window_range<BQDataType>(
// 0, 128, 0, 2, "bq block window");
bq_block_window.template print_tile_window_range<BQDataType>(
0, 8, 0, 64, "bq block window");
0, 1, 0, 64, "bq block window");
}
return GemmPipeline{}.template operator()(
a_block_window, b_block_window, bq_block_window, num_loop, smem_ptr_0, n);

View File

@@ -35,12 +35,12 @@ struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
using BQLayout = remove_cvref_t<typename Problem::BQLayout>;
using BlockGemmShape = typename Problem::BlockGemmShape;
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t NPerBlockBQ = NPerBlock / Problem::QuantGroupSize::kN;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t KPerBlockBQ = KPerBlock / Problem::QuantGroupSize::kK;
constexpr index_t VecLoadSize = GetVectorSizeBQ<Problem>();
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t NPerBlockBQ = NPerBlock / Problem::QuantGroupSize::kN;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t KPerBlockBQ = KPerBlock / Problem::QuantGroupSize::kK;
// constexpr index_t VecLoadSize = GetVectorSizeBQ<Problem>();
constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant;
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
@@ -61,7 +61,7 @@ struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
BlockSize,
NPerBlock / WarpGemm::kN,
ck_tile::integer_least_multiple(WarpGemm::kN * KPerBlockBQ, get_warp_size()),
VecLoadSize,
Problem::QuantGroupSize::kN,
PreshuffleQuant>;
return TileEncodingPattern::make_2d_static_tile_distribution();
}

View File

@@ -189,8 +189,8 @@ struct tile_distribution_encoding_pattern_aq_transposed_c
template <typename BlockGemmShape,
typename WarpGemm,
index_t BlockSize,
index_t YPerTile,
index_t XPerTile,
index_t YPerTile, // 4
index_t XPerTile, // 64
index_t YPerQ,
bool PreshuffleQuant = false>
struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding_pattern
@@ -236,41 +236,65 @@ struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding
{
if constexpr(PreshuffleQuant)
{
constexpr index_t X1 = warp_size;
constexpr index_t X0 = XPerTile / warp_size;
constexpr index_t Y1 = NWarps;
constexpr index_t Y0 = YPerTile / Y1;
// constexpr index_t X1 = warp_size;
// constexpr index_t X0 = XPerTile / warp_size;
// constexpr index_t Y1 = NWarps;
// constexpr index_t Y0 = YPerTile / Y1;
// return make_static_tile_distribution(
// tile_distribution_encoding<sequence<MWarps>,
// tuple<sequence<Y0, Y1>, sequence<X0, X1>>,
// tuple<sequence<0, 1>, sequence<2>>, // (Mwarp, Nwarp),
// (X1 = warp_size(64)) tuple<sequence<0, 1>,
// sequence<1>>, sequence<1, 2>, //(NiterPerWarp,
// X(threads in x dimension, 1)) sequence<0, 0>>{});
// constexpr index_t X1 = warp_size; //64
constexpr index_t X0 = XPerTile / warp_size; // 64/64 = 1
constexpr index_t X1 = XPerTile / WarpGemm::kN; // 64/16 = 4
constexpr index_t X2 = WarpGemm::kN / YPerQ; // 16/8=2
constexpr index_t XR = YPerQ; // 8
constexpr index_t Y1 = NWarps; // 4
constexpr index_t Y0 = YPerTile / Y1; // 1
constexpr index_t YR = 1;
return make_static_tile_distribution(
tile_distribution_encoding<sequence<MWarps>,
tuple<sequence<Y0, Y1>, sequence<X0, X1>>,
tuple<sequence<0, 1>, sequence<2>>,
tuple<sequence<0, 1>, sequence<1>>,
sequence<1, 2>,
sequence<0, 0>>{});
tile_distribution_encoding<
sequence<MWarps, XR, YR>,
tuple<sequence<Y0, Y1>, sequence<X0, X1, X2>>,
tuple<sequence<0, 1>, sequence<0, 2, 0>>, // (Mwarp, Nwarp),
tuple<sequence<0, 1>,
sequence<1, 2, 2>>, //(repeat for 8 threads in X direction, X2(no of
// scales per warp), X1(warp_size/quant_group_size),
// YR)(8, 2, 4, 1)
sequence<1, 2>,
sequence<0, 0>>{});
}
else
{
if constexpr(YPerQ < WarpGemm::kN)
{
// Case 1: Fine-grained - multiple quantization scales within a single warp
constexpr index_t X = XPerTile; // Full X dimension of tile
constexpr index_t XR = 1; // No Y replication needed
constexpr index_t Y0 = NIterPerWarp; // Iterations per warp in N-dim
constexpr index_t Y1 = NWarps; // Number of warps in N-dim
constexpr index_t Y2 = WarpGemm::kN / YPerQ; // Number of scales per warp
constexpr index_t YR = YPerQ; // Elements per quantization group
constexpr index_t X = XPerTile; // Full X dimension of tile
constexpr index_t XR = 1; // No Y replication needed
constexpr index_t Y0 = NIterPerWarp; // Iterations per warp in N-dim
constexpr index_t Y1 = NWarps; // Number of warps in N-dim
constexpr index_t Y2 =
WarpGemm::kN / YPerQ; // Number of scales per warp 16/ 8 = 2
constexpr index_t YR = YPerQ; // Elements per quantization group 8
static_assert(Y0 * Y1 * Y2 == YPerTile,
"Y0, Y1, Y2 must cover the blocktile along Y.");
return make_static_tile_distribution(
tile_distribution_encoding<sequence<MWarps, XR, YR>,
tuple<sequence<Y0, Y1, Y2>, sequence<X>>,
tuple<sequence<0, 1>, sequence<0, 1, 0>>,
tuple<sequence<0, 1>, sequence<1, 2, 2>>,
sequence<1, 2>,
sequence<0, 0>>{});
tile_distribution_encoding<
sequence<MWarps, XR, YR>,
tuple<sequence<Y0, Y1, Y2>, sequence<X>>,
tuple<sequence<0, 1>, sequence<0, 1, 0>>, //(Mwarp, Nwarp), (XR, Y2[no of
// scales per warp], YR)
tuple<sequence<0, 1>, sequence<1, 2, 2>>,
sequence<1, 2>, //(NiterPerWarp, X(threads in x dimension))
sequence<0, 0>>{});
}
else if constexpr(YPerQ <= WarpGemm::kN * NWarps)
{