implement device batched gemm b scale for wmma (#2825)

* rebased on top of develop

* fixed missing shuffeling and wrong indexing

* added tests for batched_b_scale

* added missing files

* fixed wrong stride computation and removed k batching (for now) due to precision issues

* reinstated k-batching with PRNG constrained to -1..1

* added specialization of GeneratorTensor_3 for int4 and fixed internal overflow

* added k-batching to reference and increased tolerances for test

* changed gemm_b_scale and gemm_universal tests to use correct parameters

* adressed review commentsd

* ported fixes back to non-batched version of b_scale

* adressed review comments

* run clang-format on older commits

* add type-conversion to AccDataType and then to CDataType to exactly mimic GPU's behavior

* added newline at end of file

* reflected changes from muitl-abd branch in batched b_scale

* fixed gfx11 issue

* changed range for pki4 to -1...1 (-0.5...0.5 never really made sense for i4 anyway and always should have caused compiler errors, but since there was no int4 specialization of GeneratorTensor3 until now, this passed

* run clang format

* set range of i4 generation to 0...1 for upstream tests to pass. This replicated previous behavior, which however means that it is NOT properly tested.

* reduced range for pk_i4 even further to 0..0

* removed failing xld instances. Failure now uncovered now that tests were fixed

* removed generation of int4 values entierly

* divide B buffer by BPackedSize

---------

Co-authored-by: Kevin Abraham <kevin.abraham@streamhpc.com>

[ROCm/composable_kernel commit: c4b2da9cbd]
This commit is contained in:
kabrahamAMD
2025-10-16 20:00:42 +02:00
committed by GitHub
parent 62afd9eb14
commit 06d76b160e
22 changed files with 1352 additions and 97 deletions

View File

@@ -289,7 +289,6 @@ int main(int argc, char* argv[])
case 0: break;
case 1:
a0_t_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
d0_t_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{0.0, 1.0});
d1_e_n.GenerateTensorValue(GeneratorTensor_3<D1DataType>{0.0, 1.0});
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0.0, 1.0});
@@ -303,7 +302,6 @@ int main(int argc, char* argv[])
break;
default:
a0_t_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
d0_t_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{0.0, 1.0});
d1_e_n.GenerateTensorValue(GeneratorTensor_3<D1DataType>{0.0, 1.0});
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0.0, 1.0});

View File

@@ -275,7 +275,7 @@ int main(int argc, char* argv[])
break;
case 3:
a0_t_k_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-1, 1});
d0_t_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{0.0, 1.0});
d1_e_n.GenerateTensorValue(GeneratorTensor_3<D1DataType>{0.0, 1.0});
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0.0, 1.0});
@@ -289,7 +289,7 @@ int main(int argc, char* argv[])
break;
default:
a0_t_k_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-1, 1});
d0_t_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{0.0, 1.0});
d1_e_n.GenerateTensorValue(GeneratorTensor_3<D1DataType>{0.0, 1.0});
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0.0, 1.0});

View File

