Add support for mixed precision bf16&int8 grouped gemm (#1166)

* add support for mixed precision bf16&int8 grouped gemm

* fix gfx versions and add bf16 kbatch condition

* added reviewers comments
This commit is contained in:
jakpiase
2024-02-21 10:35:35 +01:00
committed by GitHub
parent 66736edb95
commit 32d4be3d09
10 changed files with 1159 additions and 19 deletions

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -97,6 +97,35 @@ void add_device_grouped_gemm_xdl_fixed_nk_f16_i8_f16_mk_nk_mn_instances(
PassThrough,
PassThrough>>>& instances);
// bf16_inputA i8_inputB
#if defined(CK_ENABLE_BF16) && defined(CK_ENABLE_INT8)
void add_device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmFixedNK<Row,
Row,
Empty_Tuple,
Row,
BF16,
I8,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_nk_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmFixedNK<Row,
Col,
Empty_Tuple,
Row,
BF16,
I8,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
template <typename ALayout,
typename BLayout,
typename ELayout,
@@ -180,6 +209,24 @@ struct DeviceOperationInstanceFactory<
}
}
// bf16_i8_input
#if defined(CK_ENABLE_BF16) && defined(CK_ENABLE_INT8)
if constexpr(is_same_v<ADataType, bhalf_t> && is_same_v<BDataType, int8_t> &&
is_same_v<EDataType, bhalf_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<ELayout, Row>)
{
add_device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_kn_mn_instances(op_ptrs);
}
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<ELayout, Row>)
{
add_device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_nk_mn_instances(op_ptrs);
}
}
#endif
return op_ptrs;
}
};