debugging PermuteN

This commit is contained in:
khuagarw
2025-11-24 18:58:51 +00:00
parent cf3f9b57b4
commit 04aaf97192
6 changed files with 92 additions and 35 deletions

View File

@@ -523,8 +523,31 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped)
{
ck_tile::FillConstant<ADataType>{static_cast<ADataType>(0x38)}(a_m_k);
ck_tile::FillConstant<BDataType>{static_cast<BDataType>(0x22)}(b_k_n);
ck_tile::FillConstant<BQDataType>{static_cast<BQDataType>(0.5f)}(*bq_tensor_ptr);
ck_tile::FillConstant<BDataType>{static_cast<BDataType>(0x38)}(b_k_n);
// ck_tile::FillConstant<BQDataType>{static_cast<BQDataType>(0.5f)}(*bq_tensor_ptr);
if(bq_tensor_ptr)
{
BQDataType value = 1.0f;
for(int i = 0; i < BQK; i++)
{
for(int j = 0; j < N / QuantGroupSize::kN; j += (16 / QuantGroupSize::kN))
{
for(int k = 0; k < 16 / QuantGroupSize::kN; k++)
{
(*bq_tensor_ptr)(i, j + k) = value;
}
value += static_cast<BQDataType>(1.0f);
}
}
}
// for(int i = 0; i < BQK; i++)
// {
// for(int j = 0; j < N / QuantGroupSize::kN; j++)
// {
// printf("%.2f ", (*bq_tensor_ptr)(i, j));
// }
// printf("\n");
// }
}
else
{

View File

@@ -111,8 +111,8 @@ auto bq_permuteN(const ck_tile::HostTensor<T>& t, index_t group_n)
{
assert(t.get_lengths().size() == 2);
int n_ = t.get_lengths()[1]; // 8
int bqk_ = t.get_lengths()[0]; // 1
int n_ = t.get_lengths()[1]; // 128
int bqk_ = t.get_lengths()[0]; // 1 x 128
constexpr int NRepeat =
GemmConfig::N_Tile / GemmConfig::N_Warp_Tile / GemmConfig::N_Warp; // 128/16/4 = 2
@@ -120,9 +120,40 @@ auto bq_permuteN(const ck_tile::HostTensor<T>& t, index_t group_n)
GemmConfig::N_Warp,
GemmConfig::N_Warp_Tile / group_n,
NRepeat,
bqk_}); //{1, 4, 16, 2, 1}
bqk_}); //{1, 4, 16, 2, 1}, group_n:16 {1, 4, 1, 2, 1}
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 3, 1, 2, 4}); //{1, 2, 4, 16, 1}
printf("I am inside bq_permuteN\n");
printf("t.get_lengths(): %lu, %lu, %lu, %lu, %lu\n",
t_view.get_lengths()[0],
t_view.get_lengths()[1],
t_view.get_lengths()[2],
t_view.get_lengths()[3],
t_view.get_lengths()[4]);
for(int i = 0; i < static_cast<int>(t.get_lengths()[0]); i++)
{
for(int j = 0; j < static_cast<int>(t_view.get_lengths()[1]); j++)
{
for(int k = 0; k < static_cast<int>(t_view.get_lengths()[2]); k++)
{
for(int l = 0; l < static_cast<int>(t_view.get_lengths()[3]); l++)
{
for(int m = 0; m < static_cast<int>(t_view.get_lengths()[4]); m++)
{
printf("t_view[%d][%d][%d][%d][%d]: %f\n",
i,
j,
k,
l,
m,
t_view(i, j, k, l, m));
}
}
}
}
}
printf("I am inside bq_permuteN\n");
return ck_tile::reference_permute(
t_view, {0, 3, 1, 2, 4}); // {1, 2, 4, 16, 1}, group_n 16 {1, 2, 4, 1, 1}
}
template <typename GemmConfig, typename T>

View File

