initial commit

This commit is contained in:
khuagarw
2025-11-25 23:38:41 +00:00
parent 04aaf97192
commit 7788979fbe
12 changed files with 279 additions and 319 deletions

View File

@@ -11,9 +11,9 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12")
gemm_quant.cpp
gemm_aquant_quantgrouped.cpp
gemm_aquant_quantgrouped_preshufflequant.cpp
# gemm_bquant_quantgrouped_bf8i4.cpp
# gemm_bquant_quantgrouped_fp8i4.cpp
# gemm_bquant_quantgrouped_bf8.cpp
gemm_bquant_quantgrouped_bf8i4.cpp
gemm_bquant_quantgrouped_fp8i4.cpp
gemm_bquant_quantgrouped_bf8.cpp
gemm_bquant_quantgrouped_fp8.cpp
gemm_bquant_quantgrouped_preshuffleb.cpp
gemm_bquant_quantgrouped_preshufflequant.cpp

View File

@@ -33,19 +33,6 @@ void bquant_quantgrouped_preshuffleb_instance_factory(
QuantGroupSize,
ck_tile::QuantType::BQuantGrouped>(arg_parser);
};
lut[hash_multiple_strings({"fp8",
"bquant",
"preshuffleb",
"non-preshufflequant",
"1x16x128"})] = [](const ck_tile::ArgParser& arg_parser) {
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 16, 128>>;
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
TypeConfig,
QuantGroupSize,
ck_tile::QuantType::BQuantGrouped>(arg_parser);
};
lut[hash_multiple_strings({"fp8",
"bquant",
"preshuffleb",
@@ -73,43 +60,158 @@ void bquant_quantgrouped_preshuffleb_instance_factory(
ck_tile::QuantType::BQuantGrouped>(arg_parser);
};
// lut[hash_multiple_strings({"bf8", "bquant", "preshuffleb", "non-preshufflequant",
// "1x1x128"})] =
// [](const ck_tile::ArgParser& arg_parser) {
// using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
// ck_tile::bf8_t,
// ck_tile::half_t,
// float>{});
// using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
// return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
// TypeConfig,
// QuantGroupSize,
// ck_tile::QuantType::BQuantGrouped>(arg_parser);
// };
// lut[hash_multiple_strings(
// {"fp8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x1x128"})] =
// [](const ck_tile::ArgParser& arg_parser) {
// using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
// ck_tile::pk_int4_t,
// ck_tile::half_t,
// ck_tile::fp8_t>{});
// using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
// return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
// TypeConfig,
// QuantGroupSize,
// ck_tile::QuantType::BQuantGrouped>(arg_parser);
// };
// lut[hash_multiple_strings(
// {"bf8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x1x128"})] =
// [](const ck_tile::ArgParser& arg_parser) {
// using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
// ck_tile::pk_int4_t,
// ck_tile::half_t,
// ck_tile::bf8_t>{});
// using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
// return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
// TypeConfig,
// QuantGroupSize,
// ck_tile::QuantType::BQuantGrouped>(arg_parser);
// };
lut[hash_multiple_strings({"bf8", "bquant", "preshuffleb", "non-preshufflequant", "1x1x128"})] =
[](const ck_tile::ArgParser& arg_parser) {
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
ck_tile::bf8_t,
ck_tile::half_t,
float>{});
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
TypeConfig,
QuantGroupSize,
ck_tile::QuantType::BQuantGrouped>(arg_parser);
};
lut[hash_multiple_strings({"bf8", "bquant", "preshuffleb", "non-preshufflequant", "1x8x128"})] =
[](const ck_tile::ArgParser& arg_parser) {
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
ck_tile::bf8_t,
ck_tile::half_t,
float>{});
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
TypeConfig,
QuantGroupSize,
ck_tile::QuantType::BQuantGrouped>(arg_parser);
};
lut[hash_multiple_strings({"bf8",
"bquant",
"preshuffleb",
"non-preshufflequant",
"1x32x128"})] = [](const ck_tile::ArgParser& arg_parser) {
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 32, 128>>;
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
TypeConfig,
QuantGroupSize,
ck_tile::QuantType::BQuantGrouped>(arg_parser);
};
lut[hash_multiple_strings({"bf8",
"bquant",
"preshuffleb",
"non-preshufflequant",
"1x64x128"})] = [](const ck_tile::ArgParser& arg_parser) {
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
TypeConfig,
QuantGroupSize,
ck_tile::QuantType::BQuantGrouped>(arg_parser);
};
lut[hash_multiple_strings(
{"fp8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x1x128"})] =
[](const ck_tile::ArgParser& arg_parser) {
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
ck_tile::pk_int4_t,
ck_tile::half_t,
ck_tile::fp8_t>{});
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
TypeConfig,
QuantGroupSize,
ck_tile::QuantType::BQuantGrouped>(arg_parser);
};
lut[hash_multiple_strings(
{"fp8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x8x128"})] =
[](const ck_tile::ArgParser& arg_parser) {
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
ck_tile::pk_int4_t,
ck_tile::half_t,
ck_tile::fp8_t>{});
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
TypeConfig,
QuantGroupSize,
ck_tile::QuantType::BQuantGrouped>(arg_parser);
};
lut[hash_multiple_strings(
{"fp8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x32x128"})] =
[](const ck_tile::ArgParser& arg_parser) {
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
ck_tile::pk_int4_t,
ck_tile::half_t,
ck_tile::fp8_t>{});
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 32, 128>>;
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
TypeConfig,
QuantGroupSize,
ck_tile::QuantType::BQuantGrouped>(arg_parser);
};
lut[hash_multiple_strings(
{"fp8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x64x128"})] =
[](const ck_tile::ArgParser& arg_parser) {
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
ck_tile::pk_int4_t,
ck_tile::half_t,
ck_tile::fp8_t>{});
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
TypeConfig,
QuantGroupSize,
ck_tile::QuantType::BQuantGrouped>(arg_parser);
};
lut[hash_multiple_strings(
{"bf8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x1x128"})] =
[](const ck_tile::ArgParser& arg_parser) {
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
ck_tile::pk_int4_t,
ck_tile::half_t,
ck_tile::bf8_t>{});
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
TypeConfig,
QuantGroupSize,
ck_tile::QuantType::BQuantGrouped>(arg_parser);
};
lut[hash_multiple_strings(
{"bf8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x8x128"})] =
[](const ck_tile::ArgParser& arg_parser) {
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
ck_tile::pk_int4_t,
ck_tile::half_t,
ck_tile::bf8_t>{});
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
TypeConfig,
QuantGroupSize,
ck_tile::QuantType::BQuantGrouped>(arg_parser);
};
lut[hash_multiple_strings(
{"bf8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x32x128"})] =
[](const ck_tile::ArgParser& arg_parser) {
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
ck_tile::pk_int4_t,
ck_tile::half_t,
ck_tile::bf8_t>{});
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 32, 128>>;
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
TypeConfig,
QuantGroupSize,
ck_tile::QuantType::BQuantGrouped>(arg_parser);
};
lut[hash_multiple_strings(
{"bf8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x64x128"})] =
[](const ck_tile::ArgParser& arg_parser) {
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
ck_tile::pk_int4_t,
ck_tile::half_t,
ck_tile::bf8_t>{});
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
TypeConfig,
QuantGroupSize,
ck_tile::QuantType::BQuantGrouped>(arg_parser);
};
}