@@ -264,7 +264,7 @@ struct GeneratorTensor_2<ck::pk_i4_t>
{
int hi = std::rand() % (max_value - min_value) + min_value + 8;
int lo = std::rand() % (max_value - min_value) + min_value + 8;
ck::pk_i4_t r = ((hi << 4) + lo) & 0xff;
ck::pk_i4_t r = (((hi & 0xf) << 4) + (lo & 0xf));
return r;
}
};
@@ -436,6 +436,22 @@ struct GeneratorTensor_3<ck::f4x2_pk_t>
}
};
template <>
struct GeneratorTensor_3<ck::pk_i4_t>
{
int min_value = 0;
int max_value = 1;
template <typename... Is>
ck::pk_i4_t operator()(Is...)
{
int hi = std::rand() % (max_value - min_value) + min_value + 8;
int lo = std::rand() % (max_value - min_value) + min_value + 8;
ck::pk_i4_t r = (((hi & 0xf) << 4) + (lo & 0xf));
return r;
}
};
template <>
struct GeneratorTensor_3<ck::f6x32_pk_t>
{

View File

@@ -0,0 +1,836 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/flush_cache.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename GridwiseGemm,
typename ComputePtrOffsetOfStridedBatch,
bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
index_t MinimumOccupancy = 1,
TailNumber TailNum = TailNumber::Full>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
#endif
kernel_batched_gemm_b_scale_wmma_cshuffle_v3(
typename GridwiseGemm::Argument karg, // This works for now but it actually receives a
// DeviceBatchedGemm_Wmma_CShuffleV3::Argument
// argument through implicit conversion to base class!
const ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch)
{
#if(defined(__gfx11__) || defined(__gfx12__))
#if defined(__gfx11__)
// gfx11 does not support *_atomic_pk_add_f16/bf16 instructions
using c_data_type = remove_cvref_t<remove_pointer_t<decltype(karg.p_e_grid)>>;
if constexpr(!(CGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd &&
(std::is_same_v<c_data_type, ck::half_t> ||
std::is_same_v<c_data_type, ck::bhalf_t>)))
{
#endif
// The normal approach to batching would be to increase the grid size by just stretching out
// the grid Z dimension (which is the outermost dimension), but this depends on lower level
// functions not directly using the Z dimension for other calculations. As it turns out, k
// batching does rely directly on blockIdx.Z through SplitKBatchOffset. Therefore, for now
// we will use the grid Y dimension for batching. This may be a bit fragile.
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t g_idx = amd_wave_read_first_lane(blockIdx.y);
const long_index_t a_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx));
const long_index_t b_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx));
const long_index_t c_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx));
const long_index_t b_scale_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetScaleBPtrOffset(g_idx));
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
// shift A matrices pointer for splitk
typename GridwiseGemm::AsGridPointer p_as_grid_shift;
static_for<0, GridwiseGemm::NumATensor, 1>{}([&](auto i) {
using ADataType_ =
remove_cvref_t<tuple_element_t<i.value, typename GridwiseGemm::AsDataType_>>;
p_as_grid_shift(i) = static_cast<const ADataType_*>(karg.p_as_grid[i]) +
splitk_batch_offset.a_k_split_offset[i] + a_batch_offset;
});
// shift B matrices pointer for splitk
typename GridwiseGemm::BsGridPointer p_bs_grid_shift;
static_for<0, GridwiseGemm::NumBTensor, 1>{}([&](auto i) {
using BDataType_ =
remove_cvref_t<tuple_element_t<i.value, typename GridwiseGemm::BsDataType_>>;
p_bs_grid_shift(i) = static_cast<const BDataType_*>(karg.p_bs_grid[i]) +
splitk_batch_offset.b_k_split_offset[i] + b_batch_offset;
});
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
p_as_grid_shift,
p_bs_grid_shift,
karg.p_ds_grid,
karg.p_e_grid + splitk_batch_offset.c_reduce_offset + c_batch_offset,
karg.p_b_scale_grid + b_scale_batch_offset + splitk_batch_offset.scale_k_split_offset,
p_shared,
karg,
karg.a_element_op,
karg.b_element_op,
karg.cde_element_op);
#if defined(__gfx11__)
}
#endif
#else
ignore = karg;
ignore = compute_ptr_offset_of_batch;
#endif
}
/// @brief \"Universal\" Batched GEMM operation without SplitK support.
///
/// @par Overview
/// This GEMM operation implements the following mathematical equation:
/// C{G,M,N} = C_op(A_op(A{G,M,K}) * B_op(B{G,K,N}))
/// Where A, B are input tensors and C is the output tensor. The A/B/C_op are
/// elementwise operations applied to the A, B, and C tensors, respectively.
/// The \"universal\" gemm comes with multiple pipelines optimized for different usage
/// scenarios. That's why it's called \"universal\". It's universal through its design
/// and versatilty.
///
/// @note This Kernel implementation currently does not support the SplitK algorithm.
///
/// @tparam ALayout A tensor data layout.
/// @tparam BLayout B tensor data layout.
/// @tparam CLayout C tensor data layout.
/// @tparam ADataType A tensor data type.
/// @tparam BDataType B tensor data type.
/// @tparam CDataType C tensor data type.
/// @tparam AccDataType The accumulation data type related to the hardware
/// matrix-multiplication instruction.
/// @tparam CShuffleDataType The data type used to store matrix-multiplication results into
/// LDS memory during \"CShuffle\" data layout optimization.
/// @tparam AElementwiseOperation Elementwise operation applied to the A input tensor elements.
/// @tparam BElementwiseOperation Elementwise operation applied to the B input tensor elements.
/// @tparam CElementwiseOperation Elementwise operation applied to the C output tensor
/// (after GEMM).
/// @tparam GemmSpec Determines used "padding" version.
/// @tparam BlockSize The number of threads within workgroup.
/// @tparam MPerBlock The input/output data tile size in the M dimension.
/// @tparam NPerBlock The input/output data tile size in the N dimension.
/// @tparam KPerBlock The input data tile size in the K dimension.
/// @tparam AK1 The vector load size from global memory for A tensor.
/// @tparam BK1 The vector load size from global memory for B tensor.
/// @tparam MPerWmma M size of Wave Matrix Multiply Accumulate (WMMA) instruction.
/// @tparam NPerWmma N size of Wave Matrix Multiply Accumulate (WMMA) instruction.
/// @tparam MRepeat The number of iterations in the M dimension over output tile per wavefront.
/// @tparam NRepeat The number of iterations in the N dimension over output tile per wavefront.
/// @tparam ABlockTransferThreadClusterLengths_AK0_M_AK1 Spatial thread distribution over the input
/// data. Can be interpreted as the answer
/// to the question, "How many threads can be
/// arranged on each input data axis?"
/// @tparam ABlockTransferThreadClusterArrangeOrder The order of thread spatial distribution over
/// the input tensor dimension. Can be interpreted
/// as the answer to the question: "In which
/// order to spread threads through tensor axes?".
/// @tparam ABlockTransferSrcAccessOrder The order of accessing input tensor axes. Can be
/// interpreted as the answer to the question "Which dimension
/// to read first? And which next?" etc.
/// @tparam ABlockTransferSrcVectorDim The index of axis on which we could do vectorized memory
/// access - the one with contiguous memory.
/// @tparam ABlockTransferSrcScalarPerVector The size of vector access instruction - the number of
/// elements accessed per thread per instruction.
/// @tparam ABlockTransferDstScalarPerVector_AK1 The size of vectorized store into LDS memory.
/// @tparam ABlockLdsExtraM Whether to use padding for LDS or not. With
/// universal GEMM there's no need for padding.
/// @tparam BBlockTransferThreadClusterLengths_BK0_N_BK1 Spatial thread distribution over the input
/// data. Can be interpreted as the answer
/// to the question: "How many threads to
/// arrange on each input data axis?"
/// @tparam BBlockTransferThreadClusterArrangeOrder The order of thread spatial distribution over
/// the input tensor dimension. Can be interpreted
/// as the answer to the question: "In which
/// order to spread threads through tensor axes?".
/// @tparam BBlockTransferSrcAccessOrder he order of accessing input tensor axes. Can be
/// interpreted as the answer to the question "Which dimension
/// to read first? And which next?" etc.
/// @tparam BBlockTransferSrcVectorDim The index of axis on which we could do vectorized memory
/// access - the one with contiguous memory.
/// @tparam BBlockTransferSrcScalarPerVector The size of vector access instruction - the number of
/// elements accessed per thread per instruction.
/// @tparam BBlockTransferDstScalarPerVector_BK1 The size of vectorized store into LDS memory.
/// @tparam BBlockLdsExtraN Whether to use padding for LDS or not. With
/// universal GEMM there's no need for padding.
/// @tparam CShuffleMRepeatPerShuffle The number of matrix-multiplication instructions
/// results to process per wave per iteration of CShuffle
/// in M dimension.
/// @tparam CShuffleNRepeatPerShuffle The number of matrix-multiplication instructions
/// results to process per wave per iteration of CShuffle
/// in N dimension.
/// @tparam CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock The spatial
/// thread distribution used for storing data into output
/// tensor across output data layout dimensions.
/// @tparam CShuffleBlockTransferScalarPerVector_NPerBlock The size of vectorized memory access.
/// Used when storing data to output tensor.
/// @tparam BlkGemmPipeSched The version of blockwise-gemm pipeline scheduler (interwave or
/// intrawave).
/// @tparam BlkGemmPipelineVer The version of blockwise-gemm pipeline.
/// @tparam ComputeTypeA Data type used for A input of hardware matrix-multiplication
/// instructions.
/// @tparam ComputeTypeB Data type used for B input of hardware matrix-multiplication
/// instructions.
/// @tparam PermuteA Whether the A input tensor has gridwise-gemm friendly data layout
/// in global memory. Currently not supported!
/// @tparam PermuteB Whether the B input tensor has gridwise-gemm friendly data layout
/// in global memory (pre-shuffled). Currently not supported!
template <typename ALayout,
typename BLayout,
typename CLayout,
typename ADataType,
typename BDataType,
typename BScaleDataType,
typename CDataType,
typename AccDataType,
typename CShuffleDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
GemmSpecialization GemmSpec,
index_t BlockSize,
index_t ScaleBlockN, // scale block for N
index_t ScaleBlockK, // scale block for K
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t AK1,
index_t BK1,
index_t MPerWmma,
index_t NPerWmma,
index_t MRepeat,
index_t NRepeat,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_AK1,
bool ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_BK1,
bool BBlockLdsExtraN,
index_t CShuffleMRepeatPerShuffle,
index_t CShuffleNRepeatPerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1,
typename ComputeTypeA = CDataType,
typename ComputeTypeB = ComputeTypeA,
bool PermuteA = false,
bool PermuteB = false>
struct DeviceBatchedGemm_Wmma_CShuffleV3_BScale
: public DeviceBatchedGemmV2BScale<ALayout,
BLayout,
CLayout,
ADataType,
BDataType,
BScaleDataType,
CDataType,
ScaleBlockN,
ScaleBlockK,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>
{
// We are inheriting from DeviceBatchedGemm and this base class does not support permuteA and
// permuteB arguments so for now we are not including this functionality.
static_assert(PermuteA == false,
"Permute A functionality not supported by DeviceBatchedGemm operations.\n");
static_assert(PermuteB == false,
"Permute B functionality not supported by DeviceBatchedGemm operations.\n");
struct ComputePtrOffsetOfStridedBatch
{
ComputePtrOffsetOfStridedBatch(index_t BatchStrideA,
index_t BatchStrideB,
index_t BatchStrideC,
index_t BatchStrideScaleB)
: BatchStrideA_(BatchStrideA),
BatchStrideB_(BatchStrideB),
BatchStrideC_(BatchStrideC),
BatchStrideScaleB_(BatchStrideScaleB)
{
}
__host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideA_);
}
__host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideB_) / GridwiseGemm::BPackedSize;
}
__host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideC_);
}
__host__ __device__ constexpr long_index_t GetScaleBPtrOffset(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideScaleB_);
}
private:
index_t BatchStrideA_;
index_t BatchStrideB_;
index_t BatchStrideC_;
index_t BatchStrideScaleB_;
};
// GridwiseGemm
using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3_b_scale<
ALayout,
BLayout,
Tuple<>, // DsLayout
CLayout,
Tuple<ADataType>,
Tuple<BDataType>,
BScaleDataType,
AccDataType,
CShuffleDataType,
Tuple<>, // DsDataType
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
GemmSpec,
BlockSize,
ScaleBlockN,
ScaleBlockK,
MPerBlock,
NPerBlock,
KPerBlock,
AK1,
BK1,
MPerWmma,
NPerWmma,
MRepeat,
NRepeat,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
false,
ABlockLdsExtraM,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
false,
BBlockLdsExtraN,
CShuffleMRepeatPerShuffle,
CShuffleNRepeatPerShuffle,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<CShuffleBlockTransferScalarPerVector_NPerBlock>,
BlkGemmPipeSched,
BlkGemmPipelineVer,
ComputeTypeA,
ComputeTypeB,
PermuteA, // PermuteA not supported by DeviceBatchedGemm base class.
PermuteB>; // PermuteB not supported by DeviceBatchedGemm base class.
// Argument
struct Argument : public GridwiseGemm::Argument
{
__host__ Argument(const ADataType* p_a_grid_,
const BDataType* p_b_grid_,
CDataType* p_c_grid_,
index_t M_,
index_t N_,
index_t K_,
index_t StrideA_,
index_t StrideB_,
index_t StrideC_,
index_t StrideScaleB_,
index_t BatchStrideA_,
index_t BatchStrideB_,
index_t BatchStrideC_,
index_t BatchStrideScaleB_,
const BScaleDataType* p_b_scale_grid_,
index_t Batch_,
index_t k_batch_,
AElementwiseOperation a_element_op_,
BElementwiseOperation b_element_op_,
CElementwiseOperation c_element_op_,
bool is_reduce_ = false)
: GridwiseGemm::Argument(std::array<const void*, 1>{p_a_grid_},
std::array<const void*, 1>{p_b_grid_},
std::array<const void*, 0>{}, // p_ds_grid_
p_c_grid_,
M_,
N_,
K_,
std::array<index_t, 1>{StrideA_},
std::array<index_t, 1>{StrideB_},
std::array<index_t, 0>{}, // StrideDs_
StrideC_,
StrideScaleB_,
p_b_scale_grid_,
k_batch_,
a_element_op_,
b_element_op_,
c_element_op_,
is_reduce_),
Batch(Batch_),
compute_ptr_offset_of_batch{
BatchStrideA_, BatchStrideB_, BatchStrideC_, BatchStrideScaleB_}
{
}
index_t Batch;
ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch;
};
/// @brief Helper structure responsible for kernel invocation.
///
/// @paragraph The `Invoker` class is responsible for preparation and invocation of actual GPU
/// kernel function. It usually determines the launched grid size prepares kernel
/// arguments as well as perform specific kernel configuration selection based on
/// runtime arguments.
///
/// @note If appropriately configured it may measure kernel execution time.
///
struct Invoker : public BaseInvoker
{
/// @brief This function issues GPU kernel execution.
/// @param arg The GPU kernel arguments.
/// @param stream_config The HIP stream configuration helper structure.
/// @return The kernel's average execution time (if time measurement is
/// enabled).
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
if(stream_config.log_level_ > 0)
{
arg.Print();
GridwiseGemm::BlockwiseGemmPipe::HotLoopInstList::Print();
}
if(!GridwiseGemm::CheckValidity(arg))
{
throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
}
index_t gdx, gdy, gdz;
std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.KBatch);
// The normal approach to batching would be to increase the grid size by just stretching
// out the grid Z dimension (which is the outermost dimension), but this depends on
// lower level functions not directly using the Z dimension for other calculations. As
// it turns out, k batching does rely directly on blockIdx.Z through SplitKBatchOffset.
// Therefore, for now we will use the grid Y dimension for batching. This may be a bit
// fragile.
gdy *= arg.Batch;
float ave_time = 0;
index_t k_grain = arg.KBatch * KPerBlock;
index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock;
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
const auto Run = [&](const auto& kernel) {
if(stream_config.flush_cache)
{
Argument arg_ = arg;
const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAsGridDescriptor_AK0_M_AK1(
arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideAs, arg_.AK0);
const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBsGridDescriptor_BK0_N_BK1(
arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideBs, arg_.BK0);
// Packed sizes are 1 for all implemented data types but we include it anyway
// for future compatibility.
// Note: the grid descriptors and size_a / size_b do *not* take batching into
// account, so we have to manually multiply overall buffer sizes for rotating
// memory by batch.
std::array<std::size_t, 1> size_as_buffers;
size_as_buffers[0] = a_grid_desc_ak0_m_ak1[Number<0>{}].GetElementSpaceSize() *
sizeof(ADataType) / GridwiseGemm::APackedSize * arg_.Batch;
std::array<std::size_t, 1> size_bs_buffers;
size_bs_buffers[0] = b_grid_desc_bk0_n_bk1[Number<0>{}].GetElementSpaceSize() *
sizeof(BDataType) / GridwiseGemm::BPackedSize * arg_.Batch;
ck::utility::RotatingMemWrapperMultiABD<Argument,
Tuple<ADataType>,
Tuple<BDataType>,
Tuple<>>
rotating_mem(arg_,
stream_config.rotating_count,
size_as_buffers,
size_bs_buffers,
std::array<std::size_t, 0>{});
rotating_mem.Print();
auto run_flush_cache = [&]() {
ck::utility::flush_icache();
rotating_mem.Next();
// clear c mem
if(arg_.KBatch > 1)
// Note: we multiply by batch since we want to clear the C matrix for
// the whole batch. Untested since we don't have k batching ATM.
// Note: This seems incorrect for non-contiguous memory layouts for C
// (padding, gaps).
HIP_CHECK_ERROR(
hipMemsetAsync(arg_.p_e_grid,
0,
arg_.Batch * arg_.M * arg_.N * sizeof(CDataType),
stream_config.stream_id_));
};
ave_time = ck::utility::launch_and_time_kernel_with_preprocess<false>(
stream_config,
run_flush_cache,
kernel,
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
arg_,
arg_.compute_ptr_offset_of_batch);
}
else
{
auto clear_workspace = [&]() {
// clear c mem
if(arg.KBatch > 1)
// Note: we multiply by batch since we want to clear the C matrix for
// the whole batch. Untested since we don't have k batching ATM.
// Note: This seems incorrect for non-contiguous memory layouts for C
// (padding, gaps).
HIP_CHECK_ERROR(
hipMemsetAsync(arg.p_e_grid,
0,
arg.Batch * arg.M * arg.N * sizeof(CDataType),
stream_config.stream_id_));
};
ave_time = ck::utility::launch_and_time_kernel_with_preprocess<false>(
stream_config,
clear_workspace,
kernel,
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
arg,
arg.compute_ptr_offset_of_batch);
}
};
constexpr index_t minimum_occupancy = []() {
if constexpr(BlkGemmPipeSched == BlockGemmPipelineScheduler::Interwave)
{
return 2;
}
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{
return (MPerBlock * NPerBlock / BlockSize <= 128) ? 2 : 1;
}
else
{
return 1;
}
}();
if(has_main_k_block_loop)
{
// Tail number always full
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 ||
BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{
if(arg.KBatch > 1)
{
const auto kernel = kernel_batched_gemm_b_scale_wmma_cshuffle_v3<
GridwiseGemm,
ComputePtrOffsetOfStridedBatch,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy>;
Run(kernel);
}
else
{
const auto kernel = kernel_batched_gemm_b_scale_wmma_cshuffle_v3<
GridwiseGemm,
remove_reference_t<ComputePtrOffsetOfStridedBatch>,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy>;
Run(kernel);
}
}
else
{
throw std::runtime_error("Pipeline not implemented");
}
}
else
{
// Tail number always 1
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
{
if(arg.KBatch > 1)
{
const auto kernel = kernel_batched_gemm_b_scale_wmma_cshuffle_v3<
GridwiseGemm,
ComputePtrOffsetOfStridedBatch,
false,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy>;
Run(kernel);
}
else
{
const auto kernel = kernel_batched_gemm_b_scale_wmma_cshuffle_v3<
GridwiseGemm,
remove_reference_t<ComputePtrOffsetOfStridedBatch>,
false,
InMemoryDataOperationEnum::Set,
minimum_occupancy>;
Run(kernel);
}
}
}
return ave_time;
}
// polymorphic
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
static bool IsSupportedArgument(const Argument& arg)
{
if(!ck::is_gfx11_supported() && !ck::is_gfx12_supported())
{
return false;
}
if constexpr(std::is_same_v<CDataType, ck::half_t> ||
std::is_same_v<CDataType, ck::bhalf_t>)
{
if(arg.KBatch > 1 && ck::is_gfx11_supported())
{
// gfx11 does not support *_atomic_pk_add_f16/bf16 instructions
return false;
}
}
if constexpr(std::is_same_v<ComputeTypeA, f8_t> || std::is_same_v<ComputeTypeA, bf8_t> ||
std::is_same_v<ComputeTypeB, f8_t> || std::is_same_v<ComputeTypeB, bf8_t>)
{
if(ck::is_gfx11_supported())
{
return false;
}
}
if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding ||
GemmSpec == GemmSpecialization::NKPadding ||
GemmSpec == GemmSpecialization::MNKPadding ||
GemmSpec == GemmSpecialization::KPadding))
{
return false;
}
return GridwiseGemm::CheckValidity(arg);
}
// polymorphic
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
index_t GetKPerBlock() override { return KPerBlock; }
bool GetPermuteB() override { return PermuteB; }
static auto MakeArgument(const ADataType* p_a,
const BDataType* p_b,
CDataType* p_c,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideB,
index_t StrideC,
index_t StrideScaleB,
index_t BatchStrideA,
index_t BatchStrideB,
index_t BatchStrideC,
index_t BatchStrideScaleB,
const BScaleDataType* p_b_scale,
index_t Batch,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
index_t KBatch = 1)
{
return Argument{p_a,
p_b,
p_c,
M,
N,
K,
StrideA,
StrideB,
StrideC,
StrideScaleB,
BatchStrideA,
BatchStrideB,
BatchStrideC,
BatchStrideScaleB,
p_b_scale,
Batch,
KBatch};
}
static auto MakeInvoker() { return Invoker{}; }
// polymorphic
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
const void* p_b,
void* p_c,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideB,
index_t StrideC,
index_t StrideScaleB,
index_t BatchStrideA,
index_t BatchStrideB,
index_t BatchStrideC,
index_t BatchStrideScaleB,
const void* p_b_scale,
index_t Batch,
index_t KBatch,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) override
{
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
static_cast<CDataType*>(p_c),
M,
N,
K,
StrideA,
StrideB,
StrideC,
StrideScaleB,
BatchStrideA,
BatchStrideB,
BatchStrideC,
BatchStrideScaleB,
static_cast<const BScaleDataType*>(p_b_scale),
Batch,
KBatch,
a_element_op,
b_element_op,
c_element_op);
}
// polymorphic
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}
// polymorphic
std::string GetTypeString() const override
{
auto str = std::stringstream();
std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
{BlockGemmPipelineScheduler::Intrawave, "Intrawave"},
{BlockGemmPipelineScheduler::Interwave, "Interwave"}};
std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
{BlockGemmPipelineVersion::v1, "v1"},
{BlockGemmPipelineVersion::v2, "v2"},
{BlockGemmPipelineVersion::v3, "v3"},
{BlockGemmPipelineVersion::v4, "v4"},
{BlockGemmPipelineVersion::v5, "v5"}};
// clang-format off
str << "DeviceBatchedGemm_Wmma_CShuffleV3_BScale"
<< "<"
<< getGemmSpecializationString(GemmSpec) << ", "
<< std::string(ALayout::name)[0]
<< std::string(BLayout::name)[0]
<< std::string(CLayout::name)[0]
<< ">"
<< " BlkSize: "
<< BlockSize << ", "
<< "BlkTile: "
<< MPerBlock << "x" << NPerBlock << "x" << KPerBlock << ", "
<< "WaveTile: "
<< MPerWmma << "x"<<NPerWmma << ", "
<< "WaveMap: "
<< MRepeat << "x" << NRepeat << ", "
<< "VmemReadVec: "
<< ABlockTransferSrcScalarPerVector << "x" << BBlockTransferSrcScalarPerVector << ", "
<< "BlkGemmPipelineScheduler: "
<< BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
<< "BlkGemmPipelineVersion: "
<< BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
<< "BlkGemmPipelinePrefetchStages: "
<< GridwiseGemm::BlockwiseGemmPipe::PrefetchStages << ", "
<< "KPack: "
<< GridwiseGemm::KPack;
// clang-format on
return str.str();
}
REGISTER_EXTRA_PRINTING_METHODS
};
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -222,6 +222,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
using typename Base::AsGridPointer;
using typename Base::BsGridPointer;
using typename Base::DsGridPointer;
using AsDataType_ = AsDataType;
using BsDataType_ = BsDataType;
struct Problem
{

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -8,6 +8,7 @@
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include <stdexcept>
namespace ck {
namespace tensor_operation {
@@ -30,14 +31,18 @@ struct ReferenceBatchedGemm : public device::BaseOperator
Tensor<CDataType>& c_g_m_n,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
CElementwiseOperation c_element_op,
const int k_batch = 1)
: a_g_m_k_{a_g_m_k},
b_g_k_n_{b_g_k_n},
c_g_m_n_{c_g_m_n},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
c_element_op_{c_element_op}
c_element_op_{c_element_op},
k_batch_(k_batch)
{
if(k_batch < 1)
throw std::invalid_argument("Batch size must be at least 1");
}
const Tensor<ADataType>& a_g_m_k_;
@@ -47,6 +52,8 @@ struct ReferenceBatchedGemm : public device::BaseOperator
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_;
const int k_batch_;
};
// Invoker
@@ -59,23 +66,54 @@ struct ReferenceBatchedGemm : public device::BaseOperator
auto f_gmk_gkn_gmn = [&](auto g, auto m, auto n) {
const int K = arg.a_g_m_k_.mDesc.GetLengths()[2];
AccDataType v_acc = 0;
// simulate fp accuacy implications of k batching
std::vector<CDataType> partialSums(arg.k_batch_);
for(int k = 0; k < K; ++k)
for(int batchIdx = 0; batchIdx < arg.k_batch_; ++batchIdx)
{
ADataType v_a;
BDataType v_b;
int batchSize = std::max(K / arg.k_batch_, 1);
int batchStart = batchSize * batchIdx;
int batchEnd = batchSize * (batchIdx + 1);
// add any extra round-off to last batch
if(batchIdx == arg.k_batch_ - 1)
batchEnd = K;
arg.a_element_op_(v_a, arg.a_g_m_k_(g, m, k));
arg.b_element_op_(v_b, arg.b_g_k_n_(g, k, n));
AccDataType v_acc = 0;
for(int k = batchStart; k < batchEnd; ++k)
{
ADataType v_a;
BDataType v_b;
v_acc +=
ck::type_convert<AccDataType>(v_a) * ck::type_convert<AccDataType>(v_b);
arg.a_element_op_(v_a, arg.a_g_m_k_(g, m, k));
arg.b_element_op_(v_b, arg.b_g_k_n_(g, k, n));
v_acc +=
ck::type_convert<AccDataType>(v_a) * ck::type_convert<AccDataType>(v_b);
}
AccDataType v_c;
arg.c_element_op_(v_c, v_acc);
partialSums[batchIdx] = ck::type_convert<CDataType>(v_c);
}
AccDataType v_c;
arg.c_element_op_(v_c, v_acc);
// finally, sum up partial sums
// note that we can't simulate the random nature of atomic additions, but at least
// we can simulate the effect of partial sums
AccDataType v_c = 0;
if(arg.k_batch_ > 1)
{
for(int batchIdx = 0; batchIdx < arg.k_batch_; batchIdx++)
{
// mimic the way fp operations would be done on GPU for k-batching
v_c = ck::type_convert<AccDataType>(ck::type_convert<CDataType>(
ck::type_convert<AccDataType>(v_c) +
ck::type_convert<AccDataType>(partialSums[batchIdx])));
}
}
else
{
v_c = ck::type_convert<AccDataType>(partialSums[0]);
}
arg.c_g_m_n_(g, m, n) = ck::type_convert<CDataType>(v_c);
};
@@ -108,9 +146,11 @@ struct ReferenceBatchedGemm : public device::BaseOperator
Tensor<CDataType>& c_g_m_n,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
CElementwiseOperation c_element_op,
const int k_batch = 1)
{
return Argument{a_g_m_k, b_g_k_n, c_g_m_n, a_element_op, b_element_op, c_element_op};
return Argument{
a_g_m_k, b_g_k_n, c_g_m_n, a_element_op, b_element_op, c_element_op, k_batch};
}
static auto MakeInvoker() { return Invoker{}; }

