mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 11:47:48 +00:00
initial commit
This commit is contained in:
@@ -11,9 +11,9 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12")
|
||||
gemm_quant.cpp
|
||||
gemm_aquant_quantgrouped.cpp
|
||||
gemm_aquant_quantgrouped_preshufflequant.cpp
|
||||
# gemm_bquant_quantgrouped_bf8i4.cpp
|
||||
# gemm_bquant_quantgrouped_fp8i4.cpp
|
||||
# gemm_bquant_quantgrouped_bf8.cpp
|
||||
gemm_bquant_quantgrouped_bf8i4.cpp
|
||||
gemm_bquant_quantgrouped_fp8i4.cpp
|
||||
gemm_bquant_quantgrouped_bf8.cpp
|
||||
gemm_bquant_quantgrouped_fp8.cpp
|
||||
gemm_bquant_quantgrouped_preshuffleb.cpp
|
||||
gemm_bquant_quantgrouped_preshufflequant.cpp
|
||||
|
||||
@@ -33,19 +33,6 @@ void bquant_quantgrouped_preshuffleb_instance_factory(
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"fp8",
|
||||
"bquant",
|
||||
"preshuffleb",
|
||||
"non-preshufflequant",
|
||||
"1x16x128"})] = [](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 16, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"fp8",
|
||||
"bquant",
|
||||
"preshuffleb",
|
||||
@@ -73,43 +60,158 @@ void bquant_quantgrouped_preshuffleb_instance_factory(
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
|
||||
// lut[hash_multiple_strings({"bf8", "bquant", "preshuffleb", "non-preshufflequant",
|
||||
// "1x1x128"})] =
|
||||
// [](const ck_tile::ArgParser& arg_parser) {
|
||||
// using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
|
||||
// ck_tile::bf8_t,
|
||||
// ck_tile::half_t,
|
||||
// float>{});
|
||||
// using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
// return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
// TypeConfig,
|
||||
// QuantGroupSize,
|
||||
// ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
// };
|
||||
// lut[hash_multiple_strings(
|
||||
// {"fp8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x1x128"})] =
|
||||
// [](const ck_tile::ArgParser& arg_parser) {
|
||||
// using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
|
||||
// ck_tile::pk_int4_t,
|
||||
// ck_tile::half_t,
|
||||
// ck_tile::fp8_t>{});
|
||||
// using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
// return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
// TypeConfig,
|
||||
// QuantGroupSize,
|
||||
// ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
// };
|
||||
// lut[hash_multiple_strings(
|
||||
// {"bf8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x1x128"})] =
|
||||
// [](const ck_tile::ArgParser& arg_parser) {
|
||||
// using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
|
||||
// ck_tile::pk_int4_t,
|
||||
// ck_tile::half_t,
|
||||
// ck_tile::bf8_t>{});
|
||||
// using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
// return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
// TypeConfig,
|
||||
// QuantGroupSize,
|
||||
// ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
// };
|
||||
lut[hash_multiple_strings({"bf8", "bquant", "preshuffleb", "non-preshufflequant", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
|
||||
ck_tile::bf8_t,
|
||||
ck_tile::half_t,
|
||||
float>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"bf8", "bquant", "preshuffleb", "non-preshufflequant", "1x8x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
|
||||
ck_tile::bf8_t,
|
||||
ck_tile::half_t,
|
||||
float>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"bf8",
|
||||
"bquant",
|
||||
"preshuffleb",
|
||||
"non-preshufflequant",
|
||||
"1x32x128"})] = [](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 32, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"bf8",
|
||||
"bquant",
|
||||
"preshuffleb",
|
||||
"non-preshufflequant",
|
||||
"1x64x128"})] = [](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"fp8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::fp8_t>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"fp8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x8x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::fp8_t>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"fp8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x32x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::fp8_t>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 32, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"fp8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x64x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::fp8_t>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"bf8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::bf8_t>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"bf8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x8x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::bf8_t>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"bf8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x32x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::bf8_t>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 32, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"bf8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x64x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::bf8_t>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
}
|
||||
|
||||
@@ -21,39 +21,37 @@ void bquant_quantgrouped_preshuffleb_preshufflequant_instance_factory(
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
// lut[hash_multiple_strings({"bf8", "bquant", "preshuffleb", "preshufflequant", "1x1x128"})] =
|
||||
// [](const ck_tile::ArgParser& arg_parser) {
|
||||
// using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
|
||||
// ck_tile::bf8_t,
|
||||
// ck_tile::half_t,
|
||||
// float>{});
|
||||
// return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
// TypeConfig,
|
||||
// QuantGroupSize,
|
||||
// ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
// };
|
||||
// lut[hash_multiple_strings({"fp8i4", "bquant", "preshuffleb", "preshufflequant", "1x1x128"})]
|
||||
// =
|
||||
// [](const ck_tile::ArgParser& arg_parser) {
|
||||
// using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
|
||||
// ck_tile::pk_int4_t,
|
||||
// ck_tile::half_t,
|
||||
// ck_tile::fp8_t>{});
|
||||
// return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
// TypeConfig,
|
||||
// QuantGroupSize,
|
||||
// ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
// };
|
||||
// lut[hash_multiple_strings({"bf8i4", "bquant", "preshuffleb", "preshufflequant", "1x1x128"})]
|
||||
// =
|
||||
// [](const ck_tile::ArgParser& arg_parser) {
|
||||
// using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
|
||||
// ck_tile::pk_int4_t,
|
||||
// ck_tile::half_t,
|
||||
// ck_tile::bf8_t>{});
|
||||
// return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
// TypeConfig,
|
||||
// QuantGroupSize,
|
||||
// ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
// };
|
||||
lut[hash_multiple_strings({"bf8", "bquant", "preshuffleb", "preshufflequant", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
|
||||
ck_tile::bf8_t,
|
||||
ck_tile::half_t,
|
||||
float>{});
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"fp8i4", "bquant", "preshuffleb", "preshufflequant", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::fp8_t>{});
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"bf8i4", "bquant", "preshuffleb", "preshufflequant", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::bf8_t>{});
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
}
|
||||
|
||||
@@ -21,40 +21,39 @@ void bquant_quantgrouped_preshufflequant_instance_factory(
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
// lut[hash_multiple_strings({"bf8", "bquant", "non-preshuffleb", "preshufflequant",
|
||||
// "1x1x128"})] =
|
||||
// [](const ck_tile::ArgParser& arg_parser) {
|
||||
// using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
|
||||
// ck_tile::bf8_t,
|
||||
// ck_tile::half_t,
|
||||
// float>{});
|
||||
// return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
// TypeConfig,
|
||||
// QuantGroupSize,
|
||||
// ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
// };
|
||||
// lut[hash_multiple_strings(
|
||||
// {"fp8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x1x128"})] =
|
||||
// [](const ck_tile::ArgParser& arg_parser) {
|
||||
// using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
|
||||
// ck_tile::pk_int4_t,
|
||||
// ck_tile::half_t,
|
||||
// ck_tile::fp8_t>{});
|
||||
// return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
// TypeConfig,
|
||||
// QuantGroupSize,
|
||||
// ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
// };
|
||||
// lut[hash_multiple_strings(
|
||||
// {"bf8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x1x128"})] =
|
||||
// [](const ck_tile::ArgParser& arg_parser) {
|
||||
// using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
|
||||
// ck_tile::pk_int4_t,
|
||||
// ck_tile::half_t,
|
||||
// ck_tile::bf8_t>{});
|
||||
// return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
// TypeConfig,
|
||||
// QuantGroupSize,
|
||||
// ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
// };
|
||||
lut[hash_multiple_strings({"bf8", "bquant", "non-preshuffleb", "preshufflequant", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
|
||||
ck_tile::bf8_t,
|
||||
ck_tile::half_t,
|
||||
float>{});
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"fp8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::fp8_t>{});
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"bf8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::bf8_t>{});
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
}
|
||||
|
||||
@@ -17,9 +17,9 @@ auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("h", "false", "Print help message")
|
||||
.insert("m", "128", "m dimension")
|
||||
.insert("n", "128", "n dimension")
|
||||
.insert("k", "128", "k dimension")
|
||||
.insert("m", "3840", "m dimension")
|
||||
.insert("n", "4096", "n dimension")
|
||||
.insert("k", "2048", "k dimension")
|
||||
.insert("a_layout", "R", "A tensor data layout - Row or Column")
|
||||
.insert("b_layout", "C", "B tensor data layout - Row or Column")
|
||||
.insert("bq_layout", "C", "Bq tensor data layout - Row or Column")
|
||||
@@ -33,14 +33,14 @@ auto create_args(int argc, char* argv[])
|
||||
"fp8",
|
||||
"Data type. For AQuant: fp8, bf8, i4fp8, or i4bf8; for Bquant: fp8, bf8, fp8i4, "
|
||||
"or bf8i4")
|
||||
.insert("warmup", "1", "Number of iterations before benchmarking the kernel")
|
||||
.insert("repeat", "0", "Number of iterations to benchmark the kernel")
|
||||
.insert("warmup", "50", "Number of iterations before benchmarking the kernel")
|
||||
.insert("repeat", "1000", "Number of iterations to benchmark the kernel")
|
||||
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
|
||||
.insert("split_k", "1", "SplitK value")
|
||||
.insert("device", "0", "Device id that will be used to run the kernel")
|
||||
.insert("init", "0", "0:random, 1:linear, 2:constant(1)")
|
||||
.insert("flush_cache", "true", "Flush cache before running the kernel")
|
||||
.insert("rotating_count", "0", "Rotating count")
|
||||
.insert("rotating_count", "1000", "Rotating count")
|
||||
.insert("quant_mode", "bquant", "Choose aquant, bquant, tensor or rowcol")
|
||||
.insert("preshuffleb", "false", "Enable preshuffle of tensor B")
|
||||
.insert("preshufflequant", "false", "Enable preshuffle of quant tensor")
|
||||
@@ -91,12 +91,12 @@ void aquant_quantgrouped_preshufflequant_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
|
||||
void bquant_quantgrouped_fp8_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
|
||||
// void bquant_quantgrouped_bf8_instance_factory(
|
||||
// std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
|
||||
// void bquant_quantgrouped_fp8i4_instance_factory(
|
||||
// std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
|
||||
// void bquant_quantgrouped_bf8i4_instance_factory(
|
||||
// std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
|
||||
void bquant_quantgrouped_bf8_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
|
||||
void bquant_quantgrouped_fp8i4_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
|
||||
void bquant_quantgrouped_bf8i4_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
|
||||
void bquant_quantgrouped_preshuffleb_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
|
||||
void bquant_quantgrouped_preshufflequant_instance_factory(
|
||||
@@ -125,9 +125,9 @@ int main(int argc, char* argv[])
|
||||
aquant_quantgrouped_instance_factory(lut);
|
||||
aquant_quantgrouped_preshufflequant_instance_factory(lut);
|
||||
bquant_quantgrouped_fp8_instance_factory(lut);
|
||||
// bquant_quantgrouped_bf8_instance_factory(lut);
|
||||
// bquant_quantgrouped_fp8i4_instance_factory(lut);
|
||||
// bquant_quantgrouped_bf8i4_instance_factory(lut);
|
||||
bquant_quantgrouped_bf8_instance_factory(lut);
|
||||
bquant_quantgrouped_fp8i4_instance_factory(lut);
|
||||
bquant_quantgrouped_bf8i4_instance_factory(lut);
|
||||
bquant_quantgrouped_preshuffleb_instance_factory(lut);
|
||||
bquant_quantgrouped_preshufflequant_instance_factory(lut);
|
||||
bquant_quantgrouped_preshuffleb_preshufflequant_instance_factory(lut);
|
||||
|
||||
@@ -457,9 +457,8 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
|
||||
if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped ||
|
||||
QuantMode == ck_tile::QuantType::RowColQuant)
|
||||
{
|
||||
bq_tensor_ptr =
|
||||
std::make_unique<ck_tile::HostTensor<BQDataType>>(ck_tile::host_tensor_descriptor(
|
||||
BQK, N / QuantGroupSize::kN, stride_BQ, is_row_major(bq_layout))); // 1x8
|
||||
bq_tensor_ptr = std::make_unique<ck_tile::HostTensor<BQDataType>>(
|
||||
ck_tile::host_tensor_descriptor(BQK, N, stride_BQ, is_row_major(bq_layout)));
|
||||
}
|
||||
else if constexpr(QuantMode == ck_tile::QuantType::TensorQuant)
|
||||
{
|
||||
@@ -482,12 +481,11 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::FillUniformDistribution<BDataType>{-2.0f,
|
||||
3.0f /*, fill_seed(gen)*/}(b_k_n);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-2.0f, 3.0f, fill_seed(gen)}(b_k_n);
|
||||
}
|
||||
ck_tile::FillUniformDistribution<BQDataType>{-2.0f,
|
||||
2.0f /*, fill_seed(gen)*/}(*bq_tensor_ptr);
|
||||
ck_tile::FillUniformDistribution<ADataType>{-5.0f, 5.0f /*, fill_seed(gen)*/}(a_m_k);
|
||||
ck_tile::FillUniformDistribution<BQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
|
||||
*bq_tensor_ptr);
|
||||
ck_tile::FillUniformDistribution<ADataType>{-5.0f, 5.0f, fill_seed(gen)}(a_m_k);
|
||||
}
|
||||
else if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped)
|
||||
{
|
||||
@@ -524,30 +522,7 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
|
||||
{
|
||||
ck_tile::FillConstant<ADataType>{static_cast<ADataType>(0x38)}(a_m_k);
|
||||
ck_tile::FillConstant<BDataType>{static_cast<BDataType>(0x38)}(b_k_n);
|
||||
// ck_tile::FillConstant<BQDataType>{static_cast<BQDataType>(0.5f)}(*bq_tensor_ptr);
|
||||
if(bq_tensor_ptr)
|
||||
{
|
||||
BQDataType value = 1.0f;
|
||||
for(int i = 0; i < BQK; i++)
|
||||
{
|
||||
for(int j = 0; j < N / QuantGroupSize::kN; j += (16 / QuantGroupSize::kN))
|
||||
{
|
||||
for(int k = 0; k < 16 / QuantGroupSize::kN; k++)
|
||||
{
|
||||
(*bq_tensor_ptr)(i, j + k) = value;
|
||||
}
|
||||
value += static_cast<BQDataType>(1.0f);
|
||||
}
|
||||
}
|
||||
}
|
||||
// for(int i = 0; i < BQK; i++)
|
||||
// {
|
||||
// for(int j = 0; j < N / QuantGroupSize::kN; j++)
|
||||
// {
|
||||
// printf("%.2f ", (*bq_tensor_ptr)(i, j));
|
||||
// }
|
||||
// printf("\n");
|
||||
// }
|
||||
ck_tile::FillConstant<BQDataType>{static_cast<BQDataType>(0.5f)}(*bq_tensor_ptr);
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -620,18 +595,11 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
|
||||
ck_tile::HostTensor<BDataType> b_k_n_dev = b_k_n;
|
||||
if constexpr(GemmConfig::PreshuffleB)
|
||||
{
|
||||
if constexpr(GemmConfig::TiledMMAPermuteN)
|
||||
if constexpr(GemmConfig::TiledMMAPermuteN &&
|
||||
QuantGroupSize::kN == 1) // temporarily only for non-grouped quant
|
||||
{
|
||||
printf("PreshuffleB with TiledMMAPermuteN\n");
|
||||
b_k_n_dev = ck_tile::shuffle_b_permuteN<GemmConfig>(b_k_n);
|
||||
printf("b_k_n_dev.get_lengths(): %lu, %lu, %lu, %lu, %lu, %lu, %lu\n",
|
||||
b_k_n_dev.get_lengths()[0],
|
||||
b_k_n_dev.get_lengths()[1],
|
||||
b_k_n_dev.get_lengths()[2],
|
||||
b_k_n_dev.get_lengths()[3],
|
||||
b_k_n_dev.get_lengths()[4],
|
||||
b_k_n_dev.get_lengths()[5],
|
||||
b_k_n_dev.get_lengths()[6]);
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -653,47 +621,12 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
|
||||
QuantMode == ck_tile::QuantType::RowColQuant ||
|
||||
QuantMode == ck_tile::QuantType::TensorQuant)
|
||||
{
|
||||
if constexpr(GemmConfig::PreshuffleB && GemmConfig::TiledMMAPermuteN)
|
||||
if constexpr(GemmConfig::PreshuffleB && GemmConfig::TiledMMAPermuteN &&
|
||||
QuantGroupSize::kN == 1) // temporarily only for non-grouped quant
|
||||
{
|
||||
printf("Preshuffle BQ with TiledMMAPermuteN \n");
|
||||
for(int i = 0; i < static_cast<int>((*bq_tensor_ptr).get_lengths()[0]); i++)
|
||||
{
|
||||
for(int j = 0; j < static_cast<int>((*bq_tensor_ptr).get_lengths()[1]); j++)
|
||||
{
|
||||
printf("(*bq_tensor_ptr)[%d][%d]: %f\n", i, j, (*bq_tensor_ptr)(i, j));
|
||||
}
|
||||
}
|
||||
printf("PreshuffleBQuant with TiledMMAPermuteN\n");
|
||||
ck_tile::HostTensor<BQDataType> bq_permuted_host =
|
||||
ck_tile::bq_permuteN<GemmConfig>(*bq_tensor_ptr, QuantGroupSize::kN);
|
||||
printf("bq_permuted_host.get_lengths(): %lu, %lu, %lu, %lu, %lu\n",
|
||||
bq_permuted_host.get_lengths()[0],
|
||||
bq_permuted_host.get_lengths()[1],
|
||||
bq_permuted_host.get_lengths()[2],
|
||||
bq_permuted_host.get_lengths()[3],
|
||||
bq_permuted_host.get_lengths()[4]);
|
||||
for(int i = 0; i < static_cast<int>(bq_permuted_host.get_lengths()[0]); i++)
|
||||
{
|
||||
for(int j = 0; j < static_cast<int>(bq_permuted_host.get_lengths()[1]); j++)
|
||||
{
|
||||
for(int k = 0; k < static_cast<int>(bq_permuted_host.get_lengths()[2]); k++)
|
||||
{
|
||||
for(int l = 0; l < static_cast<int>(bq_permuted_host.get_lengths()[3]); l++)
|
||||
{
|
||||
for(int m = 0; m < static_cast<int>(bq_permuted_host.get_lengths()[4]);
|
||||
m++)
|
||||
{
|
||||
printf("bq_permuted_host[%d][%d][%d][%d][%d]: %f\n",
|
||||
i,
|
||||
j,
|
||||
k,
|
||||
l,
|
||||
m,
|
||||
bq_permuted_host(i, j, k, l, m));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(GemmConfig::PreshuffleQuant)
|
||||
{
|
||||
@@ -708,12 +641,16 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
|
||||
}
|
||||
else if constexpr(GemmConfig::PreshuffleQuant)
|
||||
{
|
||||
printf("PreshuffleBQuant without TiledMMAPermuteN\n");
|
||||
ck_tile::HostTensor<BQDataType> bq_shuffle_host =
|
||||
ck_tile::shuffle_bq(bq_tensor_ptr.get(), GemmConfig::K_Tile / QuantGroupSize::kK);
|
||||
bq_dev_buf_ptr->ToDevice(bq_shuffle_host.data());
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("No PreshuffleBQuant\n");
|
||||
bq_dev_buf_ptr->ToDevice(bq_tensor_ptr->data());
|
||||
}
|
||||
}
|
||||
|
||||
invoke_gemm<GemmConfig,
|
||||
|
||||
@@ -1251,7 +1251,7 @@ struct tile_window_with_static_lengths
|
||||
using ThreadBuf = thread_buffer<DataType, 2>;
|
||||
auto buf = tensor_view.template get_vectorized_elements<ThreadBuf>(coord, 0);
|
||||
auto value = buf.at(number<0>{}); // Extract first element from thread buffer
|
||||
printf(" %s[%d,%d] = %f\n", label, i, j, type_convert<float>(value));
|
||||
printf(" %s[%d,%d] = %f", label, i, j, type_convert<float>(value));
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
|
||||
@@ -111,49 +111,17 @@ auto bq_permuteN(const ck_tile::HostTensor<T>& t, index_t group_n)
|
||||
{
|
||||
assert(t.get_lengths().size() == 2);
|
||||
|
||||
int n_ = t.get_lengths()[1]; // 128
|
||||
int bqk_ = t.get_lengths()[0]; // 1 x 128
|
||||
constexpr int NRepeat =
|
||||
GemmConfig::N_Tile / GemmConfig::N_Warp_Tile / GemmConfig::N_Warp; // 128/16/4 = 2
|
||||
int n_ = t.get_lengths()[1];
|
||||
int bqk_ = t.get_lengths()[0];
|
||||
constexpr int NRepeat = GemmConfig::N_Tile / GemmConfig::N_Warp_Tile / GemmConfig::N_Warp;
|
||||
|
||||
ck_tile::HostTensor<T> t_view({n_ / (GemmConfig::N_Tile / group_n),
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::N_Warp_Tile / group_n,
|
||||
NRepeat,
|
||||
bqk_}); //{1, 4, 16, 2, 1}, group_n:16 {1, 4, 1, 2, 1}
|
||||
bqk_});
|
||||
std::copy(t.begin(), t.end(), t_view.begin());
|
||||
printf("I am inside bq_permuteN\n");
|
||||
printf("t.get_lengths(): %lu, %lu, %lu, %lu, %lu\n",
|
||||
t_view.get_lengths()[0],
|
||||
t_view.get_lengths()[1],
|
||||
t_view.get_lengths()[2],
|
||||
t_view.get_lengths()[3],
|
||||
t_view.get_lengths()[4]);
|
||||
for(int i = 0; i < static_cast<int>(t.get_lengths()[0]); i++)
|
||||
{
|
||||
for(int j = 0; j < static_cast<int>(t_view.get_lengths()[1]); j++)
|
||||
{
|
||||
for(int k = 0; k < static_cast<int>(t_view.get_lengths()[2]); k++)
|
||||
{
|
||||
for(int l = 0; l < static_cast<int>(t_view.get_lengths()[3]); l++)
|
||||
{
|
||||
for(int m = 0; m < static_cast<int>(t_view.get_lengths()[4]); m++)
|
||||
{
|
||||
printf("t_view[%d][%d][%d][%d][%d]: %f\n",
|
||||
i,
|
||||
j,
|
||||
k,
|
||||
l,
|
||||
m,
|
||||
t_view(i, j, k, l, m));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
printf("I am inside bq_permuteN\n");
|
||||
return ck_tile::reference_permute(
|
||||
t_view, {0, 3, 1, 2, 4}); // {1, 2, 4, 16, 1}, group_n 16 {1, 2, 4, 1, 1}
|
||||
return ck_tile::reference_permute(t_view, {0, 3, 1, 2, 4});
|
||||
}
|
||||
|
||||
template <typename GemmConfig, typename T>
|
||||
|
||||
@@ -28,7 +28,6 @@ struct BlockGemmWeightPreshuffleBQuantARegBRegCReg
|
||||
using QuantGroupSize = remove_cvref_t<typename Problem::QuantGroupSize>;
|
||||
|
||||
static_assert(QuantGroupSize::kM == 1, "only N/K blocks for BQuant preshuffle kernel!");
|
||||
// static_assert(QuantGroupSize::kN == 1, "no block for N supported yet!");
|
||||
|
||||
static constexpr auto I0 = number<0>();
|
||||
static constexpr auto I1 = number<1>();
|
||||
@@ -215,35 +214,13 @@ struct BlockGemmWeightPreshuffleBQuantARegBRegCReg
|
||||
return nIter * KPerBlockBQ + kQScale;
|
||||
}
|
||||
}();
|
||||
// constexpr index_t reg_offset = nIter * KPerBlockBQ + kQScale;
|
||||
auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset];
|
||||
float scale_reg_f = cvt_scale_to_fp32(scale_reg);
|
||||
// if(get_block_id() == 0 && get_thread_id() == 1)
|
||||
//{
|
||||
// printf("get_block_id(): %d, get_warp_id(): %d, get_thread_id(): %d,
|
||||
// nIter: "
|
||||
// "%d, NWarp: %d, WG::kN: %d, QuantGroupSize::kN: %d, "
|
||||
// "KPerBlockBQ: %d, kQScale: %d, scale_reg_f: %f, reg_offset: %d\n",
|
||||
// get_block_id(),
|
||||
// get_warp_id(),
|
||||
// get_thread_id(),
|
||||
// static_cast<int>(nIter),
|
||||
// NWarp,
|
||||
// WG::kN,
|
||||
// static_cast<int>(QuantGroupSize::kN),
|
||||
// static_cast<int>(KPerBlockBQ),
|
||||
// static_cast<int>(kQScale),
|
||||
// scale_reg_f,
|
||||
// reg_offset);
|
||||
//}
|
||||
|
||||
static_for<0, WG::kM * WG::kN / warp_size, 1>{}([&](auto c_row) {
|
||||
auto& c_ref = c_block_tensor.get_thread_buffer()[tbuf_offset + c_row];
|
||||
const auto acc_val = c_acc(mIter)(nIter).get_thread_buffer()[c_row];
|
||||
// if(get_block_id() == 0 && get_thread_id() == 0) {
|
||||
// printf("acc_val: %f, scale_reg_f: %f\n", acc_val, scale_reg_f);
|
||||
// }
|
||||
c_ref = c_ref + acc_val * scale_reg_f;
|
||||
c_ref = c_ref + acc_val * scale_reg_f;
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
@@ -653,10 +653,6 @@ struct QuantGemmKernel
|
||||
(splitk_batch_offset.splitted_k /
|
||||
GemmPipeline::BlockGemmShape::WarpTile::at(number<2>{}));
|
||||
index_t kFlatN = kargs.N * kargs.K / kFlatK;
|
||||
if(get_block_id() == 0 && get_thread_id() == 0)
|
||||
{
|
||||
printf("kFlatN: %d, kFlatK: %d\n", kFlatN, kFlatK);
|
||||
}
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
b_ptr,
|
||||
make_tuple(kFlatN, kFlatK),
|
||||
@@ -991,20 +987,10 @@ struct QuantGemmKernel
|
||||
{
|
||||
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
|
||||
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::QuantGroupSize>;
|
||||
if(get_block_id() == 0 && get_thread_id() == 0)
|
||||
{
|
||||
printf("TilePartitioner::KPerBlock %d, QuantGroupSize::kK: %d, "
|
||||
"TilePartitioner::NPerBlock %d, QuantGroupSize::kN: %d\n",
|
||||
TilePartitioner::KPerBlock,
|
||||
QuantGroupSize::kK,
|
||||
TilePartitioner::NPerBlock,
|
||||
QuantGroupSize::kN);
|
||||
}
|
||||
return make_tile_window(
|
||||
bq_pad_view,
|
||||
make_tuple(number<TilePartitioner::KPerBlock / QuantGroupSize::kK>{}, // 1
|
||||
number<TilePartitioner::NPerBlock /
|
||||
QuantGroupSize::kN>{}), // 128/16 = 8
|
||||
make_tuple(number<TilePartitioner::KPerBlock / QuantGroupSize::kK>{},
|
||||
number<TilePartitioner::NPerBlock / QuantGroupSize::kN>{}),
|
||||
{0, i_n / QuantGroupSize::kN});
|
||||
}
|
||||
}
|
||||
@@ -1164,11 +1150,6 @@ struct QuantGemmKernel
|
||||
if constexpr(kQuantType == QuantType::BQuantGrouped)
|
||||
{
|
||||
const auto& bq_block_window = gemm_tile_windows.at(I3);
|
||||
if(get_block_id() == 0 && get_thread_id() == 0)
|
||||
{
|
||||
bq_block_window.template print_tile_window_range<BQDataType>(
|
||||
0, 1, 0, 128, "bq block window");
|
||||
}
|
||||
return GemmPipeline{}.template operator()(a_block_window,
|
||||
b_block_window,
|
||||
bq_block_window,
|
||||
|
||||
@@ -71,8 +71,8 @@ struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
|
||||
tile_distribution_encoding_pattern_bq<BlockGemmShape,
|
||||
WarpGemm,
|
||||
BlockSize,
|
||||
KPerBlockBQ, // 128/128 = 1
|
||||
NPerBlockBQ, // 128/16 = 8
|
||||
KPerBlockBQ,
|
||||
NPerBlockBQ,
|
||||
Problem::QuantGroupSize::kN>;
|
||||
|
||||
return TileEncodingPattern::make_2d_static_tile_distribution();
|
||||
|
||||
@@ -169,9 +169,9 @@ struct tile_distribution_encoding_pattern_aq_transposed_c
|
||||
template <typename BlockGemmShape,
|
||||
typename WarpGemm,
|
||||
index_t BlockSize,
|
||||
index_t YPerTile, // 1
|
||||
index_t XPerTile, // 8
|
||||
index_t XPerQ, // 16
|
||||
index_t YPerTile,
|
||||
index_t XPerTile,
|
||||
index_t XPerQ,
|
||||
bool PreshuffleQuant = false>
|
||||
struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding_pattern
|
||||
{
|
||||
@@ -255,18 +255,16 @@ struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding
|
||||
else if constexpr(XPerQ <= WarpGemm::kN * NWarps)
|
||||
{
|
||||
// Case 2: Medium-grained - one quantization scale per warp
|
||||
constexpr auto XR =
|
||||
XPerQ / WarpGemm::kN; // Scale replication factor //16/16 = 1
|
||||
constexpr auto X1 = NWarps / XR; // Warps per unique scale //4/1 = 4
|
||||
constexpr auto X0 = XPerTile / X1; // Iterations to cover X dimension //8/4 = 2
|
||||
constexpr auto XR = XPerQ / WarpGemm::kN; // Scale replication factor
|
||||
constexpr auto X1 = NWarps / XR; // Warps per unique scale
|
||||
constexpr auto X0 = XPerTile / X1; // Iterations to cover X dimension
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<MWarps, XR, get_warp_size()>, // 1, 1, 64
|
||||
tuple<sequence<YPerTile>, sequence<X0, X1>>, // 1, (2, 4)
|
||||
tuple<sequence<0, 2, 0>, sequence<0>>, //(1, 4, 1) (64)
|
||||
tuple<sequence<0, 1, 1>, sequence<2>>,
|
||||
sequence<2, 1>, //(2, 1(in Y dimension))
|
||||
sequence<0, 0>>{});
|
||||
tile_distribution_encoding<sequence<MWarps, XR, get_warp_size()>,
|
||||
tuple<sequence<YPerTile>, sequence<X0, X1>>,
|
||||
tuple<sequence<0, 2, 0>, sequence<0>>,
|
||||
tuple<sequence<0, 1, 1>, sequence<2>>,
|
||||
sequence<2, 1>,
|
||||
sequence<0, 0>>{});
|
||||
}
|
||||
else // XPerQ > WarpGemm::kN * NWarps
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user