@@ -220,20 +220,21 @@ struct BlockGemmWeightPreshuffleBQuantARegBRegCReg
float scale_reg_f = cvt_scale_to_fp32(scale_reg);
// if(get_block_id() == 0 && get_thread_id() == 1)
//{
printf("get_block_id(): %d, get_warp_id(): %d, get_thread_id(): %d, nIter: "
"%d, NWarp: %d, WG::kN: %d, QuantGroupSize::kN: %d, "
"KPerBlockBQ: %d, kQScale: %d, scale_reg_f: %f, reg_offset: %d\n",
get_block_id(),
get_warp_id(),
get_thread_id(),
static_cast<int>(nIter),
NWarp,
WG::kN,
static_cast<int>(QuantGroupSize::kN),
static_cast<int>(KPerBlockBQ),
static_cast<int>(kQScale),
scale_reg_f,
reg_offset);
// printf("get_block_id(): %d, get_warp_id(): %d, get_thread_id(): %d,
// nIter: "
// "%d, NWarp: %d, WG::kN: %d, QuantGroupSize::kN: %d, "
// "KPerBlockBQ: %d, kQScale: %d, scale_reg_f: %f, reg_offset: %d\n",
// get_block_id(),
// get_warp_id(),
// get_thread_id(),
// static_cast<int>(nIter),
// NWarp,
// WG::kN,
// static_cast<int>(QuantGroupSize::kN),
// static_cast<int>(KPerBlockBQ),
// static_cast<int>(kQScale),
// scale_reg_f,
// reg_offset);
//}
static_for<0, WG::kM * WG::kN / warp_size, 1>{}([&](auto c_row) {

View File

@@ -1167,7 +1167,7 @@ struct QuantGemmKernel
if(get_block_id() == 0 && get_thread_id() == 0)
{
bq_block_window.template print_tile_window_range<BQDataType>(
0, 1, 0, 16, "bq block window");
0, 1, 0, 128, "bq block window");
}
return GemmPipeline{}.template operator()(a_block_window,
b_block_window,

View File

@@ -71,8 +71,8 @@ struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
tile_distribution_encoding_pattern_bq<BlockGemmShape,
WarpGemm,
BlockSize,
KPerBlockBQ,
NPerBlockBQ,
KPerBlockBQ, // 128/128 = 1
NPerBlockBQ, // 128/16 = 8
Problem::QuantGroupSize::kN>;
return TileEncodingPattern::make_2d_static_tile_distribution();

View File

@@ -169,9 +169,9 @@ struct tile_distribution_encoding_pattern_aq_transposed_c
template <typename BlockGemmShape,
typename WarpGemm,
index_t BlockSize,
index_t YPerTile,
index_t XPerTile,
index_t XPerQ,
index_t YPerTile, // 1
index_t XPerTile, // 8
index_t XPerQ, // 16
bool PreshuffleQuant = false>
struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding_pattern
{
@@ -255,16 +255,18 @@ struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding
else if constexpr(XPerQ <= WarpGemm::kN * NWarps)
{
// Case 2: Medium-grained - one quantization scale per warp
constexpr auto XR = XPerQ / WarpGemm::kN; // Scale replication factor
constexpr auto X1 = NWarps / XR; // Warps per unique scale
constexpr auto X0 = XPerTile / X1; // Iterations to cover X dimension
constexpr auto XR =
XPerQ / WarpGemm::kN; // Scale replication factor //16/16 = 1
constexpr auto X1 = NWarps / XR; // Warps per unique scale //4/1 = 4
constexpr auto X0 = XPerTile / X1; // Iterations to cover X dimension //8/4 = 2
return make_static_tile_distribution(
tile_distribution_encoding<sequence<MWarps, XR, get_warp_size()>,
tuple<sequence<YPerTile>, sequence<X0, X1>>,
tuple<sequence<0, 2, 0>, sequence<0>>,
tuple<sequence<0, 1, 1>, sequence<2>>,
sequence<2, 1>,
sequence<0, 0>>{});
tile_distribution_encoding<
sequence<MWarps, XR, get_warp_size()>, // 1, 1, 64
tuple<sequence<YPerTile>, sequence<X0, X1>>, // 1, (2, 4)
tuple<sequence<0, 2, 0>, sequence<0>>, //(1, 4, 1) (64)
tuple<sequence<0, 1, 1>, sequence<2>>,
sequence<2, 1>, //(2, 1(in Y dimension))
sequence<0, 0>>{});
}
else // XPerQ > WarpGemm::kN * NWarps
{