View File

@@ -5,6 +5,8 @@
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl_fpAintB_b_scale.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3_b_scale.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include <memory>
@@ -16,6 +18,8 @@ namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
#if defined(CK_USE_XDL)
#if(defined(CK_ENABLE_FP16) || defined(CK_ENABLE_FP8))
void add_device_batched_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn_mem_v2_default_instances(
std::vector<std::unique_ptr<DeviceBatchedGemmV2BScale<Row,
@@ -31,6 +35,25 @@ void add_device_batched_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn_mem_v2_default_inst
PassThrough,
PassThrough>>>& instances);
#endif
#endif // CK_USE_XDL
#if defined(CK_USE_WMMA)
#if(defined(CK_ENABLE_FP16) || defined(CK_ENABLE_FP8)) // TODO: really, or?
void add_device_batched_gemm_b_scale_wmma_f16_i4_f16_mk_nk_mn_mem_default_instances(
std::vector<std::unique_ptr<DeviceBatchedGemmV2BScale<Row,
Col,
Row,
F16,
I4,
F16,
F16,
1,
128,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif // CK_ENABLE_FP16 || CK_ENABLE_FP8
#endif // CK_USE_WMMA
template <typename ADataType,
typename BDataType,
@@ -40,6 +63,7 @@ template <typename ADataType,
typename BLayout,
typename CLayout,
index_t ScaleBlockK>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceBatchedGemmV2BScale<
ALayout,
BLayout,
@@ -77,8 +101,14 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceBatche
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<CLayout, Row>)
{
#if defined(CK_USE_XDL)
add_device_batched_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn_mem_v2_default_instances(
op_ptrs);
#endif // CK_USE_XDL
#if defined(CK_USE_WMMA)
add_device_batched_gemm_b_scale_wmma_f16_i4_f16_mk_nk_mn_mem_default_instances(
op_ptrs);
#endif // CK_USE_WMMA
}
}

