Revert "[BlockScale GEMM] FP8 Blockscale GEMM optimization and ckProfiler (#1913)" (#1933)

This reverts commit 020148d0f7.
This commit is contained in:
asleepzzz
2025-03-03 23:17:39 +08:00
committed by GitHub
parent 1bf29478cd
commit ef16010273
18 changed files with 663 additions and 1018 deletions

View File

@@ -17,7 +17,7 @@ namespace tensor_operation {
namespace device {
namespace instance {
#if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_FP8))
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_comp_default_instances(
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Row,
Col,
Tuple<>,
@@ -28,14 +28,14 @@ void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_comp_default_ins
F32,
Tuple<>,
BF16,
1,
128,
128,
128,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_comp_kpadding_instances(
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Row,
Col,
Tuple<>,
@@ -46,14 +46,14 @@ void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_comp_kpadding_in
F32,
Tuple<>,
BF16,
1,
128,
128,
128,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_mem_v1_default_instances(
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnpadding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Row,
Col,
Tuple<>,
@@ -64,14 +64,14 @@ void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_mem_v1_default_i
F32,
Tuple<>,
BF16,
1,
128,
128,
128,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_mem_v1_kpadding_instances(
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Row,
Col,
Tuple<>,
@@ -82,7 +82,61 @@ void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_mem_v1_kpadding_
F32,
Tuple<>,
BF16,
1,
128,
128,
128,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_default_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Row,
Col,
Tuple<>,
Row,
F8,
F32,
F8,
F32,
Tuple<>,
BF16,
128,
128,
128,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_kpadding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Row,
Col,
Tuple<>,
Row,
F8,
F32,
F8,
F32,
Tuple<>,
BF16,
128,
128,
128,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Row,
Col,
Tuple<>,
Row,
F8,
F32,
F8,
F32,
Tuple<>,
BF16,
128,
128,
128,
PassThrough,
@@ -109,7 +163,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
B1DataType,
Tuple<>,
CDataType,
1,
128,
128,
128,
ck::tensor_operation::element_wise::PassThrough,
@@ -126,7 +180,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
B1DataType,
Tuple<>,
CDataType,
1,
128,
128,
128,
ck::tensor_operation::element_wise::PassThrough,
@@ -144,14 +198,20 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_comp_default_instances(
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_instances(
op_ptrs);
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_comp_kpadding_instances(
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_instances(
op_ptrs);
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnpadding_instances(
op_ptrs);
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnkpadding_instances(
op_ptrs);
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_mem_v1_default_instances(
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_default_instances(
op_ptrs);
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_mem_v1_kpadding_instances(
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_kpadding_instances(
op_ptrs);
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_mnkpadding_instances(
op_ptrs);
}
}