View File

@@ -21,39 +21,37 @@ void bquant_quantgrouped_preshuffleb_preshufflequant_instance_factory(
QuantGroupSize,
ck_tile::QuantType::BQuantGrouped>(arg_parser);
};
// lut[hash_multiple_strings({"bf8", "bquant", "preshuffleb", "preshufflequant", "1x1x128"})] =
// [](const ck_tile::ArgParser& arg_parser) {
// using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
// ck_tile::bf8_t,
// ck_tile::half_t,
// float>{});
// return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
// TypeConfig,
// QuantGroupSize,
// ck_tile::QuantType::BQuantGrouped>(arg_parser);
// };
// lut[hash_multiple_strings({"fp8i4", "bquant", "preshuffleb", "preshufflequant", "1x1x128"})]
// =
// [](const ck_tile::ArgParser& arg_parser) {
// using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
// ck_tile::pk_int4_t,
// ck_tile::half_t,
// ck_tile::fp8_t>{});
// return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
// TypeConfig,
// QuantGroupSize,
// ck_tile::QuantType::BQuantGrouped>(arg_parser);
// };
// lut[hash_multiple_strings({"bf8i4", "bquant", "preshuffleb", "preshufflequant", "1x1x128"})]
// =
// [](const ck_tile::ArgParser& arg_parser) {
// using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
// ck_tile::pk_int4_t,
// ck_tile::half_t,
// ck_tile::bf8_t>{});
// return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
// TypeConfig,
// QuantGroupSize,
// ck_tile::QuantType::BQuantGrouped>(arg_parser);
// };
lut[hash_multiple_strings({"bf8", "bquant", "preshuffleb", "preshufflequant", "1x1x128"})] =
[](const ck_tile::ArgParser& arg_parser) {
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
ck_tile::bf8_t,
ck_tile::half_t,
float>{});
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
TypeConfig,
QuantGroupSize,
ck_tile::QuantType::BQuantGrouped>(arg_parser);
};
lut[hash_multiple_strings({"fp8i4", "bquant", "preshuffleb", "preshufflequant", "1x1x128"})] =
[](const ck_tile::ArgParser& arg_parser) {
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
ck_tile::pk_int4_t,
ck_tile::half_t,
ck_tile::fp8_t>{});
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
TypeConfig,
QuantGroupSize,
ck_tile::QuantType::BQuantGrouped>(arg_parser);
};
lut[hash_multiple_strings({"bf8i4", "bquant", "preshuffleb", "preshufflequant", "1x1x128"})] =
[](const ck_tile::ArgParser& arg_parser) {
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
ck_tile::pk_int4_t,
ck_tile::half_t,
ck_tile::bf8_t>{});
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
TypeConfig,
QuantGroupSize,
ck_tile::QuantType::BQuantGrouped>(arg_parser);
};
}

