implement device batched gemm b scale for wmma (#2825)

* rebased on top of develop

* fixed missing shuffeling and wrong indexing

* added tests for batched_b_scale

* added missing files

* fixed wrong stride computation and removed k batching (for now) due to precision issues

* reinstated k-batching with PRNG constrained to -1..1

* added specialization of GeneratorTensor_3 for int4 and fixed internal overflow

* added k-batching to reference and increased tolerances for test

* changed gemm_b_scale and gemm_universal tests to use correct parameters

* adressed review commentsd

* ported fixes back to non-batched version of b_scale

* adressed review comments

* run clang-format on older commits

* add type-conversion to AccDataType and then to CDataType to exactly mimic GPU's behavior

* added newline at end of file

* reflected changes from muitl-abd branch in batched b_scale

* fixed gfx11 issue

* changed range for pki4 to -1...1 (-0.5...0.5 never really made sense for i4 anyway and always should have caused compiler errors, but since there was no int4 specialization of GeneratorTensor3 until now, this passed

* run clang format

* set range of i4 generation to 0...1 for upstream tests to pass. This replicated previous behavior, which however means that it is NOT properly tested.

* reduced range for pk_i4 even further to 0..0

* removed failing xld instances. Failure now uncovered now that tests were fixed

* removed generation of int4 values entierly

* divide B buffer by BPackedSize

---------

Co-authored-by: Kevin Abraham <kevin.abraham@streamhpc.com>
This commit is contained in:
kabrahamAMD
2025-10-16 20:00:42 +02:00
committed by GitHub
parent d7278cc664
commit c4b2da9cbd
22 changed files with 1352 additions and 97 deletions

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -8,6 +8,7 @@
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include <stdexcept>
namespace ck {
namespace tensor_operation {
@@ -30,14 +31,18 @@ struct ReferenceBatchedGemm : public device::BaseOperator
Tensor<CDataType>& c_g_m_n,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
CElementwiseOperation c_element_op,
const int k_batch = 1)
: a_g_m_k_{a_g_m_k},
b_g_k_n_{b_g_k_n},
c_g_m_n_{c_g_m_n},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
c_element_op_{c_element_op}
c_element_op_{c_element_op},
k_batch_(k_batch)
{
if(k_batch < 1)
throw std::invalid_argument("Batch size must be at least 1");
}
const Tensor<ADataType>& a_g_m_k_;
@@ -47,6 +52,8 @@ struct ReferenceBatchedGemm : public device::BaseOperator
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_;
const int k_batch_;
};
// Invoker
@@ -59,23 +66,54 @@ struct ReferenceBatchedGemm : public device::BaseOperator
auto f_gmk_gkn_gmn = [&](auto g, auto m, auto n) {
const int K = arg.a_g_m_k_.mDesc.GetLengths()[2];
AccDataType v_acc = 0;
// simulate fp accuacy implications of k batching
std::vector<CDataType> partialSums(arg.k_batch_);
for(int k = 0; k < K; ++k)
for(int batchIdx = 0; batchIdx < arg.k_batch_; ++batchIdx)
{
ADataType v_a;
BDataType v_b;
int batchSize = std::max(K / arg.k_batch_, 1);
int batchStart = batchSize * batchIdx;
int batchEnd = batchSize * (batchIdx + 1);
// add any extra round-off to last batch
if(batchIdx == arg.k_batch_ - 1)
batchEnd = K;
arg.a_element_op_(v_a, arg.a_g_m_k_(g, m, k));
arg.b_element_op_(v_b, arg.b_g_k_n_(g, k, n));
AccDataType v_acc = 0;
for(int k = batchStart; k < batchEnd; ++k)
{
ADataType v_a;
BDataType v_b;
v_acc +=
ck::type_convert<AccDataType>(v_a) * ck::type_convert<AccDataType>(v_b);
arg.a_element_op_(v_a, arg.a_g_m_k_(g, m, k));
arg.b_element_op_(v_b, arg.b_g_k_n_(g, k, n));
v_acc +=
ck::type_convert<AccDataType>(v_a) * ck::type_convert<AccDataType>(v_b);
}
AccDataType v_c;
arg.c_element_op_(v_c, v_acc);
partialSums[batchIdx] = ck::type_convert<CDataType>(v_c);
}
AccDataType v_c;
arg.c_element_op_(v_c, v_acc);
// finally, sum up partial sums
// note that we can't simulate the random nature of atomic additions, but at least
// we can simulate the effect of partial sums
AccDataType v_c = 0;
if(arg.k_batch_ > 1)
{
for(int batchIdx = 0; batchIdx < arg.k_batch_; batchIdx++)
{
// mimic the way fp operations would be done on GPU for k-batching
v_c = ck::type_convert<AccDataType>(ck::type_convert<CDataType>(
ck::type_convert<AccDataType>(v_c) +
ck::type_convert<AccDataType>(partialSums[batchIdx])));
}
}
else
{
v_c = ck::type_convert<AccDataType>(partialSums[0]);
}
arg.c_g_m_n_(g, m, n) = ck::type_convert<CDataType>(v_c);
};
@@ -108,9 +146,11 @@ struct ReferenceBatchedGemm : public device::BaseOperator
Tensor<CDataType>& c_g_m_n,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
CElementwiseOperation c_element_op,
const int k_batch = 1)
{
return Argument{a_g_m_k, b_g_k_n, c_g_m_n, a_element_op, b_element_op, c_element_op};
return Argument{
a_g_m_k, b_g_k_n, c_g_m_n, a_element_op, b_element_op, c_element_op, k_batch};
}
static auto MakeInvoker() { return Invoker{}; }

View File

@@ -5,6 +5,8 @@
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl_fpAintB_b_scale.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3_b_scale.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include <memory>
@@ -16,6 +18,8 @@ namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
#if defined(CK_USE_XDL)
#if(defined(CK_ENABLE_FP16) || defined(CK_ENABLE_FP8))
void add_device_batched_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn_mem_v2_default_instances(
std::vector<std::unique_ptr<DeviceBatchedGemmV2BScale<Row,
@@ -31,6 +35,25 @@ void add_device_batched_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn_mem_v2_default_inst
PassThrough,
PassThrough>>>& instances);
#endif
#endif // CK_USE_XDL
#if defined(CK_USE_WMMA)
#if(defined(CK_ENABLE_FP16) || defined(CK_ENABLE_FP8)) // TODO: really, or?
void add_device_batched_gemm_b_scale_wmma_f16_i4_f16_mk_nk_mn_mem_default_instances(
std::vector<std::unique_ptr<DeviceBatchedGemmV2BScale<Row,
Col,
Row,
F16,
I4,
F16,
F16,
1,
128,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif // CK_ENABLE_FP16 || CK_ENABLE_FP8
#endif // CK_USE_WMMA
template <typename ADataType,
typename BDataType,
@@ -40,6 +63,7 @@ template <typename ADataType,
typename BLayout,
typename CLayout,
index_t ScaleBlockK>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceBatchedGemmV2BScale<
ALayout,
BLayout,
@@ -77,8 +101,14 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceBatche
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<CLayout, Row>)
{
#if defined(CK_USE_XDL)
add_device_batched_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn_mem_v2_default_instances(
op_ptrs);
#endif // CK_USE_XDL
#if defined(CK_USE_WMMA)
add_device_batched_gemm_b_scale_wmma_f16_i4_f16_mk_nk_mn_mem_default_instances(
op_ptrs);
#endif // CK_USE_WMMA
}
}

