mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-13 01:36:06 +00:00
Support b_scale: (#2350)
- extend pipeline v1 and v3 - add instances - add tests - add example Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -17,6 +17,22 @@ namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
#if(defined(CK_ENABLE_FP16) || defined(CK_ENABLE_FP8))
|
||||
#ifdef CK_USE_WMMA
|
||||
void add_device_gemm_b_scale_wmma_f16_i4_f16_mk_nk_mn_mem_v2_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmV2BScale<Row,
|
||||
Col,
|
||||
Row,
|
||||
F16,
|
||||
I4,
|
||||
F16,
|
||||
F16,
|
||||
1,
|
||||
128,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_USE_XDL
|
||||
void add_device_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn_mem_v2_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmV2BScale<Row,
|
||||
Col,
|
||||
@@ -31,6 +47,7 @@ void add_device_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn_mem_v2_default_instances(
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#endif
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
@@ -77,7 +94,12 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmV2
|
||||
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
#ifdef CK_USE_WMMA
|
||||
add_device_gemm_b_scale_wmma_f16_i4_f16_mk_nk_mn_mem_v2_default_instances(op_ptrs);
|
||||
#endif
|
||||
#ifdef CK_USE_XDL
|
||||
add_device_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn_mem_v2_default_instances(op_ptrs);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user