Merge commit '161835533becff72c71d20eff1e907a702820252' into develop

This commit is contained in:
assistant-librarian[bot]
2025-12-03 17:14:58 +00:00
parent 0cb95a3e70
commit 834ee396bb
30 changed files with 2482 additions and 86 deletions

View File

@@ -31,6 +31,7 @@ using TGemmMulMulF8F8F16Instances =
PassThrough,
MultiplyMultiply>>>;
#ifdef CK_USE_XDL
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_compute_default_instances_p1(
TGemmMulMulF8F8F16Instances& instances);
@@ -86,6 +87,21 @@ void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma16
TGemmMulMulF8F8F16Instances& instances);
#endif
#ifdef CK_USE_WMMA
void add_device_gemm_multiply_multiply_weight_preshuffle_wmma_f8_f8_f16_mk_wmma_mn_default_instances_p1(
TGemmMulMulF8F8F16Instances& instances);
void add_device_gemm_multiply_multiply_weight_preshuffle_wmma_f8_f8_f16_mk_wmma_mn_default_instances_p2(
TGemmMulMulF8F8F16Instances& instances);
void add_device_gemm_multiply_multiply_weight_preshuffle_wmma_f8_f8_f16_mk_wmma_mn_default_instances_p3(
TGemmMulMulF8F8F16Instances& instances);
void add_device_gemm_multiply_multiply_weight_preshuffle_wmma_f8_f8_f16_mk_wmma_mn_default_instances_p4(
TGemmMulMulF8F8F16Instances& instances);
#endif
#endif
#if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_FP8))
using TGemmMulMulF8F8BF16Instances =
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
@@ -100,6 +116,7 @@ using TGemmMulMulF8F8BF16Instances =
PassThrough,
MultiplyMultiply>>>;
#ifdef CK_USE_XDL
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_compute_default_instances_p1(
TGemmMulMulF8F8BF16Instances& instances);
@@ -153,7 +170,21 @@ void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma1
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma16x16_mn_compute_default_instances_p6(
TGemmMulMulF8F8BF16Instances& instances);
#endif
#ifdef CK_USE_WMMA
void add_device_gemm_multiply_multiply_weight_preshuffle_wmma_f8_f8_bf16_mk_wmma_mn_default_instances_p1(
TGemmMulMulF8F8BF16Instances& instances);
void add_device_gemm_multiply_multiply_weight_preshuffle_wmma_f8_f8_bf16_mk_wmma_mn_default_instances_p2(
TGemmMulMulF8F8BF16Instances& instances);
void add_device_gemm_multiply_multiply_weight_preshuffle_wmma_f8_f8_bf16_mk_wmma_mn_default_instances_p3(
TGemmMulMulF8F8BF16Instances& instances);
void add_device_gemm_multiply_multiply_weight_preshuffle_wmma_f8_f8_bf16_mk_wmma_mn_default_instances_p4(
TGemmMulMulF8F8BF16Instances& instances);
#endif
#endif
template <typename ADataType,
@@ -200,6 +231,7 @@ struct DeviceOperationInstanceFactory<
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<CLayout, Row>)
{
#ifdef CK_USE_XDL
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instances_p1(
op_ptrs);
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instances_p2(
@@ -237,6 +269,17 @@ struct DeviceOperationInstanceFactory<
op_ptrs);
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p5_default_instances_v2(
op_ptrs);
#endif
#ifdef CK_USE_WMMA
add_device_gemm_multiply_multiply_weight_preshuffle_wmma_f8_f8_f16_mk_wmma_mn_default_instances_p1(
op_ptrs);
add_device_gemm_multiply_multiply_weight_preshuffle_wmma_f8_f8_f16_mk_wmma_mn_default_instances_p2(
op_ptrs);
add_device_gemm_multiply_multiply_weight_preshuffle_wmma_f8_f8_f16_mk_wmma_mn_default_instances_p3(
op_ptrs);
add_device_gemm_multiply_multiply_weight_preshuffle_wmma_f8_f8_f16_mk_wmma_mn_default_instances_p4(
op_ptrs);
#endif
}
}
#endif
@@ -248,6 +291,7 @@ struct DeviceOperationInstanceFactory<
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<CLayout, Row>)
{
#ifdef CK_USE_XDL
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma16x16_mn_compute_default_instances_p1(
op_ptrs);
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma16x16_mn_compute_default_instances_p2(
@@ -285,6 +329,17 @@ struct DeviceOperationInstanceFactory<
op_ptrs);
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p5_default_instances_v2(
op_ptrs);
#endif
#ifdef CK_USE_WMMA
add_device_gemm_multiply_multiply_weight_preshuffle_wmma_f8_f8_bf16_mk_wmma_mn_default_instances_p1(
op_ptrs);
add_device_gemm_multiply_multiply_weight_preshuffle_wmma_f8_f8_bf16_mk_wmma_mn_default_instances_p2(
op_ptrs);
add_device_gemm_multiply_multiply_weight_preshuffle_wmma_f8_f8_bf16_mk_wmma_mn_default_instances_p3(
op_ptrs);
add_device_gemm_multiply_multiply_weight_preshuffle_wmma_f8_f8_bf16_mk_wmma_mn_default_instances_p4(
op_ptrs);
#endif
}
}
#endif