View File

@@ -21,40 +21,39 @@ void bquant_quantgrouped_preshufflequant_instance_factory(
QuantGroupSize,
ck_tile::QuantType::BQuantGrouped>(arg_parser);
};
// lut[hash_multiple_strings({"bf8", "bquant", "non-preshuffleb", "preshufflequant",
// "1x1x128"})] =
// [](const ck_tile::ArgParser& arg_parser) {
// using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
// ck_tile::bf8_t,
// ck_tile::half_t,
// float>{});
// return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
// TypeConfig,
// QuantGroupSize,
// ck_tile::QuantType::BQuantGrouped>(arg_parser);
// };
// lut[hash_multiple_strings(
// {"fp8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x1x128"})] =
// [](const ck_tile::ArgParser& arg_parser) {
// using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
// ck_tile::pk_int4_t,
// ck_tile::half_t,
// ck_tile::fp8_t>{});
// return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
// TypeConfig,
// QuantGroupSize,
// ck_tile::QuantType::BQuantGrouped>(arg_parser);
// };
// lut[hash_multiple_strings(
// {"bf8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x1x128"})] =
// [](const ck_tile::ArgParser& arg_parser) {
// using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
// ck_tile::pk_int4_t,
// ck_tile::half_t,
// ck_tile::bf8_t>{});
// return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
// TypeConfig,
// QuantGroupSize,
// ck_tile::QuantType::BQuantGrouped>(arg_parser);
// };
lut[hash_multiple_strings({"bf8", "bquant", "non-preshuffleb", "preshufflequant", "1x1x128"})] =
[](const ck_tile::ArgParser& arg_parser) {
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
ck_tile::bf8_t,
ck_tile::half_t,
float>{});
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
TypeConfig,
QuantGroupSize,
ck_tile::QuantType::BQuantGrouped>(arg_parser);
};
lut[hash_multiple_strings(
{"fp8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x1x128"})] =
[](const ck_tile::ArgParser& arg_parser) {
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
ck_tile::pk_int4_t,
ck_tile::half_t,
ck_tile::fp8_t>{});
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
TypeConfig,
QuantGroupSize,
ck_tile::QuantType::BQuantGrouped>(arg_parser);
};
lut[hash_multiple_strings(
{"bf8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x1x128"})] =
[](const ck_tile::ArgParser& arg_parser) {
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
ck_tile::pk_int4_t,
ck_tile::half_t,
ck_tile::bf8_t>{});
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
TypeConfig,
QuantGroupSize,
ck_tile::QuantType::BQuantGrouped>(arg_parser);
};
}

