mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 19:57:40 +00:00
formatting
This commit is contained in:
@@ -16,7 +16,7 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12")
|
||||
# gemm_aquant_quantgrouped_preshufflequant.cpp
|
||||
# gemm_bquant_quantgrouped_bf8i4.cpp
|
||||
# gemm_bquant_quantgrouped_fp8i4.cpp
|
||||
gemm_bquant_quantgrouped_bf16mxfp4.cpp
|
||||
# gemm_bquant_quantgrouped_bf16mxfp4.cpp
|
||||
# gemm_bquant_quantgrouped_bf8.cpp
|
||||
# gemm_bquant_quantgrouped_fp8.cpp
|
||||
# gemm_bquant_quantgrouped_preshuffleb.cpp
|
||||
|
||||
@@ -35,30 +35,45 @@ void bquant_quantgrouped_preshufflequant_instance_factory(
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"fp8", "bquant", "non-preshuffleb", "preshufflequant", "1x16x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
|
||||
ck_tile::fp8_t,
|
||||
ck_tile::half_t,
|
||||
float>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 16, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"fp8", "bquant", "non-preshuffleb", "preshufflequant", "1x32x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
|
||||
ck_tile::fp8_t,
|
||||
ck_tile::half_t,
|
||||
float>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 32, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"fp8",
|
||||
"bquant",
|
||||
"non-preshuffleb",
|
||||
"preshufflequant",
|
||||
"1x16x128"})] = [](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 16, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"fp8",
|
||||
"bquant",
|
||||
"non-preshuffleb",
|
||||
"preshufflequant",
|
||||
"1x32x128"})] = [](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 32, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"fp8",
|
||||
"bquant",
|
||||
"non-preshuffleb",
|
||||
"preshufflequant",
|
||||
"1x64x128"})] = [](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
// lut[hash_multiple_strings({"bf8", "bquant", "non-preshuffleb", "preshufflequant",
|
||||
// "1x1x128"})] =
|
||||
// [](const ck_tile::ArgParser& arg_parser) {
|
||||
|
||||
@@ -33,14 +33,14 @@ auto create_args(int argc, char* argv[])
|
||||
"fp8",
|
||||
"Data type. For AQuant: fp8, bf8, i4fp8, or i4bf8; for Bquant: fp8, bf8, fp8i4, "
|
||||
"bf8i4 or bf16fp4")
|
||||
.insert("warmup", "1", "Number of iterations before benchmarking the kernel")
|
||||
.insert("repeat", "0", "Number of iterations to benchmark the kernel")
|
||||
.insert("warmup", "50", "Number of iterations before benchmarking the kernel")
|
||||
.insert("repeat", "1000", "Number of iterations to benchmark the kernel")
|
||||
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
|
||||
.insert("split_k", "1", "SplitK value")
|
||||
.insert("device", "0", "Device id that will be used to run the kernel")
|
||||
.insert("init", "0", "0:random, 1:linear, 2:constant(1)")
|
||||
.insert("flush_cache", "true", "Flush cache before running the kernel")
|
||||
.insert("rotating_count", "0", "Rotating count")
|
||||
.insert("rotating_count", "1000", "Rotating count")
|
||||
.insert("quant_mode", "bquant", "Choose aquant, bquant, tensor or rowcol")
|
||||
.insert("preshuffleb", "false", "Enable preshuffle of tensor B")
|
||||
.insert("preshufflequant", "false", "Enable preshuffle of quant tensor")
|
||||
|
||||
@@ -357,9 +357,9 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
|
||||
<< " C_Type = " << ck_tile::DataTypeTraits<typename TypeConfig::CDataType>::name
|
||||
<< " QuantMode = " << quant_type_to_string(QuantMode)
|
||||
<< " PreshuffleQuant = " << (GemmConfig::PreshuffleQuant ? "true" : "false") << " : "
|
||||
<< " PreshuffleB = " << (GemmConfig::PreshuffleB ? "true" : "false") << " : "
|
||||
<< ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
|
||||
<< std::endl;
|
||||
<< " PreshuffleB = \n"
|
||||
<< (GemmConfig::PreshuffleB ? "true" : "false") << " : " << ave_time << " ms, "
|
||||
<< tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl;
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
@@ -349,6 +349,7 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase<Problem_>
|
||||
constexpr index_t reg_offset = nIter;
|
||||
auto pull_from_lane =
|
||||
(__lane_id() & (WarpGemm::kN - 1)) * Traits::KQPerBlock + kQScale;
|
||||
|
||||
auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset];
|
||||
// cross lane ops
|
||||
uint32_t scale_reg_dword;
|
||||
@@ -368,19 +369,20 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase<Problem_>
|
||||
|
||||
float scale_reg_f = Base::cvt_scale_to_fp32(gathered_scale_reg);
|
||||
|
||||
printf("block_id: %d, warp_id: %d, thread_id(): %d, nIter: %d, lane_id(): "
|
||||
"%u, kQScale: %d, pull_from_lane: %u, scale_reg: %f, "
|
||||
"scale_reg_f: %f\n",
|
||||
get_block_id(),
|
||||
get_warp_id(),
|
||||
get_thread_id(),
|
||||
static_cast<int>(nIter),
|
||||
__lane_id(),
|
||||
static_cast<int>(kQScale),
|
||||
pull_from_lane,
|
||||
scale_reg,
|
||||
scale_reg_f);
|
||||
|
||||
// printf("block_id: %d, warp_id: %d, thread_id(): %d, nIter: %d,
|
||||
// lane_id(): "
|
||||
// "%u, kQScale: %d, pull_from_lane: %u, scale_reg: %f, "
|
||||
// "scale_reg_f: %f\n",
|
||||
// get_block_id(),
|
||||
// get_warp_id(),
|
||||
// get_thread_id(),
|
||||
// static_cast<int>(nIter),
|
||||
// __lane_id(),
|
||||
// static_cast<int>(kQScale),
|
||||
// pull_from_lane,
|
||||
// scale_reg,
|
||||
// scale_reg_f);
|
||||
|
||||
static_for<0, WarpGemm::kM * WarpGemm::kN / warp_size, 1>{}(
|
||||
[&](auto c_row) {
|
||||
c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] +=
|
||||
|
||||
@@ -298,15 +298,15 @@ struct QuantGemmKernel
|
||||
const auto bq_x = N * KPerBlockBQ; // 2x2 = 4
|
||||
const auto bq_y = QK_B / KPerBlockBQ; // 4/2 = 2
|
||||
|
||||
if(get_block_id() == 0 && get_thread_id() == 0)
|
||||
{
|
||||
printf("N:%d, QK_B:%d\n", N, QK_B);
|
||||
printf("bq_x: %d, bq_y: %d, getVectorSizeBQ: %d, kPerBlockBQ: %d\n",
|
||||
bq_x,
|
||||
bq_y,
|
||||
GetVectorSizeBQ,
|
||||
KPerBlockBQ);
|
||||
}
|
||||
// if(get_block_id() == 0 && get_thread_id() == 0)
|
||||
// {
|
||||
// printf("N:%d, QK_B:%d\n", N, QK_B);
|
||||
// printf("bq_x: %d, bq_y: %d, getVectorSizeBQ: %d, kPerBlockBQ: %d\n",
|
||||
// bq_x,
|
||||
// bq_y,
|
||||
// GetVectorSizeBQ,
|
||||
// KPerBlockBQ);
|
||||
// }
|
||||
|
||||
const auto bq_desc = make_naive_tensor_descriptor(make_tuple(bq_y, bq_x),
|
||||
make_tuple(bq_x, 1),
|
||||
@@ -319,10 +319,10 @@ struct QuantGemmKernel
|
||||
// each thread block can process complete tiles without edge cases
|
||||
const auto block_tile_size = NPerBlockBQ * KPerBlockBQ; // 2x2 = 4
|
||||
|
||||
if(get_block_id() == 0 && get_thread_id() == 0)
|
||||
{
|
||||
printf("block_tile_size:%d \n", block_tile_size);
|
||||
}
|
||||
// if(get_block_id() == 0 && get_thread_id() == 0)
|
||||
// {
|
||||
// printf("block_tile_size:%d \n", block_tile_size);
|
||||
// }
|
||||
|
||||
const auto bq_pad0_desc = transform_tensor_descriptor(
|
||||
bq_desc,
|
||||
@@ -337,22 +337,25 @@ struct QuantGemmKernel
|
||||
// Split the X dimension into [wave_tile_count_x, wave_tile_size]
|
||||
// This separates the work into tiles that can be processed by
|
||||
// individual warps/waves
|
||||
const auto pad_bq_x = bq_pad0_desc.get_lengths()[I1]; // 4
|
||||
const auto wave_tile_size = ((QN_B <= WarpTileN) ? (WarpTileN / QN_B) : 1) * KPerBlockBQ; // 32/16 x 2 = 4 = 2
|
||||
const auto pad_bq_x = bq_pad0_desc.get_lengths()[I1]; // 4
|
||||
const auto wave_tile_size =
|
||||
((QN_B <= WarpTileN) ? (WarpTileN / QN_B) : 1 /*QN_B/WarpTileN*/) *
|
||||
KPerBlockBQ; // 32/16 x 2 = 4 = 2
|
||||
const auto wave_tile_count_x =
|
||||
ck_tile::integer_divide_ceil(pad_bq_x, wave_tile_size); // 4/4 = 1 ==2
|
||||
|
||||
if(get_block_id() == 0 && get_thread_id() == 0)
|
||||
{
|
||||
printf("pad_bq_x:%d, WarpTileN:%d, NPerBlockPQ: %d, KPerBlockBQ: %d, wave_tile_size: "
|
||||
"%d, wave_tile_count_x: %d\n",
|
||||
pad_bq_x,
|
||||
WarpTileN,
|
||||
NPerBlockBQ,
|
||||
KPerBlockBQ,
|
||||
wave_tile_size,
|
||||
wave_tile_count_x);
|
||||
}
|
||||
// if(get_block_id() == 0 && get_thread_id() == 0)
|
||||
// {
|
||||
// printf("pad_bq_x:%d, WarpTileN:%d, NPerBlockPQ: %d, KPerBlockBQ: %d, wave_tile_size:
|
||||
// "
|
||||
// "%d, wave_tile_count_x: %d\n",
|
||||
// pad_bq_x,
|
||||
// WarpTileN,
|
||||
// NPerBlockBQ,
|
||||
// KPerBlockBQ,
|
||||
// wave_tile_size,
|
||||
// wave_tile_count_x);
|
||||
// }
|
||||
|
||||
const auto bq_unmerge_pad0_desc = transform_tensor_descriptor(
|
||||
bq_pad0_desc,
|
||||
@@ -383,13 +386,16 @@ struct QuantGemmKernel
|
||||
// where merged_outer_dim = bq_y * wave_tile_count_x
|
||||
// This layout facilitates efficient block-to-data mapping
|
||||
const auto pad_wave_size = ck_tile::integer_least_multiple(wave_tile_size, get_warp_size());
|
||||
if(get_block_id() == 0 && get_thread_id() == 0)
|
||||
{
|
||||
printf("pad_wave_size:%d\n", pad_wave_size);
|
||||
}
|
||||
// if(get_block_id() == 0 && get_thread_id() == 0)
|
||||
// {
|
||||
// printf("pad_wave_size:%d\n", pad_wave_size);
|
||||
// printf("Final bq tensor lengths: %d x %d \n",
|
||||
// bq_y * wave_tile_count_x,
|
||||
// pad_wave_size);
|
||||
// }
|
||||
const auto bq_merge_pad1_desc = transform_tensor_descriptor(
|
||||
bq_pad1_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(bq_y, wave_tile_count_x)), // 2
|
||||
make_tuple(make_merge_transform(make_tuple(bq_y, wave_tile_count_x)), // 4
|
||||
make_pass_through_transform(pad_wave_size)), // 64
|
||||
make_tuple(sequence<0, 1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
@@ -1115,13 +1121,33 @@ struct QuantGemmKernel
|
||||
if constexpr(PreshuffleQuant)
|
||||
{
|
||||
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
|
||||
constexpr auto block_n = TilePartitioner::NPerBlock / QuantGroupSize::kN;
|
||||
constexpr auto warp_n = TilePartitioner::BlockGemmShape::WarpTile::at(I1);
|
||||
constexpr auto block_n =
|
||||
TilePartitioner::NPerBlock / QuantGroupSize::kN; // 64 / 32 = 2
|
||||
constexpr auto warp_n = TilePartitioner::BlockGemmShape::WarpTile::at(I1);
|
||||
constexpr auto bqk_per_block = TilePartitioner::KPerBlock / QuantGroupSize::kK;
|
||||
constexpr auto tile_window_width = ck_tile::integer_least_multiple(
|
||||
warp_n * bqk_per_block, get_warp_size()); // 128
|
||||
constexpr auto tile_window_height = block_n / warp_n; // 2
|
||||
auto block_n_idx = i_n / block_n;
|
||||
warp_n * bqk_per_block, get_warp_size()); // 128
|
||||
constexpr auto tile_window_height =
|
||||
min(block_n,
|
||||
TilePartitioner::BlockGemmShape::BlockWarps::at(
|
||||
I1)); // block_n / warp_n; // 2 / 4 = 0
|
||||
auto block_n_idx = i_n / TilePartitioner::NPerBlock; // 0,1,2
|
||||
|
||||
// if(get_thread_id() == 0)
|
||||
// {
|
||||
// printf("In MakeGemmTileWindows for BQ with PreshuffleQuant\n");
|
||||
// printf("block_id: %d, block_n: %d, warp_n: %d, bqk_per_block: %d,
|
||||
// block_n_idx: %d, "
|
||||
// "tile_window_width: %d, tile_window_height: %d, i_n: %d\n",
|
||||
// get_block_id(),
|
||||
// static_cast<int>(block_n),
|
||||
// static_cast<int>(warp_n),
|
||||
// static_cast<int>(bqk_per_block),
|
||||
// static_cast<int>(block_n_idx),
|
||||
// tile_window_width,
|
||||
// static_cast<int>(tile_window_height),
|
||||
// static_cast<int>(i_n));
|
||||
// }
|
||||
|
||||
return make_tile_window(
|
||||
bq_pad_view,
|
||||
@@ -1226,15 +1252,15 @@ struct QuantGemmKernel
|
||||
{
|
||||
n = kargs.N;
|
||||
}
|
||||
if(get_block_id() == 0 && get_thread_id() == 0)
|
||||
{
|
||||
printf("In RunGemm, before GemmPipeline call for BQuantGrouped\n");
|
||||
// To print Tile window after bq_pad0_desc
|
||||
// bq_block_window.template print_tile_window_range<BQDataType>(
|
||||
// 0, 128, 0, 2, "bq block window");
|
||||
bq_block_window.template print_tile_window_range<BQDataType>(
|
||||
0, 8, 0, 64, "bq block window");
|
||||
}
|
||||
// if(get_block_id() == 0 && get_thread_id() == 0)
|
||||
// {
|
||||
// printf("In RunGemm, before GemmPipeline call for BQuantGrouped\n");
|
||||
// // To print Tile window after bq_pad0_desc
|
||||
// // bq_block_window.template print_tile_window_range<BQDataType>(
|
||||
// // 0, 128, 0, 2, "bq block window");
|
||||
// bq_block_window.template print_tile_window_range<BQDataType>(
|
||||
// 0, 8, 0, 64, "bq block window");
|
||||
// }
|
||||
return GemmPipeline{}.template operator()(
|
||||
a_block_window, b_block_window, bq_block_window, num_loop, smem_ptr_0, n);
|
||||
}
|
||||
|
||||
@@ -196,7 +196,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const BElementFunction& b_element_func,
|
||||
const BQDramBlockWindowTmp& bq_dram_block_window_tmp,
|
||||
index_t n,
|
||||
[[maybe_unused]] index_t n,
|
||||
index_t num_loop,
|
||||
void* p_smem) const
|
||||
{
|
||||
@@ -280,9 +280,12 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
|
||||
constexpr BDramTileWindowStep b_dram_tile_window_step =
|
||||
is_b_row_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
|
||||
const BQDramTileWindowStep bq_dram_tile_window_step =
|
||||
(PreshuffleQuant) ? make_array(ck_tile::integer_least_multiple(n, NPerBlock) /
|
||||
BlockGemmShape::WarpTile::at(number<1>{}),
|
||||
0)
|
||||
(PreshuffleQuant)
|
||||
? make_array(((NPerBlockBQ < BlockGemmShape::BlockWarps::at(number<1>{}))
|
||||
? ck_tile::integer_divide_ceil(n, QuantGroupSize::kN)
|
||||
: ck_tile::integer_least_multiple(n, NPerBlock) /
|
||||
BlockGemmShape::WarpTile::at(number<1>{})),
|
||||
0)
|
||||
: is_bq_row_major ? make_array(KPerBlockBQ, 0)
|
||||
: make_array(0, KPerBlockBQ);
|
||||
|
||||
|
||||
@@ -245,56 +245,70 @@ struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding
|
||||
if constexpr(NPerQ <= WarpGemm::kN)
|
||||
{
|
||||
constexpr auto N1 = BlockGemmShape::kK / KPerQ; // 2
|
||||
constexpr auto N0 = WarpGemm::kN / NPerQ; //BlockGemmShape::kN / KPerQ; // 1
|
||||
constexpr auto N0 = WarpGemm::kN / NPerQ; // BlockGemmShape::kN / KPerQ; // 1
|
||||
constexpr auto N2 = 1;
|
||||
constexpr auto NR1 = NPerQ; // 16
|
||||
constexpr auto NR1 = NPerQ; // 16
|
||||
constexpr auto NR0 = warp_size / (N0 * N1 * N2 * NR1); // 64/(1*2*1*16)=2
|
||||
constexpr auto K1 = NWarps; // 4
|
||||
constexpr auto K0 = KPerTile / K1; // 1
|
||||
constexpr auto K1 = NWarps; // 4
|
||||
constexpr auto K0 = KPerTile / K1; // 1
|
||||
constexpr auto KR = 1;
|
||||
|
||||
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<MWarps, NR0, NR1, KR>,
|
||||
tuple<sequence<K0, K1>, sequence<N0, N1, N2>>,
|
||||
tuple<sequence<0, 1>, sequence<0, 2, 0, 2, 0>>, // (Mwarp, Nwarp),(XR0, X0, XR1, X1, YR)
|
||||
tuple<sequence<0, 1>, sequence<0, 2, 0, 2, 0>>, // (Mwarp, Nwarp),(XR0, X0,
|
||||
// XR1, X1, YR)
|
||||
tuple<sequence<0, 1>, sequence<1, 0, 2, 1, 3>>, // (1, 4), (2, 1, 16, 2, 1)
|
||||
sequence<1, 2>,
|
||||
sequence<0, 2>>{});
|
||||
}
|
||||
else if constexpr(NPerQ <= WarpGemm::kN * NWarps)
|
||||
else if constexpr(NPerQ < WarpGemm::kN * NWarps)
|
||||
{
|
||||
constexpr auto KR = NPerQ / WarpGemm::kN; // Scale replication factor 32/16 = 2
|
||||
constexpr auto K1 = NWarps / KR; // Warps per unique scale 4/2 = 2
|
||||
constexpr auto K0 = KPerTile / K1; // Iterations to cover N dimension 4/2 = 2
|
||||
constexpr auto KR = NPerQ / WarpGemm::kN; // Scale replication factor 64/16 = 4
|
||||
constexpr auto K1 = NWarps / KR; // Warps per unique scale 4/4 = 1
|
||||
constexpr auto K0 = KPerTile / K1; // Iterations to cover N dimension 4/1 = 4
|
||||
constexpr auto N1 = BlockGemmShape::kK / KPerQ; // 2
|
||||
constexpr auto N0 = 1; //NPerQ/WarpGemm::kN; // 2
|
||||
constexpr auto N0 = 1; // NPerQ/WarpGemm::kN; // 1
|
||||
constexpr auto N2 = 1;
|
||||
constexpr auto NR1 = NPerQ; // 32
|
||||
constexpr auto NR1 = NPerQ; // 32
|
||||
constexpr auto NR0 = warp_size / (N0 * N1 * N2 * NR1); // 64/(1*2*1*32)=1
|
||||
|
||||
if(get_block_id() == 0 && get_thread_id() == 0)
|
||||
{
|
||||
// Debug print to verify values
|
||||
printf("PreshuffleQuant Medium-grained: MWarps: %d, K1=%d, KR=%d, get_warp_size(): %d, K0=%d, N0=%d\n",
|
||||
MWarps,
|
||||
K1,
|
||||
KR,
|
||||
get_warp_size(),
|
||||
K0,
|
||||
N0);
|
||||
}
|
||||
|
||||
// if(get_block_id() == 0 && get_thread_id() == 0)
|
||||
// {
|
||||
// // Debug print to verify values
|
||||
// printf("PreshuffleQuant Medium-grained: MWarps: %d, K1=%d, KR=%d,
|
||||
// get_warp_size(): %d, K0=%d, N0=%d\n",
|
||||
// MWarps,
|
||||
// K1,
|
||||
// KR,
|
||||
// get_warp_size(),
|
||||
// K0,
|
||||
// N0);
|
||||
// }
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<MWarps, NR0, NR1, KR>,
|
||||
tuple<sequence<K0, K1>, sequence<N0, N1, N2>>,
|
||||
tuple<sequence<0, 1, 0>, sequence<0, 2, 0, 2>>,
|
||||
tuple<sequence<0, 1, 3>, sequence<1, 0, 2, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 2>>{});
|
||||
|
||||
tuple<sequence<K0, K1>, sequence<N0, N1, N2>>,
|
||||
tuple<sequence<0, 1, 0>, sequence<0, 2, 0, 2>>,
|
||||
tuple<sequence<0, 1, 3>, sequence<1, 0, 2, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 2>>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr auto N1 = BlockGemmShape::kK / KPerQ; // 2
|
||||
constexpr auto N0 = 1; // NPerQ/WarpGemm::kN; // 1
|
||||
constexpr auto N2 = 1;
|
||||
constexpr auto NR1 = 32; // 32
|
||||
constexpr auto NR0 = warp_size / (N0 * N1 * N2 * NR1); // 64/(1*2*1*32)=1
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<MWarps, NWarps, NR0, NR1>,
|
||||
tuple<sequence<KPerTile>, sequence<N0, N1, N2>>,
|
||||
tuple<sequence<0, 0>, sequence<0, 2, 0, 2>>,
|
||||
tuple<sequence<0, 1>, sequence<2, 0, 3, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 2>>{});
|
||||
}
|
||||
}
|
||||
else
|
||||
|
||||
Reference in New Issue
Block a user