mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-18 20:09:25 +00:00
Redesign the DPP8 GEMM kernel to use warp-wise component (#863)
* Redesign the DPP8 GEMM kernel to use warp-wise component
* Review: Improve error messages
* Review: Remove unnecessary empty lines
* Review: Fix M, N per thread names
* Review: Rename mfma_input_type to dpp_input_type
* Review: Fix tensor adaptor; remove unnecessary element
* Review: Remove calls to dpp_gemm's MakeCDescriptor
* Review: Add blockwise doc, change function names to include dimension names
* Review: Remove duplicated code; Move Block2CtileMap alias to the top of the file
* Review: Add __restrict__ keywords
* Review: Use MatrixPadder for padding A, B, C matrices
* Review: Remove hardcoded datatypes
* Review: Change names from FloatX to XDataType
* Review: Introduce AK0 and BK0 instead of a single K0
* Review: Remove construction of dpp_datatypes object
* Review: Rename DppInstrRunner to DppLanegroupGemm
[ROCm/composable_kernel commit: 37a8c1f756]
This commit is contained in:
committed by
GitHub
parent
3446ff1e7d
commit
88bb9d5fac
@@ -23,7 +23,7 @@ void add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances(
|
||||
DeviceGemm<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dl_dpp8_f16_f16_f16_km_kn_mn_instances(
|
||||
void add_device_gemm_dpp_f16_f16_f16_km_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
@@ -38,7 +38,7 @@ void add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances(
|
||||
DeviceGemm<Col, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dl_dpp8_f16_f16_f16_km_nk_mn_instances(
|
||||
void add_device_gemm_dpp_f16_f16_f16_km_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
@@ -53,7 +53,7 @@ void add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances(
|
||||
DeviceGemm<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dl_dpp8_f16_f16_f16_mk_kn_mn_instances(
|
||||
void add_device_gemm_dpp_f16_f16_f16_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
@@ -68,7 +68,7 @@ void add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances(
|
||||
DeviceGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dl_dpp8_f16_f16_f16_mk_nk_mn_instances(
|
||||
void add_device_gemm_dpp_f16_f16_f16_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
@@ -374,7 +374,7 @@ struct DeviceOperationInstanceFactory<
|
||||
#ifdef DL_KERNELS
|
||||
add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
|
||||
add_device_gemm_dl_f16_f16_f16_mk_kn_mn_irregular_instances(op_ptrs);
|
||||
add_device_gemm_dl_dpp8_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
|
||||
add_device_gemm_dpp_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
|
||||
#endif
|
||||
add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
|
||||
}
|
||||
@@ -385,7 +385,7 @@ struct DeviceOperationInstanceFactory<
|
||||
#ifdef DL_KERNELS
|
||||
add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
|
||||
add_device_gemm_dl_f16_f16_f16_mk_nk_mn_irregular_instances(op_ptrs);
|
||||
add_device_gemm_dl_dpp8_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
|
||||
add_device_gemm_dpp_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
|
||||
#endif
|
||||
add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
|
||||
add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
|
||||
@@ -397,7 +397,7 @@ struct DeviceOperationInstanceFactory<
|
||||
#ifdef DL_KERNELS
|
||||
add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances(op_ptrs);
|
||||
add_device_gemm_dl_f16_f16_f16_km_kn_mn_irregular_instances(op_ptrs);
|
||||
add_device_gemm_dl_dpp8_f16_f16_f16_km_kn_mn_instances(op_ptrs);
|
||||
add_device_gemm_dpp_f16_f16_f16_km_kn_mn_instances(op_ptrs);
|
||||
#endif
|
||||
add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(op_ptrs);
|
||||
}
|
||||
@@ -408,7 +408,7 @@ struct DeviceOperationInstanceFactory<
|
||||
#ifdef DL_KERNELS
|
||||
add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances(op_ptrs);
|
||||
add_device_gemm_dl_f16_f16_f16_km_nk_mn_irregular_instances(op_ptrs);
|
||||
add_device_gemm_dl_dpp8_f16_f16_f16_km_nk_mn_instances(op_ptrs);
|
||||
add_device_gemm_dpp_f16_f16_f16_km_nk_mn_instances(op_ptrs);
|
||||
#endif
|
||||
add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(op_ptrs);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user