View File

@@ -17,9 +17,9 @@ auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("h", "false", "Print help message")
.insert("m", "128", "m dimension")
.insert("n", "128", "n dimension")
.insert("k", "128", "k dimension")
.insert("m", "3840", "m dimension")
.insert("n", "4096", "n dimension")
.insert("k", "2048", "k dimension")
.insert("a_layout", "R", "A tensor data layout - Row or Column")
.insert("b_layout", "C", "B tensor data layout - Row or Column")
.insert("bq_layout", "C", "Bq tensor data layout - Row or Column")
@@ -33,14 +33,14 @@ auto create_args(int argc, char* argv[])
"fp8",
"Data type. For AQuant: fp8, bf8, i4fp8, or i4bf8; for Bquant: fp8, bf8, fp8i4, "
"or bf8i4")
.insert("warmup", "1", "Number of iterations before benchmarking the kernel")
.insert("repeat", "0", "Number of iterations to benchmark the kernel")
.insert("warmup", "50", "Number of iterations before benchmarking the kernel")
.insert("repeat", "1000", "Number of iterations to benchmark the kernel")
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
.insert("split_k", "1", "SplitK value")
.insert("device", "0", "Device id that will be used to run the kernel")
.insert("init", "0", "0:random, 1:linear, 2:constant(1)")
.insert("flush_cache", "true", "Flush cache before running the kernel")
.insert("rotating_count", "0", "Rotating count")
.insert("rotating_count", "1000", "Rotating count")
.insert("quant_mode", "bquant", "Choose aquant, bquant, tensor or rowcol")
.insert("preshuffleb", "false", "Enable preshuffle of tensor B")
.insert("preshufflequant", "false", "Enable preshuffle of quant tensor")
@@ -91,12 +91,12 @@ void aquant_quantgrouped_preshufflequant_instance_factory(
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
void bquant_quantgrouped_fp8_instance_factory(
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
// void bquant_quantgrouped_bf8_instance_factory(
// std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
// void bquant_quantgrouped_fp8i4_instance_factory(
// std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
// void bquant_quantgrouped_bf8i4_instance_factory(
// std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
void bquant_quantgrouped_bf8_instance_factory(
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
void bquant_quantgrouped_fp8i4_instance_factory(
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
void bquant_quantgrouped_bf8i4_instance_factory(
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
void bquant_quantgrouped_preshuffleb_instance_factory(
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
void bquant_quantgrouped_preshufflequant_instance_factory(
@@ -125,9 +125,9 @@ int main(int argc, char* argv[])
aquant_quantgrouped_instance_factory(lut);
aquant_quantgrouped_preshufflequant_instance_factory(lut);
bquant_quantgrouped_fp8_instance_factory(lut);
// bquant_quantgrouped_bf8_instance_factory(lut);
// bquant_quantgrouped_fp8i4_instance_factory(lut);
// bquant_quantgrouped_bf8i4_instance_factory(lut);
bquant_quantgrouped_bf8_instance_factory(lut);
bquant_quantgrouped_fp8i4_instance_factory(lut);
bquant_quantgrouped_bf8i4_instance_factory(lut);
bquant_quantgrouped_preshuffleb_instance_factory(lut);
bquant_quantgrouped_preshufflequant_instance_factory(lut);
bquant_quantgrouped_preshuffleb_preshufflequant_instance_factory(lut);

View File

@@ -457,9 +457,8 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped ||
QuantMode == ck_tile::QuantType::RowColQuant)
{
bq_tensor_ptr =
std::make_unique<ck_tile::HostTensor<BQDataType>>(ck_tile::host_tensor_descriptor(
BQK, N / QuantGroupSize::kN, stride_BQ, is_row_major(bq_layout))); // 1x8
bq_tensor_ptr = std::make_unique<ck_tile::HostTensor<BQDataType>>(
ck_tile::host_tensor_descriptor(BQK, N, stride_BQ, is_row_major(bq_layout)));
}
else if constexpr(QuantMode == ck_tile::QuantType::TensorQuant)
{
@@ -482,12 +481,11 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
}
else
{
ck_tile::FillUniformDistribution<BDataType>{-2.0f,
3.0f /*, fill_seed(gen)*/}(b_k_n);
ck_tile::FillUniformDistribution<BDataType>{-2.0f, 3.0f, fill_seed(gen)}(b_k_n);
}
ck_tile::FillUniformDistribution<BQDataType>{-2.0f,
2.0f /*, fill_seed(gen)*/}(*bq_tensor_ptr);
ck_tile::FillUniformDistribution<ADataType>{-5.0f, 5.0f /*, fill_seed(gen)*/}(a_m_k);
ck_tile::FillUniformDistribution<BQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
*bq_tensor_ptr);
ck_tile::FillUniformDistribution<ADataType>{-5.0f, 5.0f, fill_seed(gen)}(a_m_k);
}
else if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped)
{
@@ -524,30 +522,7 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
{
ck_tile::FillConstant<ADataType>{static_cast<ADataType>(0x38)}(a_m_k);
ck_tile::FillConstant<BDataType>{static_cast<BDataType>(0x38)}(b_k_n);
// ck_tile::FillConstant<BQDataType>{static_cast<BQDataType>(0.5f)}(*bq_tensor_ptr);
if(bq_tensor_ptr)
{
BQDataType value = 1.0f;
for(int i = 0; i < BQK; i++)
{
for(int j = 0; j < N / QuantGroupSize::kN; j += (16 / QuantGroupSize::kN))
{
for(int k = 0; k < 16 / QuantGroupSize::kN; k++)
{
(*bq_tensor_ptr)(i, j + k) = value;
}
value += static_cast<BQDataType>(1.0f);
}
}
}
// for(int i = 0; i < BQK; i++)
// {
// for(int j = 0; j < N / QuantGroupSize::kN; j++)
// {
// printf("%.2f ", (*bq_tensor_ptr)(i, j));
// }
// printf("\n");
// }
ck_tile::FillConstant<BQDataType>{static_cast<BQDataType>(0.5f)}(*bq_tensor_ptr);
}
else
{
@@ -620,18 +595,11 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
ck_tile::HostTensor<BDataType> b_k_n_dev = b_k_n;
if constexpr(GemmConfig::PreshuffleB)
{
if constexpr(GemmConfig::TiledMMAPermuteN)
if constexpr(GemmConfig::TiledMMAPermuteN &&
QuantGroupSize::kN == 1) // temporarily only for non-grouped quant
{
printf("PreshuffleB with TiledMMAPermuteN\n");
b_k_n_dev = ck_tile::shuffle_b_permuteN<GemmConfig>(b_k_n);
printf("b_k_n_dev.get_lengths(): %lu, %lu, %lu, %lu, %lu, %lu, %lu\n",
b_k_n_dev.get_lengths()[0],
b_k_n_dev.get_lengths()[1],
b_k_n_dev.get_lengths()[2],
b_k_n_dev.get_lengths()[3],
b_k_n_dev.get_lengths()[4],
b_k_n_dev.get_lengths()[5],
b_k_n_dev.get_lengths()[6]);
}
else
{
@@ -653,47 +621,12 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
QuantMode == ck_tile::QuantType::RowColQuant ||
QuantMode == ck_tile::QuantType::TensorQuant)
{
if constexpr(GemmConfig::PreshuffleB && GemmConfig::TiledMMAPermuteN)
if constexpr(GemmConfig::PreshuffleB && GemmConfig::TiledMMAPermuteN &&
QuantGroupSize::kN == 1) // temporarily only for non-grouped quant
{
printf("Preshuffle BQ with TiledMMAPermuteN \n");
for(int i = 0; i < static_cast<int>((*bq_tensor_ptr).get_lengths()[0]); i++)
{
for(int j = 0; j < static_cast<int>((*bq_tensor_ptr).get_lengths()[1]); j++)
{
printf("(*bq_tensor_ptr)[%d][%d]: %f\n", i, j, (*bq_tensor_ptr)(i, j));
}
}
printf("PreshuffleBQuant with TiledMMAPermuteN\n");
ck_tile::HostTensor<BQDataType> bq_permuted_host =
ck_tile::bq_permuteN<GemmConfig>(*bq_tensor_ptr, QuantGroupSize::kN);
printf("bq_permuted_host.get_lengths(): %lu, %lu, %lu, %lu, %lu\n",
bq_permuted_host.get_lengths()[0],
bq_permuted_host.get_lengths()[1],
bq_permuted_host.get_lengths()[2],
bq_permuted_host.get_lengths()[3],
bq_permuted_host.get_lengths()[4]);
for(int i = 0; i < static_cast<int>(bq_permuted_host.get_lengths()[0]); i++)
{
for(int j = 0; j < static_cast<int>(bq_permuted_host.get_lengths()[1]); j++)
{
for(int k = 0; k < static_cast<int>(bq_permuted_host.get_lengths()[2]); k++)
{
for(int l = 0; l < static_cast<int>(bq_permuted_host.get_lengths()[3]); l++)
{
for(int m = 0; m < static_cast<int>(bq_permuted_host.get_lengths()[4]);
m++)
{
printf("bq_permuted_host[%d][%d][%d][%d][%d]: %f\n",
i,
j,
k,
l,
m,
bq_permuted_host(i, j, k, l, m));
}
}
}
}
}
if constexpr(GemmConfig::PreshuffleQuant)
{
@@ -708,12 +641,16 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
}
else if constexpr(GemmConfig::PreshuffleQuant)
{
printf("PreshuffleBQuant without TiledMMAPermuteN\n");
ck_tile::HostTensor<BQDataType> bq_shuffle_host =
ck_tile::shuffle_bq(bq_tensor_ptr.get(), GemmConfig::K_Tile / QuantGroupSize::kK);
bq_dev_buf_ptr->ToDevice(bq_shuffle_host.data());
}
else
{
printf("No PreshuffleBQuant\n");
bq_dev_buf_ptr->ToDevice(bq_tensor_ptr->data());
}
}
invoke_gemm<GemmConfig,

View File

@@ -1251,7 +1251,7 @@ struct tile_window_with_static_lengths
using ThreadBuf = thread_buffer<DataType, 2>;
auto buf = tensor_view.template get_vectorized_elements<ThreadBuf>(coord, 0);
auto value = buf.at(number<0>{}); // Extract first element from thread buffer
printf(" %s[%d,%d] = %f\n", label, i, j, type_convert<float>(value));
printf(" %s[%d,%d] = %f", label, i, j, type_convert<float>(value));
}
printf("\n");
}

View File

@@ -111,49 +111,17 @@ auto bq_permuteN(const ck_tile::HostTensor<T>& t, index_t group_n)
{
assert(t.get_lengths().size() == 2);
int n_ = t.get_lengths()[1]; // 128
int bqk_ = t.get_lengths()[0]; // 1 x 128
constexpr int NRepeat =
GemmConfig::N_Tile / GemmConfig::N_Warp_Tile / GemmConfig::N_Warp; // 128/16/4 = 2
int n_ = t.get_lengths()[1];
int bqk_ = t.get_lengths()[0];
constexpr int NRepeat = GemmConfig::N_Tile / GemmConfig::N_Warp_Tile / GemmConfig::N_Warp;
ck_tile::HostTensor<T> t_view({n_ / (GemmConfig::N_Tile / group_n),
GemmConfig::N_Warp,
GemmConfig::N_Warp_Tile / group_n,
NRepeat,
bqk_}); //{1, 4, 16, 2, 1}, group_n:16 {1, 4, 1, 2, 1}
bqk_});
std::copy(t.begin(), t.end(), t_view.begin());
printf("I am inside bq_permuteN\n");
printf("t.get_lengths(): %lu, %lu, %lu, %lu, %lu\n",
t_view.get_lengths()[0],
t_view.get_lengths()[1],
t_view.get_lengths()[2],
t_view.get_lengths()[3],
t_view.get_lengths()[4]);
for(int i = 0; i < static_cast<int>(t.get_lengths()[0]); i++)
{
for(int j = 0; j < static_cast<int>(t_view.get_lengths()[1]); j++)
{
for(int k = 0; k < static_cast<int>(t_view.get_lengths()[2]); k++)
{
for(int l = 0; l < static_cast<int>(t_view.get_lengths()[3]); l++)
{
for(int m = 0; m < static_cast<int>(t_view.get_lengths()[4]); m++)
{
printf("t_view[%d][%d][%d][%d][%d]: %f\n",
i,
j,
k,
l,
m,
t_view(i, j, k, l, m));
}
}
}
}
}
printf("I am inside bq_permuteN\n");
return ck_tile::reference_permute(
t_view, {0, 3, 1, 2, 4}); // {1, 2, 4, 16, 1}, group_n 16 {1, 2, 4, 1, 1}
return ck_tile::reference_permute(t_view, {0, 3, 1, 2, 4});
}
template <typename GemmConfig, typename T>

View File

@@ -28,7 +28,6 @@ struct BlockGemmWeightPreshuffleBQuantARegBRegCReg
using QuantGroupSize = remove_cvref_t<typename Problem::QuantGroupSize>;
static_assert(QuantGroupSize::kM == 1, "only N/K blocks for BQuant preshuffle kernel!");
// static_assert(QuantGroupSize::kN == 1, "no block for N supported yet!");
static constexpr auto I0 = number<0>();
static constexpr auto I1 = number<1>();
@@ -215,35 +214,13 @@ struct BlockGemmWeightPreshuffleBQuantARegBRegCReg
return nIter * KPerBlockBQ + kQScale;
}
}();
// constexpr index_t reg_offset = nIter * KPerBlockBQ + kQScale;
auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset];
float scale_reg_f = cvt_scale_to_fp32(scale_reg);
// if(get_block_id() == 0 && get_thread_id() == 1)
//{
// printf("get_block_id(): %d, get_warp_id(): %d, get_thread_id(): %d,
// nIter: "
// "%d, NWarp: %d, WG::kN: %d, QuantGroupSize::kN: %d, "
// "KPerBlockBQ: %d, kQScale: %d, scale_reg_f: %f, reg_offset: %d\n",
// get_block_id(),
// get_warp_id(),
// get_thread_id(),
// static_cast<int>(nIter),
// NWarp,
// WG::kN,
// static_cast<int>(QuantGroupSize::kN),
// static_cast<int>(KPerBlockBQ),
// static_cast<int>(kQScale),
// scale_reg_f,
// reg_offset);
//}
static_for<0, WG::kM * WG::kN / warp_size, 1>{}([&](auto c_row) {
auto& c_ref = c_block_tensor.get_thread_buffer()[tbuf_offset + c_row];
const auto acc_val = c_acc(mIter)(nIter).get_thread_buffer()[c_row];
// if(get_block_id() == 0 && get_thread_id() == 0) {
// printf("acc_val: %f, scale_reg_f: %f\n", acc_val, scale_reg_f);
// }
c_ref = c_ref + acc_val * scale_reg_f;
c_ref = c_ref + acc_val * scale_reg_f;
});
}
});

