debugging permuteN

This commit is contained in:
khuagarw
2025-11-18 21:59:30 +00:00
parent f5856af85e
commit 2275548400
10 changed files with 263 additions and 130 deletions

View File

@@ -11,9 +11,9 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12")
gemm_quant.cpp
gemm_aquant_quantgrouped.cpp
gemm_aquant_quantgrouped_preshufflequant.cpp
gemm_bquant_quantgrouped_bf8i4.cpp
gemm_bquant_quantgrouped_fp8i4.cpp
gemm_bquant_quantgrouped_bf8.cpp
# gemm_bquant_quantgrouped_bf8i4.cpp
# gemm_bquant_quantgrouped_fp8i4.cpp
# gemm_bquant_quantgrouped_bf8.cpp
gemm_bquant_quantgrouped_fp8.cpp
gemm_bquant_quantgrouped_preshuffleb.cpp
gemm_bquant_quantgrouped_preshufflequant.cpp

View File

@@ -9,51 +9,94 @@ using GemmConfig = GemmConfigPreshuffleB_BQuant_Prefill<T>;
void bquant_quantgrouped_preshuffleb_instance_factory(
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
{
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
lut[hash_multiple_strings({"fp8", "bquant", "preshuffleb", "non-preshufflequant", "1x1x128"})] =
[](const ck_tile::ArgParser& arg_parser) {
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
ck_tile::fp8_t,
ck_tile::half_t,
float>{});
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
ck_tile::fp8_t,
ck_tile::half_t,
float>{});
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
TypeConfig,
QuantGroupSize,
ck_tile::QuantType::BQuantGrouped>(arg_parser);
};
lut[hash_multiple_strings({"bf8", "bquant", "preshuffleb", "non-preshufflequant", "1x1x128"})] =
lut[hash_multiple_strings({"fp8", "bquant", "preshuffleb", "non-preshufflequant", "1x8x128"})] =
[](const ck_tile::ArgParser& arg_parser) {
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
ck_tile::bf8_t,
ck_tile::half_t,
float>{});
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
TypeConfig,
QuantGroupSize,
ck_tile::QuantType::BQuantGrouped>(arg_parser);
};
lut[hash_multiple_strings(
{"fp8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x1x128"})] =
[](const ck_tile::ArgParser& arg_parser) {
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
ck_tile::pk_int4_t,
ck_tile::half_t,
ck_tile::fp8_t>{});
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
ck_tile::fp8_t,
ck_tile::half_t,
float>{});
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
TypeConfig,
QuantGroupSize,
ck_tile::QuantType::BQuantGrouped>(arg_parser);
};
lut[hash_multiple_strings(
{"bf8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x1x128"})] =
[](const ck_tile::ArgParser& arg_parser) {
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
ck_tile::pk_int4_t,
ck_tile::half_t,
ck_tile::bf8_t>{});
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
TypeConfig,
QuantGroupSize,
ck_tile::QuantType::BQuantGrouped>(arg_parser);
};
lut[hash_multiple_strings({"fp8",
"bquant",
"preshuffleb",
"non-preshufflequant",
"1x32x128"})] = [](const ck_tile::ArgParser& arg_parser) {
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 32, 128>>;
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
TypeConfig,
QuantGroupSize,
ck_tile::QuantType::BQuantGrouped>(arg_parser);
};
lut[hash_multiple_strings({"fp8",
"bquant",
"preshuffleb",
"non-preshufflequant",
"1x64x128"})] = [](const ck_tile::ArgParser& arg_parser) {
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
TypeConfig,
QuantGroupSize,
ck_tile::QuantType::BQuantGrouped>(arg_parser);
};
// lut[hash_multiple_strings({"bf8", "bquant", "preshuffleb", "non-preshufflequant",
// "1x1x128"})] =
// [](const ck_tile::ArgParser& arg_parser) {
// using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
// ck_tile::bf8_t,
// ck_tile::half_t,
// float>{});
// using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
// return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
// TypeConfig,
// QuantGroupSize,
// ck_tile::QuantType::BQuantGrouped>(arg_parser);
// };
// lut[hash_multiple_strings(
// {"fp8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x1x128"})] =
// [](const ck_tile::ArgParser& arg_parser) {
// using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
// ck_tile::pk_int4_t,
// ck_tile::half_t,
// ck_tile::fp8_t>{});
// using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
// return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
// TypeConfig,
// QuantGroupSize,
// ck_tile::QuantType::BQuantGrouped>(arg_parser);
// };
// lut[hash_multiple_strings(
// {"bf8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x1x128"})] =
// [](const ck_tile::ArgParser& arg_parser) {
// using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
// ck_tile::pk_int4_t,
// ck_tile::half_t,
// ck_tile::bf8_t>{});
// using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
// return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
// TypeConfig,
// QuantGroupSize,
// ck_tile::QuantType::BQuantGrouped>(arg_parser);
// };
}

