diff --git a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_pk_i4.cpp b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_pk_i4.cpp index 5e306ac6dd..1cfe896b1b 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_pk_i4.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_pk_i4.cpp @@ -289,7 +289,6 @@ int main(int argc, char* argv[]) case 0: break; case 1: a0_t_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - b0_e_n_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); d0_t_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); d1_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); d2_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); @@ -303,7 +302,6 @@ int main(int argc, char* argv[]) break; default: a0_t_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - b0_e_n_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); d0_t_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); d1_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); d2_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); diff --git a/example/65_gemm_multiply_multiply/moe_gemm2_xdl_pk_i4.cpp b/example/65_gemm_multiply_multiply/moe_gemm2_xdl_pk_i4.cpp index 29e758f9d4..d44ca19d2f 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm2_xdl_pk_i4.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm2_xdl_pk_i4.cpp @@ -275,7 +275,7 @@ int main(int argc, char* argv[]) break; case 3: a0_t_k_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - b0_e_n_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_3{-1, 1}); d0_t_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); d1_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); d2_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); @@ -289,7 +289,7 @@ int main(int argc, char* argv[]) break; default: a0_t_k_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - b0_e_n_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_3{-1, 1}); d0_t_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); d1_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); d2_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); diff --git a/include/ck/library/utility/host_tensor_generator.hpp b/include/ck/library/utility/host_tensor_generator.hpp index ab69412c15..fc433c15f0 100644 --- a/include/ck/library/utility/host_tensor_generator.hpp +++ b/include/ck/library/utility/host_tensor_generator.hpp @@ -264,7 +264,7 @@ struct GeneratorTensor_2 { int hi = std::rand() % (max_value - min_value) + min_value + 8; int lo = std::rand() % (max_value - min_value) + min_value + 8; - ck::pk_i4_t r = ((hi << 4) + lo) & 0xff; + ck::pk_i4_t r = (((hi & 0xf) << 4) + (lo & 0xf)); return r; } }; @@ -436,6 +436,22 @@ struct GeneratorTensor_3 } }; +template <> +struct GeneratorTensor_3 +{ + int min_value = 0; + int max_value = 1; + + template + ck::pk_i4_t operator()(Is...) + { + int hi = std::rand() % (max_value - min_value) + min_value + 8; + int lo = std::rand() % (max_value - min_value) + min_value + 8; + ck::pk_i4_t r = (((hi & 0xf) << 4) + (lo & 0xf)); + return r; + } +}; + template <> struct GeneratorTensor_3 { diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3_b_scale.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3_b_scale.hpp new file mode 100644 index 0000000000..4f676528bc --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3_b_scale.hpp @@ -0,0 +1,836 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/flush_cache.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + kernel_batched_gemm_b_scale_wmma_cshuffle_v3( + typename GridwiseGemm::Argument karg, // This works for now but it actually receives a + // DeviceBatchedGemm_Wmma_CShuffleV3::Argument + // argument through implicit conversion to base class! + const ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch) +{ +#if(defined(__gfx11__) || defined(__gfx12__)) +#if defined(__gfx11__) + // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions + using c_data_type = remove_cvref_t>; + if constexpr(!(CGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd && + (std::is_same_v || + std::is_same_v))) + { +#endif + // The normal approach to batching would be to increase the grid size by just stretching out + // the grid Z dimension (which is the outermost dimension), but this depends on lower level + // functions not directly using the Z dimension for other calculations. As it turns out, k + // batching does rely directly on blockIdx.Z through SplitKBatchOffset. Therefore, for now + // we will use the grid Y dimension for batching. This may be a bit fragile. + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + const index_t g_idx = amd_wave_read_first_lane(blockIdx.y); + + const long_index_t a_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)); + const long_index_t b_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)); + const long_index_t c_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx)); + const long_index_t b_scale_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetScaleBPtrOffset(g_idx)); + + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); + + // shift A matrices pointer for splitk + typename GridwiseGemm::AsGridPointer p_as_grid_shift; + static_for<0, GridwiseGemm::NumATensor, 1>{}([&](auto i) { + using ADataType_ = + remove_cvref_t>; + p_as_grid_shift(i) = static_cast(karg.p_as_grid[i]) + + splitk_batch_offset.a_k_split_offset[i] + a_batch_offset; + }); + + // shift B matrices pointer for splitk + typename GridwiseGemm::BsGridPointer p_bs_grid_shift; + static_for<0, GridwiseGemm::NumBTensor, 1>{}([&](auto i) { + using BDataType_ = + remove_cvref_t>; + p_bs_grid_shift(i) = static_cast(karg.p_bs_grid[i]) + + splitk_batch_offset.b_k_split_offset[i] + b_batch_offset; + }); + + GridwiseGemm::template Run( + p_as_grid_shift, + p_bs_grid_shift, + karg.p_ds_grid, + karg.p_e_grid + splitk_batch_offset.c_reduce_offset + c_batch_offset, + karg.p_b_scale_grid + b_scale_batch_offset + splitk_batch_offset.scale_k_split_offset, + p_shared, + karg, + karg.a_element_op, + karg.b_element_op, + karg.cde_element_op); +#if defined(__gfx11__) + } +#endif +#else + ignore = karg; + ignore = compute_ptr_offset_of_batch; +#endif +} + +/// @brief \"Universal\" Batched GEMM operation without SplitK support. +/// +/// @par Overview +/// This GEMM operation implements the following mathematical equation: +/// C{G,M,N} = C_op(A_op(A{G,M,K}) * B_op(B{G,K,N})) +/// Where A, B are input tensors and C is the output tensor. The A/B/C_op are +/// elementwise operations applied to the A, B, and C tensors, respectively. +/// The \"universal\" gemm comes with multiple pipelines optimized for different usage +/// scenarios. That's why it's called \"universal\". It's universal through its design +/// and versatilty. +/// +/// @note This Kernel implementation currently does not support the SplitK algorithm. +/// +/// @tparam ALayout A tensor data layout. +/// @tparam BLayout B tensor data layout. +/// @tparam CLayout C tensor data layout. +/// @tparam ADataType A tensor data type. +/// @tparam BDataType B tensor data type. +/// @tparam CDataType C tensor data type. +/// @tparam AccDataType The accumulation data type related to the hardware +/// matrix-multiplication instruction. +/// @tparam CShuffleDataType The data type used to store matrix-multiplication results into +/// LDS memory during \"CShuffle\" data layout optimization. +/// @tparam AElementwiseOperation Elementwise operation applied to the A input tensor elements. +/// @tparam BElementwiseOperation Elementwise operation applied to the B input tensor elements. +/// @tparam CElementwiseOperation Elementwise operation applied to the C output tensor +/// (after GEMM). +/// @tparam GemmSpec Determines used "padding" version. +/// @tparam BlockSize The number of threads within workgroup. +/// @tparam MPerBlock The input/output data tile size in the M dimension. +/// @tparam NPerBlock The input/output data tile size in the N dimension. +/// @tparam KPerBlock The input data tile size in the K dimension. +/// @tparam AK1 The vector load size from global memory for A tensor. +/// @tparam BK1 The vector load size from global memory for B tensor. +/// @tparam MPerWmma M size of Wave Matrix Multiply Accumulate (WMMA) instruction. +/// @tparam NPerWmma N size of Wave Matrix Multiply Accumulate (WMMA) instruction. +/// @tparam MRepeat The number of iterations in the M dimension over output tile per wavefront. +/// @tparam NRepeat The number of iterations in the N dimension over output tile per wavefront. +/// @tparam ABlockTransferThreadClusterLengths_AK0_M_AK1 Spatial thread distribution over the input +/// data. Can be interpreted as the answer +/// to the question, "How many threads can be +/// arranged on each input data axis?" +/// @tparam ABlockTransferThreadClusterArrangeOrder The order of thread spatial distribution over +/// the input tensor dimension. Can be interpreted +/// as the answer to the question: "In which +/// order to spread threads through tensor axes?". +/// @tparam ABlockTransferSrcAccessOrder The order of accessing input tensor axes. Can be +/// interpreted as the answer to the question "Which dimension +/// to read first? And which next?" etc. +/// @tparam ABlockTransferSrcVectorDim The index of axis on which we could do vectorized memory +/// access - the one with contiguous memory. +/// @tparam ABlockTransferSrcScalarPerVector The size of vector access instruction - the number of +/// elements accessed per thread per instruction. +/// @tparam ABlockTransferDstScalarPerVector_AK1 The size of vectorized store into LDS memory. +/// @tparam ABlockLdsExtraM Whether to use padding for LDS or not. With +/// universal GEMM there's no need for padding. +/// @tparam BBlockTransferThreadClusterLengths_BK0_N_BK1 Spatial thread distribution over the input +/// data. Can be interpreted as the answer +/// to the question: "How many threads to +/// arrange on each input data axis?" +/// @tparam BBlockTransferThreadClusterArrangeOrder The order of thread spatial distribution over +/// the input tensor dimension. Can be interpreted +/// as the answer to the question: "In which +/// order to spread threads through tensor axes?". +/// @tparam BBlockTransferSrcAccessOrder he order of accessing input tensor axes. Can be +/// interpreted as the answer to the question "Which dimension +/// to read first? And which next?" etc. +/// @tparam BBlockTransferSrcVectorDim The index of axis on which we could do vectorized memory +/// access - the one with contiguous memory. +/// @tparam BBlockTransferSrcScalarPerVector The size of vector access instruction - the number of +/// elements accessed per thread per instruction. +/// @tparam BBlockTransferDstScalarPerVector_BK1 The size of vectorized store into LDS memory. +/// @tparam BBlockLdsExtraN Whether to use padding for LDS or not. With +/// universal GEMM there's no need for padding. +/// @tparam CShuffleMRepeatPerShuffle The number of matrix-multiplication instructions +/// results to process per wave per iteration of CShuffle +/// in M dimension. +/// @tparam CShuffleNRepeatPerShuffle The number of matrix-multiplication instructions +/// results to process per wave per iteration of CShuffle +/// in N dimension. +/// @tparam CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock The spatial +/// thread distribution used for storing data into output +/// tensor across output data layout dimensions. +/// @tparam CShuffleBlockTransferScalarPerVector_NPerBlock The size of vectorized memory access. +/// Used when storing data to output tensor. +/// @tparam BlkGemmPipeSched The version of blockwise-gemm pipeline scheduler (interwave or +/// intrawave). +/// @tparam BlkGemmPipelineVer The version of blockwise-gemm pipeline. +/// @tparam ComputeTypeA Data type used for A input of hardware matrix-multiplication +/// instructions. +/// @tparam ComputeTypeB Data type used for B input of hardware matrix-multiplication +/// instructions. +/// @tparam PermuteA Whether the A input tensor has gridwise-gemm friendly data layout +/// in global memory. Currently not supported! +/// @tparam PermuteB Whether the B input tensor has gridwise-gemm friendly data layout +/// in global memory (pre-shuffled). Currently not supported! +template +struct DeviceBatchedGemm_Wmma_CShuffleV3_BScale + : public DeviceBatchedGemmV2BScale +{ + // We are inheriting from DeviceBatchedGemm and this base class does not support permuteA and + // permuteB arguments so for now we are not including this functionality. + static_assert(PermuteA == false, + "Permute A functionality not supported by DeviceBatchedGemm operations.\n"); + static_assert(PermuteB == false, + "Permute B functionality not supported by DeviceBatchedGemm operations.\n"); + + struct ComputePtrOffsetOfStridedBatch + { + ComputePtrOffsetOfStridedBatch(index_t BatchStrideA, + index_t BatchStrideB, + index_t BatchStrideC, + index_t BatchStrideScaleB) + : BatchStrideA_(BatchStrideA), + BatchStrideB_(BatchStrideB), + BatchStrideC_(BatchStrideC), + BatchStrideScaleB_(BatchStrideScaleB) + { + } + + __host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideA_); + } + + __host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideB_) / GridwiseGemm::BPackedSize; + } + + __host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideC_); + } + __host__ __device__ constexpr long_index_t GetScaleBPtrOffset(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideScaleB_); + } + + private: + index_t BatchStrideA_; + index_t BatchStrideB_; + index_t BatchStrideC_; + index_t BatchStrideScaleB_; + }; + + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3_b_scale< + ALayout, + BLayout, + Tuple<>, // DsLayout + CLayout, + Tuple, + Tuple, + BScaleDataType, + AccDataType, + CShuffleDataType, + Tuple<>, // DsDataType + CDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpec, + BlockSize, + ScaleBlockN, + ScaleBlockK, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerWmma, + NPerWmma, + MRepeat, + NRepeat, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB, + PermuteA, // PermuteA not supported by DeviceBatchedGemm base class. + PermuteB>; // PermuteB not supported by DeviceBatchedGemm base class. + + // Argument + struct Argument : public GridwiseGemm::Argument + { + __host__ Argument(const ADataType* p_a_grid_, + const BDataType* p_b_grid_, + CDataType* p_c_grid_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_, + index_t StrideScaleB_, + index_t BatchStrideA_, + index_t BatchStrideB_, + index_t BatchStrideC_, + index_t BatchStrideScaleB_, + const BScaleDataType* p_b_scale_grid_, + index_t Batch_, + index_t k_batch_, + AElementwiseOperation a_element_op_, + BElementwiseOperation b_element_op_, + CElementwiseOperation c_element_op_, + bool is_reduce_ = false) + : GridwiseGemm::Argument(std::array{p_a_grid_}, + std::array{p_b_grid_}, + std::array{}, // p_ds_grid_ + p_c_grid_, + M_, + N_, + K_, + std::array{StrideA_}, + std::array{StrideB_}, + std::array{}, // StrideDs_ + StrideC_, + StrideScaleB_, + p_b_scale_grid_, + k_batch_, + a_element_op_, + b_element_op_, + c_element_op_, + is_reduce_), + Batch(Batch_), + compute_ptr_offset_of_batch{ + BatchStrideA_, BatchStrideB_, BatchStrideC_, BatchStrideScaleB_} + { + } + + index_t Batch; + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch; + }; + + /// @brief Helper structure responsible for kernel invocation. + /// + /// @paragraph The `Invoker` class is responsible for preparation and invocation of actual GPU + /// kernel function. It usually determines the launched grid size prepares kernel + /// arguments as well as perform specific kernel configuration selection based on + /// runtime arguments. + /// + /// @note If appropriately configured it may measure kernel execution time. + /// + struct Invoker : public BaseInvoker + { + /// @brief This function issues GPU kernel execution. + /// @param arg The GPU kernel arguments. + /// @param stream_config The HIP stream configuration helper structure. + /// @return The kernel's average execution time (if time measurement is + /// enabled). + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(stream_config.log_level_ > 0) + { + arg.Print(); + GridwiseGemm::BlockwiseGemmPipe::HotLoopInstList::Print(); + } + + if(!GridwiseGemm::CheckValidity(arg)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + index_t gdx, gdy, gdz; + std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.KBatch); + + // The normal approach to batching would be to increase the grid size by just stretching + // out the grid Z dimension (which is the outermost dimension), but this depends on + // lower level functions not directly using the Z dimension for other calculations. As + // it turns out, k batching does rely directly on blockIdx.Z through SplitKBatchOffset. + // Therefore, for now we will use the grid Y dimension for batching. This may be a bit + // fragile. + gdy *= arg.Batch; + + float ave_time = 0; + + index_t k_grain = arg.KBatch * KPerBlock; + index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock; + + const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); + + const auto Run = [&](const auto& kernel) { + if(stream_config.flush_cache) + { + Argument arg_ = arg; + + const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAsGridDescriptor_AK0_M_AK1( + arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideAs, arg_.AK0); + const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBsGridDescriptor_BK0_N_BK1( + arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideBs, arg_.BK0); + + // Packed sizes are 1 for all implemented data types but we include it anyway + // for future compatibility. + // Note: the grid descriptors and size_a / size_b do *not* take batching into + // account, so we have to manually multiply overall buffer sizes for rotating + // memory by batch. + std::array size_as_buffers; + size_as_buffers[0] = a_grid_desc_ak0_m_ak1[Number<0>{}].GetElementSpaceSize() * + sizeof(ADataType) / GridwiseGemm::APackedSize * arg_.Batch; + + std::array size_bs_buffers; + size_bs_buffers[0] = b_grid_desc_bk0_n_bk1[Number<0>{}].GetElementSpaceSize() * + sizeof(BDataType) / GridwiseGemm::BPackedSize * arg_.Batch; + + ck::utility::RotatingMemWrapperMultiABD, + Tuple, + Tuple<>> + rotating_mem(arg_, + stream_config.rotating_count, + size_as_buffers, + size_bs_buffers, + std::array{}); + rotating_mem.Print(); + + auto run_flush_cache = [&]() { + ck::utility::flush_icache(); + rotating_mem.Next(); + // clear c mem + if(arg_.KBatch > 1) + // Note: we multiply by batch since we want to clear the C matrix for + // the whole batch. Untested since we don't have k batching ATM. + // Note: This seems incorrect for non-contiguous memory layouts for C + // (padding, gaps). + HIP_CHECK_ERROR( + hipMemsetAsync(arg_.p_e_grid, + 0, + arg_.Batch * arg_.M * arg_.N * sizeof(CDataType), + stream_config.stream_id_)); + }; + + ave_time = ck::utility::launch_and_time_kernel_with_preprocess( + stream_config, + run_flush_cache, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + arg_, + arg_.compute_ptr_offset_of_batch); + } + else + { + auto clear_workspace = [&]() { + // clear c mem + if(arg.KBatch > 1) + // Note: we multiply by batch since we want to clear the C matrix for + // the whole batch. Untested since we don't have k batching ATM. + // Note: This seems incorrect for non-contiguous memory layouts for C + // (padding, gaps). + HIP_CHECK_ERROR( + hipMemsetAsync(arg.p_e_grid, + 0, + arg.Batch * arg.M * arg.N * sizeof(CDataType), + stream_config.stream_id_)); + }; + + ave_time = ck::utility::launch_and_time_kernel_with_preprocess( + stream_config, + clear_workspace, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + arg, + arg.compute_ptr_offset_of_batch); + } + }; + + constexpr index_t minimum_occupancy = []() { + if constexpr(BlkGemmPipeSched == BlockGemmPipelineScheduler::Interwave) + { + return 2; + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + return (MPerBlock * NPerBlock / BlockSize <= 128) ? 2 : 1; + } + else + { + return 1; + } + }(); + + if(has_main_k_block_loop) + { + // Tail number always full + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 || + BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + if(arg.KBatch > 1) + { + const auto kernel = kernel_batched_gemm_b_scale_wmma_cshuffle_v3< + GridwiseGemm, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy>; + Run(kernel); + } + else + { + const auto kernel = kernel_batched_gemm_b_scale_wmma_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy>; + Run(kernel); + } + } + else + { + throw std::runtime_error("Pipeline not implemented"); + } + } + else + { + // Tail number always 1 + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + if(arg.KBatch > 1) + { + const auto kernel = kernel_batched_gemm_b_scale_wmma_cshuffle_v3< + GridwiseGemm, + ComputePtrOffsetOfStridedBatch, + false, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy>; + Run(kernel); + } + else + { + const auto kernel = kernel_batched_gemm_b_scale_wmma_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + false, + InMemoryDataOperationEnum::Set, + minimum_occupancy>; + Run(kernel); + } + } + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_gfx11_supported() && !ck::is_gfx12_supported()) + { + return false; + } + + if constexpr(std::is_same_v || + std::is_same_v) + { + if(arg.KBatch > 1 && ck::is_gfx11_supported()) + { + // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions + return false; + } + } + + if constexpr(std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v) + { + if(ck::is_gfx11_supported()) + { + return false; + } + } + + if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding || + GemmSpec == GemmSpecialization::KPadding)) + { + return false; + } + + return GridwiseGemm::CheckValidity(arg); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + index_t GetKPerBlock() override { return KPerBlock; } + bool GetPermuteB() override { return PermuteB; } + + static auto MakeArgument(const ADataType* p_a, + const BDataType* p_b, + CDataType* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + index_t StrideScaleB, + index_t BatchStrideA, + index_t BatchStrideB, + index_t BatchStrideC, + index_t BatchStrideScaleB, + const BScaleDataType* p_b_scale, + index_t Batch, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + index_t KBatch = 1) + { + return Argument{p_a, + p_b, + p_c, + M, + N, + K, + StrideA, + StrideB, + StrideC, + StrideScaleB, + BatchStrideA, + BatchStrideB, + BatchStrideC, + BatchStrideScaleB, + p_b_scale, + Batch, + KBatch}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + index_t StrideScaleB, + index_t BatchStrideA, + index_t BatchStrideB, + index_t BatchStrideC, + index_t BatchStrideScaleB, + const void* p_b_scale, + index_t Batch, + index_t KBatch, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + M, + N, + K, + StrideA, + StrideB, + StrideC, + StrideScaleB, + BatchStrideA, + BatchStrideB, + BatchStrideC, + BatchStrideScaleB, + static_cast(p_b_scale), + Batch, + KBatch, + a_element_op, + b_element_op, + c_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map BlkGemmPipelineSchedulerToString{ + {BlockGemmPipelineScheduler::Intrawave, "Intrawave"}, + {BlockGemmPipelineScheduler::Interwave, "Interwave"}}; + + std::map BlkGemmPipelineVersionToString{ + {BlockGemmPipelineVersion::v1, "v1"}, + {BlockGemmPipelineVersion::v2, "v2"}, + {BlockGemmPipelineVersion::v3, "v3"}, + {BlockGemmPipelineVersion::v4, "v4"}, + {BlockGemmPipelineVersion::v5, "v5"}}; + + // clang-format off + str << "DeviceBatchedGemm_Wmma_CShuffleV3_BScale" + << "<" + << getGemmSpecializationString(GemmSpec) << ", " + << std::string(ALayout::name)[0] + << std::string(BLayout::name)[0] + << std::string(CLayout::name)[0] + << ">" + << " BlkSize: " + << BlockSize << ", " + << "BlkTile: " + << MPerBlock << "x" << NPerBlock << "x" << KPerBlock << ", " + << "WaveTile: " + << MPerWmma << "x"< namespace ck { namespace tensor_operation { @@ -30,14 +31,18 @@ struct ReferenceBatchedGemm : public device::BaseOperator Tensor& c_g_m_n, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, - CElementwiseOperation c_element_op) + CElementwiseOperation c_element_op, + const int k_batch = 1) : a_g_m_k_{a_g_m_k}, b_g_k_n_{b_g_k_n}, c_g_m_n_{c_g_m_n}, a_element_op_{a_element_op}, b_element_op_{b_element_op}, - c_element_op_{c_element_op} + c_element_op_{c_element_op}, + k_batch_(k_batch) { + if(k_batch < 1) + throw std::invalid_argument("Batch size must be at least 1"); } const Tensor& a_g_m_k_; @@ -47,6 +52,8 @@ struct ReferenceBatchedGemm : public device::BaseOperator AElementwiseOperation a_element_op_; BElementwiseOperation b_element_op_; CElementwiseOperation c_element_op_; + + const int k_batch_; }; // Invoker @@ -59,23 +66,54 @@ struct ReferenceBatchedGemm : public device::BaseOperator auto f_gmk_gkn_gmn = [&](auto g, auto m, auto n) { const int K = arg.a_g_m_k_.mDesc.GetLengths()[2]; - AccDataType v_acc = 0; + // simulate fp accuacy implications of k batching + std::vector partialSums(arg.k_batch_); - for(int k = 0; k < K; ++k) + for(int batchIdx = 0; batchIdx < arg.k_batch_; ++batchIdx) { - ADataType v_a; - BDataType v_b; + int batchSize = std::max(K / arg.k_batch_, 1); + int batchStart = batchSize * batchIdx; + int batchEnd = batchSize * (batchIdx + 1); + // add any extra round-off to last batch + if(batchIdx == arg.k_batch_ - 1) + batchEnd = K; - arg.a_element_op_(v_a, arg.a_g_m_k_(g, m, k)); - arg.b_element_op_(v_b, arg.b_g_k_n_(g, k, n)); + AccDataType v_acc = 0; + for(int k = batchStart; k < batchEnd; ++k) + { + ADataType v_a; + BDataType v_b; - v_acc += - ck::type_convert(v_a) * ck::type_convert(v_b); + arg.a_element_op_(v_a, arg.a_g_m_k_(g, m, k)); + arg.b_element_op_(v_b, arg.b_g_k_n_(g, k, n)); + + v_acc += + ck::type_convert(v_a) * ck::type_convert(v_b); + } + + AccDataType v_c; + arg.c_element_op_(v_c, v_acc); + partialSums[batchIdx] = ck::type_convert(v_c); } - AccDataType v_c; - - arg.c_element_op_(v_c, v_acc); + // finally, sum up partial sums + // note that we can't simulate the random nature of atomic additions, but at least + // we can simulate the effect of partial sums + AccDataType v_c = 0; + if(arg.k_batch_ > 1) + { + for(int batchIdx = 0; batchIdx < arg.k_batch_; batchIdx++) + { + // mimic the way fp operations would be done on GPU for k-batching + v_c = ck::type_convert(ck::type_convert( + ck::type_convert(v_c) + + ck::type_convert(partialSums[batchIdx]))); + } + } + else + { + v_c = ck::type_convert(partialSums[0]); + } arg.c_g_m_n_(g, m, n) = ck::type_convert(v_c); }; @@ -108,9 +146,11 @@ struct ReferenceBatchedGemm : public device::BaseOperator Tensor& c_g_m_n, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, - CElementwiseOperation c_element_op) + CElementwiseOperation c_element_op, + const int k_batch = 1) { - return Argument{a_g_m_k, b_g_k_n, c_g_m_n, a_element_op, b_element_op, c_element_op}; + return Argument{ + a_g_m_k, b_g_k_n, c_g_m_n, a_element_op, b_element_op, c_element_op, k_batch}; } static auto MakeInvoker() { return Invoker{}; } diff --git a/library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_b_scale.hpp b/library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_b_scale.hpp index 9f4b31528b..c57c69d91c 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_b_scale.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_b_scale.hpp @@ -5,6 +5,8 @@ #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl_fpAintB_b_scale.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3_b_scale.hpp" + #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include @@ -16,6 +18,8 @@ namespace ck { namespace tensor_operation { namespace device { namespace instance { + +#if defined(CK_USE_XDL) #if(defined(CK_ENABLE_FP16) || defined(CK_ENABLE_FP8)) void add_device_batched_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn_mem_v2_default_instances( std::vector>>& instances); #endif +#endif // CK_USE_XDL + +#if defined(CK_USE_WMMA) +#if(defined(CK_ENABLE_FP16) || defined(CK_ENABLE_FP8)) // TODO: really, or? +void add_device_batched_gemm_b_scale_wmma_f16_i4_f16_mk_nk_mn_mem_default_instances( + std::vector>>& instances); +#endif // CK_ENABLE_FP16 || CK_ENABLE_FP8 +#endif // CK_USE_WMMA template + struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { +#if defined(CK_USE_XDL) add_device_batched_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn_mem_v2_default_instances( op_ptrs); +#endif // CK_USE_XDL +#if defined(CK_USE_WMMA) + add_device_batched_gemm_b_scale_wmma_f16_i4_f16_mk_nk_mn_mem_default_instances( + op_ptrs); +#endif // CK_USE_WMMA } } diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_b_scale/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/batched_gemm_b_scale/CMakeLists.txt index 3221f4c17e..77295ed151 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm_b_scale/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_b_scale/CMakeLists.txt @@ -1,10 +1,13 @@ -# ONLY XDL_KERNELS +# ONLY XDL_AND_WMMA_KERNELS set(BATCHED_GEMM_B_SCALE_INSTANCES) list(APPEND BATCHED_GEMM_B_SCALE_INSTANCES device_batched_gemm_b_scale_xdl_f16_i4_f16/device_batched_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn_mem_v2_default_instance.cpp + device_batched_gemm_b_scale_wmma_f16_i4_f16/device_batched_gemm_b_scale_wmma_f16_i4_f16_mk_nk_mn_mem_default_instance.cpp ) set_source_files_properties(device_batched_gemm_b_scale_xdl_f16_i4_f16/device_batched_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn_mem_v2_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_batched_gemm_b_scale_wmma_f16_i4_f16/device_batched_gemm_b_scale_wmma_f16_i4_f16_mk_nk_mn_mem_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") + add_instance_library(device_batched_gemm_b_scale_instance ${BATCHED_GEMM_B_SCALE_INSTANCES}) \ No newline at end of file diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_b_scale/device_batched_gemm_b_scale_wmma_f16_i4_f16/device_batched_gemm_b_scale_wmma_f16_i4_f16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/batched_gemm_b_scale/device_batched_gemm_b_scale_wmma_f16_i4_f16/device_batched_gemm_b_scale_wmma_f16_i4_f16_mk_nk_mn.hpp new file mode 100644 index 0000000000..8cf9933d6c --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_b_scale/device_batched_gemm_b_scale_wmma_f16_i4_f16/device_batched_gemm_b_scale_wmma_f16_i4_f16_mk_nk_mn.hpp @@ -0,0 +1,72 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3_b_scale.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using I4 = pk_i4_t; +using F16 = half_t; +using F32 = float; + +using Row = tensor_layout::gemm::RowMajor; +using Col = tensor_layout::gemm::ColumnMajor; + +template +using S = Sequence; + +using PassThrough = element_wise::PassThrough; + +static constexpr auto GemmDefault = GemmSpecialization::Default; +static constexpr auto GemmKPadding = GemmSpecialization::KPadding; +static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +template +using device_batched_gemm_b_scale_wmma_f16_i4_f16_mk_nk_mn_mem_instances = std::tuple< + // clang-format off + //################################| ALayout| BLayout| CLayout|AData|BData| BScale| CData| AccData| Cshuffle| A| B| C| GEMM| Block| Scale| Scale| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| Compute| Compute| PermuteA| PermuteB| + //################################| | | | Type| Type| Data| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| Block| Block| | |Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock| ScalarPerVector| Pipeline| Pipeline| TypeA| TypeB| | | + //################################| | | | | | Type| | | | Operation| Operation| Operation| | | N| K| | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| Scheduler| Verision| | | | | + //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + + DeviceBatchedGemm_Wmma_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 64, 8, 8, 16, 16, 4, 2, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //0 + DeviceBatchedGemm_Wmma_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 64, 8, 8, 16, 16, 4, 2, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1, half_t, half_t, false, false>, //1 + DeviceBatchedGemm_Wmma_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 64, 8, 8, 16, 16, 4, 2, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, Interwave, BlockGemmPipelineVersion::v1, half_t, half_t, false, false>, //2 + DeviceBatchedGemm_Wmma_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 64, 64, 64, 8, 8, 16, 16, 2, 2, S< 4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //3 + DeviceBatchedGemm_Wmma_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 64, 64, 64, 8, 8, 16, 16, 2, 2, S< 4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1, half_t, half_t, false, false>, //4 + DeviceBatchedGemm_Wmma_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 64, 64, 64, 8, 8, 16, 16, 2, 2, S< 4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 8, Interwave, BlockGemmPipelineVersion::v1, half_t, half_t, false, false>, //5 + DeviceBatchedGemm_Wmma_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 1, 128, 32, 32, 64, 8, 8, 16, 16, 2, 2, S< 2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //6 + DeviceBatchedGemm_Wmma_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 1, 128, 32, 32, 64, 8, 8, 16, 16, 2, 2, S< 2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v1, half_t, half_t, false, false>, //7 + DeviceBatchedGemm_Wmma_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 1, 128, 32, 32, 64, 8, 8, 16, 16, 2, 2, S< 2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 2>, 8, Interwave, BlockGemmPipelineVersion::v1, half_t, half_t, false, false>, //8 + + DeviceBatchedGemm_Wmma_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 128, 8, 8, 16, 16, 4, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //9 + DeviceBatchedGemm_Wmma_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 128, 8, 8, 16, 16, 4, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1, half_t, half_t, false, false>, //10 + DeviceBatchedGemm_Wmma_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 128, 8, 8, 16, 16, 4, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, Interwave, BlockGemmPipelineVersion::v1, half_t, half_t, false, false>, //11 + DeviceBatchedGemm_Wmma_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 32, 32, 128, 8, 8, 16, 16, 1, 1, S< 8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, Intrawave, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //12 + DeviceBatchedGemm_Wmma_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 32, 32, 128, 8, 8, 16, 16, 1, 1, S< 8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, Intrawave, BlockGemmPipelineVersion::v1, half_t, half_t, false, false>, //13 + DeviceBatchedGemm_Wmma_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 32, 32, 128, 8, 8, 16, 16, 1, 1, S< 8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, Interwave, BlockGemmPipelineVersion::v1, half_t, half_t, false, false>, //14 + DeviceBatchedGemm_Wmma_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 1, 128, 16, 16, 128, 8, 8, 16, 16, 1, 1, S< 2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //15 + DeviceBatchedGemm_Wmma_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 1, 128, 16, 16, 128, 8, 8, 16, 16, 1, 1, S< 2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v1, half_t, half_t, false, false>, //16 + DeviceBatchedGemm_Wmma_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 1, 128, 16, 16, 128, 8, 8, 16, 16, 1, 1, S< 2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 2>, 8, Interwave, BlockGemmPipelineVersion::v1, half_t, half_t, false, false>, //17 + DeviceBatchedGemm_Wmma_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 1, 128, 16, 16, 128, 8, 8, 16, 16, 1, 1, S< 4, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 4, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 2>, 4, Intrawave, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //18 + DeviceBatchedGemm_Wmma_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 1, 128, 16, 16, 128, 8, 8, 16, 16, 1, 1, S< 4, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 4, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 2>, 4, Intrawave, BlockGemmPipelineVersion::v1, half_t, half_t, false, false>, //19 + DeviceBatchedGemm_Wmma_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 1, 128, 16, 16, 128, 8, 8, 16, 16, 1, 1, S< 4, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 4, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 2>, 4, Interwave, BlockGemmPipelineVersion::v1, half_t, half_t, false, false> //20 + + // clang-format on + >; +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_b_scale/device_batched_gemm_b_scale_wmma_f16_i4_f16/device_batched_gemm_b_scale_wmma_f16_i4_f16_mk_nk_mn_mem_default_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm_b_scale/device_batched_gemm_b_scale_wmma_f16_i4_f16/device_batched_gemm_b_scale_wmma_f16_i4_f16_mk_nk_mn_mem_default_instance.cpp new file mode 100644 index 0000000000..5203beb92c --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_b_scale/device_batched_gemm_b_scale_wmma_f16_i4_f16/device_batched_gemm_b_scale_wmma_f16_i4_f16_mk_nk_mn_mem_default_instance.cpp @@ -0,0 +1,33 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_batched_gemm_b_scale_wmma_f16_i4_f16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +void add_device_batched_gemm_b_scale_wmma_f16_i4_f16_mk_nk_mn_mem_default_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_batched_gemm_b_scale_wmma_f16_i4_f16_mk_nk_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_b_scale/device_batched_gemm_b_scale_xdl_f16_i4_f16/device_batched_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/batched_gemm_b_scale/device_batched_gemm_b_scale_xdl_f16_i4_f16/device_batched_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn.hpp index 9abe6f95b6..1a8b10ab30 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm_b_scale/device_batched_gemm_b_scale_xdl_f16_i4_f16/device_batched_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_b_scale/device_batched_gemm_b_scale_xdl_f16_i4_f16/device_batched_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn.hpp @@ -45,9 +45,6 @@ using device_batched_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn_mem_instances = std::t DeviceBatchedGemm_Xdl_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 128, 8, 32, 32, 32, 2, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //0 DeviceBatchedGemm_Xdl_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 64, 8, 32, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v4, half_t, half_t, false, false>, //1 - DeviceBatchedGemm_Xdl_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 64, 8, 32, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v4, half_t, half_t, false, false>, //3 - DeviceBatchedGemm_Xdl_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 64, 8, 32, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //4 - //Latency friendly DeviceBatchedGemm_Xdl_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 32, 16, 128, 8, 16, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //5 DeviceBatchedGemm_Xdl_CShuffleV3_BScale< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 1, 128, 16, 16, 128, 8, 16, 16, 16, 1, 1, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //6 diff --git a/library/src/tensor_operation_instance/gpu/gemm_b_scale/device_gemm_b_scale_xdl_f16_i4_f16/device_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_b_scale/device_gemm_b_scale_xdl_f16_i4_f16/device_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn.hpp index 1f8ca4d23a..46e569e3c7 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_b_scale/device_gemm_b_scale_xdl_f16_i4_f16/device_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_b_scale/device_gemm_b_scale_xdl_f16_i4_f16/device_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn.hpp @@ -51,9 +51,6 @@ using device_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn_mem_instances = std::tuple< DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 128, 8, 32, 32, 32, 2, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //0 DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 64, 8, 32, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v4, half_t, half_t, false, false>, //1 - DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 64, 8, 32, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v4, half_t, half_t, false, false>, //3 - DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 64, 8, 32, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //4 - //Latency friendly DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 32, 16, 128, 8, 16, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //5 DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 1, 128, 16, 16, 128, 8, 16, 16, 16, 1, 1, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //6 diff --git a/profiler/include/profiler/profile_batched_gemm_b_scale_impl.hpp b/profiler/include/profiler/profile_batched_gemm_b_scale_impl.hpp index 060fbd70e5..357ab8d70f 100644 --- a/profiler/include/profiler/profile_batched_gemm_b_scale_impl.hpp +++ b/profiler/include/profiler/profile_batched_gemm_b_scale_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -9,12 +9,13 @@ #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl_fpAintB_b_scale.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3_b_scale.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/library/tensor_operation_instance/gpu/batched_gemm_b_scale.hpp" -#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" #include "ck/library/utility/check_err.hpp" #include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/host_tensor.hpp" @@ -113,22 +114,21 @@ bool profile_batched_gemm_b_scale_impl(int do_verification, std::cout << "c_g_m_n: " << c_g_m_n_device_result.mDesc << std::endl; std::cout << "rotating count: " << rotating_count << std::endl; + static constexpr index_t BPackedSize = []() { + if constexpr(is_same_v, pk_i4_t>) + return 2; + else + return 1; + }(); + switch(init_method) { case 0: break; - case 1: - a_g_m_k.GenerateTensorValue(GeneratorTensor_2{-1, 2}); - b_g_k_n.GenerateTensorValue(GeneratorTensor_2{-1, 2}); - b1_g_k_n.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); - break; - case 2: - a_g_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - b_g_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); - b1_g_k_n.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); - break; + // NOTE: for an int4, there is no point differentiating between decimal and integer + // initialization also, the random number seem to be for a int4_2 type, so we use range 0...255 default: a_g_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - b_g_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b_g_k_n.GenerateTensorValue(GeneratorTensor_2{-1, 2}); b1_g_k_n.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); } @@ -141,7 +141,8 @@ bool profile_batched_gemm_b_scale_impl(int do_verification, const auto c_element_op = CElementOp{}; DeviceMem a_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpaceSize()); - DeviceMem b_device_buf(sizeof(BDataType) * b_g_k_n_permute.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b_g_k_n_permute.mDesc.GetElementSpaceSize() / + BPackedSize); DeviceMem b1_device_buf(sizeof(BScaleDataType) * b1_g_k_n.mDesc.GetElementSpaceSize()); DeviceMem c_device_buf(sizeof(CDataType) * c_g_m_n_device_result.mDesc.GetElementSpaceSize()); @@ -166,54 +167,63 @@ bool profile_batched_gemm_b_scale_impl(int do_verification, DeviceOp>::GetInstances(); std::cout << "found " << op_ptrs.size() << " instances" << std::endl; - // Run reference GEMM if(do_verification) { - Tensor b_g_k_n_dequant({K, N}); + Tensor b_g_k_n_dequant({BatchSize, K, N}); float v_b = 0; for(int bs = 0; bs < BatchSize; bs++) { for(int n = 0; n < N; n++) { + for(int k = 0; k < K; k++) { - ck::pk_i4_t i4x2 = b_g_k_n(bs, k, n).data; - int8_t i4 = 0; - if(k % 2 == 1) + + // for proper testing, we need to replicate k_shuffle when used + // see unary_element_wise_operation.hpp +#if CK_USE_PK4_LAYOUT_SHUFFLE + int k_shuffle = (k / 8) * 8 + (k % 2) * 4 + (k % 8) / 2; +#else + int k_shuffle = k; +#endif + + ck::pk_i4_t i4x2 = b_g_k_n(bs, k_shuffle, n).data; + int i4 = 0; + if(k_shuffle % 2 == 0) i4 = (i4x2.data >> 0) & 0xf; else i4 = (i4x2.data >> 4) & 0xf; - i4 = i4 - 8; + i4 = i4 - 8; + v_b = ck::type_convert(i4); - b_g_k_n_dequant(bs, k, n) = - ck::type_convert(v_b) * - ck::type_convert(b1_g_k_n(bs, k / ScaleBlockK, n)); + float out = ck::type_convert(v_b) * + ck::type_convert(b1_g_k_n(bs, k / ScaleBlockK, n)); + + b_g_k_n_dequant(bs, k, n) = out; } } } + using ReferenceBatchedGemmInstance = + ck::tensor_operation::host::ReferenceBatchedGemm; - using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; - - auto ref_gemm = ReferenceGemmInstance{}; - auto ref_invoker = ref_gemm.MakeInvoker(); - - auto ref_argument = ref_gemm.MakeArgument(a_g_m_k, - b_g_k_n_dequant, - c_g_m_n_host_result, - a_element_op, - b_element_op, - c_element_op); - + auto ref_batched_gemm = ReferenceBatchedGemmInstance{}; + auto ref_invoker = ref_batched_gemm.MakeInvoker(); + auto ref_argument = ref_batched_gemm.MakeArgument(a_g_m_k, + b_g_k_n_dequant, + c_g_m_n_host_result, + a_element_op, + b_element_op, + c_element_op, + KBatch); ref_invoker.Run(ref_argument); } @@ -230,6 +240,7 @@ bool profile_batched_gemm_b_scale_impl(int do_verification, if(op_ptr->GetPermuteB()) { + int K1 = KPerBlock; int K0 = K / KPerBlock; @@ -306,6 +317,7 @@ bool profile_batched_gemm_b_scale_impl(int do_verification, } else { + b_g_k_n_permute = b_g_k_n; } @@ -375,8 +387,12 @@ bool profile_batched_gemm_b_scale_impl(int do_verification, else { #endif + std::string msg = "Error: Incorrect results!"; + double rtol = 1e-2; + double atol = 1e-2; pass = - pass & ck::utils::check_err(c_g_m_n_device_result, c_g_m_n_host_result); + pass & ck::utils::check_err( + c_g_m_n_device_result, c_g_m_n_host_result, msg, rtol, atol); #if defined CK_ENABLE_FP8 } #endif @@ -407,13 +423,6 @@ bool profile_batched_gemm_b_scale_impl(int do_verification, std::size_t flop = std::size_t(2) * M * N * K * BatchSize; - static constexpr index_t BPackedSize = []() { - if constexpr(is_same_v, pk_i4_t>) - return 2; - else - return 1; - }(); - std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N / BPackedSize + sizeof(CDataType) * M * N; diff --git a/profiler/include/profiler/profile_gemm_b_scale_impl.hpp b/profiler/include/profiler/profile_gemm_b_scale_impl.hpp index 86370e2f47..8ca1350523 100644 --- a/profiler/include/profiler/profile_gemm_b_scale_impl.hpp +++ b/profiler/include/profiler/profile_gemm_b_scale_impl.hpp @@ -105,7 +105,7 @@ bool profile_gemm_b_scale_impl(int do_verification, break; case 2: a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-1, 2}); b1_k_n.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); break; default: @@ -122,8 +122,16 @@ bool profile_gemm_b_scale_impl(int do_verification, const auto b_element_op = BElementOp{}; const auto c_element_op = CElementOp{}; + static constexpr index_t BPackedSize = []() { + if constexpr(is_same_v, pk_i4_t>) + return 2; + else + return 1; + }(); + DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); - DeviceMem b_device_buf(sizeof(BDataType) * b_k_n_permute.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n_permute.mDesc.GetElementSpaceSize() / + BPackedSize); DeviceMem b1_device_buf(sizeof(BScaleDataType) * b1_k_n.mDesc.GetElementSpaceSize()); DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize()); @@ -152,16 +160,24 @@ bool profile_gemm_b_scale_impl(int do_verification, // Run reference GEMM if(do_verification) { - Tensor b_k_n_dequant({K, N}); + Tensor b_k_n_dequant({K, N}); float v_b = 0; for(int n = 0; n < N; n++) { for(int k = 0; k < K; k++) { - ck::pk_i4_t i4x2 = b_k_n(k, n).data; - int8_t i4 = 0; - if(k % 2 == 1) + // for proper testing, we need to replicate k_shuffle when used + // see unary_element_wise_operation.hpp +#if CK_USE_PK4_LAYOUT_SHUFFLE + int k_shuffle = (k / 8) * 8 + (k % 2) * 4 + (k % 8) / 2; +#else + int k_shuffle = k; +#endif + + ck::pk_i4_t i4x2 = b_k_n(k_shuffle, n).data; + int i4 = 0; + if(k_shuffle % 2 == 0) i4 = (i4x2.data >> 0) & 0xf; else i4 = (i4x2.data >> 4) & 0xf; @@ -173,7 +189,7 @@ bool profile_gemm_b_scale_impl(int do_verification, } } using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm, pk_i4_t>) - return 2; - else - return 1; - }(); - std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N / BPackedSize + sizeof(CDataType) * M * N; diff --git a/profiler/include/profiler/profile_gemm_universal_impl.hpp b/profiler/include/profiler/profile_gemm_universal_impl.hpp index bb73c4e3da..bee907dd76 100644 --- a/profiler/include/profiler/profile_gemm_universal_impl.hpp +++ b/profiler/include/profiler/profile_gemm_universal_impl.hpp @@ -90,7 +90,7 @@ bool profile_gemm_universal_impl(int do_verification, break; case 2: a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-1, 2}); break; default: a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); diff --git a/profiler/src/CMakeLists.txt b/profiler/src/CMakeLists.txt index c31ede2c73..9f86f6d88f 100644 --- a/profiler/src/CMakeLists.txt +++ b/profiler/src/CMakeLists.txt @@ -67,7 +67,6 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9") list(APPEND PROFILER_OPS profile_gemm_add.cpp) list(APPEND PROFILER_OPS profile_gemm_bias_add_reduce.cpp) list(APPEND PROFILER_OPS profile_gemm_splitk.cpp) - list(APPEND PROFILER_OPS profile_batched_gemm_b_scale.cpp) list(APPEND PROFILER_OPS profile_gemm_universal_batched.cpp) list(APPEND PROFILER_OPS profile_gemm_universal_streamk.cpp) list(APPEND PROFILER_OPS profile_conv_fwd_bias_relu.cpp) @@ -89,6 +88,7 @@ endif() if(SUPPORTED_GPU_TARGETS MATCHES "gfx9" OR SUPPORTED_GPU_TARGETS MATCHES "gfx1[12]") list(APPEND PROFILER_OPS profile_gemm_universal.cpp) list(APPEND PROFILER_OPS profile_batched_gemm.cpp) + list(APPEND PROFILER_OPS profile_batched_gemm_b_scale.cpp) list(APPEND PROFILER_OPS profile_gemm_b_scale.cpp) list(APPEND PROFILER_OPS profile_gemm_universal_reduce.cpp) list(APPEND PROFILER_OPS profile_grouped_conv_fwd.cpp) @@ -191,7 +191,6 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9" OR SUPPORTED_GPU_TARGETS MATCHES "gfx1[1 list(APPEND DEVICE_INSTANCES device_gemm_mx_instance) endif() list(APPEND DEVICE_INSTANCES device_gemm_splitk_instance) - list(APPEND DEVICE_INSTANCES device_batched_gemm_b_scale_instance) list(APPEND DEVICE_INSTANCES device_gemm_universal_batched_instance) list(APPEND DEVICE_INSTANCES device_gemm_universal_streamk_instance) list(APPEND DEVICE_INSTANCES device_gemm_add_multiply_instance) @@ -229,6 +228,7 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9" OR SUPPORTED_GPU_TARGETS MATCHES "gfx1[1 list(APPEND DEVICE_INSTANCES device_batched_gemm_instance) list(APPEND DEVICE_INSTANCES device_gemm_b_scale_instance) list(APPEND DEVICE_INSTANCES device_gemm_universal_reduce_instance) + list(APPEND DEVICE_INSTANCES device_batched_gemm_b_scale_instance) list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_instance) list(APPEND DEVICE_INSTANCES device_grouped_conv2d_bwd_data_instance) list(APPEND DEVICE_INSTANCES device_grouped_conv3d_bwd_data_instance) diff --git a/profiler/src/profile_batched_gemm_b_scale.cpp b/profiler/src/profile_batched_gemm_b_scale.cpp index 5fe6f490be..5ed673e127 100644 --- a/profiler/src/profile_batched_gemm_b_scale.cpp +++ b/profiler/src/profile_batched_gemm_b_scale.cpp @@ -57,7 +57,7 @@ int profile_batched_gemm_b_scale(int argc, char* argv[]) printf("arg6: initialization (0: no init; 1: integer value; 2: decimal value)\n"); printf("arg7: print tensor value (0: no; 1: yes)\n"); printf("arg8: time kernel (0=no, 1=yes)\n"); - printf("arg9 to 15: M, N, K, StrideA, StrideB, StrideC, BatachCount\n"); + printf("arg9 to 15: M, N, K, StrideA, StrideB, StrideC, BatchCount\n"); printf("arg16: split k into mulitiple batch\n"); printf("optional:\n"); printf("arg17: number of warm-up cycles (default 1)\n"); diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 292bc41a0b..c16841d595 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -24,6 +24,7 @@ set(REGRESSION_TESTS test_batched_gemm_softmax_gemm_permute_bf16 test_batched_gemm_bias_softmax_gemm_permute_bf16 test_grouped_gemm_splitk + test_batched_gemm_b_scale_wmma test_reduce_no_index test_reduce_with_index test_convnd_fwd @@ -257,6 +258,7 @@ add_subdirectory(batched_gemm_reduce) add_subdirectory(batched_gemm_gemm) add_subdirectory(batched_gemm_softmax_gemm) add_subdirectory(batched_gemm_softmax_gemm_permute) +add_subdirectory(batched_gemm_b_scale) add_subdirectory(grouped_gemm) add_subdirectory(reduce) add_subdirectory(convnd_fwd) diff --git a/test/batched_gemm_b_scale/CMakeLists.txt b/test/batched_gemm_b_scale/CMakeLists.txt new file mode 100644 index 0000000000..abc3d14ee1 --- /dev/null +++ b/test/batched_gemm_b_scale/CMakeLists.txt @@ -0,0 +1,5 @@ + +add_gtest_executable(test_batched_gemm_b_scale_wmma test_batched_gemm_b_scale_wmma.cpp) +if(result EQUAL 0) + target_link_libraries(test_batched_gemm_b_scale_wmma PRIVATE utility device_batched_gemm_b_scale_instance) +endif() diff --git a/test/batched_gemm_b_scale/test_batched_gemm_b_scale_ut_cases.inc b/test/batched_gemm_b_scale/test_batched_gemm_b_scale_ut_cases.inc new file mode 100644 index 0000000000..66cbaad323 --- /dev/null +++ b/test/batched_gemm_b_scale/test_batched_gemm_b_scale_ut_cases.inc @@ -0,0 +1,49 @@ +#pragma once + +TYPED_TEST(TestBatchedGemmBScale_MK_NK, SmallM) +{ + std::vector Ms{1, 2, 3, 4, 5, 6}; + constexpr int N = 256; + constexpr int K = 1024; + + constexpr int StrideA = K; + constexpr int StrideB = K; + constexpr int StrideC = N; + + constexpr int NBatches = 10; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC, NBatches); +} + +TYPED_TEST(TestBatchedGemmBScale_MK_NK, MidLargeM) +{ + std::vector Ms{127, 255, 312, 799, 1573}; + constexpr int N = 512; + constexpr int K = 768; + + constexpr int StrideA = K; + constexpr int StrideB = K; + constexpr int StrideC = N; + + constexpr int NBatches = 7; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC, NBatches); +} + +TYPED_TEST(TestBatchedGemmBScale_MK_NK, Regular) +{ + std::vector Ms{512, 1024}; + constexpr int N = 512; + constexpr int K = 1024; + + constexpr int StrideA = K; + constexpr int StrideB = K; + constexpr int StrideC = N; + + constexpr int NBatches = 3; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC, NBatches); +} diff --git a/test/batched_gemm_b_scale/test_batched_gemm_b_scale_util.hpp b/test/batched_gemm_b_scale/test_batched_gemm_b_scale_util.hpp new file mode 100644 index 0000000000..e413a762a3 --- /dev/null +++ b/test/batched_gemm_b_scale/test_batched_gemm_b_scale_util.hpp @@ -0,0 +1,108 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "include/ck/utility/data_type.hpp" +#include "profiler/profile_batched_gemm_b_scale_impl.hpp" + +namespace ck { +namespace test { + +template +class TestBatchedGemmBScale : public testing::Test +{ + using Row = ck::tensor_layout::gemm::RowMajor; + using F32 = float; + + protected: + using ALayout = std::tuple_element_t<0, Tuple>; + using BLayout = std::tuple_element_t<1, Tuple>; + using CLayout = Row; + using ADataType = std::tuple_element_t<2, Tuple>; + using BDataType = std::tuple_element_t<3, Tuple>; + using BScaleDataType = std::tuple_element_t<4, Tuple>; + using ComputeDataType = std::tuple_element_t<5, Tuple>; + using CDataType = std::tuple_element_t<6, Tuple>; + + public: + static constexpr ck::index_t ScaleBlockK = 128; // all instances + static constexpr bool verify_ = true; + static constexpr int init_method_ = 2; + static constexpr bool log_ = false; + static constexpr bool bench_ = false; // measure kernel performance + std::vector k_batches_; + + void SetUp() override { k_batches_ = {1, 2}; } + + void Run(const int M, + const int N, + const int K, + const int StrideA, + const int StrideB, + const int StrideC, + const int NBatch) + { + for(auto kb : k_batches_) + { + RunSingle(M, N, K, StrideA, StrideB, StrideC, NBatch, kb); + } + } + + void RunSingle(const int M, + const int N, + const int K, + const int StrideA, + const int StrideB, + const int StrideC, + const int Nbatch, + int kbatch = 1, + int n_warmup = 1, + int n_iter = 10) + { + const int BatchStrideA = StrideA * M; + const int BatchStrideB = StrideB * K; + const int BatchStrideC = StrideC * M; + const int BatchStrideScaleB = StrideB * K; + bool pass = ck::profiler::profile_batched_gemm_b_scale_impl(verify_, + init_method_, + log_, + bench_, + M, + N, + K, + StrideA, + StrideB, + StrideC, + BatchStrideA, + BatchStrideB, + BatchStrideC, + BatchStrideScaleB, + Nbatch, + kbatch, + n_warmup, + n_iter); + EXPECT_TRUE(pass); + } +}; + +} // namespace test +} // namespace ck diff --git a/test/batched_gemm_b_scale/test_batched_gemm_b_scale_wmma.cpp b/test/batched_gemm_b_scale/test_batched_gemm_b_scale_wmma.cpp new file mode 100644 index 0000000000..f004c78969 --- /dev/null +++ b/test/batched_gemm_b_scale/test_batched_gemm_b_scale_wmma.cpp @@ -0,0 +1,45 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "gtest/gtest.h" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "test_batched_gemm_b_scale_util.hpp" + +using I4 = ck::pk_i4_t; +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +namespace { + +template +struct tuple_concat; + +template +struct tuple_concat, std::tuple> +{ + using type = std::tuple; +}; + +} // namespace + +template +class TestBatchedGemmBScale_MK_NK : public ck::test::TestBatchedGemmBScale< + typename tuple_concat, Tuple>::type> +{ +}; + +// clang-format off +using KernelTypes_MK_NK = ::testing::Types< + // ADataType, BDataType, BScaleDataType, ComputeDataType, CDataType + std::tuple< F16, I4, F16, F16, F16> + >; +// clang-format on + +TYPED_TEST_SUITE(TestBatchedGemmBScale_MK_NK, KernelTypes_MK_NK); + +#include "test_batched_gemm_b_scale_ut_cases.inc"