View File

@@ -1,10 +1,13 @@
# ONLY XDL_KERNELS
# ONLY XDL_AND_WMMA_KERNELS
set(BATCHED_GEMM_B_SCALE_INSTANCES)
list(APPEND BATCHED_GEMM_B_SCALE_INSTANCES
device_batched_gemm_b_scale_xdl_f16_i4_f16/device_batched_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn_mem_v2_default_instance.cpp
device_batched_gemm_b_scale_wmma_f16_i4_f16/device_batched_gemm_b_scale_wmma_f16_i4_f16_mk_nk_mn_mem_default_instance.cpp
)
set_source_files_properties(device_batched_gemm_b_scale_xdl_f16_i4_f16/device_batched_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn_mem_v2_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
set_source_files_properties(device_batched_gemm_b_scale_wmma_f16_i4_f16/device_batched_gemm_b_scale_wmma_f16_i4_f16_mk_nk_mn_mem_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
add_instance_library(device_batched_gemm_b_scale_instance ${BATCHED_GEMM_B_SCALE_INSTANCES})

View File

@@ -0,0 +1,72 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3_b_scale.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using I4 = pk_i4_t;
using F16 = half_t;
using F32 = float;
using Row = tensor_layout::gemm::RowMajor;
using Col = tensor_layout::gemm::ColumnMajor;
template <index_t... Is>
using S = Sequence<Is...>;
using PassThrough = element_wise::PassThrough;
static constexpr auto GemmDefault = GemmSpecialization::Default;
static constexpr auto GemmKPadding = GemmSpecialization::KPadding;
static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding;
static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding;
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
template <BlockGemmPipelineScheduler BlkGemmPipeSched, GemmSpecialization GemmSpec>
using device_batched_gemm_b_scale_wmma_f16_i4_f16_mk_nk_mn_mem_instances = std::tuple<
// clang-format off
//################################| ALayout| BLayout| CLayout|AData|BData| BScale| CData| AccData| Cshuffle| A| B| C| GEMM| Block| Scale| Scale| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| Compute| Compute| PermuteA| PermuteB|
//################################| | | | Type| Type| Data| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| Block| Block| | |Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock| ScalarPerVector| Pipeline| Pipeline| TypeA| TypeB| | |
//################################| | | | | | Type| | | | Operation| Operation| Operation| | | N| K| | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| Scheduler| Verision| | | | |
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceBatchedGemm_Wmma_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 64, 8, 8, 16, 16, 4, 2, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //0
DeviceBatchedGemm_Wmma_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 64, 8, 8, 16, 16, 4, 2, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1, half_t, half_t, false, false>, //1
DeviceBatchedGemm_Wmma_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 64, 8, 8, 16, 16, 4, 2, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, Interwave, BlockGemmPipelineVersion::v1, half_t, half_t, false, false>, //2
DeviceBatchedGemm_Wmma_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 64, 64, 64, 8, 8, 16, 16, 2, 2, S< 4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //3
DeviceBatchedGemm_Wmma_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 64, 64, 64, 8, 8, 16, 16, 2, 2, S< 4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1, half_t, half_t, false, false>, //4
DeviceBatchedGemm_Wmma_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 64, 64, 64, 8, 8, 16, 16, 2, 2, S< 4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 8, Interwave, BlockGemmPipelineVersion::v1, half_t, half_t, false, false>, //5
DeviceBatchedGemm_Wmma_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 1, 128, 32, 32, 64, 8, 8, 16, 16, 2, 2, S< 2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //6
DeviceBatchedGemm_Wmma_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 1, 128, 32, 32, 64, 8, 8, 16, 16, 2, 2, S< 2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v1, half_t, half_t, false, false>, //7
DeviceBatchedGemm_Wmma_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 1, 128, 32, 32, 64, 8, 8, 16, 16, 2, 2, S< 2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 2>, 8, Interwave, BlockGemmPipelineVersion::v1, half_t, half_t, false, false>, //8
DeviceBatchedGemm_Wmma_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 128, 8, 8, 16, 16, 4, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //9
DeviceBatchedGemm_Wmma_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 128, 8, 8, 16, 16, 4, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1, half_t, half_t, false, false>, //10
DeviceBatchedGemm_Wmma_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 128, 8, 8, 16, 16, 4, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, Interwave, BlockGemmPipelineVersion::v1, half_t, half_t, false, false>, //11
DeviceBatchedGemm_Wmma_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 32, 32, 128, 8, 8, 16, 16, 1, 1, S< 8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, Intrawave, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //12
DeviceBatchedGemm_Wmma_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 32, 32, 128, 8, 8, 16, 16, 1, 1, S< 8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, Intrawave, BlockGemmPipelineVersion::v1, half_t, half_t, false, false>, //13
DeviceBatchedGemm_Wmma_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 32, 32, 128, 8, 8, 16, 16, 1, 1, S< 8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, Interwave, BlockGemmPipelineVersion::v1, half_t, half_t, false, false>, //14
DeviceBatchedGemm_Wmma_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 1, 128, 16, 16, 128, 8, 8, 16, 16, 1, 1, S< 2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //15
DeviceBatchedGemm_Wmma_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 1, 128, 16, 16, 128, 8, 8, 16, 16, 1, 1, S< 2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v1, half_t, half_t, false, false>, //16
DeviceBatchedGemm_Wmma_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 1, 128, 16, 16, 128, 8, 8, 16, 16, 1, 1, S< 2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 2>, 8, Interwave, BlockGemmPipelineVersion::v1, half_t, half_t, false, false>, //17
DeviceBatchedGemm_Wmma_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 1, 128, 16, 16, 128, 8, 8, 16, 16, 1, 1, S< 4, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 4, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 2>, 4, Intrawave, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //18
DeviceBatchedGemm_Wmma_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 1, 128, 16, 16, 128, 8, 8, 16, 16, 1, 1, S< 4, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 4, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 2>, 4, Intrawave, BlockGemmPipelineVersion::v1, half_t, half_t, false, false>, //19
DeviceBatchedGemm_Wmma_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 1, 128, 16, 16, 128, 8, 8, 16, 16, 1, 1, S< 4, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 4, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 2>, 4, Interwave, BlockGemmPipelineVersion::v1, half_t, half_t, false, false> //20
// clang-format on
>;
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,33 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "device_batched_gemm_b_scale_wmma_f16_i4_f16_mk_nk_mn.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_batched_gemm_b_scale_wmma_f16_i4_f16_mk_nk_mn_mem_default_instances(
std::vector<std::unique_ptr<DeviceBatchedGemmV2BScale<Row,
Col,
Row,
F16,
I4,
F16,
F16,
1,
128,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances,
device_batched_gemm_b_scale_wmma_f16_i4_f16_mk_nk_mn_mem_instances<Intrawave,
GemmDefault>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -45,9 +45,6 @@ using device_batched_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn_mem_instances = std::t
DeviceBatchedGemm_Xdl_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 128, 8, 32, 32, 32, 2, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //0
DeviceBatchedGemm_Xdl_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 64, 8, 32, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v4, half_t, half_t, false, false>, //1
DeviceBatchedGemm_Xdl_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 64, 8, 32, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v4, half_t, half_t, false, false>, //3
DeviceBatchedGemm_Xdl_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 64, 8, 32, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //4
//Latency friendly
DeviceBatchedGemm_Xdl_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 32, 16, 128, 8, 16, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //5
DeviceBatchedGemm_Xdl_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 1, 128, 16, 16, 128, 8, 16, 16, 16, 1, 1, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //6

View File

@@ -51,9 +51,6 @@ using device_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn_mem_instances = std::tuple<
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 128, 8, 32, 32, 32, 2, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //0
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 64, 8, 32, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v4, half_t, half_t, false, false>, //1
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 64, 8, 32, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v4, half_t, half_t, false, false>, //3
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 64, 8, 32, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //4
//Latency friendly
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 32, 16, 128, 8, 16, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //5
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 1, 128, 16, 16, 128, 8, 16, 16, 16, 1, 1, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //6