mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-21 21:39:15 +00:00
Merge commit 'e6b5e31c20bf859a869f7295489a2bbe10ef6eca' into develop
This commit is contained in:
@@ -18,168 +18,141 @@ namespace device {
|
||||
namespace instance {
|
||||
|
||||
#if(defined(CK_ENABLE_F16) || defined(CK_ENABLE_FP8))
|
||||
using TGemmMulMulF8F8F16Instances =
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
|
||||
Col,
|
||||
Tuple<Row, Col>,
|
||||
Row,
|
||||
F8,
|
||||
F8,
|
||||
Tuple<F32, F32>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyMultiply>>>;
|
||||
|
||||
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_compute_default_instances_p1(
|
||||
TGemmMulMulF8F8F16Instances& instances);
|
||||
|
||||
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_compute_default_instances_p2(
|
||||
TGemmMulMulF8F8F16Instances& instances);
|
||||
|
||||
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_default_instances(
|
||||
TGemmMulMulF8F8F16Instances& instances);
|
||||
|
||||
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_default_instances(
|
||||
TGemmMulMulF8F8F16Instances& instances);
|
||||
|
||||
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_default_instances(
|
||||
TGemmMulMulF8F8F16Instances& instances);
|
||||
|
||||
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p4_default_instances(
|
||||
TGemmMulMulF8F8F16Instances& instances);
|
||||
|
||||
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p5_default_instances(
|
||||
TGemmMulMulF8F8F16Instances& instances);
|
||||
|
||||
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_default_instances_v2(
|
||||
TGemmMulMulF8F8F16Instances& instances);
|
||||
|
||||
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_default_instances_v2(
|
||||
TGemmMulMulF8F8F16Instances& instances);
|
||||
|
||||
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_default_instances_v2(
|
||||
TGemmMulMulF8F8F16Instances& instances);
|
||||
|
||||
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p4_default_instances_v2(
|
||||
TGemmMulMulF8F8F16Instances& instances);
|
||||
|
||||
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p5_default_instances_v2(
|
||||
TGemmMulMulF8F8F16Instances& instances);
|
||||
|
||||
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instances_p1(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
|
||||
Col,
|
||||
Tuple<Row, Col>,
|
||||
Row,
|
||||
F8,
|
||||
F8,
|
||||
Tuple<F32, F32>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyMultiply>>>&
|
||||
instances);
|
||||
TGemmMulMulF8F8F16Instances& instances);
|
||||
|
||||
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instances_p2(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
|
||||
Col,
|
||||
Tuple<Row, Col>,
|
||||
Row,
|
||||
F8,
|
||||
F8,
|
||||
Tuple<F32, F32>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyMultiply>>>&
|
||||
instances);
|
||||
TGemmMulMulF8F8F16Instances& instances);
|
||||
|
||||
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instances_p3(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
|
||||
Col,
|
||||
Tuple<Row, Col>,
|
||||
Row,
|
||||
F8,
|
||||
F8,
|
||||
Tuple<F32, F32>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyMultiply>>>&
|
||||
instances);
|
||||
TGemmMulMulF8F8F16Instances& instances);
|
||||
|
||||
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instances_p4(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
|
||||
Col,
|
||||
Tuple<Row, Col>,
|
||||
Row,
|
||||
F8,
|
||||
F8,
|
||||
Tuple<F32, F32>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyMultiply>>>&
|
||||
instances);
|
||||
TGemmMulMulF8F8F16Instances& instances);
|
||||
|
||||
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instances_p5(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
|
||||
Col,
|
||||
Tuple<Row, Col>,
|
||||
Row,
|
||||
F8,
|
||||
F8,
|
||||
Tuple<F32, F32>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyMultiply>>>&
|
||||
instances);
|
||||
TGemmMulMulF8F8F16Instances& instances);
|
||||
|
||||
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instances_p6(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
|
||||
Col,
|
||||
Tuple<Row, Col>,
|
||||
Row,
|
||||
F8,
|
||||
F8,
|
||||
Tuple<F32, F32>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyMultiply>>>&
|
||||
instances);
|
||||
TGemmMulMulF8F8F16Instances& instances);
|
||||
#endif
|
||||
|
||||
#if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_FP8))
|
||||
using TGemmMulMulF8F8BF16Instances =
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
|
||||
Col,
|
||||
Tuple<Row, Col>,
|
||||
Row,
|
||||
F8,
|
||||
F8,
|
||||
Tuple<F32, F32>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyMultiply>>>;
|
||||
|
||||
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_compute_default_instances_p1(
|
||||
TGemmMulMulF8F8BF16Instances& instances);
|
||||
|
||||
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_compute_default_instances_p2(
|
||||
TGemmMulMulF8F8BF16Instances& instances);
|
||||
|
||||
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_default_instances(
|
||||
TGemmMulMulF8F8BF16Instances& instances);
|
||||
|
||||
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_default_instances(
|
||||
TGemmMulMulF8F8BF16Instances& instances);
|
||||
|
||||
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_default_instances(
|
||||
TGemmMulMulF8F8BF16Instances& instances);
|
||||
|
||||
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p4_default_instances(
|
||||
TGemmMulMulF8F8BF16Instances& instances);
|
||||
|
||||
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p5_default_instances(
|
||||
TGemmMulMulF8F8BF16Instances& instances);
|
||||
|
||||
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_default_instances_v2(
|
||||
TGemmMulMulF8F8BF16Instances& instances);
|
||||
|
||||
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_default_instances_v2(
|
||||
TGemmMulMulF8F8BF16Instances& instances);
|
||||
|
||||
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_default_instances_v2(
|
||||
TGemmMulMulF8F8BF16Instances& instances);
|
||||
|
||||
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p4_default_instances_v2(
|
||||
TGemmMulMulF8F8BF16Instances& instances);
|
||||
|
||||
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p5_default_instances_v2(
|
||||
TGemmMulMulF8F8BF16Instances& instances);
|
||||
|
||||
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma16x16_mn_compute_default_instances_p1(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
|
||||
Col,
|
||||
Tuple<Row, Col>,
|
||||
Row,
|
||||
F8,
|
||||
F8,
|
||||
Tuple<F32, F32>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyMultiply>>>&
|
||||
instances);
|
||||
TGemmMulMulF8F8BF16Instances& instances);
|
||||
|
||||
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma16x16_mn_compute_default_instances_p2(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
|
||||
Col,
|
||||
Tuple<Row, Col>,
|
||||
Row,
|
||||
F8,
|
||||
F8,
|
||||
Tuple<F32, F32>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyMultiply>>>&
|
||||
instances);
|
||||
TGemmMulMulF8F8BF16Instances& instances);
|
||||
|
||||
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma16x16_mn_compute_default_instances_p3(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
|
||||
Col,
|
||||
Tuple<Row, Col>,
|
||||
Row,
|
||||
F8,
|
||||
F8,
|
||||
Tuple<F32, F32>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyMultiply>>>&
|
||||
instances);
|
||||
TGemmMulMulF8F8BF16Instances& instances);
|
||||
|
||||
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma16x16_mn_compute_default_instances_p4(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
|
||||
Col,
|
||||
Tuple<Row, Col>,
|
||||
Row,
|
||||
F8,
|
||||
F8,
|
||||
Tuple<F32, F32>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyMultiply>>>&
|
||||
instances);
|
||||
TGemmMulMulF8F8BF16Instances& instances);
|
||||
|
||||
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma16x16_mn_compute_default_instances_p5(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
|
||||
Col,
|
||||
Tuple<Row, Col>,
|
||||
Row,
|
||||
F8,
|
||||
F8,
|
||||
Tuple<F32, F32>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyMultiply>>>&
|
||||
instances);
|
||||
TGemmMulMulF8F8BF16Instances& instances);
|
||||
|
||||
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma16x16_mn_compute_default_instances_p6(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
|
||||
Col,
|
||||
Tuple<Row, Col>,
|
||||
Row,
|
||||
F8,
|
||||
F8,
|
||||
Tuple<F32, F32>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyMultiply>>>&
|
||||
instances);
|
||||
TGemmMulMulF8F8BF16Instances& instances);
|
||||
|
||||
#endif
|
||||
|
||||
@@ -239,6 +212,31 @@ struct DeviceOperationInstanceFactory<
|
||||
op_ptrs);
|
||||
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instances_p6(
|
||||
op_ptrs);
|
||||
|
||||
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_compute_default_instances_p1(
|
||||
op_ptrs);
|
||||
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_compute_default_instances_p2(
|
||||
op_ptrs);
|
||||
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_default_instances(
|
||||
op_ptrs);
|
||||
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_default_instances(
|
||||
op_ptrs);
|
||||
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_default_instances(
|
||||
op_ptrs);
|
||||
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p4_default_instances(
|
||||
op_ptrs);
|
||||
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p5_default_instances(
|
||||
op_ptrs);
|
||||
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_default_instances_v2(
|
||||
op_ptrs);
|
||||
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_default_instances_v2(
|
||||
op_ptrs);
|
||||
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_default_instances_v2(
|
||||
op_ptrs);
|
||||
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p4_default_instances_v2(
|
||||
op_ptrs);
|
||||
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p5_default_instances_v2(
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
@@ -262,6 +260,31 @@ struct DeviceOperationInstanceFactory<
|
||||
op_ptrs);
|
||||
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma16x16_mn_compute_default_instances_p6(
|
||||
op_ptrs);
|
||||
|
||||
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_compute_default_instances_p1(
|
||||
op_ptrs);
|
||||
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_compute_default_instances_p2(
|
||||
op_ptrs);
|
||||
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_default_instances(
|
||||
op_ptrs);
|
||||
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_default_instances(
|
||||
op_ptrs);
|
||||
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_default_instances(
|
||||
op_ptrs);
|
||||
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p4_default_instances(
|
||||
op_ptrs);
|
||||
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p5_default_instances(
|
||||
op_ptrs);
|
||||
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_default_instances_v2(
|
||||
op_ptrs);
|
||||
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_default_instances_v2(
|
||||
op_ptrs);
|
||||
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_default_instances_v2(
|
||||
op_ptrs);
|
||||
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p4_default_instances_v2(
|
||||
op_ptrs);
|
||||
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p5_default_instances_v2(
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
Reference in New Issue
Block a user