mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 17:00:18 +00:00
[BlockScale GEMM] FP8 Blockscale GEMM optimization and ckProfiler (#1913)
* Added two kernel for M=32 problem * Comment the first one * Enable multiply_multiply for Scale_Block_M = 1 for deepseek * Modify the a_thread offset since the A data load is different from B. * edit fp8 ab scale for Scale_Block_M=1 * edit GemmSpec to MNKPadding * enable blockwise pipelie v1 and v2. v1 is work for small K. * add instance for gemm_ab_scale * fix cmakelist of ckProfiler * optimize blockscale gemm. todo: reduce vgpr usage * fix a correctness bug * sanity checked * revert ckprofiler cmake changes * clang format * revert unnecessary changes. * remove commented codes. --------- Co-authored-by: mtgu0705 <mtgu@amd.com> Co-authored-by: chenjun <junchen2@amd.com>
This commit is contained in:
@@ -17,7 +17,7 @@ namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
#if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_FP8))
|
||||
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_instances(
|
||||
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_comp_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Row,
|
||||
Col,
|
||||
Tuple<>,
|
||||
@@ -28,14 +28,14 @@ void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_i
|
||||
F32,
|
||||
Tuple<>,
|
||||
BF16,
|
||||
128,
|
||||
1,
|
||||
128,
|
||||
128,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_instances(
|
||||
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_comp_kpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Row,
|
||||
Col,
|
||||
Tuple<>,
|
||||
@@ -46,14 +46,14 @@ void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_
|
||||
F32,
|
||||
Tuple<>,
|
||||
BF16,
|
||||
128,
|
||||
1,
|
||||
128,
|
||||
128,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnpadding_instances(
|
||||
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_mem_v1_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Row,
|
||||
Col,
|
||||
Tuple<>,
|
||||
@@ -64,14 +64,14 @@ void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnpadding
|
||||
F32,
|
||||
Tuple<>,
|
||||
BF16,
|
||||
128,
|
||||
1,
|
||||
128,
|
||||
128,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnkpadding_instances(
|
||||
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_mem_v1_kpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Row,
|
||||
Col,
|
||||
Tuple<>,
|
||||
@@ -82,61 +82,7 @@ void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnkpaddin
|
||||
F32,
|
||||
Tuple<>,
|
||||
BF16,
|
||||
128,
|
||||
128,
|
||||
128,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Row,
|
||||
Col,
|
||||
Tuple<>,
|
||||
Row,
|
||||
F8,
|
||||
F32,
|
||||
F8,
|
||||
F32,
|
||||
Tuple<>,
|
||||
BF16,
|
||||
128,
|
||||
128,
|
||||
128,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_kpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Row,
|
||||
Col,
|
||||
Tuple<>,
|
||||
Row,
|
||||
F8,
|
||||
F32,
|
||||
F8,
|
||||
F32,
|
||||
Tuple<>,
|
||||
BF16,
|
||||
128,
|
||||
128,
|
||||
128,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Row,
|
||||
Col,
|
||||
Tuple<>,
|
||||
Row,
|
||||
F8,
|
||||
F32,
|
||||
F8,
|
||||
F32,
|
||||
Tuple<>,
|
||||
BF16,
|
||||
128,
|
||||
1,
|
||||
128,
|
||||
128,
|
||||
PassThrough,
|
||||
@@ -163,7 +109,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
|
||||
B1DataType,
|
||||
Tuple<>,
|
||||
CDataType,
|
||||
128,
|
||||
1,
|
||||
128,
|
||||
128,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
@@ -180,7 +126,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
|
||||
B1DataType,
|
||||
Tuple<>,
|
||||
CDataType,
|
||||
128,
|
||||
1,
|
||||
128,
|
||||
128,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
@@ -198,20 +144,14 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
|
||||
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_instances(
|
||||
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_comp_default_instances(
|
||||
op_ptrs);
|
||||
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_instances(
|
||||
op_ptrs);
|
||||
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnpadding_instances(
|
||||
op_ptrs);
|
||||
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnkpadding_instances(
|
||||
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_comp_kpadding_instances(
|
||||
op_ptrs);
|
||||
|
||||
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_default_instances(
|
||||
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_mem_v1_default_instances(
|
||||
op_ptrs);
|
||||
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_kpadding_instances(
|
||||
op_ptrs);
|
||||
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_mnkpadding_instances(
|
||||
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_mem_v1_kpadding_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user