Implement grouped gemm tile loop for RDNA4 (#3304)

* feat: grouped gemm tile loop support for RDNA4

* fix: removed extra parameter from grouped gemm example instance

* fix: FP8 check incorrectly enabling FP8 on RDNA3
This commit is contained in:
Erwin Terpstra
2026-01-13 07:14:23 +01:00
committed by GitHub
parent 141f77aa12
commit eb041079a3
44 changed files with 3067 additions and 1223 deletions

View File

@@ -151,7 +151,10 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
static constexpr index_t PrefillStages = 1;
static constexpr index_t GlobalBufferNum = 1;
static bool BlockHasHotloop(index_t num_loop) { return num_loop > PrefetchStages; }
static bool __host__ __device__ BlockHasHotloop(index_t num_loop)
{
return num_loop > PrefetchStages;
}
static TailNumber BlockLoopTailNum(index_t num_loop)
{
@@ -707,7 +710,10 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
static constexpr index_t PrefillStages = 1;
static constexpr index_t GlobalBufferNum = 1;
static bool BlockHasHotloop(index_t num_loop) { return num_loop > PrefetchStages; }
__host__ __device__ static bool BlockHasHotloop(index_t num_loop)
{
return num_loop > PrefetchStages;
}
static TailNumber BlockLoopTailNum(index_t num_loop)
{

View File

@@ -3,6 +3,11 @@
#pragma once
#include "ck/ck.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/stream_utility.hpp"
#include "device_grouped_gemm.hpp"
namespace ck {
@@ -43,6 +48,59 @@ struct DeviceGroupedGemmTileLoop : public DeviceGroupedGemm<ALayout,
{
};
template <ck::index_t BlockSize>
struct TileLoopKernelConfig
{
// The oversubscription factor for the number of blocks that can simultaneously reside on
// GPU.
static constexpr int BLOCK_SUBSCRIPTION_FACTOR = 1;
// static constexpr int BLOCK_WAVES = BlockSize / get_warp_size();
static constexpr int CU_SIMDS = 4;
// Assume we want to have at most 2 waves per SIMD
// static constexpr int CU_BLOCKS = math::integer_divide_floor(2 * CU_SIMDS, BLOCK_WAVES);
static int GetCuBlocks()
{
int BLOCK_WAVES = BlockSize / get_warp_size();
return ck::math::integer_divide_floor(2 * CU_SIMDS, BLOCK_WAVES);
}
template <typename KernelFunction>
static int CalculateMaxOccupancyGridSize(const KernelFunction& kernel,
const StreamConfig& stream_config)
{
// Calculate max number of workgroups that can simultaneously reside on the CU.
int occ_num_blocks = GetKernelOccupancy(kernel);
int cu_count = getAvailableComputeUnitCount(stream_config);
if(stream_config.log_level_ > 0)
{
std::cout << "MaxActiveBlocksPerCU: " << occ_num_blocks
<< ", available CUs count: " << cu_count << ", occup. grid size: "
<< ck::math::min(occ_num_blocks, GetCuBlocks()) * cu_count << std::endl;
}
return cu_count * ck::math::min(occ_num_blocks, GetCuBlocks());
}
template <typename KernelFunction>
static int GetKernelOccupancy(const KernelFunction& kernel)
{
int occupancy = 0;
ck::hip_check_error(
hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, BlockSize, 0));
return occupancy;
}
static int GetComputeUnitCount()
{
hipDeviceProp_t dev_prop;
hipDevice_t dev;
ck::hip_check_error(hipGetDevice(&dev));
ck::hip_check_error(hipGetDeviceProperties(&dev_prop, dev));
return dev_prop.multiProcessorCount;
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,689 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <iostream>
#include <sstream>
#include <tuple>
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/hip_check_error.hpp"
#include "ck/host_utility/stream_utility.hpp"
#include "ck/utility/loop_scheduler.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_tile_loop.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include <ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp>
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
///
/// @brief Entry point kernel for device-wide Grouped GEMM operation.
///
/// @param[in] gemm_descs_const The pointer to the array of GEMM descriptor structures.
/// @param[in] group_count The number of together processed GEMMs.
///
/// @tparam GridwiseGemm The specific GridwiseGEMM algorithm implementation.
/// @tparam GemmDesc The structure holding all necessary descriptors and
/// other data needed for grouped gemm calculation and work
/// distribution.
/// @tparam LocalBlock2ETileMap The structure providing mapping between workgroup ids,
/// the data tiles to process and the output tiles.
///
template <typename GridwiseGemm,
typename GemmDesc,
typename ADataType,
typename BDataType,
typename DsDataType,
typename EDataType,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
index_t KPerBlock,
typename OffsettedBlockToCTileMap,
typename LocalBlock2ETileMap,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
BlockGemmPipelineScheduler BlkGemmPipeSched,
BlockGemmPipelineVersion BlkGemmPipelineVer>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_grouped_gemm_multiple_d_wmma(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
const index_t group_count,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op)
{
#if(defined(__gfx11__) || defined(__gfx12__))
constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte<
typename GridwiseGemm::EpilogueCShuffle>();
__shared__ uint8_t p_shared[LDS_size];
const auto gemm_desc_ptr =
reinterpret_cast<const GemmDesc*>(cast_pointer_to_generic_address_space(gemm_descs_const));
constexpr auto NumDTensor = DsDataType::Size();
index_t tile_id = get_block_1d_id();
index_t tile_offset = 0;
index_t group_id = -1;
index_t group_offset = 0;
index_t grid_size_grp = 0;
index_t gemm_tile_id_start = 0;
index_t gemm_tile_id_end = 0;
index_t M = 0, N = 0, K = 0;
auto b2c_tile_map = OffsettedBlockToCTileMap(LocalBlock2ETileMap(1, 1), 1, 1);
do
{
// Find corresponding GEMM group for our tile
while(!(tile_id >= gemm_tile_id_start && tile_id < gemm_tile_id_end) &&
group_id < group_count)
{
group_offset += grid_size_grp;
group_id++;
if(group_id >= group_count)
return;
M = gemm_desc_ptr[group_id].M;
N = gemm_desc_ptr[group_id].N;
K = gemm_desc_ptr[group_id].K;
if(M == 0 || N == 0 || K == 0)
{
grid_size_grp = 0;
continue;
}
b2c_tile_map =
OffsettedBlockToCTileMap(LocalBlock2ETileMap(M, N, 4), group_offset, tile_offset);
grid_size_grp = b2c_tile_map.CalculateGridSize(M, N);
gemm_tile_id_start = group_offset;
gemm_tile_id_end = group_offset + grid_size_grp;
}
// Create A&B grid pointer containing their single tensors
typename GridwiseGemm::AsGridPointer p_as_grid = Tuple<const ADataType*>(
static_cast<const ADataType*>(gemm_desc_ptr[group_id].p_a_grid));
typename GridwiseGemm::BsGridPointer p_bs_grid = Tuple<const BDataType*>(
static_cast<const BDataType*>(gemm_desc_ptr[group_id].p_b_grid));
// Make a DsGridPointer instance containing all D tensors
using DsGridPointer = decltype(GridwiseGemm::MakeDsGridPointer());
DsGridPointer p_ds_grid;
std::array<index_t, NumDTensor> stride_Ds;
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
p_ds_grid(i) = static_cast<const DDataType*>(gemm_desc_ptr[group_id].p_ds_grid[i]);
stride_Ds[i] = gemm_desc_ptr[group_id].StrideDs[i];
});
index_t K_split = ck::math::integer_divide_ceil(K, KPerBlock) * KPerBlock;
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
// Update tile offset if we have moved within group
b2c_tile_map.UpdateTileOffset(tile_offset);
using Problem = typename GridwiseGemm::Problem;
auto problem = Problem(gemm_desc_ptr[group_id].M,
gemm_desc_ptr[group_id].N,
gemm_desc_ptr[group_id].K,
std::array<index_t, 1>{gemm_desc_ptr[group_id].StrideA},
std::array<index_t, 1>{gemm_desc_ptr[group_id].StrideB},
stride_Ds,
gemm_desc_ptr[group_id].StrideE,
1);
auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{};
constexpr TailNumber TailNum = TailNumber::Full;
if(has_main_k_block_loop)
{
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 ||
BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{
GridwiseGemm::template Run<true, InMemoryDataOperationEnum::Set, TailNum>(
p_as_grid,
p_bs_grid,
p_ds_grid,
static_cast<EDataType*>(gemm_desc_ptr[group_id].p_e_grid),
static_cast<void*>(p_shared),
problem,
b2c_tile_map,
a_element_op,
b_element_op,
cde_element_op,
epilogue_args);
}
}
else
{
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
{
GridwiseGemm::template Run<false, InMemoryDataOperationEnum::Set, TailNum>(
p_as_grid,
p_bs_grid,
p_ds_grid,
static_cast<EDataType*>(gemm_desc_ptr[group_id].p_e_grid),
static_cast<void*>(p_shared),
problem,
b2c_tile_map,
a_element_op,
b_element_op,
cde_element_op,
epilogue_args);
}
}
tile_id += get_grid_size();
tile_offset += get_grid_size();
} while(group_id < group_count);
#else
ignore = gemm_descs_const;
ignore = group_count;
ignore = a_element_op;
ignore = b_element_op;
ignore = cde_element_op;
#endif // end of if (defined(__gfx11__) || defined(__gfx12__))
}
template <typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename ADataType,
typename BDataType,
typename AccDataType,
typename CShuffleDataType,
typename DsDataType,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
GemmSpecialization GemmSpec,
ck::index_t BlockSize,
ck::index_t MPerBlock,
ck::index_t NPerBlock,
ck::index_t KPerBlock,
ck::index_t AK1,
ck::index_t BK1,
ck::index_t MPerWmma,
ck::index_t NPerWmma,
ck::index_t MRepeat,
ck::index_t NRepeat,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_AK1,
index_t ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_BK1,
index_t BBlockLdsExtraN,
index_t CShuffleMRepeatPerShuffle,
index_t CShuffleNRepeatPerShuffle,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
typename CDEBlockTransferScalarPerVector_NPerBlock,
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1,
typename ComputeTypeA = EDataType,
typename ComputeTypeB = ComputeTypeA>
struct DeviceGroupedGemmMultipleD_Wmma_CShuffle_TileLoop_V3
: public DeviceGroupedGemmTileLoop<ALayout,
BLayout,
DsLayout,
ELayout,
ADataType,
BDataType,
DsDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation>
{
using DeviceOp = DeviceGroupedGemmMultipleD_Wmma_CShuffle_TileLoop_V3;
static constexpr index_t NumDTensor = DsDataType::Size();
using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3<
ALayout,
BLayout,
DsLayout,
ELayout,
Tuple<ADataType>,
Tuple<BDataType>,
AccDataType,
CShuffleDataType,
DsDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
GemmSpec,
BlockSize,
MPerBlock,
NPerBlock,
KPerBlock,
AK1,
BK1,
MPerWmma,
NPerWmma,
MRepeat,
NRepeat,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
false,
ABlockLdsExtraM,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
false,
BBlockLdsExtraN,
CShuffleMRepeatPerShuffle,
CShuffleNRepeatPerShuffle,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CDEBlockTransferScalarPerVector_NPerBlock,
BlkGemmPipeSched,
BlkGemmPipelineVer,
ComputeTypeA,
ComputeTypeB,
false, // PermuteA not supported by GridwiseOp.
false>; // PermuteB not supported by DeviceGroupedGemmTileLoop base class.
using KernelConfig = TileLoopKernelConfig<BlockSize>;
using KernelArguments = GroupedGemmKernelArgument<NumDTensor>;
using Block2ETileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>;
using OffsettedLocalBlock2ETileMap = OffsettedBlockToCTileMap2<Block2ETileMap>;
// Argument
struct Argument : public BaseArgument
{
Argument(std::vector<const void*>& /* p_As */,
std::vector<const void*>& /* p_Bs */,
std::vector<std::array<const void*, NumDTensor>>& /* p_Ds */,
std::vector<void*>& /* p_Es */,
const std::vector<GemmDesc>& gemm_descs,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op,
int occupancy_num_blocks,
int gpu_cu_count)
: group_count_{static_cast<index_t>(gemm_descs.size())},
occupancy_num_blocks_{occupancy_num_blocks},
gpu_cu_count_{gpu_cu_count},
gemm_descs_{gemm_descs},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
cde_element_op_{cde_element_op},
tile_count_{0}
{
for(const auto& desc : gemm_descs)
{
const auto M = desc.M_;
const auto N = desc.N_;
const auto b2c_tile_map = Block2ETileMap(M, N);
tile_count_ += b2c_tile_map.CalculateGridSize(M, N);
}
}
index_t group_count_;
const void* p_dev_gemm_args_;
int occupancy_num_blocks_;
int gpu_cu_count_;
const std::vector<GemmDesc>& gemm_descs_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CDEElementwiseOperation cde_element_op_;
index_t tile_count_;
};
// Invoker
struct Invoker : public BaseInvoker
{
///
/// @brief Launch Grouped Gemm kernel.
///
/// @note This function overload is using user provided device buffer for kernel
/// arguments.
///
/// @param[in] arg The structure containing kernel arguments (in host
/// memory).
/// @param[in] dev_gemm_args The pointer to device memory with kernel arguments.
/// @param[in] stream_config The device stream configuration.
///
/// @return The average kernel execution time (if time measurement is enabled.)
///
float Run(const Argument& arg,
const void* dev_gemm_args,
const StreamConfig& stream_config = StreamConfig{})
{
if(dev_gemm_args == nullptr)
{
std::ostringstream err;
err << "The gemm arguments device buffer is not allocated!" << " In " << __FILE__
<< ":" << __LINE__ << ", in function: " << __func__;
throw std::runtime_error(err.str());
}
const auto kernel = GetKernelFunction();
int grid_size = KernelConfig::CalculateMaxOccupancyGridSize(kernel, stream_config);
if(stream_config.log_level_ > 0)
{
std::cout << "grid_size: " << grid_size << " tile_count: " << arg.tile_count_
<< std::endl;
}
// run multiple kernels
return launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
cast_pointer_to_constant_address_space(dev_gemm_args),
arg.group_count_,
arg.a_element_op_,
arg.b_element_op_,
arg.cde_element_op_);
}
///
/// @brief Launch Grouped Gemm kernel.
///
/// @note This function overload is using device buffers (for kernel arguments and
/// for kernel auxiliary workspace) provided with an argument. The user should
/// call @see GetDeviceKernelArgSize, and @see SetDeviceKernelArgs, on arg
/// parameter to properly allocate those buffers.
///
/// @param[in] arg The structure containing kernel arguments (in host memory).
/// @param[in] stream_config The device stream configuration.
///
/// @return The average kernel execution time (if time measurement is enabled.)
///
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
if(arg.p_dev_gemm_args_ == nullptr)
{
std::ostringstream err;
err << "The gemm arguments device buffer is not allocated!" << " In " << __FILE__
<< ":" << __LINE__ << ", in function: " << __func__;
throw std::runtime_error(err.str());
}
return Run(arg, arg.p_dev_gemm_args_, stream_config);
}
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
static auto GetKernelFunction()
{
const auto kernel = kernel_grouped_gemm_multiple_d_wmma<GridwiseGemm,
KernelArguments,
ADataType,
BDataType,
DsDataType,
EDataType,
ALayout,
BLayout,
DsLayout,
ELayout,
KPerBlock,
OffsettedLocalBlock2ETileMap,
Block2ETileMap,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
BlkGemmPipeSched,
BlkGemmPipelineVer>;
return kernel;
}
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
static bool IsSupportedArgument(const Argument& arg)
{
if(!ck::is_gfx11_supported() && !ck::is_gfx12_supported())
{
return false;
}
if constexpr(std::is_same_v<ComputeTypeA, f8_t> || std::is_same_v<ComputeTypeA, bf8_t> ||
std::is_same_v<ComputeTypeB, f8_t> || std::is_same_v<ComputeTypeB, bf8_t>)
{
if(ck::is_gfx11_supported())
{
return false;
}
}
bool supported = true;
for(index_t i = 0; i < arg.group_count_; ++i)
{
std::array<const void*, NumDTensor> placeholder_p_ds_grid{};
std::array<index_t, NumDTensor> stride_Ds;
std::copy_n(arg.gemm_descs_[i].stride_Ds_.begin(), NumDTensor, stride_Ds.begin());
typename GridwiseGemm::Argument gridwise_arg(
std::array<const void*, 1>{nullptr}, // p_a_grid,
std::array<const void*, 1>{nullptr}, // p_b_grid,
placeholder_p_ds_grid, // p_ds_grid,
nullptr, // p_e_grid ,
arg.gemm_descs_[i].M_,
arg.gemm_descs_[i].N_,
arg.gemm_descs_[i].K_,
std::array<index_t, 1>{arg.gemm_descs_[i].stride_A_},
std::array<index_t, 1>{arg.gemm_descs_[i].stride_B_},
stride_Ds,
arg.gemm_descs_[i].stride_C_,
1, // KBatch
arg.a_element_op_,
arg.b_element_op_,
arg.cde_element_op_,
false);
bool group_arg_valid = GridwiseGemm::CheckValidity(gridwise_arg);
supported = supported && group_arg_valid;
if(!group_arg_valid)
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "[" << __func__ << "] group id: " << i
<< " has invalid GridwiseGemm settings!" << std::endl;
gridwise_arg.Print();
}
}
}
return supported;
}
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static int GetKernelOccupancy()
{
const auto kernel = GetKernelFunction();
return KernelConfig::GetKernelOccupancy(kernel);
}
static auto MakeArgument(std::vector<const void*>& p_As,
std::vector<const void*>& p_Bs,
std::vector<std::array<const void*, NumDTensor>>& p_Ds,
std::vector<void*>& p_Es,
std::vector<GemmDesc>& gemm_descs,
AElementwiseOperation a_elementwise_op,
BElementwiseOperation b_elementwise_op,
CDEElementwiseOperation cde_elementwise_op)
{
int occupancy = GetKernelOccupancy();
int num_cu = KernelConfig::GetComputeUnitCount();
return Argument{p_As,
p_Bs,
p_Ds,
p_Es,
gemm_descs,
a_elementwise_op,
b_elementwise_op,
cde_elementwise_op,
occupancy,
num_cu};
}
std::unique_ptr<BaseArgument>
MakeArgumentPointer(std::vector<const void*>& p_As,
std::vector<const void*>& p_Bs,
std::vector<std::array<const void*, NumDTensor>>& p_Ds,
std::vector<void*>& p_Es,
std::vector<GemmDesc>& gemm_descs,
AElementwiseOperation a_elementwise_op,
BElementwiseOperation b_elementwise_op,
CDEElementwiseOperation cde_elementwise_op) override
{
int occupancy = GetKernelOccupancy();
int num_cu = KernelConfig::GetComputeUnitCount();
return std::make_unique<Argument>(p_As,
p_Bs,
p_Ds,
p_Es,
gemm_descs,
a_elementwise_op,
b_elementwise_op,
cde_elementwise_op,
occupancy,
num_cu);
}
static auto MakeInvoker() { return Invoker{}; }
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::ostringstream();
std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
{BlockGemmPipelineScheduler::Intrawave, "Intrawave"},
{BlockGemmPipelineScheduler::Interwave, "Interwave"}};
std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
{BlockGemmPipelineVersion::v1, "v1"},
{BlockGemmPipelineVersion::v2, "v2"},
{BlockGemmPipelineVersion::v3, "v3"},
{BlockGemmPipelineVersion::v4, "v4"},
{BlockGemmPipelineVersion::v5, "v5"}};
// clang-format off
str << "DeviceGroupedGemmMultipleD_Wmma_CShuffle_TileLoop_V3"
<< "<"
<< std::string(ALayout::name)[0] << ","
<< std::string(BLayout::name)[0] << ","
<< std::string(ELayout::name)[0] << ","
<< BlockSize << ", "
<< MPerBlock << ", "
<< NPerBlock << ", "
<< KPerBlock << ", "
<< AK1 << ", "
<< BK1 << ", "
<< MPerWmma << ", "
<< NPerWmma << ", "
<< MRepeat << ", "
<< NRepeat << ", "
<< ABlockTransferSrcScalarPerVector << ", "
<< BBlockTransferSrcScalarPerVector << ", "
<< CShuffleMRepeatPerShuffle << ", "
<< CShuffleNRepeatPerShuffle << ", "
<< getGemmSpecializationString(GemmSpec) << ", "
<< BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
<< BlkGemmPipelineVersionToString[BlkGemmPipelineVer]
<< ">";
// clang-format on
return str.str();
}
void SetDeviceKernelArgs(Argument& arg,
void* p_dev_kernel_args,
const void* p_host_kernel_args) const
{
arg.p_dev_gemm_args_ = p_dev_kernel_args;
hip_check_error(hipMemcpyAsync(p_dev_kernel_args,
p_host_kernel_args,
GetDeviceKernelArgSize(&arg),
hipMemcpyHostToDevice));
}
virtual void SetDeviceKernelArgs(BaseArgument* p_arg,
void* p_dev_kernel_args,
const void* p_host_kernel_args) const override
{
return SetDeviceKernelArgs(
*dynamic_cast<Argument*>(p_arg), p_dev_kernel_args, p_host_kernel_args);
}
void SetDeviceKernelArgs(Argument& arg, void* p_dev_kernel_args) const
{
arg.p_dev_gemm_args_ = p_dev_kernel_args;
}
virtual void SetDeviceKernelArgs(BaseArgument* p_arg, void* p_dev_kernel_args) const override
{
return SetDeviceKernelArgs(*dynamic_cast<Argument*>(p_arg), p_dev_kernel_args);
}
size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const override
{
return dynamic_cast<const Argument*>(p_arg)->group_count_ * sizeof(KernelArguments);
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -4,6 +4,7 @@
#pragma once
#include <iostream>
#include <optional>
#include <sstream>
#include <tuple>
@@ -26,6 +27,18 @@ namespace ck {
namespace tensor_operation {
namespace device {
// Dummy kernel to use as a fallback in the kernel selection logic
// Is not used in practice, but only used in case of misconfigured parameters
template <typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation>
__global__ void kernel_dummy(const void CK_CONSTANT_ADDRESS_SPACE*,
const index_t,
const AElementwiseOperation,
const BElementwiseOperation,
const CDEElementwiseOperation)
{
}
///
/// @brief Entry point kernel for device-wide Grouped GEMM operation.
///
@@ -528,6 +541,7 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
using GridwiseGemm64 = GridwiseGemmBase<math::max(NXdlPerWave64, 1)>;
using GridwiseGemm32 = GridwiseGemmBase<NXdlPerWave32>;
using KernelConfig = TileLoopKernelConfig<BlockSize>;
using KernelArguments = GroupedGemmKernelArgument<NumDTensor>;
using Block2ETileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>;
using OffsettedLocalBlock2ETileMap = OffsettedBlockToCTileMap2<Block2ETileMap>;
@@ -574,22 +588,6 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
index_t tile_count_;
};
struct KernelConfig
{
// The oversubscription factor for the number of blocks that can simultaneously reside on
// GPU.
static constexpr int BLOCK_SUBSCRIPTION_FACTOR = 1;
// static constexpr int BLOCK_WAVES = BlockSize / get_warp_size();
static constexpr int CU_SIMDS = 4;
// Assume we want to have at most 2 waves per SIMD
// static constexpr int CU_BLOCKS = math::integer_divide_floor(2 * CU_SIMDS, BLOCK_WAVES);
static int GetCuBlocks()
{
int BLOCK_WAVES = BlockSize / get_warp_size();
return math::integer_divide_floor(2 * CU_SIMDS, BLOCK_WAVES);
}
};
// Invoker
struct Invoker : public BaseInvoker
{
@@ -666,58 +664,17 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
const void* dev_gemm_args,
const StreamConfig& stream_config) const
{
const auto kernel = kernel_grouped_gemm_multiple_d_xdl<GridwiseGemm,
KernelArguments,
GemmSpec,
ADataType,
BDataType,
DsDataType,
EDataType,
ALayout,
BLayout,
DsLayout,
ELayout,
KPerBlock,
OffsettedLocalBlock2ETileMap,
Block2ETileMap,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
BlkGemmPipeSched,
BlkGemmPipelineVer>;
const auto kernel = GetKernelFunction<GridwiseGemm>();
return LaunchKernel(kernel, arg, dev_gemm_args, stream_config);
}
template <typename KernelFunction>
int CalculateMaxOccupancyGridSize(const KernelFunction& kernel,
const StreamConfig& stream_config) const
{
// Calculate max number of workgroups that can simultaneously reside on the CU.
int occ_num_blocks = 0;
size_t dyn_shared_mem_per_blk = 0;
hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor(
&occ_num_blocks, kernel, BlockSize, dyn_shared_mem_per_blk));
int cu_count = getAvailableComputeUnitCount(stream_config);
if(stream_config.log_level_ > 0)
{
std::cout << "MaxActiveBlocksPerCU: " << occ_num_blocks
<< ", available CUs count: " << cu_count << ", occup. grid size: "
<< ck::math::min(occ_num_blocks, KernelConfig::GetCuBlocks()) * cu_count
<< std::endl;
}
return cu_count * ck::math::min(occ_num_blocks, KernelConfig::GetCuBlocks());
}
template <typename KernelFunction>
float LaunchKernel(const KernelFunction& kernel,
const Argument& arg,
const void* dev_gemm_args,
const StreamConfig& stream_config) const
{
int grid_size = CalculateMaxOccupancyGridSize(kernel, stream_config);
int grid_size = KernelConfig::CalculateMaxOccupancyGridSize(kernel, stream_config);
if(stream_config.log_level_ > 0)
{
@@ -835,65 +792,60 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static int GetKernelOccupancy()
template <typename GridwiseGemm>
static auto GetKernelFunction()
{
const auto kernel = kernel_grouped_gemm_multiple_d_xdl<GridwiseGemm,
KernelArguments,
GemmSpec,
ADataType,
BDataType,
DsDataType,
EDataType,
ALayout,
BLayout,
DsLayout,
ELayout,
KPerBlock,
OffsettedLocalBlock2ETileMap,
Block2ETileMap,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
BlkGemmPipeSched,
BlkGemmPipelineVer>;
return kernel;
}
static auto GetKernelFunction()
{
int occupancy = 0;
if(get_warp_size() == 64)
{
if constexpr(NXdlPerWave64 > 0)
{
const auto kernel = kernel_grouped_gemm_multiple_d_xdl<GridwiseGemm64,
KernelArguments,
GemmSpec,
ADataType,
BDataType,
DsDataType,
EDataType,
ALayout,
BLayout,
DsLayout,
ELayout,
KPerBlock,
OffsettedLocalBlock2ETileMap,
Block2ETileMap,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
BlkGemmPipeSched,
BlkGemmPipelineVer>;
hip_check_error(
hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, BlockSize, 0));
const auto kernel = GetKernelFunction<GridwiseGemm64>();
return kernel;
}
}
else
{
if constexpr(NXdlPerWave32 > 0)
{
const auto kernel = kernel_grouped_gemm_multiple_d_xdl<GridwiseGemm32,
KernelArguments,
GemmSpec,
ADataType,
BDataType,
DsDataType,
EDataType,
ALayout,
BLayout,
DsLayout,
ELayout,
KPerBlock,
OffsettedLocalBlock2ETileMap,
Block2ETileMap,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
BlkGemmPipeSched,
BlkGemmPipelineVer>;
hip_check_error(
hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, BlockSize, 0));
const auto kernel = GetKernelFunction<GridwiseGemm32>();
return kernel;
}
}
return occupancy;
// This is here to handle the case where MXdlPerWave/NxdPerWave is too small
// This is caught by IsSupportedArgument(), but as GetKernelFunction is sometimes called
// before we need a fallback kernel to return here.
return kernel_dummy<AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation>;
}
static int GetKernelOccupancy()
{
const auto kernel = GetKernelFunction();
return KernelConfig::GetKernelOccupancy(kernel);
}
static auto MakeArgument(std::vector<const void*>& p_As,
@@ -906,13 +858,7 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
CDEElementwiseOperation cde_elementwise_op)
{
int occupancy = GetKernelOccupancy();
int num_cu;
hipDeviceProp_t dev_prop;
hipDevice_t dev;
hip_check_error(hipGetDevice(&dev));
hip_check_error(hipGetDeviceProperties(&dev_prop, dev));
num_cu = dev_prop.multiProcessorCount;
int num_cu = KernelConfig::GetComputeUnitCount();
return Argument{p_As,
p_Bs,
@@ -937,13 +883,7 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
CDEElementwiseOperation cde_elementwise_op) override
{
int occupancy = GetKernelOccupancy();
int num_cu;
hipDeviceProp_t dev_prop;
hipDevice_t dev;
hip_check_error(hipGetDevice(&dev));
hip_check_error(hipGetDeviceProperties(&dev_prop, dev));
num_cu = dev_prop.multiProcessorCount;
int num_cu = KernelConfig::GetComputeUnitCount();
return std::make_unique<Argument>(p_As,
p_Bs,

View File

@@ -126,7 +126,6 @@ template <typename ALayout,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
GemmSpecialization GemmSpec,
ck::index_t NumGemmKPrefetchStage,
ck::index_t BlockSize,
ck::index_t MPerBlock,
ck::index_t NPerBlock,
@@ -158,9 +157,7 @@ template <typename ALayout,
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1,
typename ComputeTypeA = EDataType,
typename ComputeTypeB = ComputeTypeA,
bool PermuteA = false,
bool PermuteB = false>
typename ComputeTypeB = ComputeTypeA>
struct DeviceGroupedGemm_Wmma_CShuffleV3 : public DeviceGroupedGemmSplitK<ALayout,
BLayout,
DsLayout,
@@ -231,8 +228,8 @@ struct DeviceGroupedGemm_Wmma_CShuffleV3 : public DeviceGroupedGemmSplitK<ALayou
BlkGemmPipelineVer,
ComputeTypeA,
ComputeTypeB,
false, // PermuteA not supported by DeviceBatchedGemm base class.
false>; // PermuteB not supported by DeviceBatchedGemm base class.
false, // PermuteA not supported by GridwiseOp
false>; // PermuteB not supported by DeviceGroupedGemm base class
using CGridDesc_M_N =
remove_cvref_t<decltype(GridwiseGemm::template MakeDEGridDescriptor_M_N<ELayout>(
@@ -779,7 +776,7 @@ struct DeviceGroupedGemm_Wmma_CShuffleV3 : public DeviceGroupedGemmSplitK<ALayou
{BlockGemmPipelineVersion::v5, "v5"}};
// clang-format off
str << "DeviceGroupedGemm_WmmaSplitK"
str << "DeviceGroupedGemm_Wmma_CShuffleV3"
<< "<"
<< std::string(ALayout::name)[0] << ","
<< std::string(BLayout::name)[0] << ","

View File

@@ -8,6 +8,7 @@
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/quantization_operation.hpp"
#include "ck/utility/type_convert.hpp"
namespace ck {
namespace tensor_operation {
@@ -236,8 +237,9 @@ struct MultiplyAdd
const half_t& d0,
const half_t& d1) const
{
const half_t y = type_convert<half_t>(c) * d0 + d1;
e = y;
const half_t y =
type_convert<half_t>(c * type_convert<float>(d0) + type_convert<float>(d1));
e = y;
}
template <>
__host__ __device__ void operator()<bhalf_t, float, bhalf_t, bhalf_t>(bhalf_t& e,
@@ -245,8 +247,9 @@ struct MultiplyAdd
const bhalf_t& d0,
const bhalf_t& d1) const
{
const bhalf_t y = type_convert<bhalf_t>(c) * d0 + d1;
e = y;
const bhalf_t y =
type_convert<bhalf_t>(c * type_convert<float>(d0) + type_convert<float>(d1));
e = y;
}
template <>
__host__ __device__ void operator()<float, float, half_t, half_t>(float& e,

View File

@@ -334,14 +334,14 @@ struct GridwiseGemm_wmma_cshuffle_v3
struct Problem
{
__host__ Problem() = default;
__host__ Problem(index_t M_,
index_t N_,
index_t K_,
std::array<index_t, NumATensor> StrideAs_,
std::array<index_t, NumBTensor> StrideBs_,
std::array<index_t, NumDTensor> StrideDs_,
index_t StrideE_,
index_t KBatch_)
__host__ __device__ Problem(index_t M_,
index_t N_,
index_t K_,
std::array<index_t, NumATensor> StrideAs_,
std::array<index_t, NumBTensor> StrideBs_,
std::array<index_t, NumDTensor> StrideDs_,
index_t StrideE_,
index_t KBatch_)
: M{M_},
N{N_},
K{K_},

View File

@@ -351,64 +351,65 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
// Calculate grid size taking into account splitk (KBatch)
// 2D grid (x,z)
__host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch)
__host__ __device__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch)
{
return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, KBatch);
}
// Calculate grid size taking into account splitk (KBatch) and multiple groups (Batch)
// 3D grid (x,y,z)
__host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch, index_t Batch)
__host__ __device__ static auto
CalculateGridSize(index_t M, index_t N, index_t KBatch, index_t Batch)
{
return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), KBatch, Batch);
}
__host__ static auto CalculateMPadded(index_t M)
__host__ __device__ static auto CalculateMPadded(index_t M)
{
return math::integer_least_multiple(M, MPerBlock);
}
__host__ static auto CalculateNPadded(index_t N)
__host__ __device__ static auto CalculateNPadded(index_t N)
{
return math::integer_least_multiple(N, NPerBlock);
}
__host__ static auto CalculateKPadded(index_t K)
__host__ __device__ static auto CalculateKPadded(index_t K)
{
return math::integer_divide_ceil(K, KPerBlock) * KPerBlock;
}
__host__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1)
__host__ __device__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1)
{
auto K_t = K_Batch * KPerBlock;
return (K + K_t - 1) / K_t * (KPerBlock / AK1Value);
}
__host__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1)
__host__ __device__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1)
{
auto K_t = K_Batch * KPerBlock;
return (K + K_t - 1) / K_t * (KPerBlock / BK1Value);
}
__host__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1)
__host__ __device__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1)
{
auto K_t = K_Batch * KPerBlock;
return (K + K_t - 1) / K_t * KPerBlock;
}
__host__ static auto CalculateKRead(index_t K, index_t K_Batch = 1)
__host__ __device__ static auto CalculateKRead(index_t K, index_t K_Batch = 1)
{
constexpr auto KReadVec = math::lcm(AK1Number, BK1Number);
auto K_t = K_Batch * KReadVec;
return (K + K_t - 1) / K_t * KReadVec;
}
__host__ static auto CalculateMBlock(index_t M)
__host__ __device__ static auto CalculateMBlock(index_t M)
{
return math::integer_divide_ceil(M, MPerBlock);
}
__host__ static auto CalculateNBlock(index_t N)
__host__ __device__ static auto CalculateNBlock(index_t N)
{
return math::integer_divide_ceil(N, NPerBlock);
}
@@ -963,14 +964,14 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
return true;
}
__host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
__host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
{
const index_t num_loop = K / KPerBlock;
return BlockwiseGemmPipe::BlockHasHotloop(num_loop);
}
__host__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K)
__host__ __device__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K)
{
const index_t num_loop = K / KPerBlock;

View File

@@ -7,6 +7,7 @@
#include "ck/utility/sequence.hpp"
#include "ck/utility/type.hpp"
#include "ck/utility/enable_if.hpp"
#include <tuple>
namespace ck {
@@ -220,4 +221,49 @@ constexpr Tuple<Args&...> tie(Args&... args) noexcept
return {args...};
}
//
// tuple_map: Map tuple with a different type
// e.g. tuple_map<Wrapper, Tuple<T1, T2, T3>> becomes Tuple<Wrapper<T1>, Wrapper<T2>, Wrapper<T3>>
//
template <template <typename> class Wrapper, typename Tuple>
struct tuple_map;
template <template <typename> class Wrapper, typename... Ts>
struct tuple_map<Wrapper, Tuple<Ts...>>
{
using type = Tuple<Wrapper<Ts>...>;
};
template <template <typename> class Wrapper, typename Tuple>
using tuple_map_t = typename tuple_map<Wrapper, Tuple>::type;
//
// tuple_element_or: helper to access type element of a tuple by index, with the option to default
// to a type if the index is out of range of the tuple size
//
namespace detail {
// Base template (will be specialized on the boolean)
template <ck::index_t N, typename Tuple, typename Default, bool InRange = (N < Tuple::Size())>
struct tuple_element_or_impl;
// Specialization for the in-range case: use tuple_element_t
template <ck::index_t N, typename Tuple, typename Default>
struct tuple_element_or_impl<N, Tuple, Default, true>
{
using type = tuple_element_t<N, Tuple>;
};
// Specialization for the out-of-range case: use Default
template <ck::index_t N, typename Tuple, typename Default>
struct tuple_element_or_impl<N, Tuple, Default, false>
{
using type = Default;
};
} // namespace detail
// User-facing alias
template <ck::index_t N, typename Tuple, typename Default>
using tuple_element_or_t = typename detail::tuple_element_or_impl<N, Tuple, Default>::type;
} // namespace ck