View File

@@ -21,37 +21,39 @@ void bquant_quantgrouped_preshuffleb_preshufflequant_instance_factory(
QuantGroupSize,
ck_tile::QuantType::BQuantGrouped>(arg_parser);
};
lut[hash_multiple_strings({"bf8", "bquant", "preshuffleb", "preshufflequant", "1x1x128"})] =
[](const ck_tile::ArgParser& arg_parser) {
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
ck_tile::bf8_t,
ck_tile::half_t,
float>{});
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
TypeConfig,
QuantGroupSize,
ck_tile::QuantType::BQuantGrouped>(arg_parser);
};
lut[hash_multiple_strings({"fp8i4", "bquant", "preshuffleb", "preshufflequant", "1x1x128"})] =
[](const ck_tile::ArgParser& arg_parser) {
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
ck_tile::pk_int4_t,
ck_tile::half_t,
ck_tile::fp8_t>{});
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
TypeConfig,
QuantGroupSize,
ck_tile::QuantType::BQuantGrouped>(arg_parser);
};
lut[hash_multiple_strings({"bf8i4", "bquant", "preshuffleb", "preshufflequant", "1x1x128"})] =
[](const ck_tile::ArgParser& arg_parser) {
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
ck_tile::pk_int4_t,
ck_tile::half_t,
ck_tile::bf8_t>{});
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
TypeConfig,
QuantGroupSize,
ck_tile::QuantType::BQuantGrouped>(arg_parser);
};
// lut[hash_multiple_strings({"bf8", "bquant", "preshuffleb", "preshufflequant", "1x1x128"})] =
// [](const ck_tile::ArgParser& arg_parser) {
// using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
// ck_tile::bf8_t,
// ck_tile::half_t,
// float>{});
// return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
// TypeConfig,
// QuantGroupSize,
// ck_tile::QuantType::BQuantGrouped>(arg_parser);
// };
// lut[hash_multiple_strings({"fp8i4", "bquant", "preshuffleb", "preshufflequant", "1x1x128"})]
// =
// [](const ck_tile::ArgParser& arg_parser) {
// using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
// ck_tile::pk_int4_t,
// ck_tile::half_t,
// ck_tile::fp8_t>{});
// return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
// TypeConfig,
// QuantGroupSize,
// ck_tile::QuantType::BQuantGrouped>(arg_parser);
// };
// lut[hash_multiple_strings({"bf8i4", "bquant", "preshuffleb", "preshufflequant", "1x1x128"})]
// =
// [](const ck_tile::ArgParser& arg_parser) {
// using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
// ck_tile::pk_int4_t,
// ck_tile::half_t,
// ck_tile::bf8_t>{});
// return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
// TypeConfig,
// QuantGroupSize,
// ck_tile::QuantType::BQuantGrouped>(arg_parser);
// };
}

View File

