mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 19:57:40 +00:00
debugging
This commit is contained in:
@@ -677,18 +677,19 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
|
||||
{
|
||||
ck_tile::HostTensor<BQDataType> bq_shuffle_host =
|
||||
ck_tile::shuffle_bq(bq_tensor_ptr.get(), GemmConfig::K_Tile / QuantGroupSize::kK);
|
||||
printf("Preshuffle BQ tensor\n");
|
||||
for(int i = 0; i < static_cast<int>(bq_shuffle_host.get_lengths()[0]); i++)
|
||||
{
|
||||
for(int j = 0; j < static_cast<int>(bq_shuffle_host.get_lengths()[1]); j++)
|
||||
{
|
||||
for(int k = 0; k < static_cast<int>(bq_shuffle_host.get_lengths()[2]); k++)
|
||||
{
|
||||
printf(
|
||||
"bq_shuffle_host[%d][%d][%d]: %f\n", i, j, k, bq_shuffle_host(i, j, k));
|
||||
}
|
||||
}
|
||||
}
|
||||
// printf("Preshuffle BQ tensor\n");
|
||||
// for(int i = 0; i < static_cast<int>(bq_shuffle_host.get_lengths()[0]); i++)
|
||||
// {
|
||||
// for(int j = 0; j < static_cast<int>(bq_shuffle_host.get_lengths()[1]); j++)
|
||||
// {
|
||||
// for(int k = 0; k < static_cast<int>(bq_shuffle_host.get_lengths()[2]); k++)
|
||||
// {
|
||||
// printf(
|
||||
// "bq_shuffle_host[%d][%d][%d]: %f\n", i, j, k, bq_shuffle_host(i, j,
|
||||
// k));
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
bq_dev_buf_ptr->ToDevice(bq_shuffle_host.data());
|
||||
}
|
||||
else
|
||||
|
||||
@@ -63,22 +63,22 @@ auto shuffle_bq(const ck_tile::HostTensor<T>* t, int block_bq_k)
|
||||
int n_ = lengths[1];
|
||||
ck_tile::HostTensor<T> t_view({n_, bqk_dim / block_bq_k, block_bq_k});
|
||||
std::copy(t->begin(), t->end(), t_view.begin());
|
||||
printf("I am inside shuffle_bq\n");
|
||||
printf("t_view.get_lengths(): %lu, %lu, %lu\n",
|
||||
t_view.get_lengths()[0],
|
||||
t_view.get_lengths()[1],
|
||||
t_view.get_lengths()[2]);
|
||||
for(int i = 0; i < static_cast<int>(t_view.get_lengths()[0]); i++)
|
||||
{
|
||||
for(int j = 0; j < static_cast<int>(t_view.get_lengths()[1]); j++)
|
||||
{
|
||||
for(int k = 0; k < static_cast<int>(t_view.get_lengths()[2]); k++)
|
||||
{
|
||||
printf("t_view[%d][%d][%d]: %f\n", i, j, k, t_view(i, j, k));
|
||||
}
|
||||
}
|
||||
}
|
||||
printf("I am inside shuffle_bq\n");
|
||||
// printf("I am inside shuffle_bq\n");
|
||||
// printf("t_view.get_lengths(): %lu, %lu, %lu\n",
|
||||
// t_view.get_lengths()[0],
|
||||
// t_view.get_lengths()[1],
|
||||
// t_view.get_lengths()[2]);
|
||||
// for(int i = 0; i < static_cast<int>(t_view.get_lengths()[0]); i++)
|
||||
// {
|
||||
// for(int j = 0; j < static_cast<int>(t_view.get_lengths()[1]); j++)
|
||||
// {
|
||||
// for(int k = 0; k < static_cast<int>(t_view.get_lengths()[2]); k++)
|
||||
// {
|
||||
// printf("t_view[%d][%d][%d]: %f\n", i, j, k, t_view(i, j, k));
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// printf("I am inside shuffle_bq\n");
|
||||
return ck_tile::reference_permute(t_view, {1, 0, 2});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -195,6 +195,18 @@ struct BlockGemmWeightPreshuffleBQuantARegBRegCReg
|
||||
pull_from_lane << 2, __builtin_bit_cast(int, scale_reg_dword));
|
||||
|
||||
float scale_reg_f = cvt_scale_to_fp32(gathered_scale_reg);
|
||||
printf("get_block_id(): %d, get_thread_id(): %d, nIter: %d, lane_id(): %d, "
|
||||
"kQScale: %d, pull_from_lane: %d, scale_reg: %f, "
|
||||
"gathered_scale_reg: %d, scale_reg_f: %f\n",
|
||||
get_block_id(),
|
||||
get_thread_id(),
|
||||
nIter,
|
||||
__lane_id(),
|
||||
static_cast<int>(kQScale),
|
||||
pull_from_lane,
|
||||
scale_reg,
|
||||
gathered_scale_reg,
|
||||
scale_reg_f);
|
||||
|
||||
static_for<0, WG::kM * WG::kN / warp_size, 1>{}([&](auto c_row) {
|
||||
auto& c_ref = c_block_tensor.get_thread_buffer()[tbuf_offset + c_row];
|
||||
|
||||
@@ -355,6 +355,23 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase<Problem_>
|
||||
|
||||
float scale_reg_f = Base::cvt_scale_to_fp32(gathered_scale_reg);
|
||||
|
||||
// if(get_block_id() ==0 && get_thread_id() == 0){
|
||||
printf("get_block_id(): %d, get_thread_id(): %d, nIter: %d, lane_id(): "
|
||||
"%u, KQPerBLock: %d, "
|
||||
"kQScale: %d, pull_from_lane: %u, scale_reg: %f, "
|
||||
"gathered_scale_reg: %d, scale_reg_f: %f\n",
|
||||
get_block_id(),
|
||||
get_thread_id(),
|
||||
static_cast<int>(nIter),
|
||||
__lane_id(),
|
||||
Traits::KQPerBlock,
|
||||
static_cast<int>(kQScale),
|
||||
pull_from_lane,
|
||||
scale_reg,
|
||||
gathered_scale_reg,
|
||||
scale_reg_f);
|
||||
//}
|
||||
|
||||
static_for<0, WarpGemm::kM * WarpGemm::kN / warp_size, 1>{}(
|
||||
[&](auto c_row) {
|
||||
c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] +=
|
||||
|
||||
@@ -280,12 +280,13 @@ struct QuantGemmKernel
|
||||
// Helper: Create Pre-shuffled Quantization Tensor Descriptor
|
||||
// ===================================================================
|
||||
template <index_t KPerBlockBQ,
|
||||
index_t NPerBlockBQ,
|
||||
index_t NPerBlock,
|
||||
index_t WarpTileN,
|
||||
index_t GetVectorSizeBQ,
|
||||
typename BQDataType_>
|
||||
CK_TILE_DEVICE static auto
|
||||
MakePreshuffledQuantTensorView(const BQDataType_* bq_ptr, index_t N, index_t QK_B)
|
||||
MakePreshuffledQuantTensorView(const BQDataType_* bq_ptr, index_t N, index_t QN_B, index_t QK_B)
|
||||
{
|
||||
// Step 1: Calculate base BQ tensor dimensions
|
||||
// ----------------------------------------------------------
|
||||
@@ -316,7 +317,7 @@ struct QuantGemmKernel
|
||||
// ----------------------------------------------------------
|
||||
// Pad the X dimension to be a multiple of block_tile_size to ensure
|
||||
// each thread block can process complete tiles without edge cases
|
||||
const auto block_tile_size = NPerBlock * KPerBlockBQ; // 64 * 2 =128
|
||||
const auto block_tile_size = NPerBlockBQ * KPerBlockBQ; // 64 * 2 = 128 || 8 * 2 = 16
|
||||
|
||||
if(get_block_id() == 0 && get_thread_id() == 0)
|
||||
{
|
||||
@@ -327,7 +328,7 @@ struct QuantGemmKernel
|
||||
bq_desc,
|
||||
make_tuple(
|
||||
make_pass_through_transform(bq_y),
|
||||
make_right_pad_transform(bq_x, get_padding_size(bq_x, block_tile_size))), // 2, 128
|
||||
make_right_pad_transform(bq_x, get_padding_size(bq_x, block_tile_size))), // 2, 16
|
||||
make_tuple(sequence<0>{}, sequence<1>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
@@ -336,22 +337,28 @@ struct QuantGemmKernel
|
||||
// Split the X dimension into [wave_tile_count_x, wave_tile_size]
|
||||
// This separates the work into tiles that can be processed by
|
||||
// individual warps/waves
|
||||
const auto pad_bq_x = bq_pad0_desc.get_lengths()[I1];
|
||||
const auto wave_tile_size = WarpTileN * KPerBlockBQ;
|
||||
const auto wave_tile_count_x = ck_tile::integer_divide_ceil(pad_bq_x, wave_tile_size);
|
||||
const auto pad_bq_x = bq_pad0_desc.get_lengths()[I1]; // 128 || 16
|
||||
const auto wave_tile_size = (WarpTileN / QN_B) * KPerBlockBQ; // 16 * 2= 32 || 16/8 x 2 = 4
|
||||
const auto wave_tile_count_x =
|
||||
ck_tile::integer_divide_ceil(pad_bq_x, wave_tile_size); // 128/32 = 4 || 16/4 = 4
|
||||
|
||||
if(get_block_id() == 0 && get_thread_id() == 0)
|
||||
{
|
||||
printf("pad_bq_x:%d, wave_tile_size: %d, wave_tile_count_x: %d\n",
|
||||
printf("pad_bq_x:%d, WarpTileN:%d, NPerBlockPQ: %d, KPerBlockBQ: %d, wave_tile_size: "
|
||||
"%d, wave_tile_count_x: %d\n",
|
||||
pad_bq_x,
|
||||
WarpTileN,
|
||||
NPerBlockBQ,
|
||||
KPerBlockBQ,
|
||||
wave_tile_size,
|
||||
wave_tile_count_x);
|
||||
}
|
||||
|
||||
const auto bq_unmerge_pad0_desc = transform_tensor_descriptor(
|
||||
bq_pad0_desc,
|
||||
make_tuple(make_pass_through_transform(bq_y),
|
||||
make_unmerge_transform(make_tuple(wave_tile_count_x, wave_tile_size))),
|
||||
make_tuple(
|
||||
make_pass_through_transform(bq_y),
|
||||
make_unmerge_transform(make_tuple(wave_tile_count_x, wave_tile_size))), // 2, 4, 4
|
||||
make_tuple(sequence<0>{}, sequence<1>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}));
|
||||
|
||||
@@ -361,10 +368,11 @@ struct QuantGemmKernel
|
||||
// This ensures coalesced memory accesses within each warp
|
||||
const auto bq_pad1_desc = transform_tensor_descriptor(
|
||||
bq_unmerge_pad0_desc,
|
||||
make_tuple(make_pass_through_transform(bq_y),
|
||||
make_pass_through_transform(wave_tile_count_x),
|
||||
make_right_pad_transform(wave_tile_size,
|
||||
get_padding_size(wave_tile_size, get_warp_size()))),
|
||||
make_tuple(
|
||||
make_pass_through_transform(bq_y), // 2
|
||||
make_pass_through_transform(wave_tile_count_x), // 4
|
||||
make_right_pad_transform(wave_tile_size,
|
||||
get_padding_size(wave_tile_size, get_warp_size()))), // 64
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}));
|
||||
|
||||
@@ -381,8 +389,8 @@ struct QuantGemmKernel
|
||||
}
|
||||
const auto bq_merge_pad1_desc = transform_tensor_descriptor(
|
||||
bq_pad1_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(bq_y, wave_tile_count_x)),
|
||||
make_pass_through_transform(pad_wave_size)),
|
||||
make_tuple(make_merge_transform(make_tuple(bq_y, wave_tile_count_x)), // 8
|
||||
make_pass_through_transform(pad_wave_size)), // 64
|
||||
make_tuple(sequence<0, 1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
@@ -817,11 +825,13 @@ struct QuantGemmKernel
|
||||
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::QuantGroupSize>;
|
||||
return MakePreshuffledQuantTensorView<
|
||||
GemmPipeline::KPerBlockBQ,
|
||||
GemmPipeline::NPerBlockBQ,
|
||||
GemmPipeline::NPerBlock,
|
||||
TilePartitioner::BlockGemmShape::WarpTile::at(I1),
|
||||
GemmPipeline::GetVectorSizeBQ()>(
|
||||
bq_ptr,
|
||||
ck_tile::integer_divide_ceil(kargs.N, QuantGroupSize::kN),
|
||||
QuantGroupSize::kN,
|
||||
kargs.QK_B);
|
||||
}
|
||||
else
|
||||
@@ -1170,7 +1180,7 @@ struct QuantGemmKernel
|
||||
// bq_block_window.template print_tile_window_range<BQDataType>(
|
||||
// 0, 128, 0, 2, "bq block window");
|
||||
bq_block_window.template print_tile_window_range<BQDataType>(
|
||||
0, 8, 0, 64, "bq block window");
|
||||
0, 1, 0, 64, "bq block window");
|
||||
}
|
||||
return GemmPipeline{}.template operator()(
|
||||
a_block_window, b_block_window, bq_block_window, num_loop, smem_ptr_0, n);
|
||||
|
||||
@@ -35,12 +35,12 @@ struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
|
||||
using BQLayout = remove_cvref_t<typename Problem::BQLayout>;
|
||||
using BlockGemmShape = typename Problem::BlockGemmShape;
|
||||
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t NPerBlockBQ = NPerBlock / Problem::QuantGroupSize::kN;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t KPerBlockBQ = KPerBlock / Problem::QuantGroupSize::kK;
|
||||
constexpr index_t VecLoadSize = GetVectorSizeBQ<Problem>();
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t NPerBlockBQ = NPerBlock / Problem::QuantGroupSize::kN;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t KPerBlockBQ = KPerBlock / Problem::QuantGroupSize::kK;
|
||||
// constexpr index_t VecLoadSize = GetVectorSizeBQ<Problem>();
|
||||
constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant;
|
||||
|
||||
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
|
||||
@@ -61,7 +61,7 @@ struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
|
||||
BlockSize,
|
||||
NPerBlock / WarpGemm::kN,
|
||||
ck_tile::integer_least_multiple(WarpGemm::kN * KPerBlockBQ, get_warp_size()),
|
||||
VecLoadSize,
|
||||
Problem::QuantGroupSize::kN,
|
||||
PreshuffleQuant>;
|
||||
return TileEncodingPattern::make_2d_static_tile_distribution();
|
||||
}
|
||||
|
||||
@@ -189,8 +189,8 @@ struct tile_distribution_encoding_pattern_aq_transposed_c
|
||||
template <typename BlockGemmShape,
|
||||
typename WarpGemm,
|
||||
index_t BlockSize,
|
||||
index_t YPerTile,
|
||||
index_t XPerTile,
|
||||
index_t YPerTile, // 4
|
||||
index_t XPerTile, // 64
|
||||
index_t YPerQ,
|
||||
bool PreshuffleQuant = false>
|
||||
struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding_pattern
|
||||
@@ -236,41 +236,65 @@ struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding
|
||||
{
|
||||
if constexpr(PreshuffleQuant)
|
||||
{
|
||||
constexpr index_t X1 = warp_size;
|
||||
constexpr index_t X0 = XPerTile / warp_size;
|
||||
constexpr index_t Y1 = NWarps;
|
||||
constexpr index_t Y0 = YPerTile / Y1;
|
||||
// constexpr index_t X1 = warp_size;
|
||||
// constexpr index_t X0 = XPerTile / warp_size;
|
||||
// constexpr index_t Y1 = NWarps;
|
||||
// constexpr index_t Y0 = YPerTile / Y1;
|
||||
|
||||
// return make_static_tile_distribution(
|
||||
// tile_distribution_encoding<sequence<MWarps>,
|
||||
// tuple<sequence<Y0, Y1>, sequence<X0, X1>>,
|
||||
// tuple<sequence<0, 1>, sequence<2>>, // (Mwarp, Nwarp),
|
||||
// (X1 = warp_size(64)) tuple<sequence<0, 1>,
|
||||
// sequence<1>>, sequence<1, 2>, //(NiterPerWarp,
|
||||
// X(threads in x dimension, 1)) sequence<0, 0>>{});
|
||||
|
||||
// constexpr index_t X1 = warp_size; //64
|
||||
constexpr index_t X0 = XPerTile / warp_size; // 64/64 = 1
|
||||
constexpr index_t X1 = XPerTile / WarpGemm::kN; // 64/16 = 4
|
||||
constexpr index_t X2 = WarpGemm::kN / YPerQ; // 16/8=2
|
||||
constexpr index_t XR = YPerQ; // 8
|
||||
constexpr index_t Y1 = NWarps; // 4
|
||||
constexpr index_t Y0 = YPerTile / Y1; // 1
|
||||
constexpr index_t YR = 1;
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<MWarps>,
|
||||
tuple<sequence<Y0, Y1>, sequence<X0, X1>>,
|
||||
tuple<sequence<0, 1>, sequence<2>>,
|
||||
tuple<sequence<0, 1>, sequence<1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{});
|
||||
tile_distribution_encoding<
|
||||
sequence<MWarps, XR, YR>,
|
||||
tuple<sequence<Y0, Y1>, sequence<X0, X1, X2>>,
|
||||
tuple<sequence<0, 1>, sequence<0, 2, 0>>, // (Mwarp, Nwarp),
|
||||
tuple<sequence<0, 1>,
|
||||
sequence<1, 2, 2>>, //(repeat for 8 threads in X direction, X2(no of
|
||||
// scales per warp), X1(warp_size/quant_group_size),
|
||||
// YR)(8, 2, 4, 1)
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(YPerQ < WarpGemm::kN)
|
||||
{
|
||||
// Case 1: Fine-grained - multiple quantization scales within a single warp
|
||||
constexpr index_t X = XPerTile; // Full X dimension of tile
|
||||
constexpr index_t XR = 1; // No Y replication needed
|
||||
constexpr index_t Y0 = NIterPerWarp; // Iterations per warp in N-dim
|
||||
constexpr index_t Y1 = NWarps; // Number of warps in N-dim
|
||||
constexpr index_t Y2 = WarpGemm::kN / YPerQ; // Number of scales per warp
|
||||
constexpr index_t YR = YPerQ; // Elements per quantization group
|
||||
constexpr index_t X = XPerTile; // Full X dimension of tile
|
||||
constexpr index_t XR = 1; // No Y replication needed
|
||||
constexpr index_t Y0 = NIterPerWarp; // Iterations per warp in N-dim
|
||||
constexpr index_t Y1 = NWarps; // Number of warps in N-dim
|
||||
constexpr index_t Y2 =
|
||||
WarpGemm::kN / YPerQ; // Number of scales per warp 16/ 8 = 2
|
||||
constexpr index_t YR = YPerQ; // Elements per quantization group 8
|
||||
|
||||
static_assert(Y0 * Y1 * Y2 == YPerTile,
|
||||
"Y0, Y1, Y2 must cover the blocktile along Y.");
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<MWarps, XR, YR>,
|
||||
tuple<sequence<Y0, Y1, Y2>, sequence<X>>,
|
||||
tuple<sequence<0, 1>, sequence<0, 1, 0>>,
|
||||
tuple<sequence<0, 1>, sequence<1, 2, 2>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{});
|
||||
tile_distribution_encoding<
|
||||
sequence<MWarps, XR, YR>,
|
||||
tuple<sequence<Y0, Y1, Y2>, sequence<X>>,
|
||||
tuple<sequence<0, 1>, sequence<0, 1, 0>>, //(Mwarp, Nwarp), (XR, Y2[no of
|
||||
// scales per warp], YR)
|
||||
tuple<sequence<0, 1>, sequence<1, 2, 2>>,
|
||||
sequence<1, 2>, //(NiterPerWarp, X(threads in x dimension))
|
||||
sequence<0, 0>>{});
|
||||
}
|
||||
else if constexpr(YPerQ <= WarpGemm::kN * NWarps)
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user