View File

@@ -1,10 +1,13 @@
# ONLY XDL_KERNELS
# ONLY XDL_AND_WMMA_KERNELS
set(BATCHED_GEMM_B_SCALE_INSTANCES)
list(APPEND BATCHED_GEMM_B_SCALE_INSTANCES
device_batched_gemm_b_scale_xdl_f16_i4_f16/device_batched_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn_mem_v2_default_instance.cpp
device_batched_gemm_b_scale_wmma_f16_i4_f16/device_batched_gemm_b_scale_wmma_f16_i4_f16_mk_nk_mn_mem_default_instance.cpp
)
set_source_files_properties(device_batched_gemm_b_scale_xdl_f16_i4_f16/device_batched_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn_mem_v2_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
set_source_files_properties(device_batched_gemm_b_scale_wmma_f16_i4_f16/device_batched_gemm_b_scale_wmma_f16_i4_f16_mk_nk_mn_mem_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
add_instance_library(device_batched_gemm_b_scale_instance ${BATCHED_GEMM_B_SCALE_INSTANCES})

View File

@@ -0,0 +1,72 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3_b_scale.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using I4 = pk_i4_t;
using F16 = half_t;
using F32 = float;
using Row = tensor_layout::gemm::RowMajor;
using Col = tensor_layout::gemm::ColumnMajor;
template <index_t... Is>
using S = Sequence<Is...>;
using PassThrough = element_wise::PassThrough;
static constexpr auto GemmDefault = GemmSpecialization::Default;
static constexpr auto GemmKPadding = GemmSpecialization::KPadding;
static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding;
static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding;
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
template <BlockGemmPipelineScheduler BlkGemmPipeSched, GemmSpecialization GemmSpec>
using device_batched_gemm_b_scale_wmma_f16_i4_f16_mk_nk_mn_mem_instances = std::tuple<
// clang-format off
//################################| ALayout| BLayout| CLayout|AData|BData| BScale| CData| AccData| Cshuffle| A| B| C| GEMM| Block| Scale| Scale| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| Compute| Compute| PermuteA| PermuteB|
//################################| | | | Type| Type| Data| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| Block| Block| | |Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock| ScalarPerVector| Pipeline| Pipeline| TypeA| TypeB| | |
//################################| | | | | | Type| | | | Operation| Operation| Operation| | | N| K| | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| Scheduler| Verision| | | | |
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceBatchedGemm_Wmma_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 64, 8, 8, 16, 16, 4, 2, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //0
DeviceBatchedGemm_Wmma_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 64, 8, 8, 16, 16, 4, 2, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1, half_t, half_t, false, false>, //1
DeviceBatchedGemm_Wmma_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 64, 8, 8, 16, 16, 4, 2, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, Interwave, BlockGemmPipelineVersion::v1, half_t, half_t, false, false>, //2
DeviceBatchedGemm_Wmma_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 64, 64, 64, 8, 8, 16, 16, 2, 2, S< 4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //3
DeviceBatchedGemm_Wmma_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 64, 64, 64, 8, 8, 16, 16, 2, 2, S< 4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1, half_t, half_t, false, false>, //4
DeviceBatchedGemm_Wmma_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 64, 64, 64, 8, 8, 16, 16, 2, 2, S< 4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 8, Interwave, BlockGemmPipelineVersion::v1, half_t, half_t, false, false>, //5
DeviceBatchedGemm_Wmma_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 1, 128, 32, 32, 64, 8, 8, 16, 16, 2, 2, S< 2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //6
DeviceBatchedGemm_Wmma_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 1, 128, 32, 32, 64, 8, 8, 16, 16, 2, 2, S< 2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v1, half_t, half_t, false, false>, //7
DeviceBatchedGemm_Wmma_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 1, 128, 32, 32, 64, 8, 8, 16, 16, 2, 2, S< 2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 2>, 8, Interwave, BlockGemmPipelineVersion::v1, half_t, half_t, false, false>, //8
DeviceBatchedGemm_Wmma_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 128, 8, 8, 16, 16, 4, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //9
DeviceBatchedGemm_Wmma_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 128, 8, 8, 16, 16, 4, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1, half_t, half_t, false, false>, //10
DeviceBatchedGemm_Wmma_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 128, 8, 8, 16, 16, 4, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, Interwave, BlockGemmPipelineVersion::v1, half_t, half_t, false, false>, //11
DeviceBatchedGemm_Wmma_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 32, 32, 128, 8, 8, 16, 16, 1, 1, S< 8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, Intrawave, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //12
DeviceBatchedGemm_Wmma_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 32, 32, 128, 8, 8, 16, 16, 1, 1, S< 8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, Intrawave, BlockGemmPipelineVersion::v1, half_t, half_t, false, false>, //13
DeviceBatchedGemm_Wmma_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 32, 32, 128, 8, 8, 16, 16, 1, 1, S< 8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, Interwave, BlockGemmPipelineVersion::v1, half_t, half_t, false, false>, //14
DeviceBatchedGemm_Wmma_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 1, 128, 16, 16, 128, 8, 8, 16, 16, 1, 1, S< 2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //15
DeviceBatchedGemm_Wmma_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 1, 128, 16, 16, 128, 8, 8, 16, 16, 1, 1, S< 2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v1, half_t, half_t, false, false>, //16
DeviceBatchedGemm_Wmma_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 1, 128, 16, 16, 128, 8, 8, 16, 16, 1, 1, S< 2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 2>, 8, Interwave, BlockGemmPipelineVersion::v1, half_t, half_t, false, false>, //17
DeviceBatchedGemm_Wmma_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 1, 128, 16, 16, 128, 8, 8, 16, 16, 1, 1, S< 4, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 4, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 2>, 4, Intrawave, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //18
DeviceBatchedGemm_Wmma_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 1, 128, 16, 16, 128, 8, 8, 16, 16, 1, 1, S< 4, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 4, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 2>, 4, Intrawave, BlockGemmPipelineVersion::v1, half_t, half_t, false, false>, //19
DeviceBatchedGemm_Wmma_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 1, 128, 16, 16, 128, 8, 8, 16, 16, 1, 1, S< 4, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 4, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 2>, 4, Interwave, BlockGemmPipelineVersion::v1, half_t, half_t, false, false> //20
// clang-format on
>;
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,33 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "device_batched_gemm_b_scale_wmma_f16_i4_f16_mk_nk_mn.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_batched_gemm_b_scale_wmma_f16_i4_f16_mk_nk_mn_mem_default_instances(
std::vector<std::unique_ptr<DeviceBatchedGemmV2BScale<Row,
Col,
Row,
F16,
I4,
F16,
F16,
1,
128,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances,
device_batched_gemm_b_scale_wmma_f16_i4_f16_mk_nk_mn_mem_instances<Intrawave,
GemmDefault>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -45,9 +45,6 @@ using device_batched_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn_mem_instances = std::t
DeviceBatchedGemm_Xdl_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 128, 8, 32, 32, 32, 2, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //0
DeviceBatchedGemm_Xdl_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 64, 8, 32, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v4, half_t, half_t, false, false>, //1
DeviceBatchedGemm_Xdl_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 64, 8, 32, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v4, half_t, half_t, false, false>, //3
DeviceBatchedGemm_Xdl_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 64, 8, 32, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //4
//Latency friendly
DeviceBatchedGemm_Xdl_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 32, 16, 128, 8, 16, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //5
DeviceBatchedGemm_Xdl_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 1, 128, 16, 16, 128, 8, 16, 16, 16, 1, 1, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //6

View File

@@ -51,9 +51,6 @@ using device_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn_mem_instances = std::tuple<
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 128, 8, 32, 32, 32, 2, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //0
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 64, 8, 32, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v4, half_t, half_t, false, false>, //1
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 64, 8, 32, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v4, half_t, half_t, false, false>, //3
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 64, 8, 32, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //4
//Latency friendly
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 32, 16, 128, 8, 16, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //5
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 1, 128, 16, 16, 128, 8, 16, 16, 16, 1, 1, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //6

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -9,12 +9,13 @@
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl_fpAintB_b_scale.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3_b_scale.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/batched_gemm_b_scale.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
@@ -113,22 +114,21 @@ bool profile_batched_gemm_b_scale_impl(int do_verification,
std::cout << "c_g_m_n: " << c_g_m_n_device_result.mDesc << std::endl;
std::cout << "rotating count: " << rotating_count << std::endl;
static constexpr index_t BPackedSize = []() {
if constexpr(is_same_v<remove_cvref_t<BDataType>, pk_i4_t>)
return 2;
else
return 1;
}();
switch(init_method)
{
case 0: break;
case 1:
a_g_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-1, 2});
b_g_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-1, 2});
b1_g_k_n.GenerateTensorValue(GeneratorTensor_3<BScaleDataType>{0, 1.0});
break;
case 2:
a_g_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_g_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
b1_g_k_n.GenerateTensorValue(GeneratorTensor_3<BScaleDataType>{0, 1.0});
break;
// NOTE: for an int4, there is no point differentiating between decimal and integer
// initialization also, the random number seem to be for a int4_2 type, so we use range 0...255
default:
a_g_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_g_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-2, 2});
b_g_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-1, 2});
b1_g_k_n.GenerateTensorValue(GeneratorTensor_3<BScaleDataType>{0, 1.0});
}
@@ -141,7 +141,8 @@ bool profile_batched_gemm_b_scale_impl(int do_verification,
const auto c_element_op = CElementOp{};
DeviceMem a_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_g_k_n_permute.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_g_k_n_permute.mDesc.GetElementSpaceSize() /
BPackedSize);
DeviceMem b1_device_buf(sizeof(BScaleDataType) * b1_g_k_n.mDesc.GetElementSpaceSize());
DeviceMem c_device_buf(sizeof(CDataType) * c_g_m_n_device_result.mDesc.GetElementSpaceSize());
@@ -166,54 +167,63 @@ bool profile_batched_gemm_b_scale_impl(int do_verification,
DeviceOp>::GetInstances();
std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
// Run reference GEMM
if(do_verification)
{
Tensor<float> b_g_k_n_dequant({K, N});
Tensor<BScaleDataType> b_g_k_n_dequant({BatchSize, K, N});
float v_b = 0;
for(int bs = 0; bs < BatchSize; bs++)
{
for(int n = 0; n < N; n++)
{
for(int k = 0; k < K; k++)
{
ck::pk_i4_t i4x2 = b_g_k_n(bs, k, n).data;
int8_t i4 = 0;
if(k % 2 == 1)
// for proper testing, we need to replicate k_shuffle when used
// see unary_element_wise_operation.hpp
#if CK_USE_PK4_LAYOUT_SHUFFLE
int k_shuffle = (k / 8) * 8 + (k % 2) * 4 + (k % 8) / 2;
#else
int k_shuffle = k;
#endif
ck::pk_i4_t i4x2 = b_g_k_n(bs, k_shuffle, n).data;
int i4 = 0;
if(k_shuffle % 2 == 0)
i4 = (i4x2.data >> 0) & 0xf;
else
i4 = (i4x2.data >> 4) & 0xf;
i4 = i4 - 8;
i4 = i4 - 8;
v_b = ck::type_convert<float>(i4);
b_g_k_n_dequant(bs, k, n) =
ck::type_convert<float>(v_b) *
ck::type_convert<float>(b1_g_k_n(bs, k / ScaleBlockK, n));
float out = ck::type_convert<float>(v_b) *
ck::type_convert<float>(b1_g_k_n(bs, k / ScaleBlockK, n));
b_g_k_n_dequant(bs, k, n) = out;
}
}
}
using ReferenceBatchedGemmInstance =
ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
BScaleDataType,
CDataType,
AccDataType,
AElementOp,
BElementOp,
CElementOp>;
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
CDataType,
AccDataType,
AElementOp,
BElementOp,
CElementOp,
ComputeDataType>;
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(a_g_m_k,
b_g_k_n_dequant,
c_g_m_n_host_result,
a_element_op,
b_element_op,
c_element_op);
auto ref_batched_gemm = ReferenceBatchedGemmInstance{};
auto ref_invoker = ref_batched_gemm.MakeInvoker();
auto ref_argument = ref_batched_gemm.MakeArgument(a_g_m_k,
b_g_k_n_dequant,
c_g_m_n_host_result,
a_element_op,
b_element_op,
c_element_op,
KBatch);
ref_invoker.Run(ref_argument);
}
@@ -230,6 +240,7 @@ bool profile_batched_gemm_b_scale_impl(int do_verification,
if(op_ptr->GetPermuteB())
{
int K1 = KPerBlock;
int K0 = K / KPerBlock;
@@ -306,6 +317,7 @@ bool profile_batched_gemm_b_scale_impl(int do_verification,
}
else
{
b_g_k_n_permute = b_g_k_n;
}
@@ -375,8 +387,12 @@ bool profile_batched_gemm_b_scale_impl(int do_verification,
else
{
#endif
std::string msg = "Error: Incorrect results!";
double rtol = 1e-2;
double atol = 1e-2;
pass =
pass & ck::utils::check_err(c_g_m_n_device_result, c_g_m_n_host_result);
pass & ck::utils::check_err(
c_g_m_n_device_result, c_g_m_n_host_result, msg, rtol, atol);
#if defined CK_ENABLE_FP8
}
#endif
@@ -407,13 +423,6 @@ bool profile_batched_gemm_b_scale_impl(int do_verification,
std::size_t flop = std::size_t(2) * M * N * K * BatchSize;
static constexpr index_t BPackedSize = []() {
if constexpr(is_same_v<remove_cvref_t<BDataType>, pk_i4_t>)
return 2;
else
return 1;
}();
std::size_t num_btype = sizeof(ADataType) * M * K +
sizeof(BDataType) * K * N / BPackedSize +
sizeof(CDataType) * M * N;

View File

@@ -105,7 +105,7 @@ bool profile_gemm_b_scale_impl(int do_verification,
break;
case 2:
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-1, 2});
b1_k_n.GenerateTensorValue(GeneratorTensor_3<BScaleDataType>{0, 1.0});
break;
default:
@@ -122,8 +122,16 @@ bool profile_gemm_b_scale_impl(int do_verification,
const auto b_element_op = BElementOp{};
const auto c_element_op = CElementOp{};
static constexpr index_t BPackedSize = []() {
if constexpr(is_same_v<remove_cvref_t<BDataType>, pk_i4_t>)
return 2;
else
return 1;
}();
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n_permute.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n_permute.mDesc.GetElementSpaceSize() /
BPackedSize);
DeviceMem b1_device_buf(sizeof(BScaleDataType) * b1_k_n.mDesc.GetElementSpaceSize());
DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize());
@@ -152,16 +160,24 @@ bool profile_gemm_b_scale_impl(int do_verification,
// Run reference GEMM
if(do_verification)
{
Tensor<float> b_k_n_dequant({K, N});
Tensor<BScaleDataType> b_k_n_dequant({K, N});
float v_b = 0;
for(int n = 0; n < N; n++)
{
for(int k = 0; k < K; k++)
{
ck::pk_i4_t i4x2 = b_k_n(k, n).data;
int8_t i4 = 0;
if(k % 2 == 1)
// for proper testing, we need to replicate k_shuffle when used
// see unary_element_wise_operation.hpp
#if CK_USE_PK4_LAYOUT_SHUFFLE
int k_shuffle = (k / 8) * 8 + (k % 2) * 4 + (k % 8) / 2;
#else
int k_shuffle = k;
#endif
ck::pk_i4_t i4x2 = b_k_n(k_shuffle, n).data;
int i4 = 0;
if(k_shuffle % 2 == 0)
i4 = (i4x2.data >> 0) & 0xf;
else
i4 = (i4x2.data >> 4) & 0xf;
@@ -173,7 +189,7 @@ bool profile_gemm_b_scale_impl(int do_verification,
}
}
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
AccDataType,
BScaleDataType,
CDataType,
AccDataType,
AElementOp,
@@ -334,7 +350,11 @@ bool profile_gemm_b_scale_impl(int do_verification,
else
{
#endif
pass = pass & ck::utils::check_err(c_m_n_device_result, c_m_n_host_result);
std::string msg = "Error: Incorrect results!";
double rtol = 2e-2;
double atol = 2e-2;
pass = pass & ck::utils::check_err(
c_m_n_device_result, c_m_n_host_result, msg, rtol, atol);
#if defined CK_ENABLE_FP8
}
#endif
@@ -365,13 +385,6 @@ bool profile_gemm_b_scale_impl(int do_verification,
std::size_t flop = std::size_t(2) * M * N * K;
static constexpr index_t BPackedSize = []() {
if constexpr(is_same_v<remove_cvref_t<BDataType>, pk_i4_t>)
return 2;
else
return 1;
}();
std::size_t num_btype = sizeof(ADataType) * M * K +
sizeof(BDataType) * K * N / BPackedSize +
sizeof(CDataType) * M * N;

View File

@@ -90,7 +90,7 @@ bool profile_gemm_universal_impl(int do_verification,
break;
case 2:
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-1, 2});
break;
default:
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});

View File

@@ -67,7 +67,6 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9")
list(APPEND PROFILER_OPS profile_gemm_add.cpp)
list(APPEND PROFILER_OPS profile_gemm_bias_add_reduce.cpp)
list(APPEND PROFILER_OPS profile_gemm_splitk.cpp)
list(APPEND PROFILER_OPS profile_batched_gemm_b_scale.cpp)
list(APPEND PROFILER_OPS profile_gemm_universal_batched.cpp)
list(APPEND PROFILER_OPS profile_gemm_universal_streamk.cpp)
list(APPEND PROFILER_OPS profile_conv_fwd_bias_relu.cpp)
@@ -89,6 +88,7 @@ endif()
if(SUPPORTED_GPU_TARGETS MATCHES "gfx9" OR SUPPORTED_GPU_TARGETS MATCHES "gfx1[12]")
list(APPEND PROFILER_OPS profile_gemm_universal.cpp)
list(APPEND PROFILER_OPS profile_batched_gemm.cpp)
list(APPEND PROFILER_OPS profile_batched_gemm_b_scale.cpp)
list(APPEND PROFILER_OPS profile_gemm_b_scale.cpp)
list(APPEND PROFILER_OPS profile_gemm_universal_reduce.cpp)
list(APPEND PROFILER_OPS profile_grouped_conv_fwd.cpp)
@@ -191,7 +191,6 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9" OR SUPPORTED_GPU_TARGETS MATCHES "gfx1[1
list(APPEND DEVICE_INSTANCES device_gemm_mx_instance)
endif()
list(APPEND DEVICE_INSTANCES device_gemm_splitk_instance)
list(APPEND DEVICE_INSTANCES device_batched_gemm_b_scale_instance)
list(APPEND DEVICE_INSTANCES device_gemm_universal_batched_instance)
list(APPEND DEVICE_INSTANCES device_gemm_universal_streamk_instance)
list(APPEND DEVICE_INSTANCES device_gemm_add_multiply_instance)
@@ -229,6 +228,7 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9" OR SUPPORTED_GPU_TARGETS MATCHES "gfx1[1
list(APPEND DEVICE_INSTANCES device_batched_gemm_instance)
list(APPEND DEVICE_INSTANCES device_gemm_b_scale_instance)
list(APPEND DEVICE_INSTANCES device_gemm_universal_reduce_instance)
list(APPEND DEVICE_INSTANCES device_batched_gemm_b_scale_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv2d_bwd_data_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_bwd_data_instance)

View File

@@ -57,7 +57,7 @@ int profile_batched_gemm_b_scale(int argc, char* argv[])
printf("arg6: initialization (0: no init; 1: integer value; 2: decimal value)\n");
printf("arg7: print tensor value (0: no; 1: yes)\n");
printf("arg8: time kernel (0=no, 1=yes)\n");
printf("arg9 to 15: M, N, K, StrideA, StrideB, StrideC, BatachCount\n");
printf("arg9 to 15: M, N, K, StrideA, StrideB, StrideC, BatchCount\n");
printf("arg16: split k into mulitiple batch\n");
printf("optional:\n");
printf("arg17: number of warm-up cycles (default 1)\n");

View File

@@ -24,6 +24,7 @@ set(REGRESSION_TESTS
test_batched_gemm_softmax_gemm_permute_bf16
test_batched_gemm_bias_softmax_gemm_permute_bf16
test_grouped_gemm_splitk
test_batched_gemm_b_scale_wmma
test_reduce_no_index
test_reduce_with_index
test_convnd_fwd
@@ -257,6 +258,7 @@ add_subdirectory(batched_gemm_reduce)
add_subdirectory(batched_gemm_gemm)
add_subdirectory(batched_gemm_softmax_gemm)
add_subdirectory(batched_gemm_softmax_gemm_permute)
add_subdirectory(batched_gemm_b_scale)
add_subdirectory(grouped_gemm)
add_subdirectory(reduce)
add_subdirectory(convnd_fwd)

View File

@@ -0,0 +1,5 @@
add_gtest_executable(test_batched_gemm_b_scale_wmma test_batched_gemm_b_scale_wmma.cpp)
if(result EQUAL 0)
target_link_libraries(test_batched_gemm_b_scale_wmma PRIVATE utility device_batched_gemm_b_scale_instance)
endif()

View File

@@ -0,0 +1,49 @@
#pragma once
TYPED_TEST(TestBatchedGemmBScale_MK_NK, SmallM)
{
std::vector<int> Ms{1, 2, 3, 4, 5, 6};
constexpr int N = 256;
constexpr int K = 1024;
constexpr int StrideA = K;
constexpr int StrideB = K;
constexpr int StrideC = N;
constexpr int NBatches = 10;
for(int M : Ms)
this->Run(M, N, K, StrideA, StrideB, StrideC, NBatches);
}
TYPED_TEST(TestBatchedGemmBScale_MK_NK, MidLargeM)
{
std::vector<int> Ms{127, 255, 312, 799, 1573};
constexpr int N = 512;
constexpr int K = 768;
constexpr int StrideA = K;
constexpr int StrideB = K;
constexpr int StrideC = N;
constexpr int NBatches = 7;
for(int M : Ms)
this->Run(M, N, K, StrideA, StrideB, StrideC, NBatches);
}
TYPED_TEST(TestBatchedGemmBScale_MK_NK, Regular)
{
std::vector<int> Ms{512, 1024};
constexpr int N = 512;
constexpr int K = 1024;
constexpr int StrideA = K;
constexpr int StrideB = K;
constexpr int StrideC = N;
constexpr int NBatches = 3;
for(int M : Ms)
this->Run(M, N, K, StrideA, StrideB, StrideC, NBatches);
}

View File

@@ -0,0 +1,108 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <string>
#include <sstream>
#include <tuple>
#include <vector>
#include <gtest/gtest.h>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "include/ck/utility/data_type.hpp"
#include "profiler/profile_batched_gemm_b_scale_impl.hpp"
namespace ck {
namespace test {
template <typename Tuple>
class TestBatchedGemmBScale : public testing::Test
{
using Row = ck::tensor_layout::gemm::RowMajor;
using F32 = float;
protected:
using ALayout = std::tuple_element_t<0, Tuple>;
using BLayout = std::tuple_element_t<1, Tuple>;
using CLayout = Row;
using ADataType = std::tuple_element_t<2, Tuple>;
using BDataType = std::tuple_element_t<3, Tuple>;
using BScaleDataType = std::tuple_element_t<4, Tuple>;
using ComputeDataType = std::tuple_element_t<5, Tuple>;
using CDataType = std::tuple_element_t<6, Tuple>;
public:
static constexpr ck::index_t ScaleBlockK = 128; // all instances
static constexpr bool verify_ = true;
static constexpr int init_method_ = 2;
static constexpr bool log_ = false;
static constexpr bool bench_ = false; // measure kernel performance
std::vector<int> k_batches_;
void SetUp() override { k_batches_ = {1, 2}; }
void Run(const int M,
const int N,
const int K,
const int StrideA,
const int StrideB,
const int StrideC,
const int NBatch)
{
for(auto kb : k_batches_)
{
RunSingle(M, N, K, StrideA, StrideB, StrideC, NBatch, kb);
}
}
void RunSingle(const int M,
const int N,
const int K,
const int StrideA,
const int StrideB,
const int StrideC,
const int Nbatch,
int kbatch = 1,
int n_warmup = 1,
int n_iter = 10)
{
const int BatchStrideA = StrideA * M;
const int BatchStrideB = StrideB * K;
const int BatchStrideC = StrideC * M;
const int BatchStrideScaleB = StrideB * K;
bool pass = ck::profiler::profile_batched_gemm_b_scale_impl<ADataType,
BDataType,
BScaleDataType,
ComputeDataType,
F32,
CDataType,
ScaleBlockK,
ALayout,
BLayout,
CLayout>(verify_,
init_method_,
log_,
bench_,
M,
N,
K,
StrideA,
StrideB,
StrideC,
BatchStrideA,
BatchStrideB,
BatchStrideC,
BatchStrideScaleB,
Nbatch,
kbatch,
n_warmup,
n_iter);
EXPECT_TRUE(pass);
}
};
} // namespace test
} // namespace ck

View File

@@ -0,0 +1,45 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <tuple>
#include "gtest/gtest.h"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "test_batched_gemm_b_scale_util.hpp"
using I4 = ck::pk_i4_t;
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
namespace {
template <typename X, typename Y>
struct tuple_concat;
template <typename... Xs, typename... Ys>
struct tuple_concat<std::tuple<Xs...>, std::tuple<Ys...>>
{
using type = std::tuple<Xs..., Ys...>;
};
} // namespace
template <typename Tuple>
class TestBatchedGemmBScale_MK_NK : public ck::test::TestBatchedGemmBScale<
typename tuple_concat<std::tuple<Row, Col>, Tuple>::type>
{
};
// clang-format off
using KernelTypes_MK_NK = ::testing::Types<
// ADataType, BDataType, BScaleDataType, ComputeDataType, CDataType
std::tuple< F16, I4, F16, F16, F16>
>;
// clang-format on
TYPED_TEST_SUITE(TestBatchedGemmBScale_MK_NK, KernelTypes_MK_NK);
#include "test_batched_gemm_b_scale_ut_cases.inc"