Add other layouts for FP8 block scaled gemm (#2665)

* Start adding other layouts for gemm_ab_scale

* Add some instances

* Create tensor descriptors for A/B scales depending on A/B layout

* Fix formatting

* Revert some comments

* Revert commented instances in CMakeLists.txt

* Add some more instances for col-row gemm

* enable more row,row instances

* Use occupancy=1 for col,row layout to avoid spills

[ROCm/composable_kernel commit: 26d3300930]
This commit is contained in:
Sami Remes
2025-08-18 11:46:10 +03:00
committed by GitHub
parent a4d70b6e13
commit 13bfcba04c
15 changed files with 758 additions and 13 deletions

View File

@@ -17,6 +17,7 @@ namespace tensor_operation {
namespace device {
namespace instance {
#if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_FP8))
// Row, Col
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_comp_default_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Row,
Col,
@@ -88,6 +89,152 @@ void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_mem_v1_kpadding_
PassThrough,
PassThrough,
PassThrough>>>& instances);
// Row, Row
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_kn_mn_1_128_128_comp_default_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Row,
Row,
Tuple<>,
Row,
F8,
F32,
F8,
F32,
Tuple<>,
BF16,
1,
128,
128,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_kn_mn_1_128_128_comp_kpadding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Row,
Row,
Tuple<>,
Row,
F8,
F32,
F8,
F32,
Tuple<>,
BF16,
1,
128,
128,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_kn_mn_1_128_128_mem_v1_default_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Row,
Row,
Tuple<>,
Row,
F8,
F32,
F8,
F32,
Tuple<>,
BF16,
1,
128,
128,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_kn_mn_1_128_128_mem_v1_kpadding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Row,
Row,
Tuple<>,
Row,
F8,
F32,
F8,
F32,
Tuple<>,
BF16,
1,
128,
128,
PassThrough,
PassThrough,
PassThrough>>>& instances);
// Col, Row
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_km_kn_mn_1_128_128_comp_default_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Col,
Row,
Tuple<>,
Row,
F8,
F32,
F8,
F32,
Tuple<>,
BF16,
1,
128,
128,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_km_kn_mn_1_128_128_comp_kpadding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Col,
Row,
Tuple<>,
Row,
F8,
F32,
F8,
F32,
Tuple<>,
BF16,
1,
128,
128,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_km_kn_mn_1_128_128_mem_v1_default_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Col,
Row,
Tuple<>,
Row,
F8,
F32,
F8,
F32,
Tuple<>,
BF16,
1,
128,
128,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_km_kn_mn_1_128_128_mem_v1_kpadding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Col,
Row,
Tuple<>,
Row,
F8,
F32,
F8,
F32,
Tuple<>,
BF16,
1,
128,
128,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
template <typename A0DataType,
@@ -154,6 +301,32 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_mem_v1_kpadding_instances(
op_ptrs);
}
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_kn_mn_1_128_128_comp_default_instances(
op_ptrs);
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_kn_mn_1_128_128_comp_kpadding_instances(
op_ptrs);
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_kn_mn_1_128_128_mem_v1_default_instances(
op_ptrs);
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_kn_mn_1_128_128_mem_v1_kpadding_instances(
op_ptrs);
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_ab_scale_xdl_f8_f8_bf16_km_kn_mn_1_128_128_comp_default_instances(
op_ptrs);
add_device_gemm_ab_scale_xdl_f8_f8_bf16_km_kn_mn_1_128_128_comp_kpadding_instances(
op_ptrs);
add_device_gemm_ab_scale_xdl_f8_f8_bf16_km_kn_mn_1_128_128_mem_v1_default_instances(
op_ptrs);
add_device_gemm_ab_scale_xdl_f8_f8_bf16_km_kn_mn_1_128_128_mem_v1_kpadding_instances(
op_ptrs);
}
}
#endif
return op_ptrs;