mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-02 21:27:45 +00:00
debugging PermuteN
This commit is contained in:
@@ -523,8 +523,31 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
|
||||
if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped)
|
||||
{
|
||||
ck_tile::FillConstant<ADataType>{static_cast<ADataType>(0x38)}(a_m_k);
|
||||
ck_tile::FillConstant<BDataType>{static_cast<BDataType>(0x22)}(b_k_n);
|
||||
ck_tile::FillConstant<BQDataType>{static_cast<BQDataType>(0.5f)}(*bq_tensor_ptr);
|
||||
ck_tile::FillConstant<BDataType>{static_cast<BDataType>(0x38)}(b_k_n);
|
||||
// ck_tile::FillConstant<BQDataType>{static_cast<BQDataType>(0.5f)}(*bq_tensor_ptr);
|
||||
if(bq_tensor_ptr)
|
||||
{
|
||||
BQDataType value = 1.0f;
|
||||
for(int i = 0; i < BQK; i++)
|
||||
{
|
||||
for(int j = 0; j < N / QuantGroupSize::kN; j += (16 / QuantGroupSize::kN))
|
||||
{
|
||||
for(int k = 0; k < 16 / QuantGroupSize::kN; k++)
|
||||
{
|
||||
(*bq_tensor_ptr)(i, j + k) = value;
|
||||
}
|
||||
value += static_cast<BQDataType>(1.0f);
|
||||
}
|
||||
}
|
||||
}
|
||||
// for(int i = 0; i < BQK; i++)
|
||||
// {
|
||||
// for(int j = 0; j < N / QuantGroupSize::kN; j++)
|
||||
// {
|
||||
// printf("%.2f ", (*bq_tensor_ptr)(i, j));
|
||||
// }
|
||||
// printf("\n");
|
||||
// }
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
@@ -111,8 +111,8 @@ auto bq_permuteN(const ck_tile::HostTensor<T>& t, index_t group_n)
|
||||
{
|
||||
assert(t.get_lengths().size() == 2);
|
||||
|
||||
int n_ = t.get_lengths()[1]; // 8
|
||||
int bqk_ = t.get_lengths()[0]; // 1
|
||||
int n_ = t.get_lengths()[1]; // 128
|
||||
int bqk_ = t.get_lengths()[0]; // 1 x 128
|
||||
constexpr int NRepeat =
|
||||
GemmConfig::N_Tile / GemmConfig::N_Warp_Tile / GemmConfig::N_Warp; // 128/16/4 = 2
|
||||
|
||||
@@ -120,9 +120,40 @@ auto bq_permuteN(const ck_tile::HostTensor<T>& t, index_t group_n)
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::N_Warp_Tile / group_n,
|
||||
NRepeat,
|
||||
bqk_}); //{1, 4, 16, 2, 1}
|
||||
bqk_}); //{1, 4, 16, 2, 1}, group_n:16 {1, 4, 1, 2, 1}
|
||||
std::copy(t.begin(), t.end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {0, 3, 1, 2, 4}); //{1, 2, 4, 16, 1}
|
||||
printf("I am inside bq_permuteN\n");
|
||||
printf("t.get_lengths(): %lu, %lu, %lu, %lu, %lu\n",
|
||||
t_view.get_lengths()[0],
|
||||
t_view.get_lengths()[1],
|
||||
t_view.get_lengths()[2],
|
||||
t_view.get_lengths()[3],
|
||||
t_view.get_lengths()[4]);
|
||||
for(int i = 0; i < static_cast<int>(t.get_lengths()[0]); i++)
|
||||
{
|
||||
for(int j = 0; j < static_cast<int>(t_view.get_lengths()[1]); j++)
|
||||
{
|
||||
for(int k = 0; k < static_cast<int>(t_view.get_lengths()[2]); k++)
|
||||
{
|
||||
for(int l = 0; l < static_cast<int>(t_view.get_lengths()[3]); l++)
|
||||
{
|
||||
for(int m = 0; m < static_cast<int>(t_view.get_lengths()[4]); m++)
|
||||
{
|
||||
printf("t_view[%d][%d][%d][%d][%d]: %f\n",
|
||||
i,
|
||||
j,
|
||||
k,
|
||||
l,
|
||||
m,
|
||||
t_view(i, j, k, l, m));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
printf("I am inside bq_permuteN\n");
|
||||
return ck_tile::reference_permute(
|
||||
t_view, {0, 3, 1, 2, 4}); // {1, 2, 4, 16, 1}, group_n 16 {1, 2, 4, 1, 1}
|
||||
}
|
||||
|
||||
template <typename GemmConfig, typename T>
|
||||
|
||||
@@ -220,20 +220,21 @@ struct BlockGemmWeightPreshuffleBQuantARegBRegCReg
|
||||
float scale_reg_f = cvt_scale_to_fp32(scale_reg);
|
||||
// if(get_block_id() == 0 && get_thread_id() == 1)
|
||||
//{
|
||||
printf("get_block_id(): %d, get_warp_id(): %d, get_thread_id(): %d, nIter: "
|
||||
"%d, NWarp: %d, WG::kN: %d, QuantGroupSize::kN: %d, "
|
||||
"KPerBlockBQ: %d, kQScale: %d, scale_reg_f: %f, reg_offset: %d\n",
|
||||
get_block_id(),
|
||||
get_warp_id(),
|
||||
get_thread_id(),
|
||||
static_cast<int>(nIter),
|
||||
NWarp,
|
||||
WG::kN,
|
||||
static_cast<int>(QuantGroupSize::kN),
|
||||
static_cast<int>(KPerBlockBQ),
|
||||
static_cast<int>(kQScale),
|
||||
scale_reg_f,
|
||||
reg_offset);
|
||||
// printf("get_block_id(): %d, get_warp_id(): %d, get_thread_id(): %d,
|
||||
// nIter: "
|
||||
// "%d, NWarp: %d, WG::kN: %d, QuantGroupSize::kN: %d, "
|
||||
// "KPerBlockBQ: %d, kQScale: %d, scale_reg_f: %f, reg_offset: %d\n",
|
||||
// get_block_id(),
|
||||
// get_warp_id(),
|
||||
// get_thread_id(),
|
||||
// static_cast<int>(nIter),
|
||||
// NWarp,
|
||||
// WG::kN,
|
||||
// static_cast<int>(QuantGroupSize::kN),
|
||||
// static_cast<int>(KPerBlockBQ),
|
||||
// static_cast<int>(kQScale),
|
||||
// scale_reg_f,
|
||||
// reg_offset);
|
||||
//}
|
||||
|
||||
static_for<0, WG::kM * WG::kN / warp_size, 1>{}([&](auto c_row) {
|
||||
|
||||
@@ -1167,7 +1167,7 @@ struct QuantGemmKernel
|
||||
if(get_block_id() == 0 && get_thread_id() == 0)
|
||||
{
|
||||
bq_block_window.template print_tile_window_range<BQDataType>(
|
||||
0, 1, 0, 16, "bq block window");
|
||||
0, 1, 0, 128, "bq block window");
|
||||
}
|
||||
return GemmPipeline{}.template operator()(a_block_window,
|
||||
b_block_window,
|
||||
|
||||
@@ -71,8 +71,8 @@ struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
|
||||
tile_distribution_encoding_pattern_bq<BlockGemmShape,
|
||||
WarpGemm,
|
||||
BlockSize,
|
||||
KPerBlockBQ,
|
||||
NPerBlockBQ,
|
||||
KPerBlockBQ, // 128/128 = 1
|
||||
NPerBlockBQ, // 128/16 = 8
|
||||
Problem::QuantGroupSize::kN>;
|
||||
|
||||
return TileEncodingPattern::make_2d_static_tile_distribution();
|
||||
|
||||
@@ -169,9 +169,9 @@ struct tile_distribution_encoding_pattern_aq_transposed_c
|
||||
template <typename BlockGemmShape,
|
||||
typename WarpGemm,
|
||||
index_t BlockSize,
|
||||
index_t YPerTile,
|
||||
index_t XPerTile,
|
||||
index_t XPerQ,
|
||||
index_t YPerTile, // 1
|
||||
index_t XPerTile, // 8
|
||||
index_t XPerQ, // 16
|
||||
bool PreshuffleQuant = false>
|
||||
struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding_pattern
|
||||
{
|
||||
@@ -255,16 +255,18 @@ struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding
|
||||
else if constexpr(XPerQ <= WarpGemm::kN * NWarps)
|
||||
{
|
||||
// Case 2: Medium-grained - one quantization scale per warp
|
||||
constexpr auto XR = XPerQ / WarpGemm::kN; // Scale replication factor
|
||||
constexpr auto X1 = NWarps / XR; // Warps per unique scale
|
||||
constexpr auto X0 = XPerTile / X1; // Iterations to cover X dimension
|
||||
constexpr auto XR =
|
||||
XPerQ / WarpGemm::kN; // Scale replication factor //16/16 = 1
|
||||
constexpr auto X1 = NWarps / XR; // Warps per unique scale //4/1 = 4
|
||||
constexpr auto X0 = XPerTile / X1; // Iterations to cover X dimension //8/4 = 2
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<MWarps, XR, get_warp_size()>,
|
||||
tuple<sequence<YPerTile>, sequence<X0, X1>>,
|
||||
tuple<sequence<0, 2, 0>, sequence<0>>,
|
||||
tuple<sequence<0, 1, 1>, sequence<2>>,
|
||||
sequence<2, 1>,
|
||||
sequence<0, 0>>{});
|
||||
tile_distribution_encoding<
|
||||
sequence<MWarps, XR, get_warp_size()>, // 1, 1, 64
|
||||
tuple<sequence<YPerTile>, sequence<X0, X1>>, // 1, (2, 4)
|
||||
tuple<sequence<0, 2, 0>, sequence<0>>, //(1, 4, 1) (64)
|
||||
tuple<sequence<0, 1, 1>, sequence<2>>,
|
||||
sequence<2, 1>, //(2, 1(in Y dimension))
|
||||
sequence<0, 0>>{});
|
||||
}
|
||||
else // XPerQ > WarpGemm::kN * NWarps
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user