mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
Implement grouped gemm tile loop for RDNA4 (#3304)
* feat: grouped gemm tile loop support for RDNA4 * fix: removed extra parameter from grouped gemm example instance * fix: FP8 check incorrectly enabling FP8 on RDNA3
This commit is contained in:
@@ -151,7 +151,10 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
|
||||
static constexpr index_t PrefillStages = 1;
|
||||
static constexpr index_t GlobalBufferNum = 1;
|
||||
|
||||
static bool BlockHasHotloop(index_t num_loop) { return num_loop > PrefetchStages; }
|
||||
static bool __host__ __device__ BlockHasHotloop(index_t num_loop)
|
||||
{
|
||||
return num_loop > PrefetchStages;
|
||||
}
|
||||
|
||||
static TailNumber BlockLoopTailNum(index_t num_loop)
|
||||
{
|
||||
@@ -707,7 +710,10 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
|
||||
static constexpr index_t PrefillStages = 1;
|
||||
static constexpr index_t GlobalBufferNum = 1;
|
||||
|
||||
static bool BlockHasHotloop(index_t num_loop) { return num_loop > PrefetchStages; }
|
||||
__host__ __device__ static bool BlockHasHotloop(index_t num_loop)
|
||||
{
|
||||
return num_loop > PrefetchStages;
|
||||
}
|
||||
|
||||
static TailNumber BlockLoopTailNum(index_t num_loop)
|
||||
{
|
||||
|
||||
@@ -3,6 +3,11 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/stream_utility.hpp"
|
||||
|
||||
#include "device_grouped_gemm.hpp"
|
||||
|
||||
namespace ck {
|
||||
@@ -43,6 +48,59 @@ struct DeviceGroupedGemmTileLoop : public DeviceGroupedGemm<ALayout,
|
||||
{
|
||||
};
|
||||
|
||||
template <ck::index_t BlockSize>
|
||||
struct TileLoopKernelConfig
|
||||
{
|
||||
// The oversubscription factor for the number of blocks that can simultaneously reside on
|
||||
// GPU.
|
||||
static constexpr int BLOCK_SUBSCRIPTION_FACTOR = 1;
|
||||
// static constexpr int BLOCK_WAVES = BlockSize / get_warp_size();
|
||||
static constexpr int CU_SIMDS = 4;
|
||||
// Assume we want to have at most 2 waves per SIMD
|
||||
// static constexpr int CU_BLOCKS = math::integer_divide_floor(2 * CU_SIMDS, BLOCK_WAVES);
|
||||
static int GetCuBlocks()
|
||||
{
|
||||
int BLOCK_WAVES = BlockSize / get_warp_size();
|
||||
return ck::math::integer_divide_floor(2 * CU_SIMDS, BLOCK_WAVES);
|
||||
}
|
||||
|
||||
template <typename KernelFunction>
|
||||
static int CalculateMaxOccupancyGridSize(const KernelFunction& kernel,
|
||||
const StreamConfig& stream_config)
|
||||
{
|
||||
// Calculate max number of workgroups that can simultaneously reside on the CU.
|
||||
int occ_num_blocks = GetKernelOccupancy(kernel);
|
||||
int cu_count = getAvailableComputeUnitCount(stream_config);
|
||||
|
||||
if(stream_config.log_level_ > 0)
|
||||
{
|
||||
std::cout << "MaxActiveBlocksPerCU: " << occ_num_blocks
|
||||
<< ", available CUs count: " << cu_count << ", occup. grid size: "
|
||||
<< ck::math::min(occ_num_blocks, GetCuBlocks()) * cu_count << std::endl;
|
||||
}
|
||||
|
||||
return cu_count * ck::math::min(occ_num_blocks, GetCuBlocks());
|
||||
}
|
||||
|
||||
template <typename KernelFunction>
|
||||
static int GetKernelOccupancy(const KernelFunction& kernel)
|
||||
{
|
||||
int occupancy = 0;
|
||||
ck::hip_check_error(
|
||||
hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, BlockSize, 0));
|
||||
return occupancy;
|
||||
}
|
||||
|
||||
static int GetComputeUnitCount()
|
||||
{
|
||||
hipDeviceProp_t dev_prop;
|
||||
hipDevice_t dev;
|
||||
ck::hip_check_error(hipGetDevice(&dev));
|
||||
ck::hip_check_error(hipGetDeviceProperties(&dev_prop, dev));
|
||||
return dev_prop.multiProcessorCount;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
|
||||
@@ -0,0 +1,689 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <tuple>
|
||||
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
#include "ck/host_utility/hip_check_error.hpp"
|
||||
#include "ck/host_utility/stream_utility.hpp"
|
||||
#include "ck/utility/loop_scheduler.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_tile_loop.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include <ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp>
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
///
|
||||
/// @brief Entry point kernel for device-wide Grouped GEMM operation.
|
||||
///
|
||||
/// @param[in] gemm_descs_const The pointer to the array of GEMM descriptor structures.
|
||||
/// @param[in] group_count The number of together processed GEMMs.
|
||||
///
|
||||
/// @tparam GridwiseGemm The specific GridwiseGEMM algorithm implementation.
|
||||
/// @tparam GemmDesc The structure holding all necessary descriptors and
|
||||
/// other data needed for grouped gemm calculation and work
|
||||
/// distribution.
|
||||
/// @tparam LocalBlock2ETileMap The structure providing mapping between workgroup ids,
|
||||
/// the data tiles to process and the output tiles.
|
||||
///
|
||||
template <typename GridwiseGemm,
|
||||
typename GemmDesc,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename EDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
index_t KPerBlock,
|
||||
typename OffsettedBlockToCTileMap,
|
||||
typename LocalBlock2ETileMap,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation,
|
||||
BlockGemmPipelineScheduler BlkGemmPipeSched,
|
||||
BlockGemmPipelineVersion BlkGemmPipelineVer>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_grouped_gemm_multiple_d_wmma(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
|
||||
const index_t group_count,
|
||||
const AElementwiseOperation a_element_op,
|
||||
const BElementwiseOperation b_element_op,
|
||||
const CDEElementwiseOperation cde_element_op)
|
||||
{
|
||||
#if(defined(__gfx11__) || defined(__gfx12__))
|
||||
constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte<
|
||||
typename GridwiseGemm::EpilogueCShuffle>();
|
||||
__shared__ uint8_t p_shared[LDS_size];
|
||||
|
||||
const auto gemm_desc_ptr =
|
||||
reinterpret_cast<const GemmDesc*>(cast_pointer_to_generic_address_space(gemm_descs_const));
|
||||
|
||||
constexpr auto NumDTensor = DsDataType::Size();
|
||||
index_t tile_id = get_block_1d_id();
|
||||
index_t tile_offset = 0;
|
||||
index_t group_id = -1;
|
||||
index_t group_offset = 0;
|
||||
index_t grid_size_grp = 0;
|
||||
|
||||
index_t gemm_tile_id_start = 0;
|
||||
index_t gemm_tile_id_end = 0;
|
||||
|
||||
index_t M = 0, N = 0, K = 0;
|
||||
|
||||
auto b2c_tile_map = OffsettedBlockToCTileMap(LocalBlock2ETileMap(1, 1), 1, 1);
|
||||
|
||||
do
|
||||
{
|
||||
// Find corresponding GEMM group for our tile
|
||||
while(!(tile_id >= gemm_tile_id_start && tile_id < gemm_tile_id_end) &&
|
||||
group_id < group_count)
|
||||
{
|
||||
group_offset += grid_size_grp;
|
||||
group_id++;
|
||||
|
||||
if(group_id >= group_count)
|
||||
return;
|
||||
|
||||
M = gemm_desc_ptr[group_id].M;
|
||||
N = gemm_desc_ptr[group_id].N;
|
||||
K = gemm_desc_ptr[group_id].K;
|
||||
|
||||
if(M == 0 || N == 0 || K == 0)
|
||||
{
|
||||
grid_size_grp = 0;
|
||||
continue;
|
||||
}
|
||||
|
||||
b2c_tile_map =
|
||||
OffsettedBlockToCTileMap(LocalBlock2ETileMap(M, N, 4), group_offset, tile_offset);
|
||||
grid_size_grp = b2c_tile_map.CalculateGridSize(M, N);
|
||||
|
||||
gemm_tile_id_start = group_offset;
|
||||
gemm_tile_id_end = group_offset + grid_size_grp;
|
||||
}
|
||||
|
||||
// Create A&B grid pointer containing their single tensors
|
||||
typename GridwiseGemm::AsGridPointer p_as_grid = Tuple<const ADataType*>(
|
||||
static_cast<const ADataType*>(gemm_desc_ptr[group_id].p_a_grid));
|
||||
typename GridwiseGemm::BsGridPointer p_bs_grid = Tuple<const BDataType*>(
|
||||
static_cast<const BDataType*>(gemm_desc_ptr[group_id].p_b_grid));
|
||||
|
||||
// Make a DsGridPointer instance containing all D tensors
|
||||
using DsGridPointer = decltype(GridwiseGemm::MakeDsGridPointer());
|
||||
DsGridPointer p_ds_grid;
|
||||
std::array<index_t, NumDTensor> stride_Ds;
|
||||
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
|
||||
p_ds_grid(i) = static_cast<const DDataType*>(gemm_desc_ptr[group_id].p_ds_grid[i]);
|
||||
stride_Ds[i] = gemm_desc_ptr[group_id].StrideDs[i];
|
||||
});
|
||||
|
||||
index_t K_split = ck::math::integer_divide_ceil(K, KPerBlock) * KPerBlock;
|
||||
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
|
||||
|
||||
// Update tile offset if we have moved within group
|
||||
b2c_tile_map.UpdateTileOffset(tile_offset);
|
||||
|
||||
using Problem = typename GridwiseGemm::Problem;
|
||||
auto problem = Problem(gemm_desc_ptr[group_id].M,
|
||||
gemm_desc_ptr[group_id].N,
|
||||
gemm_desc_ptr[group_id].K,
|
||||
std::array<index_t, 1>{gemm_desc_ptr[group_id].StrideA},
|
||||
std::array<index_t, 1>{gemm_desc_ptr[group_id].StrideB},
|
||||
stride_Ds,
|
||||
gemm_desc_ptr[group_id].StrideE,
|
||||
1);
|
||||
|
||||
auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{};
|
||||
constexpr TailNumber TailNum = TailNumber::Full;
|
||||
|
||||
if(has_main_k_block_loop)
|
||||
{
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 ||
|
||||
BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
|
||||
{
|
||||
GridwiseGemm::template Run<true, InMemoryDataOperationEnum::Set, TailNum>(
|
||||
p_as_grid,
|
||||
p_bs_grid,
|
||||
p_ds_grid,
|
||||
static_cast<EDataType*>(gemm_desc_ptr[group_id].p_e_grid),
|
||||
static_cast<void*>(p_shared),
|
||||
problem,
|
||||
b2c_tile_map,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op,
|
||||
epilogue_args);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
|
||||
{
|
||||
GridwiseGemm::template Run<false, InMemoryDataOperationEnum::Set, TailNum>(
|
||||
p_as_grid,
|
||||
p_bs_grid,
|
||||
p_ds_grid,
|
||||
static_cast<EDataType*>(gemm_desc_ptr[group_id].p_e_grid),
|
||||
static_cast<void*>(p_shared),
|
||||
problem,
|
||||
b2c_tile_map,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op,
|
||||
epilogue_args);
|
||||
}
|
||||
}
|
||||
|
||||
tile_id += get_grid_size();
|
||||
tile_offset += get_grid_size();
|
||||
|
||||
} while(group_id < group_count);
|
||||
#else
|
||||
ignore = gemm_descs_const;
|
||||
ignore = group_count;
|
||||
ignore = a_element_op;
|
||||
ignore = b_element_op;
|
||||
ignore = cde_element_op;
|
||||
#endif // end of if (defined(__gfx11__) || defined(__gfx12__))
|
||||
}
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename DsDataType,
|
||||
typename EDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation,
|
||||
GemmSpecialization GemmSpec,
|
||||
ck::index_t BlockSize,
|
||||
ck::index_t MPerBlock,
|
||||
ck::index_t NPerBlock,
|
||||
ck::index_t KPerBlock,
|
||||
ck::index_t AK1,
|
||||
ck::index_t BK1,
|
||||
ck::index_t MPerWmma,
|
||||
ck::index_t NPerWmma,
|
||||
ck::index_t MRepeat,
|
||||
ck::index_t NRepeat,
|
||||
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
index_t ABlockTransferSrcVectorDim,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t ABlockTransferDstScalarPerVector_AK1,
|
||||
index_t ABlockLdsExtraM,
|
||||
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
index_t BBlockTransferSrcVectorDim,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferDstScalarPerVector_BK1,
|
||||
index_t BBlockLdsExtraN,
|
||||
index_t CShuffleMRepeatPerShuffle,
|
||||
index_t CShuffleNRepeatPerShuffle,
|
||||
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename CDEBlockTransferScalarPerVector_NPerBlock,
|
||||
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1,
|
||||
typename ComputeTypeA = EDataType,
|
||||
typename ComputeTypeB = ComputeTypeA>
|
||||
|
||||
struct DeviceGroupedGemmMultipleD_Wmma_CShuffle_TileLoop_V3
|
||||
: public DeviceGroupedGemmTileLoop<ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation>
|
||||
{
|
||||
using DeviceOp = DeviceGroupedGemmMultipleD_Wmma_CShuffle_TileLoop_V3;
|
||||
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
|
||||
using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3<
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
Tuple<ADataType>,
|
||||
Tuple<BDataType>,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
GemmSpec,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
AK1,
|
||||
BK1,
|
||||
MPerWmma,
|
||||
NPerWmma,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_AK1,
|
||||
false,
|
||||
ABlockLdsExtraM,
|
||||
BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_BK1,
|
||||
false,
|
||||
BBlockLdsExtraN,
|
||||
CShuffleMRepeatPerShuffle,
|
||||
CShuffleNRepeatPerShuffle,
|
||||
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CDEBlockTransferScalarPerVector_NPerBlock,
|
||||
BlkGemmPipeSched,
|
||||
BlkGemmPipelineVer,
|
||||
ComputeTypeA,
|
||||
ComputeTypeB,
|
||||
false, // PermuteA not supported by GridwiseOp.
|
||||
false>; // PermuteB not supported by DeviceGroupedGemmTileLoop base class.
|
||||
|
||||
using KernelConfig = TileLoopKernelConfig<BlockSize>;
|
||||
using KernelArguments = GroupedGemmKernelArgument<NumDTensor>;
|
||||
using Block2ETileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>;
|
||||
using OffsettedLocalBlock2ETileMap = OffsettedBlockToCTileMap2<Block2ETileMap>;
|
||||
|
||||
// Argument
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
Argument(std::vector<const void*>& /* p_As */,
|
||||
std::vector<const void*>& /* p_Bs */,
|
||||
std::vector<std::array<const void*, NumDTensor>>& /* p_Ds */,
|
||||
std::vector<void*>& /* p_Es */,
|
||||
const std::vector<GemmDesc>& gemm_descs,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op,
|
||||
int occupancy_num_blocks,
|
||||
int gpu_cu_count)
|
||||
: group_count_{static_cast<index_t>(gemm_descs.size())},
|
||||
occupancy_num_blocks_{occupancy_num_blocks},
|
||||
gpu_cu_count_{gpu_cu_count},
|
||||
gemm_descs_{gemm_descs},
|
||||
a_element_op_{a_element_op},
|
||||
b_element_op_{b_element_op},
|
||||
cde_element_op_{cde_element_op},
|
||||
tile_count_{0}
|
||||
{
|
||||
for(const auto& desc : gemm_descs)
|
||||
{
|
||||
const auto M = desc.M_;
|
||||
const auto N = desc.N_;
|
||||
const auto b2c_tile_map = Block2ETileMap(M, N);
|
||||
tile_count_ += b2c_tile_map.CalculateGridSize(M, N);
|
||||
}
|
||||
}
|
||||
|
||||
index_t group_count_;
|
||||
const void* p_dev_gemm_args_;
|
||||
int occupancy_num_blocks_;
|
||||
int gpu_cu_count_;
|
||||
const std::vector<GemmDesc>& gemm_descs_;
|
||||
AElementwiseOperation a_element_op_;
|
||||
BElementwiseOperation b_element_op_;
|
||||
CDEElementwiseOperation cde_element_op_;
|
||||
index_t tile_count_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
///
|
||||
/// @brief Launch Grouped Gemm kernel.
|
||||
///
|
||||
/// @note This function overload is using user provided device buffer for kernel
|
||||
/// arguments.
|
||||
///
|
||||
/// @param[in] arg The structure containing kernel arguments (in host
|
||||
/// memory).
|
||||
/// @param[in] dev_gemm_args The pointer to device memory with kernel arguments.
|
||||
/// @param[in] stream_config The device stream configuration.
|
||||
///
|
||||
/// @return The average kernel execution time (if time measurement is enabled.)
|
||||
///
|
||||
float Run(const Argument& arg,
|
||||
const void* dev_gemm_args,
|
||||
const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
if(dev_gemm_args == nullptr)
|
||||
{
|
||||
std::ostringstream err;
|
||||
err << "The gemm arguments device buffer is not allocated!" << " In " << __FILE__
|
||||
<< ":" << __LINE__ << ", in function: " << __func__;
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
|
||||
const auto kernel = GetKernelFunction();
|
||||
|
||||
int grid_size = KernelConfig::CalculateMaxOccupancyGridSize(kernel, stream_config);
|
||||
|
||||
if(stream_config.log_level_ > 0)
|
||||
{
|
||||
std::cout << "grid_size: " << grid_size << " tile_count: " << arg.tile_count_
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
// run multiple kernels
|
||||
|
||||
return launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
cast_pointer_to_constant_address_space(dev_gemm_args),
|
||||
arg.group_count_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.cde_element_op_);
|
||||
}
|
||||
|
||||
///
|
||||
/// @brief Launch Grouped Gemm kernel.
|
||||
///
|
||||
/// @note This function overload is using device buffers (for kernel arguments and
|
||||
/// for kernel auxiliary workspace) provided with an argument. The user should
|
||||
/// call @see GetDeviceKernelArgSize, and @see SetDeviceKernelArgs, on arg
|
||||
/// parameter to properly allocate those buffers.
|
||||
///
|
||||
/// @param[in] arg The structure containing kernel arguments (in host memory).
|
||||
/// @param[in] stream_config The device stream configuration.
|
||||
///
|
||||
/// @return The average kernel execution time (if time measurement is enabled.)
|
||||
///
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
if(arg.p_dev_gemm_args_ == nullptr)
|
||||
{
|
||||
std::ostringstream err;
|
||||
err << "The gemm arguments device buffer is not allocated!" << " In " << __FILE__
|
||||
<< ":" << __LINE__ << ", in function: " << __func__;
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
|
||||
return Run(arg, arg.p_dev_gemm_args_, stream_config);
|
||||
}
|
||||
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
static auto GetKernelFunction()
|
||||
{
|
||||
const auto kernel = kernel_grouped_gemm_multiple_d_wmma<GridwiseGemm,
|
||||
KernelArguments,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
KPerBlock,
|
||||
OffsettedLocalBlock2ETileMap,
|
||||
Block2ETileMap,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
BlkGemmPipeSched,
|
||||
BlkGemmPipelineVer>;
|
||||
return kernel;
|
||||
}
|
||||
|
||||
static constexpr bool IsValidCompilationParameter()
|
||||
{
|
||||
// TODO: properly implement this check
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(!ck::is_gfx11_supported() && !ck::is_gfx12_supported())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
if constexpr(std::is_same_v<ComputeTypeA, f8_t> || std::is_same_v<ComputeTypeA, bf8_t> ||
|
||||
std::is_same_v<ComputeTypeB, f8_t> || std::is_same_v<ComputeTypeB, bf8_t>)
|
||||
{
|
||||
if(ck::is_gfx11_supported())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
bool supported = true;
|
||||
for(index_t i = 0; i < arg.group_count_; ++i)
|
||||
{
|
||||
std::array<const void*, NumDTensor> placeholder_p_ds_grid{};
|
||||
std::array<index_t, NumDTensor> stride_Ds;
|
||||
std::copy_n(arg.gemm_descs_[i].stride_Ds_.begin(), NumDTensor, stride_Ds.begin());
|
||||
|
||||
typename GridwiseGemm::Argument gridwise_arg(
|
||||
std::array<const void*, 1>{nullptr}, // p_a_grid,
|
||||
std::array<const void*, 1>{nullptr}, // p_b_grid,
|
||||
placeholder_p_ds_grid, // p_ds_grid,
|
||||
nullptr, // p_e_grid ,
|
||||
arg.gemm_descs_[i].M_,
|
||||
arg.gemm_descs_[i].N_,
|
||||
arg.gemm_descs_[i].K_,
|
||||
std::array<index_t, 1>{arg.gemm_descs_[i].stride_A_},
|
||||
std::array<index_t, 1>{arg.gemm_descs_[i].stride_B_},
|
||||
stride_Ds,
|
||||
arg.gemm_descs_[i].stride_C_,
|
||||
1, // KBatch
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.cde_element_op_,
|
||||
false);
|
||||
|
||||
bool group_arg_valid = GridwiseGemm::CheckValidity(gridwise_arg);
|
||||
supported = supported && group_arg_valid;
|
||||
|
||||
if(!group_arg_valid)
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "[" << __func__ << "] group id: " << i
|
||||
<< " has invalid GridwiseGemm settings!" << std::endl;
|
||||
gridwise_arg.Print();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return supported;
|
||||
}
|
||||
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
|
||||
static int GetKernelOccupancy()
|
||||
{
|
||||
const auto kernel = GetKernelFunction();
|
||||
return KernelConfig::GetKernelOccupancy(kernel);
|
||||
}
|
||||
|
||||
static auto MakeArgument(std::vector<const void*>& p_As,
|
||||
std::vector<const void*>& p_Bs,
|
||||
std::vector<std::array<const void*, NumDTensor>>& p_Ds,
|
||||
std::vector<void*>& p_Es,
|
||||
std::vector<GemmDesc>& gemm_descs,
|
||||
AElementwiseOperation a_elementwise_op,
|
||||
BElementwiseOperation b_elementwise_op,
|
||||
CDEElementwiseOperation cde_elementwise_op)
|
||||
{
|
||||
int occupancy = GetKernelOccupancy();
|
||||
int num_cu = KernelConfig::GetComputeUnitCount();
|
||||
|
||||
return Argument{p_As,
|
||||
p_Bs,
|
||||
p_Ds,
|
||||
p_Es,
|
||||
gemm_descs,
|
||||
a_elementwise_op,
|
||||
b_elementwise_op,
|
||||
cde_elementwise_op,
|
||||
occupancy,
|
||||
num_cu};
|
||||
}
|
||||
|
||||
std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(std::vector<const void*>& p_As,
|
||||
std::vector<const void*>& p_Bs,
|
||||
std::vector<std::array<const void*, NumDTensor>>& p_Ds,
|
||||
std::vector<void*>& p_Es,
|
||||
std::vector<GemmDesc>& gemm_descs,
|
||||
AElementwiseOperation a_elementwise_op,
|
||||
BElementwiseOperation b_elementwise_op,
|
||||
CDEElementwiseOperation cde_elementwise_op) override
|
||||
{
|
||||
int occupancy = GetKernelOccupancy();
|
||||
int num_cu = KernelConfig::GetComputeUnitCount();
|
||||
|
||||
return std::make_unique<Argument>(p_As,
|
||||
p_Bs,
|
||||
p_Ds,
|
||||
p_Es,
|
||||
gemm_descs,
|
||||
a_elementwise_op,
|
||||
b_elementwise_op,
|
||||
cde_elementwise_op,
|
||||
occupancy,
|
||||
num_cu);
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::ostringstream();
|
||||
|
||||
std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
|
||||
{BlockGemmPipelineScheduler::Intrawave, "Intrawave"},
|
||||
{BlockGemmPipelineScheduler::Interwave, "Interwave"}};
|
||||
|
||||
std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
|
||||
{BlockGemmPipelineVersion::v1, "v1"},
|
||||
{BlockGemmPipelineVersion::v2, "v2"},
|
||||
{BlockGemmPipelineVersion::v3, "v3"},
|
||||
{BlockGemmPipelineVersion::v4, "v4"},
|
||||
{BlockGemmPipelineVersion::v5, "v5"}};
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceGroupedGemmMultipleD_Wmma_CShuffle_TileLoop_V3"
|
||||
<< "<"
|
||||
<< std::string(ALayout::name)[0] << ","
|
||||
<< std::string(BLayout::name)[0] << ","
|
||||
<< std::string(ELayout::name)[0] << ","
|
||||
<< BlockSize << ", "
|
||||
<< MPerBlock << ", "
|
||||
<< NPerBlock << ", "
|
||||
<< KPerBlock << ", "
|
||||
<< AK1 << ", "
|
||||
<< BK1 << ", "
|
||||
<< MPerWmma << ", "
|
||||
<< NPerWmma << ", "
|
||||
<< MRepeat << ", "
|
||||
<< NRepeat << ", "
|
||||
<< ABlockTransferSrcScalarPerVector << ", "
|
||||
<< BBlockTransferSrcScalarPerVector << ", "
|
||||
<< CShuffleMRepeatPerShuffle << ", "
|
||||
<< CShuffleNRepeatPerShuffle << ", "
|
||||
<< getGemmSpecializationString(GemmSpec) << ", "
|
||||
<< BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
|
||||
<< BlkGemmPipelineVersionToString[BlkGemmPipelineVer]
|
||||
<< ">";
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
|
||||
void SetDeviceKernelArgs(Argument& arg,
|
||||
void* p_dev_kernel_args,
|
||||
const void* p_host_kernel_args) const
|
||||
{
|
||||
arg.p_dev_gemm_args_ = p_dev_kernel_args;
|
||||
hip_check_error(hipMemcpyAsync(p_dev_kernel_args,
|
||||
p_host_kernel_args,
|
||||
GetDeviceKernelArgSize(&arg),
|
||||
hipMemcpyHostToDevice));
|
||||
}
|
||||
|
||||
virtual void SetDeviceKernelArgs(BaseArgument* p_arg,
|
||||
void* p_dev_kernel_args,
|
||||
const void* p_host_kernel_args) const override
|
||||
{
|
||||
return SetDeviceKernelArgs(
|
||||
*dynamic_cast<Argument*>(p_arg), p_dev_kernel_args, p_host_kernel_args);
|
||||
}
|
||||
|
||||
void SetDeviceKernelArgs(Argument& arg, void* p_dev_kernel_args) const
|
||||
{
|
||||
arg.p_dev_gemm_args_ = p_dev_kernel_args;
|
||||
}
|
||||
|
||||
virtual void SetDeviceKernelArgs(BaseArgument* p_arg, void* p_dev_kernel_args) const override
|
||||
{
|
||||
return SetDeviceKernelArgs(*dynamic_cast<Argument*>(p_arg), p_dev_kernel_args);
|
||||
}
|
||||
|
||||
size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const override
|
||||
{
|
||||
return dynamic_cast<const Argument*>(p_arg)->group_count_ * sizeof(KernelArguments);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -4,6 +4,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <optional>
|
||||
#include <sstream>
|
||||
#include <tuple>
|
||||
|
||||
@@ -26,6 +27,18 @@ namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
// Dummy kernel to use as a fallback in the kernel selection logic
|
||||
// Is not used in practice, but only used in case of misconfigured parameters
|
||||
template <typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation>
|
||||
__global__ void kernel_dummy(const void CK_CONSTANT_ADDRESS_SPACE*,
|
||||
const index_t,
|
||||
const AElementwiseOperation,
|
||||
const BElementwiseOperation,
|
||||
const CDEElementwiseOperation)
|
||||
{
|
||||
}
|
||||
///
|
||||
/// @brief Entry point kernel for device-wide Grouped GEMM operation.
|
||||
///
|
||||
@@ -528,6 +541,7 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
|
||||
using GridwiseGemm64 = GridwiseGemmBase<math::max(NXdlPerWave64, 1)>;
|
||||
using GridwiseGemm32 = GridwiseGemmBase<NXdlPerWave32>;
|
||||
|
||||
using KernelConfig = TileLoopKernelConfig<BlockSize>;
|
||||
using KernelArguments = GroupedGemmKernelArgument<NumDTensor>;
|
||||
using Block2ETileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>;
|
||||
using OffsettedLocalBlock2ETileMap = OffsettedBlockToCTileMap2<Block2ETileMap>;
|
||||
@@ -574,22 +588,6 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
|
||||
index_t tile_count_;
|
||||
};
|
||||
|
||||
struct KernelConfig
|
||||
{
|
||||
// The oversubscription factor for the number of blocks that can simultaneously reside on
|
||||
// GPU.
|
||||
static constexpr int BLOCK_SUBSCRIPTION_FACTOR = 1;
|
||||
// static constexpr int BLOCK_WAVES = BlockSize / get_warp_size();
|
||||
static constexpr int CU_SIMDS = 4;
|
||||
// Assume we want to have at most 2 waves per SIMD
|
||||
// static constexpr int CU_BLOCKS = math::integer_divide_floor(2 * CU_SIMDS, BLOCK_WAVES);
|
||||
static int GetCuBlocks()
|
||||
{
|
||||
int BLOCK_WAVES = BlockSize / get_warp_size();
|
||||
return math::integer_divide_floor(2 * CU_SIMDS, BLOCK_WAVES);
|
||||
}
|
||||
};
|
||||
|
||||
// Invoker
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
@@ -666,58 +664,17 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
|
||||
const void* dev_gemm_args,
|
||||
const StreamConfig& stream_config) const
|
||||
{
|
||||
const auto kernel = kernel_grouped_gemm_multiple_d_xdl<GridwiseGemm,
|
||||
KernelArguments,
|
||||
GemmSpec,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
KPerBlock,
|
||||
OffsettedLocalBlock2ETileMap,
|
||||
Block2ETileMap,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
BlkGemmPipeSched,
|
||||
BlkGemmPipelineVer>;
|
||||
const auto kernel = GetKernelFunction<GridwiseGemm>();
|
||||
return LaunchKernel(kernel, arg, dev_gemm_args, stream_config);
|
||||
}
|
||||
|
||||
template <typename KernelFunction>
|
||||
int CalculateMaxOccupancyGridSize(const KernelFunction& kernel,
|
||||
const StreamConfig& stream_config) const
|
||||
{
|
||||
// Calculate max number of workgroups that can simultaneously reside on the CU.
|
||||
int occ_num_blocks = 0;
|
||||
size_t dyn_shared_mem_per_blk = 0;
|
||||
hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&occ_num_blocks, kernel, BlockSize, dyn_shared_mem_per_blk));
|
||||
|
||||
int cu_count = getAvailableComputeUnitCount(stream_config);
|
||||
|
||||
if(stream_config.log_level_ > 0)
|
||||
{
|
||||
std::cout << "MaxActiveBlocksPerCU: " << occ_num_blocks
|
||||
<< ", available CUs count: " << cu_count << ", occup. grid size: "
|
||||
<< ck::math::min(occ_num_blocks, KernelConfig::GetCuBlocks()) * cu_count
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
return cu_count * ck::math::min(occ_num_blocks, KernelConfig::GetCuBlocks());
|
||||
}
|
||||
|
||||
template <typename KernelFunction>
|
||||
float LaunchKernel(const KernelFunction& kernel,
|
||||
const Argument& arg,
|
||||
const void* dev_gemm_args,
|
||||
const StreamConfig& stream_config) const
|
||||
{
|
||||
int grid_size = CalculateMaxOccupancyGridSize(kernel, stream_config);
|
||||
int grid_size = KernelConfig::CalculateMaxOccupancyGridSize(kernel, stream_config);
|
||||
|
||||
if(stream_config.log_level_ > 0)
|
||||
{
|
||||
@@ -835,65 +792,60 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
|
||||
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
|
||||
static int GetKernelOccupancy()
|
||||
template <typename GridwiseGemm>
|
||||
static auto GetKernelFunction()
|
||||
{
|
||||
const auto kernel = kernel_grouped_gemm_multiple_d_xdl<GridwiseGemm,
|
||||
KernelArguments,
|
||||
GemmSpec,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
KPerBlock,
|
||||
OffsettedLocalBlock2ETileMap,
|
||||
Block2ETileMap,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
BlkGemmPipeSched,
|
||||
BlkGemmPipelineVer>;
|
||||
return kernel;
|
||||
}
|
||||
|
||||
static auto GetKernelFunction()
|
||||
{
|
||||
int occupancy = 0;
|
||||
if(get_warp_size() == 64)
|
||||
{
|
||||
if constexpr(NXdlPerWave64 > 0)
|
||||
{
|
||||
const auto kernel = kernel_grouped_gemm_multiple_d_xdl<GridwiseGemm64,
|
||||
KernelArguments,
|
||||
GemmSpec,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
KPerBlock,
|
||||
OffsettedLocalBlock2ETileMap,
|
||||
Block2ETileMap,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
BlkGemmPipeSched,
|
||||
BlkGemmPipelineVer>;
|
||||
hip_check_error(
|
||||
hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, BlockSize, 0));
|
||||
const auto kernel = GetKernelFunction<GridwiseGemm64>();
|
||||
return kernel;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
if constexpr(NXdlPerWave32 > 0)
|
||||
{
|
||||
const auto kernel = kernel_grouped_gemm_multiple_d_xdl<GridwiseGemm32,
|
||||
KernelArguments,
|
||||
GemmSpec,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
KPerBlock,
|
||||
OffsettedLocalBlock2ETileMap,
|
||||
Block2ETileMap,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
BlkGemmPipeSched,
|
||||
BlkGemmPipelineVer>;
|
||||
hip_check_error(
|
||||
hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, BlockSize, 0));
|
||||
const auto kernel = GetKernelFunction<GridwiseGemm32>();
|
||||
return kernel;
|
||||
}
|
||||
}
|
||||
return occupancy;
|
||||
|
||||
// This is here to handle the case where MXdlPerWave/NxdPerWave is too small
|
||||
// This is caught by IsSupportedArgument(), but as GetKernelFunction is sometimes called
|
||||
// before we need a fallback kernel to return here.
|
||||
return kernel_dummy<AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation>;
|
||||
}
|
||||
|
||||
static int GetKernelOccupancy()
|
||||
{
|
||||
const auto kernel = GetKernelFunction();
|
||||
return KernelConfig::GetKernelOccupancy(kernel);
|
||||
}
|
||||
|
||||
static auto MakeArgument(std::vector<const void*>& p_As,
|
||||
@@ -906,13 +858,7 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
|
||||
CDEElementwiseOperation cde_elementwise_op)
|
||||
{
|
||||
int occupancy = GetKernelOccupancy();
|
||||
int num_cu;
|
||||
|
||||
hipDeviceProp_t dev_prop;
|
||||
hipDevice_t dev;
|
||||
hip_check_error(hipGetDevice(&dev));
|
||||
hip_check_error(hipGetDeviceProperties(&dev_prop, dev));
|
||||
num_cu = dev_prop.multiProcessorCount;
|
||||
int num_cu = KernelConfig::GetComputeUnitCount();
|
||||
|
||||
return Argument{p_As,
|
||||
p_Bs,
|
||||
@@ -937,13 +883,7 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
|
||||
CDEElementwiseOperation cde_elementwise_op) override
|
||||
{
|
||||
int occupancy = GetKernelOccupancy();
|
||||
int num_cu;
|
||||
|
||||
hipDeviceProp_t dev_prop;
|
||||
hipDevice_t dev;
|
||||
hip_check_error(hipGetDevice(&dev));
|
||||
hip_check_error(hipGetDeviceProperties(&dev_prop, dev));
|
||||
num_cu = dev_prop.multiProcessorCount;
|
||||
int num_cu = KernelConfig::GetComputeUnitCount();
|
||||
|
||||
return std::make_unique<Argument>(p_As,
|
||||
p_Bs,
|
||||
|
||||
@@ -126,7 +126,6 @@ template <typename ALayout,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation,
|
||||
GemmSpecialization GemmSpec,
|
||||
ck::index_t NumGemmKPrefetchStage,
|
||||
ck::index_t BlockSize,
|
||||
ck::index_t MPerBlock,
|
||||
ck::index_t NPerBlock,
|
||||
@@ -158,9 +157,7 @@ template <typename ALayout,
|
||||
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1,
|
||||
typename ComputeTypeA = EDataType,
|
||||
typename ComputeTypeB = ComputeTypeA,
|
||||
bool PermuteA = false,
|
||||
bool PermuteB = false>
|
||||
typename ComputeTypeB = ComputeTypeA>
|
||||
struct DeviceGroupedGemm_Wmma_CShuffleV3 : public DeviceGroupedGemmSplitK<ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
@@ -231,8 +228,8 @@ struct DeviceGroupedGemm_Wmma_CShuffleV3 : public DeviceGroupedGemmSplitK<ALayou
|
||||
BlkGemmPipelineVer,
|
||||
ComputeTypeA,
|
||||
ComputeTypeB,
|
||||
false, // PermuteA not supported by DeviceBatchedGemm base class.
|
||||
false>; // PermuteB not supported by DeviceBatchedGemm base class.
|
||||
false, // PermuteA not supported by GridwiseOp
|
||||
false>; // PermuteB not supported by DeviceGroupedGemm base class
|
||||
|
||||
using CGridDesc_M_N =
|
||||
remove_cvref_t<decltype(GridwiseGemm::template MakeDEGridDescriptor_M_N<ELayout>(
|
||||
@@ -779,7 +776,7 @@ struct DeviceGroupedGemm_Wmma_CShuffleV3 : public DeviceGroupedGemmSplitK<ALayou
|
||||
{BlockGemmPipelineVersion::v5, "v5"}};
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceGroupedGemm_WmmaSplitK"
|
||||
str << "DeviceGroupedGemm_Wmma_CShuffleV3"
|
||||
<< "<"
|
||||
<< std::string(ALayout::name)[0] << ","
|
||||
<< std::string(BLayout::name)[0] << ","
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/quantization_operation.hpp"
|
||||
#include "ck/utility/type_convert.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -236,8 +237,9 @@ struct MultiplyAdd
|
||||
const half_t& d0,
|
||||
const half_t& d1) const
|
||||
{
|
||||
const half_t y = type_convert<half_t>(c) * d0 + d1;
|
||||
e = y;
|
||||
const half_t y =
|
||||
type_convert<half_t>(c * type_convert<float>(d0) + type_convert<float>(d1));
|
||||
e = y;
|
||||
}
|
||||
template <>
|
||||
__host__ __device__ void operator()<bhalf_t, float, bhalf_t, bhalf_t>(bhalf_t& e,
|
||||
@@ -245,8 +247,9 @@ struct MultiplyAdd
|
||||
const bhalf_t& d0,
|
||||
const bhalf_t& d1) const
|
||||
{
|
||||
const bhalf_t y = type_convert<bhalf_t>(c) * d0 + d1;
|
||||
e = y;
|
||||
const bhalf_t y =
|
||||
type_convert<bhalf_t>(c * type_convert<float>(d0) + type_convert<float>(d1));
|
||||
e = y;
|
||||
}
|
||||
template <>
|
||||
__host__ __device__ void operator()<float, float, half_t, half_t>(float& e,
|
||||
|
||||
@@ -334,14 +334,14 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
struct Problem
|
||||
{
|
||||
__host__ Problem() = default;
|
||||
__host__ Problem(index_t M_,
|
||||
index_t N_,
|
||||
index_t K_,
|
||||
std::array<index_t, NumATensor> StrideAs_,
|
||||
std::array<index_t, NumBTensor> StrideBs_,
|
||||
std::array<index_t, NumDTensor> StrideDs_,
|
||||
index_t StrideE_,
|
||||
index_t KBatch_)
|
||||
__host__ __device__ Problem(index_t M_,
|
||||
index_t N_,
|
||||
index_t K_,
|
||||
std::array<index_t, NumATensor> StrideAs_,
|
||||
std::array<index_t, NumBTensor> StrideBs_,
|
||||
std::array<index_t, NumDTensor> StrideDs_,
|
||||
index_t StrideE_,
|
||||
index_t KBatch_)
|
||||
: M{M_},
|
||||
N{N_},
|
||||
K{K_},
|
||||
|
||||
@@ -351,64 +351,65 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
|
||||
// Calculate grid size taking into account splitk (KBatch)
|
||||
// 2D grid (x,z)
|
||||
__host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch)
|
||||
__host__ __device__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch)
|
||||
{
|
||||
return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, KBatch);
|
||||
}
|
||||
|
||||
// Calculate grid size taking into account splitk (KBatch) and multiple groups (Batch)
|
||||
// 3D grid (x,y,z)
|
||||
__host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch, index_t Batch)
|
||||
__host__ __device__ static auto
|
||||
CalculateGridSize(index_t M, index_t N, index_t KBatch, index_t Batch)
|
||||
{
|
||||
return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), KBatch, Batch);
|
||||
}
|
||||
|
||||
__host__ static auto CalculateMPadded(index_t M)
|
||||
__host__ __device__ static auto CalculateMPadded(index_t M)
|
||||
{
|
||||
return math::integer_least_multiple(M, MPerBlock);
|
||||
}
|
||||
|
||||
__host__ static auto CalculateNPadded(index_t N)
|
||||
__host__ __device__ static auto CalculateNPadded(index_t N)
|
||||
{
|
||||
return math::integer_least_multiple(N, NPerBlock);
|
||||
}
|
||||
|
||||
__host__ static auto CalculateKPadded(index_t K)
|
||||
__host__ __device__ static auto CalculateKPadded(index_t K)
|
||||
{
|
||||
return math::integer_divide_ceil(K, KPerBlock) * KPerBlock;
|
||||
}
|
||||
|
||||
__host__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1)
|
||||
__host__ __device__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1)
|
||||
{
|
||||
auto K_t = K_Batch * KPerBlock;
|
||||
return (K + K_t - 1) / K_t * (KPerBlock / AK1Value);
|
||||
}
|
||||
|
||||
__host__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1)
|
||||
__host__ __device__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1)
|
||||
{
|
||||
auto K_t = K_Batch * KPerBlock;
|
||||
return (K + K_t - 1) / K_t * (KPerBlock / BK1Value);
|
||||
}
|
||||
|
||||
__host__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1)
|
||||
__host__ __device__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1)
|
||||
{
|
||||
auto K_t = K_Batch * KPerBlock;
|
||||
return (K + K_t - 1) / K_t * KPerBlock;
|
||||
}
|
||||
|
||||
__host__ static auto CalculateKRead(index_t K, index_t K_Batch = 1)
|
||||
__host__ __device__ static auto CalculateKRead(index_t K, index_t K_Batch = 1)
|
||||
{
|
||||
constexpr auto KReadVec = math::lcm(AK1Number, BK1Number);
|
||||
auto K_t = K_Batch * KReadVec;
|
||||
return (K + K_t - 1) / K_t * KReadVec;
|
||||
}
|
||||
|
||||
__host__ static auto CalculateMBlock(index_t M)
|
||||
__host__ __device__ static auto CalculateMBlock(index_t M)
|
||||
{
|
||||
return math::integer_divide_ceil(M, MPerBlock);
|
||||
}
|
||||
|
||||
__host__ static auto CalculateNBlock(index_t N)
|
||||
__host__ __device__ static auto CalculateNBlock(index_t N)
|
||||
{
|
||||
return math::integer_divide_ceil(N, NPerBlock);
|
||||
}
|
||||
@@ -963,14 +964,14 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
return true;
|
||||
}
|
||||
|
||||
__host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
|
||||
__host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
|
||||
{
|
||||
const index_t num_loop = K / KPerBlock;
|
||||
|
||||
return BlockwiseGemmPipe::BlockHasHotloop(num_loop);
|
||||
}
|
||||
|
||||
__host__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K)
|
||||
__host__ __device__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K)
|
||||
{
|
||||
const index_t num_loop = K / KPerBlock;
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
#include "ck/utility/sequence.hpp"
|
||||
#include "ck/utility/type.hpp"
|
||||
#include "ck/utility/enable_if.hpp"
|
||||
#include <tuple>
|
||||
|
||||
namespace ck {
|
||||
|
||||
@@ -220,4 +221,49 @@ constexpr Tuple<Args&...> tie(Args&... args) noexcept
|
||||
return {args...};
|
||||
}
|
||||
|
||||
//
|
||||
// tuple_map: Map tuple with a different type
|
||||
// e.g. tuple_map<Wrapper, Tuple<T1, T2, T3>> becomes Tuple<Wrapper<T1>, Wrapper<T2>, Wrapper<T3>>
|
||||
//
|
||||
template <template <typename> class Wrapper, typename Tuple>
|
||||
struct tuple_map;
|
||||
|
||||
template <template <typename> class Wrapper, typename... Ts>
|
||||
struct tuple_map<Wrapper, Tuple<Ts...>>
|
||||
{
|
||||
using type = Tuple<Wrapper<Ts>...>;
|
||||
};
|
||||
|
||||
template <template <typename> class Wrapper, typename Tuple>
|
||||
using tuple_map_t = typename tuple_map<Wrapper, Tuple>::type;
|
||||
|
||||
//
|
||||
// tuple_element_or: helper to access type element of a tuple by index, with the option to default
|
||||
// to a type if the index is out of range of the tuple size
|
||||
//
|
||||
namespace detail {
|
||||
|
||||
// Base template (will be specialized on the boolean)
|
||||
template <ck::index_t N, typename Tuple, typename Default, bool InRange = (N < Tuple::Size())>
|
||||
struct tuple_element_or_impl;
|
||||
|
||||
// Specialization for the in-range case: use tuple_element_t
|
||||
template <ck::index_t N, typename Tuple, typename Default>
|
||||
struct tuple_element_or_impl<N, Tuple, Default, true>
|
||||
{
|
||||
using type = tuple_element_t<N, Tuple>;
|
||||
};
|
||||
|
||||
// Specialization for the out-of-range case: use Default
|
||||
template <ck::index_t N, typename Tuple, typename Default>
|
||||
struct tuple_element_or_impl<N, Tuple, Default, false>
|
||||
{
|
||||
using type = Default;
|
||||
};
|
||||
} // namespace detail
|
||||
|
||||
// User-facing alias
|
||||
template <ck::index_t N, typename Tuple, typename Default>
|
||||
using tuple_element_or_t = typename detail::tuple_element_or_impl<N, Tuple, Default>::type;
|
||||
|
||||
} // namespace ck
|
||||
|
||||
Reference in New Issue
Block a user