mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 17:00:18 +00:00
Added bf16 instances grouped gemm fixed nk (#1825)
* Feat: Add bf16 input instances * feat: Add BF16 profiler code * fix: reorder enum types * fix: CI fail due to clang-format * fix: clang script format issue * fix: clang format broke cmakelist file
This commit is contained in:
@@ -126,6 +126,35 @@ void add_device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_nk_mn_instances(
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
// bf16_inputA bf16_inputB
|
||||
#if defined(CK_ENABLE_BF16)
|
||||
void add_device_grouped_gemm_xdl_fixed_nk_bf16_bf16_bf16_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemmFixedNK<Row,
|
||||
Row,
|
||||
Empty_Tuple,
|
||||
Row,
|
||||
BF16,
|
||||
BF16,
|
||||
Empty_Tuple,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_gemm_xdl_fixed_nk_bf16_bf16_bf16_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemmFixedNK<Row,
|
||||
Col,
|
||||
Empty_Tuple,
|
||||
Row,
|
||||
BF16,
|
||||
BF16,
|
||||
Empty_Tuple,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif // CK_ENABLE_BF16
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename ELayout,
|
||||
@@ -227,6 +256,24 @@ struct DeviceOperationInstanceFactory<
|
||||
}
|
||||
#endif
|
||||
|
||||
// bf16_inputA bf16_inputB
|
||||
#if defined(CK_ENABLE_BF16)
|
||||
if constexpr(is_same_v<ADataType, bhalf_t> && is_same_v<BDataType, bhalf_t> &&
|
||||
is_same_v<EDataType, bhalf_t>)
|
||||
{
|
||||
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<ELayout, Row>)
|
||||
{
|
||||
add_device_grouped_gemm_xdl_fixed_nk_bf16_bf16_bf16_mk_kn_mn_instances(op_ptrs);
|
||||
}
|
||||
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
|
||||
is_same_v<ELayout, Row>)
|
||||
{
|
||||
add_device_grouped_gemm_xdl_fixed_nk_bf16_bf16_bf16_mk_nk_mn_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif // CK_ENABLE_BF16
|
||||
|
||||
return op_ptrs;
|
||||
}
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user