View File

@@ -653,10 +653,6 @@ struct QuantGemmKernel
(splitk_batch_offset.splitted_k /
GemmPipeline::BlockGemmShape::WarpTile::at(number<2>{}));
index_t kFlatN = kargs.N * kargs.K / kFlatK;
if(get_block_id() == 0 && get_thread_id() == 0)
{
printf("kFlatN: %d, kFlatK: %d\n", kFlatN, kFlatK);
}
return make_naive_tensor_view<address_space_enum::global>(
b_ptr,
make_tuple(kFlatN, kFlatK),
@@ -991,20 +987,10 @@ struct QuantGemmKernel
{
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::QuantGroupSize>;
if(get_block_id() == 0 && get_thread_id() == 0)
{
printf("TilePartitioner::KPerBlock %d, QuantGroupSize::kK: %d, "
"TilePartitioner::NPerBlock %d, QuantGroupSize::kN: %d\n",
TilePartitioner::KPerBlock,
QuantGroupSize::kK,
TilePartitioner::NPerBlock,
QuantGroupSize::kN);
}
return make_tile_window(
bq_pad_view,
make_tuple(number<TilePartitioner::KPerBlock / QuantGroupSize::kK>{}, // 1
number<TilePartitioner::NPerBlock /
QuantGroupSize::kN>{}), // 128/16 = 8
make_tuple(number<TilePartitioner::KPerBlock / QuantGroupSize::kK>{},
number<TilePartitioner::NPerBlock / QuantGroupSize::kN>{}),
{0, i_n / QuantGroupSize::kN});
}
}
@@ -1164,11 +1150,6 @@ struct QuantGemmKernel
if constexpr(kQuantType == QuantType::BQuantGrouped)
{
const auto& bq_block_window = gemm_tile_windows.at(I3);
if(get_block_id() == 0 && get_thread_id() == 0)
{
bq_block_window.template print_tile_window_range<BQDataType>(
0, 1, 0, 128, "bq block window");
}
return GemmPipeline{}.template operator()(a_block_window,
b_block_window,
bq_block_window,

View File

@@ -71,8 +71,8 @@ struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
tile_distribution_encoding_pattern_bq<BlockGemmShape,
WarpGemm,
BlockSize,
KPerBlockBQ, // 128/128 = 1
NPerBlockBQ, // 128/16 = 8
KPerBlockBQ,
NPerBlockBQ,
Problem::QuantGroupSize::kN>;
return TileEncodingPattern::make_2d_static_tile_distribution();

View File

@@ -169,9 +169,9 @@ struct tile_distribution_encoding_pattern_aq_transposed_c
template <typename BlockGemmShape,
typename WarpGemm,
index_t BlockSize,
index_t YPerTile, // 1
index_t XPerTile, // 8
index_t XPerQ, // 16
index_t YPerTile,
index_t XPerTile,
index_t XPerQ,
bool PreshuffleQuant = false>
struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding_pattern
{
@@ -255,18 +255,16 @@ struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding
else if constexpr(XPerQ <= WarpGemm::kN * NWarps)
{
// Case 2: Medium-grained - one quantization scale per warp
constexpr auto XR =
XPerQ / WarpGemm::kN; // Scale replication factor //16/16 = 1
constexpr auto X1 = NWarps / XR; // Warps per unique scale //4/1 = 4
constexpr auto X0 = XPerTile / X1; // Iterations to cover X dimension //8/4 = 2
constexpr auto XR = XPerQ / WarpGemm::kN; // Scale replication factor
constexpr auto X1 = NWarps / XR; // Warps per unique scale
constexpr auto X0 = XPerTile / X1; // Iterations to cover X dimension
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<MWarps, XR, get_warp_size()>, // 1, 1, 64
tuple<sequence<YPerTile>, sequence<X0, X1>>, // 1, (2, 4)
tuple<sequence<0, 2, 0>, sequence<0>>, //(1, 4, 1) (64)
tuple<sequence<0, 1, 1>, sequence<2>>,
sequence<2, 1>, //(2, 1(in Y dimension))
sequence<0, 0>>{});
tile_distribution_encoding<sequence<MWarps, XR, get_warp_size()>,
tuple<sequence<YPerTile>, sequence<X0, X1>>,
tuple<sequence<0, 2, 0>, sequence<0>>,
tuple<sequence<0, 1, 1>, sequence<2>>,
sequence<2, 1>,
sequence<0, 0>>{});
}
else // XPerQ > WarpGemm::kN * NWarps
{