@@ -21,39 +21,40 @@ void bquant_quantgrouped_preshufflequant_instance_factory(
QuantGroupSize,
ck_tile::QuantType::BQuantGrouped>(arg_parser);
};
lut[hash_multiple_strings({"bf8", "bquant", "non-preshuffleb", "preshufflequant", "1x1x128"})] =
[](const ck_tile::ArgParser& arg_parser) {
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
ck_tile::bf8_t,
ck_tile::half_t,
float>{});
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
TypeConfig,
QuantGroupSize,
ck_tile::QuantType::BQuantGrouped>(arg_parser);
};
lut[hash_multiple_strings(
{"fp8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x1x128"})] =
[](const ck_tile::ArgParser& arg_parser) {
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
ck_tile::pk_int4_t,
ck_tile::half_t,
ck_tile::fp8_t>{});
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
TypeConfig,
QuantGroupSize,
ck_tile::QuantType::BQuantGrouped>(arg_parser);
};
lut[hash_multiple_strings(
{"bf8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x1x128"})] =
[](const ck_tile::ArgParser& arg_parser) {
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
ck_tile::pk_int4_t,
ck_tile::half_t,
ck_tile::bf8_t>{});
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
TypeConfig,
QuantGroupSize,
ck_tile::QuantType::BQuantGrouped>(arg_parser);
};
// lut[hash_multiple_strings({"bf8", "bquant", "non-preshuffleb", "preshufflequant",
// "1x1x128"})] =
// [](const ck_tile::ArgParser& arg_parser) {
// using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
// ck_tile::bf8_t,
// ck_tile::half_t,
// float>{});
// return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
// TypeConfig,
// QuantGroupSize,
// ck_tile::QuantType::BQuantGrouped>(arg_parser);
// };
// lut[hash_multiple_strings(
// {"fp8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x1x128"})] =
// [](const ck_tile::ArgParser& arg_parser) {
// using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
// ck_tile::pk_int4_t,
// ck_tile::half_t,
// ck_tile::fp8_t>{});
// return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
// TypeConfig,
// QuantGroupSize,
// ck_tile::QuantType::BQuantGrouped>(arg_parser);
// };
// lut[hash_multiple_strings(
// {"bf8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x1x128"})] =
// [](const ck_tile::ArgParser& arg_parser) {
// using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
// ck_tile::pk_int4_t,
// ck_tile::half_t,
// ck_tile::bf8_t>{});
// return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
// TypeConfig,
// QuantGroupSize,
// ck_tile::QuantType::BQuantGrouped>(arg_parser);
// };
}

View File

