mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-02 04:37:02 +00:00
debugging permuteN
This commit is contained in:
@@ -11,9 +11,9 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12")
|
||||
gemm_quant.cpp
|
||||
gemm_aquant_quantgrouped.cpp
|
||||
gemm_aquant_quantgrouped_preshufflequant.cpp
|
||||
gemm_bquant_quantgrouped_bf8i4.cpp
|
||||
gemm_bquant_quantgrouped_fp8i4.cpp
|
||||
gemm_bquant_quantgrouped_bf8.cpp
|
||||
# gemm_bquant_quantgrouped_bf8i4.cpp
|
||||
# gemm_bquant_quantgrouped_fp8i4.cpp
|
||||
# gemm_bquant_quantgrouped_bf8.cpp
|
||||
gemm_bquant_quantgrouped_fp8.cpp
|
||||
gemm_bquant_quantgrouped_preshuffleb.cpp
|
||||
gemm_bquant_quantgrouped_preshufflequant.cpp
|
||||
|
||||
@@ -9,51 +9,94 @@ using GemmConfig = GemmConfigPreshuffleB_BQuant_Prefill<T>;
|
||||
void bquant_quantgrouped_preshuffleb_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
|
||||
{
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
lut[hash_multiple_strings({"fp8", "bquant", "preshuffleb", "non-preshufflequant", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
|
||||
ck_tile::fp8_t,
|
||||
ck_tile::half_t,
|
||||
float>{});
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
|
||||
ck_tile::fp8_t,
|
||||
ck_tile::half_t,
|
||||
float>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"bf8", "bquant", "preshuffleb", "non-preshufflequant", "1x1x128"})] =
|
||||
lut[hash_multiple_strings({"fp8", "bquant", "preshuffleb", "non-preshufflequant", "1x8x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
|
||||
ck_tile::bf8_t,
|
||||
ck_tile::half_t,
|
||||
float>{});
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"fp8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::fp8_t>{});
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
|
||||
ck_tile::fp8_t,
|
||||
ck_tile::half_t,
|
||||
float>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"bf8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::bf8_t>{});
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"fp8",
|
||||
"bquant",
|
||||
"preshuffleb",
|
||||
"non-preshufflequant",
|
||||
"1x32x128"})] = [](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 32, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"fp8",
|
||||
"bquant",
|
||||
"preshuffleb",
|
||||
"non-preshufflequant",
|
||||
"1x64x128"})] = [](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
|
||||
// lut[hash_multiple_strings({"bf8", "bquant", "preshuffleb", "non-preshufflequant",
|
||||
// "1x1x128"})] =
|
||||
// [](const ck_tile::ArgParser& arg_parser) {
|
||||
// using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
|
||||
// ck_tile::bf8_t,
|
||||
// ck_tile::half_t,
|
||||
// float>{});
|
||||
// using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
// return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
// TypeConfig,
|
||||
// QuantGroupSize,
|
||||
// ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
// };
|
||||
// lut[hash_multiple_strings(
|
||||
// {"fp8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x1x128"})] =
|
||||
// [](const ck_tile::ArgParser& arg_parser) {
|
||||
// using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
|
||||
// ck_tile::pk_int4_t,
|
||||
// ck_tile::half_t,
|
||||
// ck_tile::fp8_t>{});
|
||||
// using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
// return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
// TypeConfig,
|
||||
// QuantGroupSize,
|
||||
// ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
// };
|
||||
// lut[hash_multiple_strings(
|
||||
// {"bf8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x1x128"})] =
|
||||
// [](const ck_tile::ArgParser& arg_parser) {
|
||||
// using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
|
||||
// ck_tile::pk_int4_t,
|
||||
// ck_tile::half_t,
|
||||
// ck_tile::bf8_t>{});
|
||||
// using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
// return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
// TypeConfig,
|
||||
// QuantGroupSize,
|
||||
// ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
// };
|
||||
}
|
||||
|
||||
@@ -21,37 +21,39 @@ void bquant_quantgrouped_preshuffleb_preshufflequant_instance_factory(
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"bf8", "bquant", "preshuffleb", "preshufflequant", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
|
||||
ck_tile::bf8_t,
|
||||
ck_tile::half_t,
|
||||
float>{});
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"fp8i4", "bquant", "preshuffleb", "preshufflequant", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::fp8_t>{});
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"bf8i4", "bquant", "preshuffleb", "preshufflequant", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::bf8_t>{});
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
// lut[hash_multiple_strings({"bf8", "bquant", "preshuffleb", "preshufflequant", "1x1x128"})] =
|
||||
// [](const ck_tile::ArgParser& arg_parser) {
|
||||
// using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
|
||||
// ck_tile::bf8_t,
|
||||
// ck_tile::half_t,
|
||||
// float>{});
|
||||
// return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
// TypeConfig,
|
||||
// QuantGroupSize,
|
||||
// ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
// };
|
||||
// lut[hash_multiple_strings({"fp8i4", "bquant", "preshuffleb", "preshufflequant", "1x1x128"})]
|
||||
// =
|
||||
// [](const ck_tile::ArgParser& arg_parser) {
|
||||
// using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
|
||||
// ck_tile::pk_int4_t,
|
||||
// ck_tile::half_t,
|
||||
// ck_tile::fp8_t>{});
|
||||
// return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
// TypeConfig,
|
||||
// QuantGroupSize,
|
||||
// ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
// };
|
||||
// lut[hash_multiple_strings({"bf8i4", "bquant", "preshuffleb", "preshufflequant", "1x1x128"})]
|
||||
// =
|
||||
// [](const ck_tile::ArgParser& arg_parser) {
|
||||
// using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
|
||||
// ck_tile::pk_int4_t,
|
||||
// ck_tile::half_t,
|
||||
// ck_tile::bf8_t>{});
|
||||
// return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
// TypeConfig,
|
||||
// QuantGroupSize,
|
||||
// ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
// };
|
||||
}
|
||||
|
||||
@@ -21,39 +21,40 @@ void bquant_quantgrouped_preshufflequant_instance_factory(
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"bf8", "bquant", "non-preshuffleb", "preshufflequant", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
|
||||
ck_tile::bf8_t,
|
||||
ck_tile::half_t,
|
||||
float>{});
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"fp8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::fp8_t>{});
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"bf8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::bf8_t>{});
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
// lut[hash_multiple_strings({"bf8", "bquant", "non-preshuffleb", "preshufflequant",
|
||||
// "1x1x128"})] =
|
||||
// [](const ck_tile::ArgParser& arg_parser) {
|
||||
// using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
|
||||
// ck_tile::bf8_t,
|
||||
// ck_tile::half_t,
|
||||
// float>{});
|
||||
// return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
// TypeConfig,
|
||||
// QuantGroupSize,
|
||||
// ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
// };
|
||||
// lut[hash_multiple_strings(
|
||||
// {"fp8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x1x128"})] =
|
||||
// [](const ck_tile::ArgParser& arg_parser) {
|
||||
// using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
|
||||
// ck_tile::pk_int4_t,
|
||||
// ck_tile::half_t,
|
||||
// ck_tile::fp8_t>{});
|
||||
// return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
// TypeConfig,
|
||||
// QuantGroupSize,
|
||||
// ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
// };
|
||||
// lut[hash_multiple_strings(
|
||||
// {"bf8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x1x128"})] =
|
||||
// [](const ck_tile::ArgParser& arg_parser) {
|
||||
// using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
|
||||
// ck_tile::pk_int4_t,
|
||||
// ck_tile::half_t,
|
||||
// ck_tile::bf8_t>{});
|
||||
// return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
// TypeConfig,
|
||||
// QuantGroupSize,
|
||||
// ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
// };
|
||||
}
|
||||
|
||||
@@ -17,9 +17,9 @@ auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("h", "false", "Print help message")
|
||||
.insert("m", "3840", "m dimension")
|
||||
.insert("n", "4096", "n dimension")
|
||||
.insert("k", "2048", "k dimension")
|
||||
.insert("m", "128", "m dimension")
|
||||
.insert("n", "128", "n dimension")
|
||||
.insert("k", "128", "k dimension")
|
||||
.insert("a_layout", "R", "A tensor data layout - Row or Column")
|
||||
.insert("b_layout", "C", "B tensor data layout - Row or Column")
|
||||
.insert("bq_layout", "C", "Bq tensor data layout - Row or Column")
|
||||
@@ -33,14 +33,14 @@ auto create_args(int argc, char* argv[])
|
||||
"fp8",
|
||||
"Data type. For AQuant: fp8, bf8, i4fp8, or i4bf8; for Bquant: fp8, bf8, fp8i4, "
|
||||
"or bf8i4")
|
||||
.insert("warmup", "50", "Number of iterations before benchmarking the kernel")
|
||||
.insert("repeat", "1000", "Number of iterations to benchmark the kernel")
|
||||
.insert("warmup", "1", "Number of iterations before benchmarking the kernel")
|
||||
.insert("repeat", "0", "Number of iterations to benchmark the kernel")
|
||||
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
|
||||
.insert("split_k", "1", "SplitK value")
|
||||
.insert("device", "0", "Device id that will be used to run the kernel")
|
||||
.insert("init", "0", "0:random, 1:linear, 2:constant(1)")
|
||||
.insert("flush_cache", "true", "Flush cache before running the kernel")
|
||||
.insert("rotating_count", "1000", "Rotating count")
|
||||
.insert("rotating_count", "0", "Rotating count")
|
||||
.insert("quant_mode", "bquant", "Choose aquant, bquant, tensor or rowcol")
|
||||
.insert("preshuffleb", "false", "Enable preshuffle of tensor B")
|
||||
.insert("preshufflequant", "false", "Enable preshuffle of quant tensor")
|
||||
@@ -91,12 +91,12 @@ void aquant_quantgrouped_preshufflequant_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
|
||||
void bquant_quantgrouped_fp8_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
|
||||
void bquant_quantgrouped_bf8_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
|
||||
void bquant_quantgrouped_fp8i4_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
|
||||
void bquant_quantgrouped_bf8i4_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
|
||||
// void bquant_quantgrouped_bf8_instance_factory(
|
||||
// std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
|
||||
// void bquant_quantgrouped_fp8i4_instance_factory(
|
||||
// std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
|
||||
// void bquant_quantgrouped_bf8i4_instance_factory(
|
||||
// std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
|
||||
void bquant_quantgrouped_preshuffleb_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
|
||||
void bquant_quantgrouped_preshufflequant_instance_factory(
|
||||
@@ -125,9 +125,9 @@ int main(int argc, char* argv[])
|
||||
aquant_quantgrouped_instance_factory(lut);
|
||||
aquant_quantgrouped_preshufflequant_instance_factory(lut);
|
||||
bquant_quantgrouped_fp8_instance_factory(lut);
|
||||
bquant_quantgrouped_bf8_instance_factory(lut);
|
||||
bquant_quantgrouped_fp8i4_instance_factory(lut);
|
||||
bquant_quantgrouped_bf8i4_instance_factory(lut);
|
||||
// bquant_quantgrouped_bf8_instance_factory(lut);
|
||||
// bquant_quantgrouped_fp8i4_instance_factory(lut);
|
||||
// bquant_quantgrouped_bf8i4_instance_factory(lut);
|
||||
bquant_quantgrouped_preshuffleb_instance_factory(lut);
|
||||
bquant_quantgrouped_preshufflequant_instance_factory(lut);
|
||||
bquant_quantgrouped_preshuffleb_preshufflequant_instance_factory(lut);
|
||||
|
||||
@@ -210,7 +210,7 @@ struct GemmConfigPreshuffleB_BQuant_Prefill : public GemmConfigBase
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
|
||||
static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp;
|
||||
static constexpr bool TiledMMAPermuteN = N_Repeat % 2 == 0;
|
||||
static constexpr bool TiledMMAPermuteN = false; // N_Repeat % 2 == 0;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
|
||||
@@ -481,7 +481,8 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::FillUniformDistribution<BDataType>{-2.0f, 3.0f, fill_seed(gen)}(b_k_n);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-2.0f,
|
||||
3.0f /*, fill_seed(gen)*/}(b_k_n);
|
||||
}
|
||||
ck_tile::FillUniformDistribution<BQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
|
||||
*bq_tensor_ptr);
|
||||
@@ -543,7 +544,6 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
|
||||
b_k_n.SetZero();
|
||||
bq_tensor_ptr->SetZero();
|
||||
}
|
||||
|
||||
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes());
|
||||
@@ -600,6 +600,14 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
|
||||
{
|
||||
printf("PreshuffleB with TiledMMAPermuteN\n");
|
||||
b_k_n_dev = ck_tile::shuffle_b_permuteN<GemmConfig>(b_k_n);
|
||||
printf("b_k_n_dev.get_lengths(): %lu, %lu, %lu, %lu, %lu, %lu, %lu\n",
|
||||
b_k_n_dev.get_lengths()[0],
|
||||
b_k_n_dev.get_lengths()[1],
|
||||
b_k_n_dev.get_lengths()[2],
|
||||
b_k_n_dev.get_lengths()[3],
|
||||
b_k_n_dev.get_lengths()[4],
|
||||
b_k_n_dev.get_lengths()[5],
|
||||
b_k_n_dev.get_lengths()[6]);
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -624,8 +632,44 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
|
||||
if constexpr(GemmConfig::PreshuffleB && GemmConfig::TiledMMAPermuteN)
|
||||
{
|
||||
printf("Preshuffle BQ with TiledMMAPermuteN \n");
|
||||
for(int i = 0; i < static_cast<int>((*bq_tensor_ptr).get_lengths()[0]); i++)
|
||||
{
|
||||
for(int j = 0; j < static_cast<int>((*bq_tensor_ptr).get_lengths()[1]); j++)
|
||||
{
|
||||
printf("(*bq_tensor_ptr)[%d][%d]: %f\n", i, j, (*bq_tensor_ptr)(i, j));
|
||||
}
|
||||
}
|
||||
ck_tile::HostTensor<BQDataType> bq_permuted_host =
|
||||
ck_tile::bq_permuteN<GemmConfig>(*bq_tensor_ptr);
|
||||
printf("bq_permuted_host.get_lengths(): %lu, %lu, %lu, %lu, %lu\n",
|
||||
bq_permuted_host.get_lengths()[0],
|
||||
bq_permuted_host.get_lengths()[1],
|
||||
bq_permuted_host.get_lengths()[2],
|
||||
bq_permuted_host.get_lengths()[3],
|
||||
bq_permuted_host.get_lengths()[4]);
|
||||
for(int i = 0; i < static_cast<int>(bq_permuted_host.get_lengths()[0]); i++)
|
||||
{
|
||||
for(int j = 0; j < static_cast<int>(bq_permuted_host.get_lengths()[1]); j++)
|
||||
{
|
||||
for(int k = 0; k < static_cast<int>(bq_permuted_host.get_lengths()[2]); k++)
|
||||
{
|
||||
for(int l = 0; l < static_cast<int>(bq_permuted_host.get_lengths()[3]); l++)
|
||||
{
|
||||
for(int m = 0; m < static_cast<int>(bq_permuted_host.get_lengths()[4]);
|
||||
m++)
|
||||
{
|
||||
printf("bq_permuted_host[%d][%d][%d][%d][%d]: %f\n",
|
||||
i,
|
||||
j,
|
||||
k,
|
||||
l,
|
||||
m,
|
||||
bq_permuted_host(i, j, k, l, m));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(GemmConfig::PreshuffleQuant)
|
||||
{
|
||||
|
||||
@@ -1197,7 +1197,7 @@ struct tile_window_with_static_lengths
|
||||
using ThreadBuf = thread_buffer<DataType, 2>;
|
||||
auto buf = tensor_view.template get_vectorized_elements<ThreadBuf>(coord, 0);
|
||||
auto value = buf.at(number<0>{}); // Extract first element from thread buffer
|
||||
printf(" %s[%d,%d] = %f", label, i, j, type_convert<float>(value));
|
||||
printf(" %s[%d,%d] = %f\n", label, i, j, type_convert<float>(value));
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
|
||||
@@ -28,7 +28,7 @@ struct BlockGemmWeightPreshuffleBQuantARegBRegCReg
|
||||
using QuantGroupSize = remove_cvref_t<typename Problem::QuantGroupSize>;
|
||||
|
||||
static_assert(QuantGroupSize::kM == 1, "only N/K blocks for BQuant preshuffle kernel!");
|
||||
static_assert(QuantGroupSize::kN == 1, "no block for N supported yet!");
|
||||
// static_assert(QuantGroupSize::kN == 1, "no block for N supported yet!");
|
||||
|
||||
static constexpr auto I0 = number<0>();
|
||||
static constexpr auto I1 = number<1>();
|
||||
@@ -204,14 +204,40 @@ struct BlockGemmWeightPreshuffleBQuantARegBRegCReg
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr index_t reg_offset = nIter * KPerBlockBQ + kQScale;
|
||||
index_t reg_offset = [&]() {
|
||||
if constexpr(QuantGroupSize::kN >= (NWarp * WG::kN))
|
||||
{
|
||||
return (nIter * NWarp * WG::kN) / QuantGroupSize::kN * KPerBlockBQ +
|
||||
kQScale;
|
||||
}
|
||||
else
|
||||
{
|
||||
return nIter * KPerBlockBQ + kQScale;
|
||||
}
|
||||
}();
|
||||
// constexpr index_t reg_offset = nIter * KPerBlockBQ + kQScale;
|
||||
auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset];
|
||||
float scale_reg_f = cvt_scale_to_fp32(scale_reg);
|
||||
if(get_block_id() == 0 && get_thread_id() == 0)
|
||||
{
|
||||
printf("scale_reg_f: %f, reg_offset: %d\n", scale_reg_f, reg_offset);
|
||||
printf("nIter: %d, NWarp: %d, WG::kN: %d, QuantGroupSize::kN: %d, "
|
||||
"KPerBlockBQ: %d, kQScale: %d\n",
|
||||
static_cast<int>(nIter),
|
||||
NWarp,
|
||||
WG::kN,
|
||||
static_cast<int>(QuantGroupSize::kN),
|
||||
static_cast<int>(KPerBlockBQ),
|
||||
static_cast<int>(kQScale));
|
||||
}
|
||||
|
||||
static_for<0, WG::kM * WG::kN / warp_size, 1>{}([&](auto c_row) {
|
||||
auto& c_ref = c_block_tensor.get_thread_buffer()[tbuf_offset + c_row];
|
||||
const auto acc_val = c_acc(mIter)(nIter).get_thread_buffer()[c_row];
|
||||
c_ref = c_ref + acc_val * scale_reg_f;
|
||||
// if(get_block_id() == 0 && get_thread_id() == 0) {
|
||||
// printf("acc_val: %f, scale_reg_f: %f\n", acc_val, scale_reg_f);
|
||||
// }
|
||||
c_ref = c_ref + acc_val * scale_reg_f;
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
@@ -654,7 +654,10 @@ struct QuantGemmKernel
|
||||
(splitk_batch_offset.splitted_k /
|
||||
TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{}));
|
||||
index_t kFlatN = kargs.N * kargs.K / kFlatK;
|
||||
|
||||
if(get_block_id() == 0 && get_thread_id() == 0)
|
||||
{
|
||||
printf("kFlatN: %d, kFlatK: %d\n", kFlatN, kFlatK);
|
||||
}
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
b_ptr,
|
||||
make_tuple(kFlatN, kFlatK),
|
||||
@@ -989,10 +992,19 @@ struct QuantGemmKernel
|
||||
{
|
||||
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
|
||||
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::QuantGroupSize>;
|
||||
if(get_block_id() == 0 && get_thread_id() == 0)
|
||||
{
|
||||
printf("TilePartitioner::KPerBlock %d, QuantGroupSize::kK: %d, "
|
||||
"TilePartitioner::NPerBlock %d, QuantGroupSize::kN: %d\n",
|
||||
TilePartitioner::KPerBlock,
|
||||
QuantGroupSize::kK,
|
||||
TilePartitioner::NPerBlock,
|
||||
QuantGroupSize::kN);
|
||||
}
|
||||
return make_tile_window(
|
||||
bq_pad_view,
|
||||
make_tuple(number<TilePartitioner::KPerBlock / QuantGroupSize::kK>{},
|
||||
number<TilePartitioner::NPerBlock / QuantGroupSize::kN>{}),
|
||||
make_tuple(number<TilePartitioner::KPerBlock / QuantGroupSize::kK>{}, // 1
|
||||
number<TilePartitioner::NPerBlock / QuantGroupSize::kN>{}), // 16
|
||||
{0, i_n / QuantGroupSize::kN});
|
||||
}
|
||||
}
|
||||
@@ -1152,6 +1164,11 @@ struct QuantGemmKernel
|
||||
if constexpr(kQuantType == QuantType::BQuantGrouped)
|
||||
{
|
||||
const auto& bq_block_window = gemm_tile_windows.at(I3);
|
||||
if(get_block_id() == 0 && get_thread_id() == 0)
|
||||
{
|
||||
bq_block_window.template print_tile_window_range<BQDataType>(
|
||||
0, 1, 0, 32, "bq block window");
|
||||
}
|
||||
return GemmPipeline{}.template operator()(a_block_window,
|
||||
b_block_window,
|
||||
bq_block_window,
|
||||
|
||||
Reference in New Issue
Block a user