mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
implement device batched gemm b scale for wmma (#2825)
* rebased on top of develop
* fixed missing shuffeling and wrong indexing
* added tests for batched_b_scale
* added missing files
* fixed wrong stride computation and removed k batching (for now) due to precision issues
* reinstated k-batching with PRNG constrained to -1..1
* added specialization of GeneratorTensor_3 for int4 and fixed internal overflow
* added k-batching to reference and increased tolerances for test
* changed gemm_b_scale and gemm_universal tests to use correct parameters
* adressed review commentsd
* ported fixes back to non-batched version of b_scale
* adressed review comments
* run clang-format on older commits
* add type-conversion to AccDataType and then to CDataType to exactly mimic GPU's behavior
* added newline at end of file
* reflected changes from muitl-abd branch in batched b_scale
* fixed gfx11 issue
* changed range for pki4 to -1...1 (-0.5...0.5 never really made sense for i4 anyway and always should have caused compiler errors, but since there was no int4 specialization of GeneratorTensor3 until now, this passed
* run clang format
* set range of i4 generation to 0...1 for upstream tests to pass. This replicated previous behavior, which however means that it is NOT properly tested.
* reduced range for pk_i4 even further to 0..0
* removed failing xld instances. Failure now uncovered now that tests were fixed
* removed generation of int4 values entierly
* divide B buffer by BPackedSize
---------
Co-authored-by: Kevin Abraham <kevin.abraham@streamhpc.com>
[ROCm/composable_kernel commit: c4b2da9cbd]
This commit is contained in:
@@ -289,7 +289,6 @@ int main(int argc, char* argv[])
|
||||
case 0: break;
|
||||
case 1:
|
||||
a0_t_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
|
||||
d0_t_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{0.0, 1.0});
|
||||
d1_e_n.GenerateTensorValue(GeneratorTensor_3<D1DataType>{0.0, 1.0});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0.0, 1.0});
|
||||
@@ -303,7 +302,6 @@ int main(int argc, char* argv[])
|
||||
break;
|
||||
default:
|
||||
a0_t_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
|
||||
d0_t_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{0.0, 1.0});
|
||||
d1_e_n.GenerateTensorValue(GeneratorTensor_3<D1DataType>{0.0, 1.0});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0.0, 1.0});
|
||||
|
||||
@@ -275,7 +275,7 @@ int main(int argc, char* argv[])
|
||||
break;
|
||||
case 3:
|
||||
a0_t_k_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-1, 1});
|
||||
d0_t_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{0.0, 1.0});
|
||||
d1_e_n.GenerateTensorValue(GeneratorTensor_3<D1DataType>{0.0, 1.0});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0.0, 1.0});
|
||||
@@ -289,7 +289,7 @@ int main(int argc, char* argv[])
|
||||
break;
|
||||
default:
|
||||
a0_t_k_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-1, 1});
|
||||
d0_t_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{0.0, 1.0});
|
||||
d1_e_n.GenerateTensorValue(GeneratorTensor_3<D1DataType>{0.0, 1.0});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0.0, 1.0});
|
||||
|
||||
@@ -264,7 +264,7 @@ struct GeneratorTensor_2<ck::pk_i4_t>
|
||||
{
|
||||
int hi = std::rand() % (max_value - min_value) + min_value + 8;
|
||||
int lo = std::rand() % (max_value - min_value) + min_value + 8;
|
||||
ck::pk_i4_t r = ((hi << 4) + lo) & 0xff;
|
||||
ck::pk_i4_t r = (((hi & 0xf) << 4) + (lo & 0xf));
|
||||
return r;
|
||||
}
|
||||
};
|
||||
@@ -436,6 +436,22 @@ struct GeneratorTensor_3<ck::f4x2_pk_t>
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GeneratorTensor_3<ck::pk_i4_t>
|
||||
{
|
||||
int min_value = 0;
|
||||
int max_value = 1;
|
||||
|
||||
template <typename... Is>
|
||||
ck::pk_i4_t operator()(Is...)
|
||||
{
|
||||
int hi = std::rand() % (max_value - min_value) + min_value + 8;
|
||||
int lo = std::rand() % (max_value - min_value) + min_value + 8;
|
||||
ck::pk_i4_t r = (((hi & 0xf) << 4) + (lo & 0xf));
|
||||
return r;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GeneratorTensor_3<ck::f6x32_pk_t>
|
||||
{
|
||||
|
||||
@@ -0,0 +1,836 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_batched_gemm.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
#include "ck/host_utility/flush_cache.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename GridwiseGemm,
|
||||
typename ComputePtrOffsetOfStridedBatch,
|
||||
bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
|
||||
index_t MinimumOccupancy = 1,
|
||||
TailNumber TailNum = TailNumber::Full>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
#endif
|
||||
kernel_batched_gemm_b_scale_wmma_cshuffle_v3(
|
||||
typename GridwiseGemm::Argument karg, // This works for now but it actually receives a
|
||||
// DeviceBatchedGemm_Wmma_CShuffleV3::Argument
|
||||
// argument through implicit conversion to base class!
|
||||
const ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch)
|
||||
{
|
||||
#if(defined(__gfx11__) || defined(__gfx12__))
|
||||
#if defined(__gfx11__)
|
||||
// gfx11 does not support *_atomic_pk_add_f16/bf16 instructions
|
||||
using c_data_type = remove_cvref_t<remove_pointer_t<decltype(karg.p_e_grid)>>;
|
||||
if constexpr(!(CGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd &&
|
||||
(std::is_same_v<c_data_type, ck::half_t> ||
|
||||
std::is_same_v<c_data_type, ck::bhalf_t>)))
|
||||
{
|
||||
#endif
|
||||
// The normal approach to batching would be to increase the grid size by just stretching out
|
||||
// the grid Z dimension (which is the outermost dimension), but this depends on lower level
|
||||
// functions not directly using the Z dimension for other calculations. As it turns out, k
|
||||
// batching does rely directly on blockIdx.Z through SplitKBatchOffset. Therefore, for now
|
||||
// we will use the grid Y dimension for batching. This may be a bit fragile.
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
const index_t g_idx = amd_wave_read_first_lane(blockIdx.y);
|
||||
|
||||
const long_index_t a_batch_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx));
|
||||
const long_index_t b_batch_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx));
|
||||
const long_index_t c_batch_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx));
|
||||
const long_index_t b_scale_batch_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetScaleBPtrOffset(g_idx));
|
||||
|
||||
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
|
||||
|
||||
// shift A matrices pointer for splitk
|
||||
typename GridwiseGemm::AsGridPointer p_as_grid_shift;
|
||||
static_for<0, GridwiseGemm::NumATensor, 1>{}([&](auto i) {
|
||||
using ADataType_ =
|
||||
remove_cvref_t<tuple_element_t<i.value, typename GridwiseGemm::AsDataType_>>;
|
||||
p_as_grid_shift(i) = static_cast<const ADataType_*>(karg.p_as_grid[i]) +
|
||||
splitk_batch_offset.a_k_split_offset[i] + a_batch_offset;
|
||||
});
|
||||
|
||||
// shift B matrices pointer for splitk
|
||||
typename GridwiseGemm::BsGridPointer p_bs_grid_shift;
|
||||
static_for<0, GridwiseGemm::NumBTensor, 1>{}([&](auto i) {
|
||||
using BDataType_ =
|
||||
remove_cvref_t<tuple_element_t<i.value, typename GridwiseGemm::BsDataType_>>;
|
||||
p_bs_grid_shift(i) = static_cast<const BDataType_*>(karg.p_bs_grid[i]) +
|
||||
splitk_batch_offset.b_k_split_offset[i] + b_batch_offset;
|
||||
});
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
|
||||
p_as_grid_shift,
|
||||
p_bs_grid_shift,
|
||||
karg.p_ds_grid,
|
||||
karg.p_e_grid + splitk_batch_offset.c_reduce_offset + c_batch_offset,
|
||||
karg.p_b_scale_grid + b_scale_batch_offset + splitk_batch_offset.scale_k_split_offset,
|
||||
p_shared,
|
||||
karg,
|
||||
karg.a_element_op,
|
||||
karg.b_element_op,
|
||||
karg.cde_element_op);
|
||||
#if defined(__gfx11__)
|
||||
}
|
||||
#endif
|
||||
#else
|
||||
ignore = karg;
|
||||
ignore = compute_ptr_offset_of_batch;
|
||||
#endif
|
||||
}
|
||||
|
||||
/// @brief \"Universal\" Batched GEMM operation without SplitK support.
|
||||
///
|
||||
/// @par Overview
|
||||
/// This GEMM operation implements the following mathematical equation:
|
||||
/// C{G,M,N} = C_op(A_op(A{G,M,K}) * B_op(B{G,K,N}))
|
||||
/// Where A, B are input tensors and C is the output tensor. The A/B/C_op are
|
||||
/// elementwise operations applied to the A, B, and C tensors, respectively.
|
||||
/// The \"universal\" gemm comes with multiple pipelines optimized for different usage
|
||||
/// scenarios. That's why it's called \"universal\". It's universal through its design
|
||||
/// and versatilty.
|
||||
///
|
||||
/// @note This Kernel implementation currently does not support the SplitK algorithm.
|
||||
///
|
||||
/// @tparam ALayout A tensor data layout.
|
||||
/// @tparam BLayout B tensor data layout.
|
||||
/// @tparam CLayout C tensor data layout.
|
||||
/// @tparam ADataType A tensor data type.
|
||||
/// @tparam BDataType B tensor data type.
|
||||
/// @tparam CDataType C tensor data type.
|
||||
/// @tparam AccDataType The accumulation data type related to the hardware
|
||||
/// matrix-multiplication instruction.
|
||||
/// @tparam CShuffleDataType The data type used to store matrix-multiplication results into
|
||||
/// LDS memory during \"CShuffle\" data layout optimization.
|
||||
/// @tparam AElementwiseOperation Elementwise operation applied to the A input tensor elements.
|
||||
/// @tparam BElementwiseOperation Elementwise operation applied to the B input tensor elements.
|
||||
/// @tparam CElementwiseOperation Elementwise operation applied to the C output tensor
|
||||
/// (after GEMM).
|
||||
/// @tparam GemmSpec Determines used "padding" version.
|
||||
/// @tparam BlockSize The number of threads within workgroup.
|
||||
/// @tparam MPerBlock The input/output data tile size in the M dimension.
|
||||
/// @tparam NPerBlock The input/output data tile size in the N dimension.
|
||||
/// @tparam KPerBlock The input data tile size in the K dimension.
|
||||
/// @tparam AK1 The vector load size from global memory for A tensor.
|
||||
/// @tparam BK1 The vector load size from global memory for B tensor.
|
||||
/// @tparam MPerWmma M size of Wave Matrix Multiply Accumulate (WMMA) instruction.
|
||||
/// @tparam NPerWmma N size of Wave Matrix Multiply Accumulate (WMMA) instruction.
|
||||
/// @tparam MRepeat The number of iterations in the M dimension over output tile per wavefront.
|
||||
/// @tparam NRepeat The number of iterations in the N dimension over output tile per wavefront.
|
||||
/// @tparam ABlockTransferThreadClusterLengths_AK0_M_AK1 Spatial thread distribution over the input
|
||||
/// data. Can be interpreted as the answer
|
||||
/// to the question, "How many threads can be
|
||||
/// arranged on each input data axis?"
|
||||
/// @tparam ABlockTransferThreadClusterArrangeOrder The order of thread spatial distribution over
|
||||
/// the input tensor dimension. Can be interpreted
|
||||
/// as the answer to the question: "In which
|
||||
/// order to spread threads through tensor axes?".
|
||||
/// @tparam ABlockTransferSrcAccessOrder The order of accessing input tensor axes. Can be
|
||||
/// interpreted as the answer to the question "Which dimension
|
||||
/// to read first? And which next?" etc.
|
||||
/// @tparam ABlockTransferSrcVectorDim The index of axis on which we could do vectorized memory
|
||||
/// access - the one with contiguous memory.
|
||||
/// @tparam ABlockTransferSrcScalarPerVector The size of vector access instruction - the number of
|
||||
/// elements accessed per thread per instruction.
|
||||
/// @tparam ABlockTransferDstScalarPerVector_AK1 The size of vectorized store into LDS memory.
|
||||
/// @tparam ABlockLdsExtraM Whether to use padding for LDS or not. With
|
||||
/// universal GEMM there's no need for padding.
|
||||
/// @tparam BBlockTransferThreadClusterLengths_BK0_N_BK1 Spatial thread distribution over the input
|
||||
/// data. Can be interpreted as the answer
|
||||
/// to the question: "How many threads to
|
||||
/// arrange on each input data axis?"
|
||||
/// @tparam BBlockTransferThreadClusterArrangeOrder The order of thread spatial distribution over
|
||||
/// the input tensor dimension. Can be interpreted
|
||||
/// as the answer to the question: "In which
|
||||
/// order to spread threads through tensor axes?".
|
||||
/// @tparam BBlockTransferSrcAccessOrder he order of accessing input tensor axes. Can be
|
||||
/// interpreted as the answer to the question "Which dimension
|
||||
/// to read first? And which next?" etc.
|
||||
/// @tparam BBlockTransferSrcVectorDim The index of axis on which we could do vectorized memory
|
||||
/// access - the one with contiguous memory.
|
||||
/// @tparam BBlockTransferSrcScalarPerVector The size of vector access instruction - the number of
|
||||
/// elements accessed per thread per instruction.
|
||||
/// @tparam BBlockTransferDstScalarPerVector_BK1 The size of vectorized store into LDS memory.
|
||||
/// @tparam BBlockLdsExtraN Whether to use padding for LDS or not. With
|
||||
/// universal GEMM there's no need for padding.
|
||||
/// @tparam CShuffleMRepeatPerShuffle The number of matrix-multiplication instructions
|
||||
/// results to process per wave per iteration of CShuffle
|
||||
/// in M dimension.
|
||||
/// @tparam CShuffleNRepeatPerShuffle The number of matrix-multiplication instructions
|
||||
/// results to process per wave per iteration of CShuffle
|
||||
/// in N dimension.
|
||||
/// @tparam CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock The spatial
|
||||
/// thread distribution used for storing data into output
|
||||
/// tensor across output data layout dimensions.
|
||||
/// @tparam CShuffleBlockTransferScalarPerVector_NPerBlock The size of vectorized memory access.
|
||||
/// Used when storing data to output tensor.
|
||||
/// @tparam BlkGemmPipeSched The version of blockwise-gemm pipeline scheduler (interwave or
|
||||
/// intrawave).
|
||||
/// @tparam BlkGemmPipelineVer The version of blockwise-gemm pipeline.
|
||||
/// @tparam ComputeTypeA Data type used for A input of hardware matrix-multiplication
|
||||
/// instructions.
|
||||
/// @tparam ComputeTypeB Data type used for B input of hardware matrix-multiplication
|
||||
/// instructions.
|
||||
/// @tparam PermuteA Whether the A input tensor has gridwise-gemm friendly data layout
|
||||
/// in global memory. Currently not supported!
|
||||
/// @tparam PermuteB Whether the B input tensor has gridwise-gemm friendly data layout
|
||||
/// in global memory (pre-shuffled). Currently not supported!
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename BScaleDataType,
|
||||
typename CDataType,
|
||||
typename AccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
GemmSpecialization GemmSpec,
|
||||
index_t BlockSize,
|
||||
index_t ScaleBlockN, // scale block for N
|
||||
index_t ScaleBlockK, // scale block for K
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t AK1,
|
||||
index_t BK1,
|
||||
index_t MPerWmma,
|
||||
index_t NPerWmma,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
index_t ABlockTransferSrcVectorDim,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t ABlockTransferDstScalarPerVector_AK1,
|
||||
bool ABlockLdsExtraM,
|
||||
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
index_t BBlockTransferSrcVectorDim,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferDstScalarPerVector_BK1,
|
||||
bool BBlockLdsExtraN,
|
||||
index_t CShuffleMRepeatPerShuffle,
|
||||
index_t CShuffleNRepeatPerShuffle,
|
||||
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1,
|
||||
typename ComputeTypeA = CDataType,
|
||||
typename ComputeTypeB = ComputeTypeA,
|
||||
bool PermuteA = false,
|
||||
bool PermuteB = false>
|
||||
struct DeviceBatchedGemm_Wmma_CShuffleV3_BScale
|
||||
: public DeviceBatchedGemmV2BScale<ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
BScaleDataType,
|
||||
CDataType,
|
||||
ScaleBlockN,
|
||||
ScaleBlockK,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation>
|
||||
{
|
||||
// We are inheriting from DeviceBatchedGemm and this base class does not support permuteA and
|
||||
// permuteB arguments so for now we are not including this functionality.
|
||||
static_assert(PermuteA == false,
|
||||
"Permute A functionality not supported by DeviceBatchedGemm operations.\n");
|
||||
static_assert(PermuteB == false,
|
||||
"Permute B functionality not supported by DeviceBatchedGemm operations.\n");
|
||||
|
||||
struct ComputePtrOffsetOfStridedBatch
|
||||
{
|
||||
ComputePtrOffsetOfStridedBatch(index_t BatchStrideA,
|
||||
index_t BatchStrideB,
|
||||
index_t BatchStrideC,
|
||||
index_t BatchStrideScaleB)
|
||||
: BatchStrideA_(BatchStrideA),
|
||||
BatchStrideB_(BatchStrideB),
|
||||
BatchStrideC_(BatchStrideC),
|
||||
BatchStrideScaleB_(BatchStrideScaleB)
|
||||
{
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideA_);
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideB_) / GridwiseGemm::BPackedSize;
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideC_);
|
||||
}
|
||||
__host__ __device__ constexpr long_index_t GetScaleBPtrOffset(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideScaleB_);
|
||||
}
|
||||
|
||||
private:
|
||||
index_t BatchStrideA_;
|
||||
index_t BatchStrideB_;
|
||||
index_t BatchStrideC_;
|
||||
index_t BatchStrideScaleB_;
|
||||
};
|
||||
|
||||
// GridwiseGemm
|
||||
using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3_b_scale<
|
||||
ALayout,
|
||||
BLayout,
|
||||
Tuple<>, // DsLayout
|
||||
CLayout,
|
||||
Tuple<ADataType>,
|
||||
Tuple<BDataType>,
|
||||
BScaleDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
Tuple<>, // DsDataType
|
||||
CDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
GemmSpec,
|
||||
BlockSize,
|
||||
ScaleBlockN,
|
||||
ScaleBlockK,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
AK1,
|
||||
BK1,
|
||||
MPerWmma,
|
||||
NPerWmma,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_AK1,
|
||||
false,
|
||||
ABlockLdsExtraM,
|
||||
BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_BK1,
|
||||
false,
|
||||
BBlockLdsExtraN,
|
||||
CShuffleMRepeatPerShuffle,
|
||||
CShuffleNRepeatPerShuffle,
|
||||
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
Sequence<CShuffleBlockTransferScalarPerVector_NPerBlock>,
|
||||
BlkGemmPipeSched,
|
||||
BlkGemmPipelineVer,
|
||||
ComputeTypeA,
|
||||
ComputeTypeB,
|
||||
PermuteA, // PermuteA not supported by DeviceBatchedGemm base class.
|
||||
PermuteB>; // PermuteB not supported by DeviceBatchedGemm base class.
|
||||
|
||||
// Argument
|
||||
struct Argument : public GridwiseGemm::Argument
|
||||
{
|
||||
__host__ Argument(const ADataType* p_a_grid_,
|
||||
const BDataType* p_b_grid_,
|
||||
CDataType* p_c_grid_,
|
||||
index_t M_,
|
||||
index_t N_,
|
||||
index_t K_,
|
||||
index_t StrideA_,
|
||||
index_t StrideB_,
|
||||
index_t StrideC_,
|
||||
index_t StrideScaleB_,
|
||||
index_t BatchStrideA_,
|
||||
index_t BatchStrideB_,
|
||||
index_t BatchStrideC_,
|
||||
index_t BatchStrideScaleB_,
|
||||
const BScaleDataType* p_b_scale_grid_,
|
||||
index_t Batch_,
|
||||
index_t k_batch_,
|
||||
AElementwiseOperation a_element_op_,
|
||||
BElementwiseOperation b_element_op_,
|
||||
CElementwiseOperation c_element_op_,
|
||||
bool is_reduce_ = false)
|
||||
: GridwiseGemm::Argument(std::array<const void*, 1>{p_a_grid_},
|
||||
std::array<const void*, 1>{p_b_grid_},
|
||||
std::array<const void*, 0>{}, // p_ds_grid_
|
||||
p_c_grid_,
|
||||
M_,
|
||||
N_,
|
||||
K_,
|
||||
std::array<index_t, 1>{StrideA_},
|
||||
std::array<index_t, 1>{StrideB_},
|
||||
std::array<index_t, 0>{}, // StrideDs_
|
||||
StrideC_,
|
||||
StrideScaleB_,
|
||||
p_b_scale_grid_,
|
||||
k_batch_,
|
||||
a_element_op_,
|
||||
b_element_op_,
|
||||
c_element_op_,
|
||||
is_reduce_),
|
||||
Batch(Batch_),
|
||||
compute_ptr_offset_of_batch{
|
||||
BatchStrideA_, BatchStrideB_, BatchStrideC_, BatchStrideScaleB_}
|
||||
{
|
||||
}
|
||||
|
||||
index_t Batch;
|
||||
ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch;
|
||||
};
|
||||
|
||||
/// @brief Helper structure responsible for kernel invocation.
|
||||
///
|
||||
/// @paragraph The `Invoker` class is responsible for preparation and invocation of actual GPU
|
||||
/// kernel function. It usually determines the launched grid size prepares kernel
|
||||
/// arguments as well as perform specific kernel configuration selection based on
|
||||
/// runtime arguments.
|
||||
///
|
||||
/// @note If appropriately configured it may measure kernel execution time.
|
||||
///
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
/// @brief This function issues GPU kernel execution.
|
||||
/// @param arg The GPU kernel arguments.
|
||||
/// @param stream_config The HIP stream configuration helper structure.
|
||||
/// @return The kernel's average execution time (if time measurement is
|
||||
/// enabled).
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
if(stream_config.log_level_ > 0)
|
||||
{
|
||||
arg.Print();
|
||||
GridwiseGemm::BlockwiseGemmPipe::HotLoopInstList::Print();
|
||||
}
|
||||
|
||||
if(!GridwiseGemm::CheckValidity(arg))
|
||||
{
|
||||
throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
|
||||
}
|
||||
|
||||
index_t gdx, gdy, gdz;
|
||||
std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.KBatch);
|
||||
|
||||
// The normal approach to batching would be to increase the grid size by just stretching
|
||||
// out the grid Z dimension (which is the outermost dimension), but this depends on
|
||||
// lower level functions not directly using the Z dimension for other calculations. As
|
||||
// it turns out, k batching does rely directly on blockIdx.Z through SplitKBatchOffset.
|
||||
// Therefore, for now we will use the grid Y dimension for batching. This may be a bit
|
||||
// fragile.
|
||||
gdy *= arg.Batch;
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
index_t k_grain = arg.KBatch * KPerBlock;
|
||||
index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock;
|
||||
|
||||
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
|
||||
|
||||
const auto Run = [&](const auto& kernel) {
|
||||
if(stream_config.flush_cache)
|
||||
{
|
||||
Argument arg_ = arg;
|
||||
|
||||
const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAsGridDescriptor_AK0_M_AK1(
|
||||
arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideAs, arg_.AK0);
|
||||
const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBsGridDescriptor_BK0_N_BK1(
|
||||
arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideBs, arg_.BK0);
|
||||
|
||||
// Packed sizes are 1 for all implemented data types but we include it anyway
|
||||
// for future compatibility.
|
||||
// Note: the grid descriptors and size_a / size_b do *not* take batching into
|
||||
// account, so we have to manually multiply overall buffer sizes for rotating
|
||||
// memory by batch.
|
||||
std::array<std::size_t, 1> size_as_buffers;
|
||||
size_as_buffers[0] = a_grid_desc_ak0_m_ak1[Number<0>{}].GetElementSpaceSize() *
|
||||
sizeof(ADataType) / GridwiseGemm::APackedSize * arg_.Batch;
|
||||
|
||||
std::array<std::size_t, 1> size_bs_buffers;
|
||||
size_bs_buffers[0] = b_grid_desc_bk0_n_bk1[Number<0>{}].GetElementSpaceSize() *
|
||||
sizeof(BDataType) / GridwiseGemm::BPackedSize * arg_.Batch;
|
||||
|
||||
ck::utility::RotatingMemWrapperMultiABD<Argument,
|
||||
Tuple<ADataType>,
|
||||
Tuple<BDataType>,
|
||||
Tuple<>>
|
||||
rotating_mem(arg_,
|
||||
stream_config.rotating_count,
|
||||
size_as_buffers,
|
||||
size_bs_buffers,
|
||||
std::array<std::size_t, 0>{});
|
||||
rotating_mem.Print();
|
||||
|
||||
auto run_flush_cache = [&]() {
|
||||
ck::utility::flush_icache();
|
||||
rotating_mem.Next();
|
||||
// clear c mem
|
||||
if(arg_.KBatch > 1)
|
||||
// Note: we multiply by batch since we want to clear the C matrix for
|
||||
// the whole batch. Untested since we don't have k batching ATM.
|
||||
// Note: This seems incorrect for non-contiguous memory layouts for C
|
||||
// (padding, gaps).
|
||||
HIP_CHECK_ERROR(
|
||||
hipMemsetAsync(arg_.p_e_grid,
|
||||
0,
|
||||
arg_.Batch * arg_.M * arg_.N * sizeof(CDataType),
|
||||
stream_config.stream_id_));
|
||||
};
|
||||
|
||||
ave_time = ck::utility::launch_and_time_kernel_with_preprocess<false>(
|
||||
stream_config,
|
||||
run_flush_cache,
|
||||
kernel,
|
||||
dim3(gdx, gdy, gdz),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg_,
|
||||
arg_.compute_ptr_offset_of_batch);
|
||||
}
|
||||
else
|
||||
{
|
||||
auto clear_workspace = [&]() {
|
||||
// clear c mem
|
||||
if(arg.KBatch > 1)
|
||||
// Note: we multiply by batch since we want to clear the C matrix for
|
||||
// the whole batch. Untested since we don't have k batching ATM.
|
||||
// Note: This seems incorrect for non-contiguous memory layouts for C
|
||||
// (padding, gaps).
|
||||
HIP_CHECK_ERROR(
|
||||
hipMemsetAsync(arg.p_e_grid,
|
||||
0,
|
||||
arg.Batch * arg.M * arg.N * sizeof(CDataType),
|
||||
stream_config.stream_id_));
|
||||
};
|
||||
|
||||
ave_time = ck::utility::launch_and_time_kernel_with_preprocess<false>(
|
||||
stream_config,
|
||||
clear_workspace,
|
||||
kernel,
|
||||
dim3(gdx, gdy, gdz),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg,
|
||||
arg.compute_ptr_offset_of_batch);
|
||||
}
|
||||
};
|
||||
|
||||
constexpr index_t minimum_occupancy = []() {
|
||||
if constexpr(BlkGemmPipeSched == BlockGemmPipelineScheduler::Interwave)
|
||||
{
|
||||
return 2;
|
||||
}
|
||||
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
|
||||
{
|
||||
return (MPerBlock * NPerBlock / BlockSize <= 128) ? 2 : 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
return 1;
|
||||
}
|
||||
}();
|
||||
|
||||
if(has_main_k_block_loop)
|
||||
{
|
||||
// Tail number always full
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 ||
|
||||
BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
|
||||
{
|
||||
if(arg.KBatch > 1)
|
||||
{
|
||||
const auto kernel = kernel_batched_gemm_b_scale_wmma_cshuffle_v3<
|
||||
GridwiseGemm,
|
||||
ComputePtrOffsetOfStridedBatch,
|
||||
true,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy>;
|
||||
Run(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_batched_gemm_b_scale_wmma_cshuffle_v3<
|
||||
GridwiseGemm,
|
||||
remove_reference_t<ComputePtrOffsetOfStridedBatch>,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Pipeline not implemented");
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// Tail number always 1
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
|
||||
{
|
||||
if(arg.KBatch > 1)
|
||||
{
|
||||
const auto kernel = kernel_batched_gemm_b_scale_wmma_cshuffle_v3<
|
||||
GridwiseGemm,
|
||||
ComputePtrOffsetOfStridedBatch,
|
||||
false,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy>;
|
||||
Run(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_batched_gemm_b_scale_wmma_cshuffle_v3<
|
||||
GridwiseGemm,
|
||||
remove_reference_t<ComputePtrOffsetOfStridedBatch>,
|
||||
false,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
static constexpr bool IsValidCompilationParameter()
|
||||
{
|
||||
// TODO: properly implement this check
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(!ck::is_gfx11_supported() && !ck::is_gfx12_supported())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if constexpr(std::is_same_v<CDataType, ck::half_t> ||
|
||||
std::is_same_v<CDataType, ck::bhalf_t>)
|
||||
{
|
||||
if(arg.KBatch > 1 && ck::is_gfx11_supported())
|
||||
{
|
||||
// gfx11 does not support *_atomic_pk_add_f16/bf16 instructions
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(std::is_same_v<ComputeTypeA, f8_t> || std::is_same_v<ComputeTypeA, bf8_t> ||
|
||||
std::is_same_v<ComputeTypeB, f8_t> || std::is_same_v<ComputeTypeB, bf8_t>)
|
||||
{
|
||||
if(ck::is_gfx11_supported())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding ||
|
||||
GemmSpec == GemmSpecialization::NKPadding ||
|
||||
GemmSpec == GemmSpecialization::MNKPadding ||
|
||||
GemmSpec == GemmSpecialization::KPadding))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
return GridwiseGemm::CheckValidity(arg);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
|
||||
index_t GetKPerBlock() override { return KPerBlock; }
|
||||
bool GetPermuteB() override { return PermuteB; }
|
||||
|
||||
static auto MakeArgument(const ADataType* p_a,
|
||||
const BDataType* p_b,
|
||||
CDataType* p_c,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
index_t StrideC,
|
||||
index_t StrideScaleB,
|
||||
index_t BatchStrideA,
|
||||
index_t BatchStrideB,
|
||||
index_t BatchStrideC,
|
||||
index_t BatchStrideScaleB,
|
||||
const BScaleDataType* p_b_scale,
|
||||
index_t Batch,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
index_t KBatch = 1)
|
||||
{
|
||||
return Argument{p_a,
|
||||
p_b,
|
||||
p_c,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideC,
|
||||
StrideScaleB,
|
||||
BatchStrideA,
|
||||
BatchStrideB,
|
||||
BatchStrideC,
|
||||
BatchStrideScaleB,
|
||||
p_b_scale,
|
||||
Batch,
|
||||
KBatch};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
void* p_c,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
index_t StrideC,
|
||||
index_t StrideScaleB,
|
||||
index_t BatchStrideA,
|
||||
index_t BatchStrideB,
|
||||
index_t BatchStrideC,
|
||||
index_t BatchStrideScaleB,
|
||||
const void* p_b_scale,
|
||||
index_t Batch,
|
||||
index_t KBatch,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op) override
|
||||
{
|
||||
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
|
||||
static_cast<const BDataType*>(p_b),
|
||||
static_cast<CDataType*>(p_c),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideC,
|
||||
StrideScaleB,
|
||||
BatchStrideA,
|
||||
BatchStrideB,
|
||||
BatchStrideC,
|
||||
BatchStrideScaleB,
|
||||
static_cast<const BScaleDataType*>(p_b_scale),
|
||||
Batch,
|
||||
KBatch,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
|
||||
{BlockGemmPipelineScheduler::Intrawave, "Intrawave"},
|
||||
{BlockGemmPipelineScheduler::Interwave, "Interwave"}};
|
||||
|
||||
std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
|
||||
{BlockGemmPipelineVersion::v1, "v1"},
|
||||
{BlockGemmPipelineVersion::v2, "v2"},
|
||||
{BlockGemmPipelineVersion::v3, "v3"},
|
||||
{BlockGemmPipelineVersion::v4, "v4"},
|
||||
{BlockGemmPipelineVersion::v5, "v5"}};
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceBatchedGemm_Wmma_CShuffleV3_BScale"
|
||||
<< "<"
|
||||
<< getGemmSpecializationString(GemmSpec) << ", "
|
||||
<< std::string(ALayout::name)[0]
|
||||
<< std::string(BLayout::name)[0]
|
||||
<< std::string(CLayout::name)[0]
|
||||
<< ">"
|
||||
<< " BlkSize: "
|
||||
<< BlockSize << ", "
|
||||
<< "BlkTile: "
|
||||
<< MPerBlock << "x" << NPerBlock << "x" << KPerBlock << ", "
|
||||
<< "WaveTile: "
|
||||
<< MPerWmma << "x"<<NPerWmma << ", "
|
||||
<< "WaveMap: "
|
||||
<< MRepeat << "x" << NRepeat << ", "
|
||||
<< "VmemReadVec: "
|
||||
<< ABlockTransferSrcScalarPerVector << "x" << BBlockTransferSrcScalarPerVector << ", "
|
||||
<< "BlkGemmPipelineScheduler: "
|
||||
<< BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
|
||||
<< "BlkGemmPipelineVersion: "
|
||||
<< BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
|
||||
<< "BlkGemmPipelinePrefetchStages: "
|
||||
<< GridwiseGemm::BlockwiseGemmPipe::PrefetchStages << ", "
|
||||
<< "KPack: "
|
||||
<< GridwiseGemm::KPack;
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
REGISTER_EXTRA_PRINTING_METHODS
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -222,6 +222,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
|
||||
using typename Base::AsGridPointer;
|
||||
using typename Base::BsGridPointer;
|
||||
using typename Base::DsGridPointer;
|
||||
using AsDataType_ = AsDataType;
|
||||
using BsDataType_ = BsDataType;
|
||||
|
||||
struct Problem
|
||||
{
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include <stdexcept>
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -30,14 +31,18 @@ struct ReferenceBatchedGemm : public device::BaseOperator
|
||||
Tensor<CDataType>& c_g_m_n,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op)
|
||||
CElementwiseOperation c_element_op,
|
||||
const int k_batch = 1)
|
||||
: a_g_m_k_{a_g_m_k},
|
||||
b_g_k_n_{b_g_k_n},
|
||||
c_g_m_n_{c_g_m_n},
|
||||
a_element_op_{a_element_op},
|
||||
b_element_op_{b_element_op},
|
||||
c_element_op_{c_element_op}
|
||||
c_element_op_{c_element_op},
|
||||
k_batch_(k_batch)
|
||||
{
|
||||
if(k_batch < 1)
|
||||
throw std::invalid_argument("Batch size must be at least 1");
|
||||
}
|
||||
|
||||
const Tensor<ADataType>& a_g_m_k_;
|
||||
@@ -47,6 +52,8 @@ struct ReferenceBatchedGemm : public device::BaseOperator
|
||||
AElementwiseOperation a_element_op_;
|
||||
BElementwiseOperation b_element_op_;
|
||||
CElementwiseOperation c_element_op_;
|
||||
|
||||
const int k_batch_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
@@ -59,23 +66,54 @@ struct ReferenceBatchedGemm : public device::BaseOperator
|
||||
auto f_gmk_gkn_gmn = [&](auto g, auto m, auto n) {
|
||||
const int K = arg.a_g_m_k_.mDesc.GetLengths()[2];
|
||||
|
||||
AccDataType v_acc = 0;
|
||||
// simulate fp accuacy implications of k batching
|
||||
std::vector<CDataType> partialSums(arg.k_batch_);
|
||||
|
||||
for(int k = 0; k < K; ++k)
|
||||
for(int batchIdx = 0; batchIdx < arg.k_batch_; ++batchIdx)
|
||||
{
|
||||
ADataType v_a;
|
||||
BDataType v_b;
|
||||
int batchSize = std::max(K / arg.k_batch_, 1);
|
||||
int batchStart = batchSize * batchIdx;
|
||||
int batchEnd = batchSize * (batchIdx + 1);
|
||||
// add any extra round-off to last batch
|
||||
if(batchIdx == arg.k_batch_ - 1)
|
||||
batchEnd = K;
|
||||
|
||||
arg.a_element_op_(v_a, arg.a_g_m_k_(g, m, k));
|
||||
arg.b_element_op_(v_b, arg.b_g_k_n_(g, k, n));
|
||||
AccDataType v_acc = 0;
|
||||
for(int k = batchStart; k < batchEnd; ++k)
|
||||
{
|
||||
ADataType v_a;
|
||||
BDataType v_b;
|
||||
|
||||
v_acc +=
|
||||
ck::type_convert<AccDataType>(v_a) * ck::type_convert<AccDataType>(v_b);
|
||||
arg.a_element_op_(v_a, arg.a_g_m_k_(g, m, k));
|
||||
arg.b_element_op_(v_b, arg.b_g_k_n_(g, k, n));
|
||||
|
||||
v_acc +=
|
||||
ck::type_convert<AccDataType>(v_a) * ck::type_convert<AccDataType>(v_b);
|
||||
}
|
||||
|
||||
AccDataType v_c;
|
||||
arg.c_element_op_(v_c, v_acc);
|
||||
partialSums[batchIdx] = ck::type_convert<CDataType>(v_c);
|
||||
}
|
||||
|
||||
AccDataType v_c;
|
||||
|
||||
arg.c_element_op_(v_c, v_acc);
|
||||
// finally, sum up partial sums
|
||||
// note that we can't simulate the random nature of atomic additions, but at least
|
||||
// we can simulate the effect of partial sums
|
||||
AccDataType v_c = 0;
|
||||
if(arg.k_batch_ > 1)
|
||||
{
|
||||
for(int batchIdx = 0; batchIdx < arg.k_batch_; batchIdx++)
|
||||
{
|
||||
// mimic the way fp operations would be done on GPU for k-batching
|
||||
v_c = ck::type_convert<AccDataType>(ck::type_convert<CDataType>(
|
||||
ck::type_convert<AccDataType>(v_c) +
|
||||
ck::type_convert<AccDataType>(partialSums[batchIdx])));
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
v_c = ck::type_convert<AccDataType>(partialSums[0]);
|
||||
}
|
||||
|
||||
arg.c_g_m_n_(g, m, n) = ck::type_convert<CDataType>(v_c);
|
||||
};
|
||||
@@ -108,9 +146,11 @@ struct ReferenceBatchedGemm : public device::BaseOperator
|
||||
Tensor<CDataType>& c_g_m_n,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op)
|
||||
CElementwiseOperation c_element_op,
|
||||
const int k_batch = 1)
|
||||
{
|
||||
return Argument{a_g_m_k, b_g_k_n, c_g_m_n, a_element_op, b_element_op, c_element_op};
|
||||
return Argument{
|
||||
a_g_m_k, b_g_k_n, c_g_m_n, a_element_op, b_element_op, c_element_op, k_batch};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
@@ -5,6 +5,8 @@
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl_fpAintB_b_scale.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3_b_scale.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include <memory>
|
||||
@@ -16,6 +18,8 @@ namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
#if defined(CK_USE_XDL)
|
||||
#if(defined(CK_ENABLE_FP16) || defined(CK_ENABLE_FP8))
|
||||
void add_device_batched_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn_mem_v2_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceBatchedGemmV2BScale<Row,
|
||||
@@ -31,6 +35,25 @@ void add_device_batched_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn_mem_v2_default_inst
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#endif // CK_USE_XDL
|
||||
|
||||
#if defined(CK_USE_WMMA)
|
||||
#if(defined(CK_ENABLE_FP16) || defined(CK_ENABLE_FP8)) // TODO: really, or?
|
||||
void add_device_batched_gemm_b_scale_wmma_f16_i4_f16_mk_nk_mn_mem_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceBatchedGemmV2BScale<Row,
|
||||
Col,
|
||||
Row,
|
||||
F16,
|
||||
I4,
|
||||
F16,
|
||||
F16,
|
||||
1,
|
||||
128,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif // CK_ENABLE_FP16 || CK_ENABLE_FP8
|
||||
#endif // CK_USE_WMMA
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
@@ -40,6 +63,7 @@ template <typename ADataType,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
index_t ScaleBlockK>
|
||||
|
||||
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceBatchedGemmV2BScale<
|
||||
ALayout,
|
||||
BLayout,
|
||||
@@ -77,8 +101,14 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceBatche
|
||||
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
#if defined(CK_USE_XDL)
|
||||
add_device_batched_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn_mem_v2_default_instances(
|
||||
op_ptrs);
|
||||
#endif // CK_USE_XDL
|
||||
#if defined(CK_USE_WMMA)
|
||||
add_device_batched_gemm_b_scale_wmma_f16_i4_f16_mk_nk_mn_mem_default_instances(
|
||||
op_ptrs);
|
||||
#endif // CK_USE_WMMA
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
# ONLY XDL_KERNELS
|
||||
# ONLY XDL_AND_WMMA_KERNELS
|
||||
set(BATCHED_GEMM_B_SCALE_INSTANCES)
|
||||
|
||||
list(APPEND BATCHED_GEMM_B_SCALE_INSTANCES
|
||||
device_batched_gemm_b_scale_xdl_f16_i4_f16/device_batched_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn_mem_v2_default_instance.cpp
|
||||
device_batched_gemm_b_scale_wmma_f16_i4_f16/device_batched_gemm_b_scale_wmma_f16_i4_f16_mk_nk_mn_mem_default_instance.cpp
|
||||
)
|
||||
|
||||
set_source_files_properties(device_batched_gemm_b_scale_xdl_f16_i4_f16/device_batched_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn_mem_v2_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
|
||||
set_source_files_properties(device_batched_gemm_b_scale_wmma_f16_i4_f16/device_batched_gemm_b_scale_wmma_f16_i4_f16_mk_nk_mn_mem_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
|
||||
|
||||
|
||||
add_instance_library(device_batched_gemm_b_scale_instance ${BATCHED_GEMM_B_SCALE_INSTANCES})
|
||||
@@ -0,0 +1,72 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3_b_scale.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using I4 = pk_i4_t;
|
||||
using F16 = half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = tensor_layout::gemm::RowMajor;
|
||||
using Col = tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
template <index_t... Is>
|
||||
using S = Sequence<Is...>;
|
||||
|
||||
using PassThrough = element_wise::PassThrough;
|
||||
|
||||
static constexpr auto GemmDefault = GemmSpecialization::Default;
|
||||
static constexpr auto GemmKPadding = GemmSpecialization::KPadding;
|
||||
static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding;
|
||||
static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding;
|
||||
|
||||
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
|
||||
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
|
||||
|
||||
template <BlockGemmPipelineScheduler BlkGemmPipeSched, GemmSpecialization GemmSpec>
|
||||
using device_batched_gemm_b_scale_wmma_f16_i4_f16_mk_nk_mn_mem_instances = std::tuple<
|
||||
// clang-format off
|
||||
//################################| ALayout| BLayout| CLayout|AData|BData| BScale| CData| AccData| Cshuffle| A| B| C| GEMM| Block| Scale| Scale| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| Compute| Compute| PermuteA| PermuteB|
|
||||
//################################| | | | Type| Type| Data| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| Block| Block| | |Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock| ScalarPerVector| Pipeline| Pipeline| TypeA| TypeB| | |
|
||||
//################################| | | | | | Type| | | | Operation| Operation| Operation| | | N| K| | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| Scheduler| Verision| | | | |
|
||||
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
|
||||
DeviceBatchedGemm_Wmma_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 64, 8, 8, 16, 16, 4, 2, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //0
|
||||
DeviceBatchedGemm_Wmma_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 64, 8, 8, 16, 16, 4, 2, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1, half_t, half_t, false, false>, //1
|
||||
DeviceBatchedGemm_Wmma_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 64, 8, 8, 16, 16, 4, 2, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, Interwave, BlockGemmPipelineVersion::v1, half_t, half_t, false, false>, //2
|
||||
DeviceBatchedGemm_Wmma_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 64, 64, 64, 8, 8, 16, 16, 2, 2, S< 4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //3
|
||||
DeviceBatchedGemm_Wmma_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 64, 64, 64, 8, 8, 16, 16, 2, 2, S< 4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1, half_t, half_t, false, false>, //4
|
||||
DeviceBatchedGemm_Wmma_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 64, 64, 64, 8, 8, 16, 16, 2, 2, S< 4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 8, Interwave, BlockGemmPipelineVersion::v1, half_t, half_t, false, false>, //5
|
||||
DeviceBatchedGemm_Wmma_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 1, 128, 32, 32, 64, 8, 8, 16, 16, 2, 2, S< 2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //6
|
||||
DeviceBatchedGemm_Wmma_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 1, 128, 32, 32, 64, 8, 8, 16, 16, 2, 2, S< 2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v1, half_t, half_t, false, false>, //7
|
||||
DeviceBatchedGemm_Wmma_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 1, 128, 32, 32, 64, 8, 8, 16, 16, 2, 2, S< 2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 2>, 8, Interwave, BlockGemmPipelineVersion::v1, half_t, half_t, false, false>, //8
|
||||
|
||||
DeviceBatchedGemm_Wmma_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 128, 8, 8, 16, 16, 4, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //9
|
||||
DeviceBatchedGemm_Wmma_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 128, 8, 8, 16, 16, 4, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1, half_t, half_t, false, false>, //10
|
||||
DeviceBatchedGemm_Wmma_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 128, 8, 8, 16, 16, 4, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, Interwave, BlockGemmPipelineVersion::v1, half_t, half_t, false, false>, //11
|
||||
DeviceBatchedGemm_Wmma_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 32, 32, 128, 8, 8, 16, 16, 1, 1, S< 8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, Intrawave, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //12
|
||||
DeviceBatchedGemm_Wmma_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 32, 32, 128, 8, 8, 16, 16, 1, 1, S< 8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, Intrawave, BlockGemmPipelineVersion::v1, half_t, half_t, false, false>, //13
|
||||
DeviceBatchedGemm_Wmma_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 32, 32, 128, 8, 8, 16, 16, 1, 1, S< 8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, Interwave, BlockGemmPipelineVersion::v1, half_t, half_t, false, false>, //14
|
||||
DeviceBatchedGemm_Wmma_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 1, 128, 16, 16, 128, 8, 8, 16, 16, 1, 1, S< 2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //15
|
||||
DeviceBatchedGemm_Wmma_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 1, 128, 16, 16, 128, 8, 8, 16, 16, 1, 1, S< 2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v1, half_t, half_t, false, false>, //16
|
||||
DeviceBatchedGemm_Wmma_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 1, 128, 16, 16, 128, 8, 8, 16, 16, 1, 1, S< 2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 2>, 8, Interwave, BlockGemmPipelineVersion::v1, half_t, half_t, false, false>, //17
|
||||
DeviceBatchedGemm_Wmma_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 1, 128, 16, 16, 128, 8, 8, 16, 16, 1, 1, S< 4, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 4, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 2>, 4, Intrawave, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //18
|
||||
DeviceBatchedGemm_Wmma_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 1, 128, 16, 16, 128, 8, 8, 16, 16, 1, 1, S< 4, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 4, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 2>, 4, Intrawave, BlockGemmPipelineVersion::v1, half_t, half_t, false, false>, //19
|
||||
DeviceBatchedGemm_Wmma_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 1, 128, 16, 16, 128, 8, 8, 16, 16, 1, 1, S< 4, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 4, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 2>, 4, Interwave, BlockGemmPipelineVersion::v1, half_t, half_t, false, false> //20
|
||||
|
||||
// clang-format on
|
||||
>;
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,33 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_batched_gemm_b_scale_wmma_f16_i4_f16_mk_nk_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
void add_device_batched_gemm_b_scale_wmma_f16_i4_f16_mk_nk_mn_mem_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceBatchedGemmV2BScale<Row,
|
||||
Col,
|
||||
Row,
|
||||
F16,
|
||||
I4,
|
||||
F16,
|
||||
F16,
|
||||
1,
|
||||
128,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_batched_gemm_b_scale_wmma_f16_i4_f16_mk_nk_mn_mem_instances<Intrawave,
|
||||
GemmDefault>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -45,9 +45,6 @@ using device_batched_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn_mem_instances = std::t
|
||||
DeviceBatchedGemm_Xdl_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 128, 8, 32, 32, 32, 2, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //0
|
||||
DeviceBatchedGemm_Xdl_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 64, 8, 32, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v4, half_t, half_t, false, false>, //1
|
||||
|
||||
DeviceBatchedGemm_Xdl_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 64, 8, 32, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v4, half_t, half_t, false, false>, //3
|
||||
DeviceBatchedGemm_Xdl_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 64, 8, 32, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //4
|
||||
|
||||
//Latency friendly
|
||||
DeviceBatchedGemm_Xdl_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 32, 16, 128, 8, 16, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //5
|
||||
DeviceBatchedGemm_Xdl_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 1, 128, 16, 16, 128, 8, 16, 16, 16, 1, 1, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //6
|
||||
|
||||
@@ -51,9 +51,6 @@ using device_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn_mem_instances = std::tuple<
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 128, 8, 32, 32, 32, 2, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //0
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 64, 8, 32, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v4, half_t, half_t, false, false>, //1
|
||||
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 64, 8, 32, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v4, half_t, half_t, false, false>, //3
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 64, 8, 32, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //4
|
||||
|
||||
//Latency friendly
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 32, 16, 128, 8, 16, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //5
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 1, 128, 16, 16, 128, 8, 16, 16, 16, 1, 1, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //6
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -9,12 +9,13 @@
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl_fpAintB_b_scale.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3_b_scale.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/batched_gemm_b_scale.hpp"
|
||||
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
@@ -113,22 +114,21 @@ bool profile_batched_gemm_b_scale_impl(int do_verification,
|
||||
std::cout << "c_g_m_n: " << c_g_m_n_device_result.mDesc << std::endl;
|
||||
std::cout << "rotating count: " << rotating_count << std::endl;
|
||||
|
||||
static constexpr index_t BPackedSize = []() {
|
||||
if constexpr(is_same_v<remove_cvref_t<BDataType>, pk_i4_t>)
|
||||
return 2;
|
||||
else
|
||||
return 1;
|
||||
}();
|
||||
|
||||
switch(init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
a_g_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-1, 2});
|
||||
b_g_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-1, 2});
|
||||
b1_g_k_n.GenerateTensorValue(GeneratorTensor_3<BScaleDataType>{0, 1.0});
|
||||
break;
|
||||
case 2:
|
||||
a_g_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
b_g_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
|
||||
b1_g_k_n.GenerateTensorValue(GeneratorTensor_3<BScaleDataType>{0, 1.0});
|
||||
break;
|
||||
// NOTE: for an int4, there is no point differentiating between decimal and integer
|
||||
// initialization also, the random number seem to be for a int4_2 type, so we use range 0...255
|
||||
default:
|
||||
a_g_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
b_g_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-2, 2});
|
||||
b_g_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-1, 2});
|
||||
b1_g_k_n.GenerateTensorValue(GeneratorTensor_3<BScaleDataType>{0, 1.0});
|
||||
}
|
||||
|
||||
@@ -141,7 +141,8 @@ bool profile_batched_gemm_b_scale_impl(int do_verification,
|
||||
const auto c_element_op = CElementOp{};
|
||||
|
||||
DeviceMem a_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b_device_buf(sizeof(BDataType) * b_g_k_n_permute.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b_device_buf(sizeof(BDataType) * b_g_k_n_permute.mDesc.GetElementSpaceSize() /
|
||||
BPackedSize);
|
||||
DeviceMem b1_device_buf(sizeof(BScaleDataType) * b1_g_k_n.mDesc.GetElementSpaceSize());
|
||||
DeviceMem c_device_buf(sizeof(CDataType) * c_g_m_n_device_result.mDesc.GetElementSpaceSize());
|
||||
|
||||
@@ -166,54 +167,63 @@ bool profile_batched_gemm_b_scale_impl(int do_verification,
|
||||
DeviceOp>::GetInstances();
|
||||
|
||||
std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
|
||||
|
||||
// Run reference GEMM
|
||||
if(do_verification)
|
||||
{
|
||||
Tensor<float> b_g_k_n_dequant({K, N});
|
||||
Tensor<BScaleDataType> b_g_k_n_dequant({BatchSize, K, N});
|
||||
|
||||
float v_b = 0;
|
||||
for(int bs = 0; bs < BatchSize; bs++)
|
||||
{
|
||||
for(int n = 0; n < N; n++)
|
||||
{
|
||||
|
||||
for(int k = 0; k < K; k++)
|
||||
{
|
||||
ck::pk_i4_t i4x2 = b_g_k_n(bs, k, n).data;
|
||||
int8_t i4 = 0;
|
||||
if(k % 2 == 1)
|
||||
|
||||
// for proper testing, we need to replicate k_shuffle when used
|
||||
// see unary_element_wise_operation.hpp
|
||||
#if CK_USE_PK4_LAYOUT_SHUFFLE
|
||||
int k_shuffle = (k / 8) * 8 + (k % 2) * 4 + (k % 8) / 2;
|
||||
#else
|
||||
int k_shuffle = k;
|
||||
#endif
|
||||
|
||||
ck::pk_i4_t i4x2 = b_g_k_n(bs, k_shuffle, n).data;
|
||||
int i4 = 0;
|
||||
if(k_shuffle % 2 == 0)
|
||||
i4 = (i4x2.data >> 0) & 0xf;
|
||||
else
|
||||
i4 = (i4x2.data >> 4) & 0xf;
|
||||
i4 = i4 - 8;
|
||||
i4 = i4 - 8;
|
||||
|
||||
v_b = ck::type_convert<float>(i4);
|
||||
|
||||
b_g_k_n_dequant(bs, k, n) =
|
||||
ck::type_convert<float>(v_b) *
|
||||
ck::type_convert<float>(b1_g_k_n(bs, k / ScaleBlockK, n));
|
||||
float out = ck::type_convert<float>(v_b) *
|
||||
ck::type_convert<float>(b1_g_k_n(bs, k / ScaleBlockK, n));
|
||||
|
||||
b_g_k_n_dequant(bs, k, n) = out;
|
||||
}
|
||||
}
|
||||
}
|
||||
using ReferenceBatchedGemmInstance =
|
||||
ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
|
||||
BScaleDataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CElementOp>;
|
||||
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CElementOp,
|
||||
ComputeDataType>;
|
||||
|
||||
auto ref_gemm = ReferenceGemmInstance{};
|
||||
auto ref_invoker = ref_gemm.MakeInvoker();
|
||||
|
||||
auto ref_argument = ref_gemm.MakeArgument(a_g_m_k,
|
||||
b_g_k_n_dequant,
|
||||
c_g_m_n_host_result,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op);
|
||||
|
||||
auto ref_batched_gemm = ReferenceBatchedGemmInstance{};
|
||||
auto ref_invoker = ref_batched_gemm.MakeInvoker();
|
||||
auto ref_argument = ref_batched_gemm.MakeArgument(a_g_m_k,
|
||||
b_g_k_n_dequant,
|
||||
c_g_m_n_host_result,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
KBatch);
|
||||
ref_invoker.Run(ref_argument);
|
||||
}
|
||||
|
||||
@@ -230,6 +240,7 @@ bool profile_batched_gemm_b_scale_impl(int do_verification,
|
||||
|
||||
if(op_ptr->GetPermuteB())
|
||||
{
|
||||
|
||||
int K1 = KPerBlock;
|
||||
int K0 = K / KPerBlock;
|
||||
|
||||
@@ -306,6 +317,7 @@ bool profile_batched_gemm_b_scale_impl(int do_verification,
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
b_g_k_n_permute = b_g_k_n;
|
||||
}
|
||||
|
||||
@@ -375,8 +387,12 @@ bool profile_batched_gemm_b_scale_impl(int do_verification,
|
||||
else
|
||||
{
|
||||
#endif
|
||||
std::string msg = "Error: Incorrect results!";
|
||||
double rtol = 1e-2;
|
||||
double atol = 1e-2;
|
||||
pass =
|
||||
pass & ck::utils::check_err(c_g_m_n_device_result, c_g_m_n_host_result);
|
||||
pass & ck::utils::check_err(
|
||||
c_g_m_n_device_result, c_g_m_n_host_result, msg, rtol, atol);
|
||||
#if defined CK_ENABLE_FP8
|
||||
}
|
||||
#endif
|
||||
@@ -407,13 +423,6 @@ bool profile_batched_gemm_b_scale_impl(int do_verification,
|
||||
|
||||
std::size_t flop = std::size_t(2) * M * N * K * BatchSize;
|
||||
|
||||
static constexpr index_t BPackedSize = []() {
|
||||
if constexpr(is_same_v<remove_cvref_t<BDataType>, pk_i4_t>)
|
||||
return 2;
|
||||
else
|
||||
return 1;
|
||||
}();
|
||||
|
||||
std::size_t num_btype = sizeof(ADataType) * M * K +
|
||||
sizeof(BDataType) * K * N / BPackedSize +
|
||||
sizeof(CDataType) * M * N;
|
||||
|
||||
@@ -105,7 +105,7 @@ bool profile_gemm_b_scale_impl(int do_verification,
|
||||
break;
|
||||
case 2:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-1, 2});
|
||||
b1_k_n.GenerateTensorValue(GeneratorTensor_3<BScaleDataType>{0, 1.0});
|
||||
break;
|
||||
default:
|
||||
@@ -122,8 +122,16 @@ bool profile_gemm_b_scale_impl(int do_verification,
|
||||
const auto b_element_op = BElementOp{};
|
||||
const auto c_element_op = CElementOp{};
|
||||
|
||||
static constexpr index_t BPackedSize = []() {
|
||||
if constexpr(is_same_v<remove_cvref_t<BDataType>, pk_i4_t>)
|
||||
return 2;
|
||||
else
|
||||
return 1;
|
||||
}();
|
||||
|
||||
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n_permute.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n_permute.mDesc.GetElementSpaceSize() /
|
||||
BPackedSize);
|
||||
DeviceMem b1_device_buf(sizeof(BScaleDataType) * b1_k_n.mDesc.GetElementSpaceSize());
|
||||
DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize());
|
||||
|
||||
@@ -152,16 +160,24 @@ bool profile_gemm_b_scale_impl(int do_verification,
|
||||
// Run reference GEMM
|
||||
if(do_verification)
|
||||
{
|
||||
Tensor<float> b_k_n_dequant({K, N});
|
||||
Tensor<BScaleDataType> b_k_n_dequant({K, N});
|
||||
|
||||
float v_b = 0;
|
||||
for(int n = 0; n < N; n++)
|
||||
{
|
||||
for(int k = 0; k < K; k++)
|
||||
{
|
||||
ck::pk_i4_t i4x2 = b_k_n(k, n).data;
|
||||
int8_t i4 = 0;
|
||||
if(k % 2 == 1)
|
||||
// for proper testing, we need to replicate k_shuffle when used
|
||||
// see unary_element_wise_operation.hpp
|
||||
#if CK_USE_PK4_LAYOUT_SHUFFLE
|
||||
int k_shuffle = (k / 8) * 8 + (k % 2) * 4 + (k % 8) / 2;
|
||||
#else
|
||||
int k_shuffle = k;
|
||||
#endif
|
||||
|
||||
ck::pk_i4_t i4x2 = b_k_n(k_shuffle, n).data;
|
||||
int i4 = 0;
|
||||
if(k_shuffle % 2 == 0)
|
||||
i4 = (i4x2.data >> 0) & 0xf;
|
||||
else
|
||||
i4 = (i4x2.data >> 4) & 0xf;
|
||||
@@ -173,7 +189,7 @@ bool profile_gemm_b_scale_impl(int do_verification,
|
||||
}
|
||||
}
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
|
||||
AccDataType,
|
||||
BScaleDataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
@@ -334,7 +350,11 @@ bool profile_gemm_b_scale_impl(int do_verification,
|
||||
else
|
||||
{
|
||||
#endif
|
||||
pass = pass & ck::utils::check_err(c_m_n_device_result, c_m_n_host_result);
|
||||
std::string msg = "Error: Incorrect results!";
|
||||
double rtol = 2e-2;
|
||||
double atol = 2e-2;
|
||||
pass = pass & ck::utils::check_err(
|
||||
c_m_n_device_result, c_m_n_host_result, msg, rtol, atol);
|
||||
#if defined CK_ENABLE_FP8
|
||||
}
|
||||
#endif
|
||||
@@ -365,13 +385,6 @@ bool profile_gemm_b_scale_impl(int do_verification,
|
||||
|
||||
std::size_t flop = std::size_t(2) * M * N * K;
|
||||
|
||||
static constexpr index_t BPackedSize = []() {
|
||||
if constexpr(is_same_v<remove_cvref_t<BDataType>, pk_i4_t>)
|
||||
return 2;
|
||||
else
|
||||
return 1;
|
||||
}();
|
||||
|
||||
std::size_t num_btype = sizeof(ADataType) * M * K +
|
||||
sizeof(BDataType) * K * N / BPackedSize +
|
||||
sizeof(CDataType) * M * N;
|
||||
|
||||
@@ -90,7 +90,7 @@ bool profile_gemm_universal_impl(int do_verification,
|
||||
break;
|
||||
case 2:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-1, 2});
|
||||
break;
|
||||
default:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
|
||||
@@ -67,7 +67,6 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9")
|
||||
list(APPEND PROFILER_OPS profile_gemm_add.cpp)
|
||||
list(APPEND PROFILER_OPS profile_gemm_bias_add_reduce.cpp)
|
||||
list(APPEND PROFILER_OPS profile_gemm_splitk.cpp)
|
||||
list(APPEND PROFILER_OPS profile_batched_gemm_b_scale.cpp)
|
||||
list(APPEND PROFILER_OPS profile_gemm_universal_batched.cpp)
|
||||
list(APPEND PROFILER_OPS profile_gemm_universal_streamk.cpp)
|
||||
list(APPEND PROFILER_OPS profile_conv_fwd_bias_relu.cpp)
|
||||
@@ -89,6 +88,7 @@ endif()
|
||||
if(SUPPORTED_GPU_TARGETS MATCHES "gfx9" OR SUPPORTED_GPU_TARGETS MATCHES "gfx1[12]")
|
||||
list(APPEND PROFILER_OPS profile_gemm_universal.cpp)
|
||||
list(APPEND PROFILER_OPS profile_batched_gemm.cpp)
|
||||
list(APPEND PROFILER_OPS profile_batched_gemm_b_scale.cpp)
|
||||
list(APPEND PROFILER_OPS profile_gemm_b_scale.cpp)
|
||||
list(APPEND PROFILER_OPS profile_gemm_universal_reduce.cpp)
|
||||
list(APPEND PROFILER_OPS profile_grouped_conv_fwd.cpp)
|
||||
@@ -191,7 +191,6 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9" OR SUPPORTED_GPU_TARGETS MATCHES "gfx1[1
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_mx_instance)
|
||||
endif()
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_splitk_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_batched_gemm_b_scale_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_universal_batched_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_universal_streamk_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_add_multiply_instance)
|
||||
@@ -229,6 +228,7 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9" OR SUPPORTED_GPU_TARGETS MATCHES "gfx1[1
|
||||
list(APPEND DEVICE_INSTANCES device_batched_gemm_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_b_scale_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_universal_reduce_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_batched_gemm_b_scale_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_conv2d_bwd_data_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_bwd_data_instance)
|
||||
|
||||
@@ -57,7 +57,7 @@ int profile_batched_gemm_b_scale(int argc, char* argv[])
|
||||
printf("arg6: initialization (0: no init; 1: integer value; 2: decimal value)\n");
|
||||
printf("arg7: print tensor value (0: no; 1: yes)\n");
|
||||
printf("arg8: time kernel (0=no, 1=yes)\n");
|
||||
printf("arg9 to 15: M, N, K, StrideA, StrideB, StrideC, BatachCount\n");
|
||||
printf("arg9 to 15: M, N, K, StrideA, StrideB, StrideC, BatchCount\n");
|
||||
printf("arg16: split k into mulitiple batch\n");
|
||||
printf("optional:\n");
|
||||
printf("arg17: number of warm-up cycles (default 1)\n");
|
||||
|
||||
@@ -24,6 +24,7 @@ set(REGRESSION_TESTS
|
||||
test_batched_gemm_softmax_gemm_permute_bf16
|
||||
test_batched_gemm_bias_softmax_gemm_permute_bf16
|
||||
test_grouped_gemm_splitk
|
||||
test_batched_gemm_b_scale_wmma
|
||||
test_reduce_no_index
|
||||
test_reduce_with_index
|
||||
test_convnd_fwd
|
||||
@@ -257,6 +258,7 @@ add_subdirectory(batched_gemm_reduce)
|
||||
add_subdirectory(batched_gemm_gemm)
|
||||
add_subdirectory(batched_gemm_softmax_gemm)
|
||||
add_subdirectory(batched_gemm_softmax_gemm_permute)
|
||||
add_subdirectory(batched_gemm_b_scale)
|
||||
add_subdirectory(grouped_gemm)
|
||||
add_subdirectory(reduce)
|
||||
add_subdirectory(convnd_fwd)
|
||||
|
||||
5
test/batched_gemm_b_scale/CMakeLists.txt
Normal file
5
test/batched_gemm_b_scale/CMakeLists.txt
Normal file
@@ -0,0 +1,5 @@
|
||||
|
||||
add_gtest_executable(test_batched_gemm_b_scale_wmma test_batched_gemm_b_scale_wmma.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_batched_gemm_b_scale_wmma PRIVATE utility device_batched_gemm_b_scale_instance)
|
||||
endif()
|
||||
@@ -0,0 +1,49 @@
|
||||
#pragma once
|
||||
|
||||
TYPED_TEST(TestBatchedGemmBScale_MK_NK, SmallM)
|
||||
{
|
||||
std::vector<int> Ms{1, 2, 3, 4, 5, 6};
|
||||
constexpr int N = 256;
|
||||
constexpr int K = 1024;
|
||||
|
||||
constexpr int StrideA = K;
|
||||
constexpr int StrideB = K;
|
||||
constexpr int StrideC = N;
|
||||
|
||||
constexpr int NBatches = 10;
|
||||
|
||||
for(int M : Ms)
|
||||
this->Run(M, N, K, StrideA, StrideB, StrideC, NBatches);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestBatchedGemmBScale_MK_NK, MidLargeM)
|
||||
{
|
||||
std::vector<int> Ms{127, 255, 312, 799, 1573};
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 768;
|
||||
|
||||
constexpr int StrideA = K;
|
||||
constexpr int StrideB = K;
|
||||
constexpr int StrideC = N;
|
||||
|
||||
constexpr int NBatches = 7;
|
||||
|
||||
for(int M : Ms)
|
||||
this->Run(M, N, K, StrideA, StrideB, StrideC, NBatches);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestBatchedGemmBScale_MK_NK, Regular)
|
||||
{
|
||||
std::vector<int> Ms{512, 1024};
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 1024;
|
||||
|
||||
constexpr int StrideA = K;
|
||||
constexpr int StrideB = K;
|
||||
constexpr int StrideC = N;
|
||||
|
||||
constexpr int NBatches = 3;
|
||||
|
||||
for(int M : Ms)
|
||||
this->Run(M, N, K, StrideA, StrideB, StrideC, NBatches);
|
||||
}
|
||||
108
test/batched_gemm_b_scale/test_batched_gemm_b_scale_util.hpp
Normal file
108
test/batched_gemm_b_scale/test_batched_gemm_b_scale_util.hpp
Normal file
@@ -0,0 +1,108 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <sstream>
|
||||
#include <tuple>
|
||||
#include <vector>
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "include/ck/utility/data_type.hpp"
|
||||
#include "profiler/profile_batched_gemm_b_scale_impl.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace test {
|
||||
|
||||
template <typename Tuple>
|
||||
class TestBatchedGemmBScale : public testing::Test
|
||||
{
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using F32 = float;
|
||||
|
||||
protected:
|
||||
using ALayout = std::tuple_element_t<0, Tuple>;
|
||||
using BLayout = std::tuple_element_t<1, Tuple>;
|
||||
using CLayout = Row;
|
||||
using ADataType = std::tuple_element_t<2, Tuple>;
|
||||
using BDataType = std::tuple_element_t<3, Tuple>;
|
||||
using BScaleDataType = std::tuple_element_t<4, Tuple>;
|
||||
using ComputeDataType = std::tuple_element_t<5, Tuple>;
|
||||
using CDataType = std::tuple_element_t<6, Tuple>;
|
||||
|
||||
public:
|
||||
static constexpr ck::index_t ScaleBlockK = 128; // all instances
|
||||
static constexpr bool verify_ = true;
|
||||
static constexpr int init_method_ = 2;
|
||||
static constexpr bool log_ = false;
|
||||
static constexpr bool bench_ = false; // measure kernel performance
|
||||
std::vector<int> k_batches_;
|
||||
|
||||
void SetUp() override { k_batches_ = {1, 2}; }
|
||||
|
||||
void Run(const int M,
|
||||
const int N,
|
||||
const int K,
|
||||
const int StrideA,
|
||||
const int StrideB,
|
||||
const int StrideC,
|
||||
const int NBatch)
|
||||
{
|
||||
for(auto kb : k_batches_)
|
||||
{
|
||||
RunSingle(M, N, K, StrideA, StrideB, StrideC, NBatch, kb);
|
||||
}
|
||||
}
|
||||
|
||||
void RunSingle(const int M,
|
||||
const int N,
|
||||
const int K,
|
||||
const int StrideA,
|
||||
const int StrideB,
|
||||
const int StrideC,
|
||||
const int Nbatch,
|
||||
int kbatch = 1,
|
||||
int n_warmup = 1,
|
||||
int n_iter = 10)
|
||||
{
|
||||
const int BatchStrideA = StrideA * M;
|
||||
const int BatchStrideB = StrideB * K;
|
||||
const int BatchStrideC = StrideC * M;
|
||||
const int BatchStrideScaleB = StrideB * K;
|
||||
bool pass = ck::profiler::profile_batched_gemm_b_scale_impl<ADataType,
|
||||
BDataType,
|
||||
BScaleDataType,
|
||||
ComputeDataType,
|
||||
F32,
|
||||
CDataType,
|
||||
ScaleBlockK,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout>(verify_,
|
||||
init_method_,
|
||||
log_,
|
||||
bench_,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideC,
|
||||
BatchStrideA,
|
||||
BatchStrideB,
|
||||
BatchStrideC,
|
||||
BatchStrideScaleB,
|
||||
Nbatch,
|
||||
kbatch,
|
||||
n_warmup,
|
||||
n_iter);
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace test
|
||||
} // namespace ck
|
||||
45
test/batched_gemm_b_scale/test_batched_gemm_b_scale_wmma.cpp
Normal file
45
test/batched_gemm_b_scale/test_batched_gemm_b_scale_wmma.cpp
Normal file
@@ -0,0 +1,45 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <tuple>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "test_batched_gemm_b_scale_util.hpp"
|
||||
|
||||
using I4 = ck::pk_i4_t;
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename X, typename Y>
|
||||
struct tuple_concat;
|
||||
|
||||
template <typename... Xs, typename... Ys>
|
||||
struct tuple_concat<std::tuple<Xs...>, std::tuple<Ys...>>
|
||||
{
|
||||
using type = std::tuple<Xs..., Ys...>;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
template <typename Tuple>
|
||||
class TestBatchedGemmBScale_MK_NK : public ck::test::TestBatchedGemmBScale<
|
||||
typename tuple_concat<std::tuple<Row, Col>, Tuple>::type>
|
||||
{
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
using KernelTypes_MK_NK = ::testing::Types<
|
||||
// ADataType, BDataType, BScaleDataType, ComputeDataType, CDataType
|
||||
std::tuple< F16, I4, F16, F16, F16>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
TYPED_TEST_SUITE(TestBatchedGemmBScale_MK_NK, KernelTypes_MK_NK);
|
||||
|
||||
#include "test_batched_gemm_b_scale_ut_cases.inc"
|
||||
Reference in New Issue
Block a user