@@ -17,9 +17,9 @@ auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("h", "false", "Print help message")
.insert("m", "3840", "m dimension")
.insert("n", "4096", "n dimension")
.insert("k", "2048", "k dimension")
.insert("m", "128", "m dimension")
.insert("n", "128", "n dimension")
.insert("k", "128", "k dimension")
.insert("a_layout", "R", "A tensor data layout - Row or Column")
.insert("b_layout", "C", "B tensor data layout - Row or Column")
.insert("bq_layout", "C", "Bq tensor data layout - Row or Column")
@@ -33,14 +33,14 @@ auto create_args(int argc, char* argv[])
"fp8",
"Data type. For AQuant: fp8, bf8, i4fp8, or i4bf8; for Bquant: fp8, bf8, fp8i4, "
"or bf8i4")
.insert("warmup", "50", "Number of iterations before benchmarking the kernel")
.insert("repeat", "1000", "Number of iterations to benchmark the kernel")
.insert("warmup", "1", "Number of iterations before benchmarking the kernel")
.insert("repeat", "0", "Number of iterations to benchmark the kernel")
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
.insert("split_k", "1", "SplitK value")
.insert("device", "0", "Device id that will be used to run the kernel")
.insert("init", "0", "0:random, 1:linear, 2:constant(1)")
.insert("flush_cache", "true", "Flush cache before running the kernel")
.insert("rotating_count", "1000", "Rotating count")
.insert("rotating_count", "0", "Rotating count")
.insert("quant_mode", "bquant", "Choose aquant, bquant, tensor or rowcol")
.insert("preshuffleb", "false", "Enable preshuffle of tensor B")
.insert("preshufflequant", "false", "Enable preshuffle of quant tensor")
@@ -91,12 +91,12 @@ void aquant_quantgrouped_preshufflequant_instance_factory(
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
void bquant_quantgrouped_fp8_instance_factory(
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
void bquant_quantgrouped_bf8_instance_factory(
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
void bquant_quantgrouped_fp8i4_instance_factory(
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
void bquant_quantgrouped_bf8i4_instance_factory(
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
// void bquant_quantgrouped_bf8_instance_factory(
// std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
// void bquant_quantgrouped_fp8i4_instance_factory(
// std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
// void bquant_quantgrouped_bf8i4_instance_factory(
// std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
void bquant_quantgrouped_preshuffleb_instance_factory(
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
void bquant_quantgrouped_preshufflequant_instance_factory(
@@ -125,9 +125,9 @@ int main(int argc, char* argv[])
aquant_quantgrouped_instance_factory(lut);
aquant_quantgrouped_preshufflequant_instance_factory(lut);
bquant_quantgrouped_fp8_instance_factory(lut);
bquant_quantgrouped_bf8_instance_factory(lut);
bquant_quantgrouped_fp8i4_instance_factory(lut);
bquant_quantgrouped_bf8i4_instance_factory(lut);
// bquant_quantgrouped_bf8_instance_factory(lut);
// bquant_quantgrouped_fp8i4_instance_factory(lut);
// bquant_quantgrouped_bf8i4_instance_factory(lut);
bquant_quantgrouped_preshuffleb_instance_factory(lut);
bquant_quantgrouped_preshufflequant_instance_factory(lut);
bquant_quantgrouped_preshuffleb_preshufflequant_instance_factory(lut);

View File

@@ -210,7 +210,7 @@ struct GemmConfigPreshuffleB_BQuant_Prefill : public GemmConfigBase
static constexpr bool DoubleSmemBuffer = true;
static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp;
static constexpr bool TiledMMAPermuteN = N_Repeat % 2 == 0;
static constexpr bool TiledMMAPermuteN = false; // N_Repeat % 2 == 0;
};
template <typename PrecType>

View File

@@ -481,7 +481,8 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
}
else
{
ck_tile::FillUniformDistribution<BDataType>{-2.0f, 3.0f, fill_seed(gen)}(b_k_n);
ck_tile::FillUniformDistribution<BDataType>{-2.0f,
3.0f /*, fill_seed(gen)*/}(b_k_n);
}
ck_tile::FillUniformDistribution<BQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
*bq_tensor_ptr);
@@ -543,7 +544,6 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
b_k_n.SetZero();
bq_tensor_ptr->SetZero();
}
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes());
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes());
@@ -600,6 +600,14 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
{
printf("PreshuffleB with TiledMMAPermuteN\n");
b_k_n_dev = ck_tile::shuffle_b_permuteN<GemmConfig>(b_k_n);
printf("b_k_n_dev.get_lengths(): %lu, %lu, %lu, %lu, %lu, %lu, %lu\n",
b_k_n_dev.get_lengths()[0],
b_k_n_dev.get_lengths()[1],
b_k_n_dev.get_lengths()[2],
b_k_n_dev.get_lengths()[3],
b_k_n_dev.get_lengths()[4],
b_k_n_dev.get_lengths()[5],
b_k_n_dev.get_lengths()[6]);
}
else
{
@@ -624,8 +632,44 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
if constexpr(GemmConfig::PreshuffleB && GemmConfig::TiledMMAPermuteN)
{
printf("Preshuffle BQ with TiledMMAPermuteN \n");
for(int i = 0; i < static_cast<int>((*bq_tensor_ptr).get_lengths()[0]); i++)
{
for(int j = 0; j < static_cast<int>((*bq_tensor_ptr).get_lengths()[1]); j++)
{
printf("(*bq_tensor_ptr)[%d][%d]: %f\n", i, j, (*bq_tensor_ptr)(i, j));
}
}
ck_tile::HostTensor<BQDataType> bq_permuted_host =
ck_tile::bq_permuteN<GemmConfig>(*bq_tensor_ptr);
printf("bq_permuted_host.get_lengths(): %lu, %lu, %lu, %lu, %lu\n",
bq_permuted_host.get_lengths()[0],
bq_permuted_host.get_lengths()[1],
bq_permuted_host.get_lengths()[2],
bq_permuted_host.get_lengths()[3],
bq_permuted_host.get_lengths()[4]);
for(int i = 0; i < static_cast<int>(bq_permuted_host.get_lengths()[0]); i++)
{
for(int j = 0; j < static_cast<int>(bq_permuted_host.get_lengths()[1]); j++)
{
for(int k = 0; k < static_cast<int>(bq_permuted_host.get_lengths()[2]); k++)
{
for(int l = 0; l < static_cast<int>(bq_permuted_host.get_lengths()[3]); l++)
{
for(int m = 0; m < static_cast<int>(bq_permuted_host.get_lengths()[4]);
m++)
{
printf("bq_permuted_host[%d][%d][%d][%d][%d]: %f\n",
i,
j,
k,
l,
m,
bq_permuted_host(i, j, k, l, m));
}
}
}
}
}
if constexpr(GemmConfig::PreshuffleQuant)
{

View File

@@ -1197,7 +1197,7 @@ struct tile_window_with_static_lengths
using ThreadBuf = thread_buffer<DataType, 2>;
auto buf = tensor_view.template get_vectorized_elements<ThreadBuf>(coord, 0);
auto value = buf.at(number<0>{}); // Extract first element from thread buffer
printf(" %s[%d,%d] = %f", label, i, j, type_convert<float>(value));
printf(" %s[%d,%d] = %f\n", label, i, j, type_convert<float>(value));
}
printf("\n");
}

View File

@@ -28,7 +28,7 @@ struct BlockGemmWeightPreshuffleBQuantARegBRegCReg
using QuantGroupSize = remove_cvref_t<typename Problem::QuantGroupSize>;
static_assert(QuantGroupSize::kM == 1, "only N/K blocks for BQuant preshuffle kernel!");
static_assert(QuantGroupSize::kN == 1, "no block for N supported yet!");
// static_assert(QuantGroupSize::kN == 1, "no block for N supported yet!");
static constexpr auto I0 = number<0>();
static constexpr auto I1 = number<1>();
@@ -204,14 +204,40 @@ struct BlockGemmWeightPreshuffleBQuantARegBRegCReg
}
else
{
constexpr index_t reg_offset = nIter * KPerBlockBQ + kQScale;
index_t reg_offset = [&]() {
if constexpr(QuantGroupSize::kN >= (NWarp * WG::kN))
{
return (nIter * NWarp * WG::kN) / QuantGroupSize::kN * KPerBlockBQ +
kQScale;
}
else
{
return nIter * KPerBlockBQ + kQScale;
}
}();
// constexpr index_t reg_offset = nIter * KPerBlockBQ + kQScale;
auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset];
float scale_reg_f = cvt_scale_to_fp32(scale_reg);
if(get_block_id() == 0 && get_thread_id() == 0)
{
printf("scale_reg_f: %f, reg_offset: %d\n", scale_reg_f, reg_offset);
printf("nIter: %d, NWarp: %d, WG::kN: %d, QuantGroupSize::kN: %d, "
"KPerBlockBQ: %d, kQScale: %d\n",
static_cast<int>(nIter),
NWarp,
WG::kN,
static_cast<int>(QuantGroupSize::kN),
static_cast<int>(KPerBlockBQ),
static_cast<int>(kQScale));
}
static_for<0, WG::kM * WG::kN / warp_size, 1>{}([&](auto c_row) {
auto& c_ref = c_block_tensor.get_thread_buffer()[tbuf_offset + c_row];
const auto acc_val = c_acc(mIter)(nIter).get_thread_buffer()[c_row];
c_ref = c_ref + acc_val * scale_reg_f;
// if(get_block_id() == 0 && get_thread_id() == 0) {
// printf("acc_val: %f, scale_reg_f: %f\n", acc_val, scale_reg_f);
// }
c_ref = c_ref + acc_val * scale_reg_f;
});
}
});

