Support b_scale: (#2350)

- extend pipeline v1 and v3
 - add instances
 - add tests
 - add example

Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
This commit is contained in:
Enrico Degregori
2025-07-25 03:49:58 +02:00
committed by GitHub
parent 2addf05b91
commit b01a27ff22
24 changed files with 3744 additions and 1660 deletions

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -17,6 +17,22 @@ namespace tensor_operation {
namespace device {
namespace instance {
#if(defined(CK_ENABLE_FP16) || defined(CK_ENABLE_FP8))
#ifdef CK_USE_WMMA
void add_device_gemm_b_scale_wmma_f16_i4_f16_mk_nk_mn_mem_v2_default_instances(
std::vector<std::unique_ptr<DeviceGemmV2BScale<Row,
Col,
Row,
F16,
I4,
F16,
F16,
1,
128,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_USE_XDL
void add_device_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn_mem_v2_default_instances(
std::vector<std::unique_ptr<DeviceGemmV2BScale<Row,
Col,
@@ -31,6 +47,7 @@ void add_device_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn_mem_v2_default_instances(
PassThrough,
PassThrough>>>& instances);
#endif
#endif
template <typename ADataType,
typename BDataType,
@@ -77,7 +94,12 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmV2
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<CLayout, Row>)
{
#ifdef CK_USE_WMMA
add_device_gemm_b_scale_wmma_f16_i4_f16_mk_nk_mn_mem_v2_default_instances(op_ptrs);
#endif
#ifdef CK_USE_XDL
add_device_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn_mem_v2_default_instances(op_ptrs);
#endif
}
}