View File

@@ -654,7 +654,10 @@ struct QuantGemmKernel
(splitk_batch_offset.splitted_k /
TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{}));
index_t kFlatN = kargs.N * kargs.K / kFlatK;
if(get_block_id() == 0 && get_thread_id() == 0)
{
printf("kFlatN: %d, kFlatK: %d\n", kFlatN, kFlatK);
}
return make_naive_tensor_view<address_space_enum::global>(
b_ptr,
make_tuple(kFlatN, kFlatK),
@@ -989,10 +992,19 @@ struct QuantGemmKernel
{
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::QuantGroupSize>;
if(get_block_id() == 0 && get_thread_id() == 0)
{
printf("TilePartitioner::KPerBlock %d, QuantGroupSize::kK: %d, "
"TilePartitioner::NPerBlock %d, QuantGroupSize::kN: %d\n",
TilePartitioner::KPerBlock,
QuantGroupSize::kK,
TilePartitioner::NPerBlock,
QuantGroupSize::kN);
}
return make_tile_window(
bq_pad_view,
make_tuple(number<TilePartitioner::KPerBlock / QuantGroupSize::kK>{},
number<TilePartitioner::NPerBlock / QuantGroupSize::kN>{}),
make_tuple(number<TilePartitioner::KPerBlock / QuantGroupSize::kK>{}, // 1
number<TilePartitioner::NPerBlock / QuantGroupSize::kN>{}), // 16
{0, i_n / QuantGroupSize::kN});
}
}
@@ -1152,6 +1164,11 @@ struct QuantGemmKernel
if constexpr(kQuantType == QuantType::BQuantGrouped)
{
const auto& bq_block_window = gemm_tile_windows.at(I3);
if(get_block_id() == 0 && get_thread_id() == 0)
{
bq_block_window.template print_tile_window_range<BQDataType>(
0, 1, 0, 32, "bq block window");
}
return GemmPipeline{}.template operator()(a_block_window,
b_block_window,
bq_block_window,