mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 13:11:25 +00:00
This reverts commit a4f72a314a.
This commit is contained in:
@@ -8,57 +8,6 @@ namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
///
|
||||
/// @brief Structure representing single GEMM problem arguments.
|
||||
///
|
||||
/// The pointer to the vector of those structures is passed
|
||||
/// to the GroupedGEMM entry point kernel.
|
||||
///
|
||||
struct GroupedGemmKernelArguments
|
||||
{
|
||||
__host__ __device__ GroupedGemmKernelArguments(const void* p_a_grid_,
|
||||
const void* p_b_grid_,
|
||||
void* p_c_grid_,
|
||||
index_t M_,
|
||||
index_t N_,
|
||||
index_t K_,
|
||||
index_t StrideA_,
|
||||
index_t StrideB_,
|
||||
index_t StrideC_)
|
||||
: p_a_grid{p_a_grid_},
|
||||
p_b_grid{p_b_grid_},
|
||||
p_c_grid{p_c_grid_},
|
||||
M{M_},
|
||||
N{N_},
|
||||
K{K_},
|
||||
StrideA{StrideA_},
|
||||
StrideB{StrideB_},
|
||||
StrideC{StrideC_}
|
||||
{
|
||||
}
|
||||
|
||||
const void* p_a_grid;
|
||||
const void* p_b_grid;
|
||||
void* p_c_grid;
|
||||
index_t M;
|
||||
index_t N;
|
||||
index_t K;
|
||||
index_t StrideA;
|
||||
index_t StrideB;
|
||||
index_t StrideC;
|
||||
|
||||
void Print() const
|
||||
{
|
||||
std::cout << "arg {"
|
||||
<< "M:" << M << ", "
|
||||
<< "N:" << N << ", "
|
||||
<< "K:" << K << ", "
|
||||
<< "SA:" << StrideA << ", "
|
||||
<< "SB:" << StrideB << ", "
|
||||
<< "SC:" << StrideC << "}" << std::endl;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
@@ -82,28 +31,7 @@ struct DeviceGroupedGemmSplitK : public DeviceGroupedGemm<ALayout,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation>
|
||||
{
|
||||
//----------------------------------------------------------------------------------------------
|
||||
/// @brief Sets the k batch size.
|
||||
///
|
||||
/// @param p_arg Pointer to the Argument we're going to change.
|
||||
/// @param[in] kbatch The kbatch value.
|
||||
///
|
||||
virtual void SetKBatchSize([[maybe_unused]] BaseArgument* p_arg,
|
||||
[[maybe_unused]] index_t kbatch) const
|
||||
{
|
||||
}
|
||||
|
||||
//----------------------------------------------------------------------------------------------
|
||||
/// @brief Sets the device kernel arguments pointer.
|
||||
///
|
||||
/// @param p_arg The pointer to the Argument we're going to update.
|
||||
/// @param[in] p_dev_kernel_args The pointer to the device memory which contains kernel
|
||||
/// arguments.
|
||||
///
|
||||
virtual void SetDeviceKernelArgs([[maybe_unused]] BaseArgument* p_arg,
|
||||
[[maybe_unused]] const void* p_dev_kernel_args) const
|
||||
{
|
||||
}
|
||||
virtual void SetKBatchSize(BaseArgument* p_arg, index_t kbatch) const = 0;
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
|
||||
@@ -22,22 +22,22 @@ template <typename InDataType,
|
||||
index_t NumReduceDim>
|
||||
struct DeviceSoftmax : public BaseOperator
|
||||
{
|
||||
///
|
||||
/// @brief Makes a pointer to Argument class.
|
||||
///
|
||||
/// @param[in] inLengths Input tensor extent(s) from high to low dimension
|
||||
/// @param[in] inStrides Input tensor stride(s) from high to low dimension
|
||||
/// @param[in] reduceDims The dimension(s) the normalization operation is applied
|
||||
/// @param[in] alpha double type value
|
||||
/// @param[in] beta double type value
|
||||
/// @param[in] in_dev Typeless const pointer in device memory storing the input
|
||||
/// tensor
|
||||
/// @param out_dev Typeless pointer in device memory storing the output tensor
|
||||
/// @param[in] in_elementwise_op The input elementwise operation.
|
||||
/// @param[in] acc_elementwise_op The accumulation elementwise operation.
|
||||
///
|
||||
/// @return Unique pointer to the Argument class.
|
||||
///
|
||||
//
|
||||
// @brief Makes a pointer to Argument class.
|
||||
//
|
||||
// @param[in] inLengths Input tensor extent(s) from high to low dimension
|
||||
// @param[in] inStrides Input tensor stride(s) from high to low dimension
|
||||
// @param[in] reduceDims The dimension(s) the normalization operation is applied
|
||||
// @param[in] alpha double type value
|
||||
// @param[in] beta double type value
|
||||
// @param[in] in_dev Typeless const pointer in device memory storing the input
|
||||
// tensor
|
||||
// @param out_dev Typeless pointer in device memory storing the output tensor
|
||||
// @param[in] in_elementwise_op The input elementwise operation.
|
||||
// @param[in] acc_elementwise_op The accumulation elementwise operation.
|
||||
//
|
||||
// @return Unique pointer to the Argument class.
|
||||
//
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const std::vector<index_t> inLengths,
|
||||
const std::vector<index_t> inStrides,
|
||||
|
||||
@@ -168,7 +168,7 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
|
||||
stream_config.stream_id_));
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg);
|
||||
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg, b2c_map);
|
||||
};
|
||||
|
||||
if(has_main_k0_block_loop)
|
||||
|
||||
@@ -157,22 +157,22 @@ __global__ void
|
||||
}
|
||||
} // namespace
|
||||
|
||||
///
|
||||
/// @brief Device Convolution operation.
|
||||
///
|
||||
/// Supports:
|
||||
/// @li Forward convolution with up to 3 spatial dimentions
|
||||
/// @li Input tensor in GNWC data format
|
||||
/// @li Weight tensor in GKXC data format
|
||||
/// @li Output tensor in GNWK data format
|
||||
///
|
||||
/// 1D:
|
||||
/// out[N, Wo, K] = in[N, Wi, C] * wei[K, X, C]
|
||||
/// 2D:
|
||||
/// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C]
|
||||
/// 3D:
|
||||
/// out[N, Do, Ho, Wo, K] = in[N, Di, Hi, Wi, C] * wei[K, Z, Y, X, C]
|
||||
///
|
||||
//
|
||||
// @brief Device Convolution operation.
|
||||
//
|
||||
// Supports:
|
||||
// @li Forward convolution with up to 3 spatial dimentions
|
||||
// @li Input tensor in GNWC data format
|
||||
// @li Weight tensor in GKXC data format
|
||||
// @li Output tensor in GNWK data format
|
||||
//
|
||||
// 1D:
|
||||
// out[N, Wo, K] = in[N, Wi, C] * wei[K, X, C]
|
||||
// 2D:
|
||||
// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C]
|
||||
// 3D:
|
||||
// out[N, Do, Ho, Wo, K] = in[N, Di, Hi, Wi, C] * wei[K, Z, Y, X, C]
|
||||
//
|
||||
template <index_t NDimSpatial,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
|
||||
@@ -154,22 +154,22 @@ __global__ void
|
||||
|
||||
} // namespace
|
||||
|
||||
///
|
||||
/// @brief Device Convolution operation.
|
||||
///
|
||||
/// Supports:
|
||||
/// @li Forward convolution with up to 3 spatial dimentions
|
||||
/// @li Input tensor in GNWC data format
|
||||
/// @li Weight tensor in GKXC data format
|
||||
/// @li Output tensor in GNWK data format
|
||||
///
|
||||
/// 1D:
|
||||
/// out[N, Wo, K] = in[N, Wi, C] * wei[K, X, C]
|
||||
/// 2D:
|
||||
/// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C]
|
||||
/// 3D:
|
||||
/// out[N, Do, Ho, Wo, K] = in[N, Di, Hi, Wi, C] * wei[K, Z, Y, X, C]
|
||||
///
|
||||
//
|
||||
// @brief Device Convolution operation.
|
||||
//
|
||||
// Supports:
|
||||
// @li Forward convolution with up to 3 spatial dimentions
|
||||
// @li Input tensor in GNWC data format
|
||||
// @li Weight tensor in GKXC data format
|
||||
// @li Output tensor in GNWK data format
|
||||
//
|
||||
// 1D:
|
||||
// out[N, Wo, K] = in[N, Wi, C] * wei[K, X, C]
|
||||
// 2D:
|
||||
// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C]
|
||||
// 3D:
|
||||
// out[N, Do, Ho, Wo, K] = in[N, Di, Hi, Wi, C] * wei[K, Z, Y, X, C]
|
||||
//
|
||||
template <
|
||||
index_t NDimSpatial,
|
||||
typename ADataType,
|
||||
|
||||
@@ -150,22 +150,22 @@ __global__ void
|
||||
|
||||
} // namespace
|
||||
|
||||
///
|
||||
/// @brief Device Convolution operation.
|
||||
///
|
||||
/// Supports:
|
||||
/// @li Forward convolution with up to 3 spatial dimentions
|
||||
/// @li Input tensor in GNWC data format
|
||||
/// @li Weight tensor in GKXC data format
|
||||
/// @li Output tensor in GNWK data format
|
||||
///
|
||||
/// 1D:
|
||||
/// out[N, Wo, K] = in[N, Wi, C] * wei[K, X, C]
|
||||
/// 2D:
|
||||
/// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C]
|
||||
/// 3D:
|
||||
/// out[N, Do, Ho, Wo, K] = in[N, Di, Hi, Wi, C] * wei[K, Z, Y, X, C]
|
||||
///
|
||||
//
|
||||
// @brief Device Convolution operation.
|
||||
//
|
||||
// Supports:
|
||||
// @li Forward convolution with up to 3 spatial dimentions
|
||||
// @li Input tensor in GNWC data format
|
||||
// @li Weight tensor in GKXC data format
|
||||
// @li Output tensor in GNWK data format
|
||||
//
|
||||
// 1D:
|
||||
// out[N, Wo, K] = in[N, Wi, C] * wei[K, X, C]
|
||||
// 2D:
|
||||
// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C]
|
||||
// 3D:
|
||||
// out[N, Do, Ho, Wo, K] = in[N, Di, Hi, Wi, C] * wei[K, Z, Y, X, C]
|
||||
//
|
||||
template <index_t NDimSpatial,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
|
||||
@@ -5,13 +5,11 @@
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <tuple>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
#include "ck/host_utility/hip_check_error.hpp"
|
||||
#include "ck/host_utility/stream_utility.hpp"
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/utility/tuple.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
@@ -25,28 +23,8 @@ namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
///
|
||||
/// @brief Entry point kernel for device-wide Grouped GEMM operation.
|
||||
///
|
||||
/// @param[in] gemm_descs_const The pointer to the array of GEMM descriptor structures.
|
||||
/// @param[in] tile_count The overall number of output tiles we divided all groups
|
||||
/// into.
|
||||
/// @param[in] k_batch The number of batches we split the K dimension into.
|
||||
///
|
||||
/// @tparam GridwiseGemm The specific GridwiseGEMM algorithm implementation.
|
||||
/// @tparam GemmDesc The structure holding all necessary descriptors and
|
||||
/// other data needed for groupd gemm calculation and work
|
||||
/// distribution.
|
||||
/// @tparam HasMainKBlockLoop Flag indicating whether all GEMM problem configurations
|
||||
/// need to loop over tiles in K dimension.
|
||||
/// @tparam CGlobalMemoryDataOperation The functor used to store data in output C matrix.
|
||||
/// In example could be: AtomicAdd or Store.
|
||||
///
|
||||
template <typename GridwiseGemm,
|
||||
typename GemmDesc,
|
||||
typename FloatA,
|
||||
typename FloatB,
|
||||
typename FloatC,
|
||||
bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum CGlobalMemoryDataOperation>
|
||||
__global__ void
|
||||
@@ -54,99 +32,42 @@ __global__ void
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_grouped_gemm_xdl_splitk(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
|
||||
const index_t tile_count,
|
||||
const index_t k_batch)
|
||||
const index_t group_count)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
|
||||
|
||||
constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte();
|
||||
__shared__ uint8_t p_shared[shared_size];
|
||||
|
||||
index_t tile_id = get_block_1d_id();
|
||||
const index_t grid_size = get_grid_size();
|
||||
const index_t block_id = get_block_1d_id();
|
||||
const auto gemm_desc_ptr =
|
||||
reinterpret_cast<const GemmDesc*>(cast_pointer_to_generic_address_space(gemm_descs_const));
|
||||
|
||||
static constexpr index_t MPerBlock = GridwiseGemm::GetMPerBlock();
|
||||
static constexpr index_t NPerBlock = GridwiseGemm::GetNPerBlock();
|
||||
static constexpr index_t B2E_M01 = 8;
|
||||
|
||||
using CGridDesc_M_N = typename GridwiseGemm::CGridDesc_M_N;
|
||||
using Block2ETileMapKSplit =
|
||||
BlockToCTileMap_KSplit_M00_N0_M01Adapt<MPerBlock, NPerBlock, CGridDesc_M_N>;
|
||||
|
||||
index_t group_id = 0;
|
||||
index_t offset = 0;
|
||||
|
||||
auto M = gemm_desc_ptr[group_id].M;
|
||||
auto N = gemm_desc_ptr[group_id].N;
|
||||
auto StrideC = gemm_desc_ptr[group_id].StrideC;
|
||||
auto c_grid_desc_m_n = GridwiseGemm::MakeCGridDescriptor_M_N(M, N, StrideC);
|
||||
auto b2c_tile_map = Block2ETileMapKSplit{c_grid_desc_m_n, B2E_M01, k_batch};
|
||||
index_t grid_size_grp = b2c_tile_map.CalculateGridSize(c_grid_desc_m_n);
|
||||
|
||||
index_t gemm_tile_id_start = 0;
|
||||
index_t gemm_tile_id_end = grid_size_grp;
|
||||
|
||||
while(tile_id < tile_count)
|
||||
index_t left = 0;
|
||||
index_t right = group_count;
|
||||
index_t group_id = index_t((left + right) / 2);
|
||||
while((!(block_id >= gemm_desc_ptr[group_id].block_start_ &&
|
||||
block_id < gemm_desc_ptr[group_id].block_end_)) &&
|
||||
left <= right)
|
||||
{
|
||||
// Find corresponding GEMM group for out tile
|
||||
while(!(tile_id >= gemm_tile_id_start && tile_id < gemm_tile_id_end))
|
||||
if(block_id < gemm_desc_ptr[group_id].block_start_)
|
||||
{
|
||||
offset += grid_size_grp;
|
||||
group_id++;
|
||||
|
||||
M = gemm_desc_ptr[group_id].M;
|
||||
N = gemm_desc_ptr[group_id].N;
|
||||
StrideC = gemm_desc_ptr[group_id].StrideC;
|
||||
c_grid_desc_m_n = GridwiseGemm::MakeCGridDescriptor_M_N(M, N, StrideC);
|
||||
b2c_tile_map = Block2ETileMapKSplit{c_grid_desc_m_n, B2E_M01, k_batch};
|
||||
grid_size_grp = b2c_tile_map.CalculateGridSize(c_grid_desc_m_n);
|
||||
|
||||
gemm_tile_id_start = offset;
|
||||
gemm_tile_id_end = offset + grid_size_grp;
|
||||
right = group_id;
|
||||
}
|
||||
|
||||
const auto p_a_grid = reinterpret_cast<const FloatA*>(gemm_desc_ptr[group_id].p_a_grid);
|
||||
const auto p_b_grid = reinterpret_cast<const FloatB*>(gemm_desc_ptr[group_id].p_b_grid);
|
||||
const auto p_c_grid = reinterpret_cast<FloatC*>(gemm_desc_ptr[group_id].p_c_grid);
|
||||
|
||||
const auto K = gemm_desc_ptr[group_id].K;
|
||||
const auto StrideA = gemm_desc_ptr[group_id].StrideA;
|
||||
const auto StrideB = gemm_desc_ptr[group_id].StrideB;
|
||||
|
||||
const auto MPadded = GridwiseGemm::CalculateMPadded(M);
|
||||
const auto NPadded = GridwiseGemm::CalculateNPadded(N);
|
||||
const auto KPadded = GridwiseGemm::CalculateKPadded(K, k_batch);
|
||||
const auto K0 = GridwiseGemm::CalculateK0(K, k_batch);
|
||||
|
||||
LocalBlockToCTileMap<Block2ETileMapKSplit> local_b2c{b2c_tile_map, tile_id - offset};
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation>(
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideC,
|
||||
MPadded,
|
||||
NPadded,
|
||||
KPadded,
|
||||
K0,
|
||||
k_batch,
|
||||
static_cast<void*>(p_shared),
|
||||
local_b2c);
|
||||
|
||||
tile_id += grid_size;
|
||||
else
|
||||
{
|
||||
left = group_id;
|
||||
}
|
||||
group_id = index_t((left + right) / 2);
|
||||
}
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation>(
|
||||
gemm_desc_ptr[group_id].karg_,
|
||||
static_cast<void*>(p_shared),
|
||||
gemm_desc_ptr[group_id].block_2_ctile_map_);
|
||||
#else
|
||||
ignore = gemm_descs_const;
|
||||
ignore = tile_count;
|
||||
ignore = k_batch;
|
||||
ignore = group_count;
|
||||
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
|
||||
}
|
||||
|
||||
@@ -265,13 +186,33 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
|
||||
LoopSched,
|
||||
PipelineVer>;
|
||||
|
||||
using CGridDesc_M_N = typename GridwiseGemm::CGridDesc_M_N;
|
||||
using GridwiseGemmArg = typename GridwiseGemm::Argument;
|
||||
using KernelArguments = GroupedGemmKernelArguments;
|
||||
using CGridDesc_M_N = typename GridwiseGemm::CGridDesc_M_N;
|
||||
using Block2ETileMapKSplit =
|
||||
BlockToCTileMap_KSplit_M00_N0_M01Adapt<MPerBlock, NPerBlock, CGridDesc_M_N>;
|
||||
// Block2CTileMap configuration parameter.
|
||||
static constexpr index_t B2E_M01 = 8;
|
||||
static constexpr index_t B2E_M01 = 8;
|
||||
using GroupedGemmBlock2ETileMap = OffsettedBlockToCTileMap<Block2ETileMapKSplit>;
|
||||
using KernelArgument = typename GridwiseGemm::Argument;
|
||||
|
||||
struct GemmTransKernelArg
|
||||
{
|
||||
KernelArgument karg_;
|
||||
GroupedGemmBlock2ETileMap block_2_ctile_map_;
|
||||
index_t block_start_, block_end_;
|
||||
|
||||
GemmTransKernelArg() = default;
|
||||
GemmTransKernelArg(KernelArgument&& karg,
|
||||
GroupedGemmBlock2ETileMap&& b2c_map,
|
||||
index_t block_start,
|
||||
index_t block_end)
|
||||
: karg_{karg},
|
||||
block_2_ctile_map_{b2c_map},
|
||||
block_start_{block_start},
|
||||
block_end_{block_end}
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
static constexpr index_t DefaultKBatch = 1;
|
||||
|
||||
// Argument
|
||||
@@ -284,6 +225,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
|
||||
std::vector<GemmDesc>& gemm_descs)
|
||||
: Argument(p_As, p_Bs, p_Es, gemm_descs, DefaultKBatch)
|
||||
{
|
||||
// TODO: use occupancy api to calculate appropriate batch size.
|
||||
}
|
||||
|
||||
Argument(std::vector<const void*>& p_As,
|
||||
@@ -291,8 +233,9 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
|
||||
std::vector<void*>& p_Es,
|
||||
std::vector<GemmDesc>& gemm_descs,
|
||||
index_t kbatch)
|
||||
: K_BATCH{kbatch}, group_count_{0}, skipped_group_count_{0}, grid_size_{0}
|
||||
: K_BATCH{kbatch}
|
||||
{
|
||||
grid_size_ = 0;
|
||||
group_count_ = ck::type_convert<ck::index_t>(gemm_descs.size());
|
||||
|
||||
if(!(group_count_ == ck::type_convert<ck::index_t>(p_As.size()) &&
|
||||
@@ -304,6 +247,8 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
|
||||
|
||||
gemm_kernel_args_.reserve(group_count_);
|
||||
|
||||
skipped_group_count_ = 0;
|
||||
|
||||
for(std::size_t i = 0; i < gemm_descs.size(); ++i)
|
||||
{
|
||||
const index_t M = gemm_descs[i].M_;
|
||||
@@ -320,29 +265,51 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
|
||||
const index_t stride_b = gemm_descs[i].stride_B_;
|
||||
const index_t stride_c = gemm_descs[i].stride_C_;
|
||||
|
||||
const index_t m_padded = GridwiseGemm::CalculateMPadded(M);
|
||||
const index_t n_padded = GridwiseGemm::CalculateNPadded(N);
|
||||
const index_t k_padded = GridwiseGemm::CalculateKPadded(K, K_BATCH);
|
||||
const index_t k0 = GridwiseGemm::CalculateK0(K, K_BATCH);
|
||||
|
||||
const auto c_grid_desc_m_n = GridwiseGemm::MakeCGridDescriptor_M_N(M, N, stride_c);
|
||||
|
||||
auto local_b2c_tile_map = Block2ETileMapKSplit{c_grid_desc_m_n, B2E_M01, K_BATCH};
|
||||
const auto local_b2c_tile_map =
|
||||
Block2ETileMapKSplit{c_grid_desc_m_n, B2E_M01, K_BATCH};
|
||||
const index_t grid_size_grp = local_b2c_tile_map.CalculateGridSize(c_grid_desc_m_n);
|
||||
|
||||
const index_t block_start = grid_size_;
|
||||
const index_t block_end = grid_size_ + grid_size_grp;
|
||||
|
||||
grid_size_ += grid_size_grp;
|
||||
|
||||
gemm_kernel_args_.emplace_back(type_convert<const ADataType*>(p_As[i]),
|
||||
type_convert<const BDataType*>(p_Bs[i]),
|
||||
type_convert<EDataType*>(p_Es[i]),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_a,
|
||||
stride_b,
|
||||
stride_c);
|
||||
// block-to-e-tile map
|
||||
auto grouped_block_2_ctile_map =
|
||||
GroupedGemmBlock2ETileMap(local_b2c_tile_map, block_start);
|
||||
|
||||
auto karg = KernelArgument{type_convert<const ADataType*>(p_As[i]),
|
||||
type_convert<const BDataType*>(p_Bs[i]),
|
||||
type_convert<EDataType*>(p_Es[i]),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_a,
|
||||
stride_b,
|
||||
stride_c,
|
||||
m_padded,
|
||||
n_padded,
|
||||
k_padded,
|
||||
k0,
|
||||
K_BATCH};
|
||||
|
||||
gemm_kernel_args_.emplace_back(
|
||||
std::move(karg), std::move(grouped_block_2_ctile_map), block_start, block_end);
|
||||
}
|
||||
}
|
||||
|
||||
///
|
||||
/// @brief Set new kbatch value.
|
||||
///
|
||||
/// @param[in] kbatch The new splitK parameter value.
|
||||
///
|
||||
/**
|
||||
* @brief Recalculate group grid size for all gemms and update B2C maps.
|
||||
*
|
||||
* @param[in] kbatch The new splitK parameter value.
|
||||
*/
|
||||
void UpdateKBatch(index_t kbatch)
|
||||
{
|
||||
K_BATCH = kbatch;
|
||||
@@ -351,14 +318,33 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
|
||||
for(std::size_t i = 0; i < gemm_kernel_args_.size(); ++i)
|
||||
{
|
||||
|
||||
auto& gemm_arg = gemm_kernel_args_[i];
|
||||
auto& karg = gemm_kernel_args_[i].karg_;
|
||||
|
||||
const index_t k_padded = GridwiseGemm::CalculateKPadded(karg.K, K_BATCH);
|
||||
const index_t k0 = GridwiseGemm::CalculateK0(karg.K, K_BATCH);
|
||||
|
||||
const auto c_grid_desc_m_n =
|
||||
GridwiseGemm::MakeCGridDescriptor_M_N(gemm_arg.M, gemm_arg.N, gemm_arg.StrideC);
|
||||
GridwiseGemm::MakeCGridDescriptor_M_N(karg.M, karg.N, karg.StrideC);
|
||||
|
||||
auto local_b2c_tile_map = Block2ETileMapKSplit{c_grid_desc_m_n, B2E_M01, K_BATCH};
|
||||
const auto local_b2c_tile_map =
|
||||
Block2ETileMapKSplit{c_grid_desc_m_n, B2E_M01, K_BATCH};
|
||||
const index_t grid_size_grp = local_b2c_tile_map.CalculateGridSize(c_grid_desc_m_n);
|
||||
|
||||
const index_t block_start = grid_size_;
|
||||
const index_t block_end = grid_size_ + grid_size_grp;
|
||||
|
||||
grid_size_ += grid_size_grp;
|
||||
|
||||
// block-to-e-tile map
|
||||
auto grouped_block_2_ctile_map =
|
||||
GroupedGemmBlock2ETileMap(local_b2c_tile_map, block_start);
|
||||
|
||||
karg.KPadded = k_padded;
|
||||
karg.K0 = k0;
|
||||
karg.k_batch = K_BATCH;
|
||||
gemm_kernel_args_[i].block_2_ctile_map_ = grouped_block_2_ctile_map;
|
||||
gemm_kernel_args_[i].block_start_ = block_start;
|
||||
gemm_kernel_args_[i].block_end_ = block_end;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -366,167 +352,31 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
|
||||
index_t K_BATCH;
|
||||
index_t group_count_;
|
||||
index_t skipped_group_count_;
|
||||
// The overall number of output tiles to be processed.
|
||||
index_t grid_size_;
|
||||
const void* p_dev_gemm_args_;
|
||||
|
||||
std::vector<KernelArguments> gemm_kernel_args_;
|
||||
std::vector<GemmTransKernelArg> gemm_kernel_args_;
|
||||
index_t grid_size_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
// The oversubscription factor for the number of blocks that can simultaneously reside on
|
||||
// GPU.
|
||||
static constexpr int BLOCK_SUBSCRIPTION_FACTOR = 1;
|
||||
static constexpr int BLOCK_WAVES = BlockSize / get_warp_size();
|
||||
static constexpr int CU_SIMDS = 4;
|
||||
// Assume we want to have at most 2 waves per SIMD
|
||||
static constexpr int CU_BLOCKS = math::integer_divide_floor(2 * CU_SIMDS, BLOCK_WAVES);
|
||||
|
||||
///
|
||||
/// @brief Launch Grouped Gemm kernel.
|
||||
///
|
||||
/// @note This function overload is using user provided device buffer for kernel
|
||||
/// arguments.
|
||||
///
|
||||
/// @param[in] arg The structure containing kernel arguments (in host memory).
|
||||
/// @param[in] dev_gemm_args The point to device memory with kernel arguments.
|
||||
/// @param[in] stream_config The device stream configuration.
|
||||
///
|
||||
/// @return The average kernel execution time (if time measurement is enabled.)
|
||||
///
|
||||
float Run(const Argument& arg,
|
||||
const void* dev_gemm_args,
|
||||
const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
auto [all_have_kbatch_gt_one, all_have_main_k0_block_loop] =
|
||||
CheckArgument(arg, stream_config);
|
||||
|
||||
if(dev_gemm_args == nullptr)
|
||||
{
|
||||
std::ostringstream err;
|
||||
err << "The gemm arguments workspace buffer is not allocated!"
|
||||
<< " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
|
||||
if(all_have_kbatch_gt_one)
|
||||
{
|
||||
for(const auto& gemm_arg : arg.gemm_kernel_args_)
|
||||
{
|
||||
hip_check_error(hipMemsetAsync(gemm_arg.p_c_grid,
|
||||
0,
|
||||
gemm_arg.M * gemm_arg.N * sizeof(EDataType),
|
||||
stream_config.stream_id_));
|
||||
}
|
||||
}
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
if(all_have_main_k0_block_loop)
|
||||
{
|
||||
if(all_have_kbatch_gt_one)
|
||||
{
|
||||
ave_time = DispatchKernel<InMemoryDataOperationEnum::AtomicAdd, true>(
|
||||
arg, dev_gemm_args, stream_config);
|
||||
}
|
||||
else
|
||||
{
|
||||
ave_time = DispatchKernel<InMemoryDataOperationEnum::Set, true>(
|
||||
arg, dev_gemm_args, stream_config);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(all_have_kbatch_gt_one)
|
||||
{
|
||||
ave_time = DispatchKernel<InMemoryDataOperationEnum::AtomicAdd, false>(
|
||||
arg, dev_gemm_args, stream_config);
|
||||
}
|
||||
else
|
||||
{
|
||||
ave_time = DispatchKernel<InMemoryDataOperationEnum::Set, false>(
|
||||
arg, dev_gemm_args, stream_config);
|
||||
}
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
///
|
||||
/// @brief Launch Grouped Gemm kernel.
|
||||
///
|
||||
/// @note This function overload is using device workspace buffer for kernel
|
||||
/// arguments. The user should call @see GetWorkSpaceSize and @see
|
||||
/// SetWorkSpacePointer on arg parameter to properly allocate this buffer.
|
||||
///
|
||||
/// @param[in] arg The structure containing kernel arguments (in host memory).
|
||||
/// @param[in] stream_config The device stream configuration.
|
||||
///
|
||||
/// @return The average kernel execution time (if time measurement is enabled.)
|
||||
///
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
if(arg.p_workspace_ != nullptr)
|
||||
{
|
||||
hip_check_error(
|
||||
hipMemcpyWithStream(arg.p_workspace_,
|
||||
arg.gemm_kernel_args_.data(),
|
||||
arg.gemm_kernel_args_.size() * sizeof(KernelArguments),
|
||||
hipMemcpyHostToDevice,
|
||||
stream_config.stream_id_));
|
||||
}
|
||||
else
|
||||
{
|
||||
std::ostringstream err;
|
||||
err << "The gemm arguments workspace buffer is not allocated!"
|
||||
<< " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
|
||||
return Run(arg, arg.p_workspace_, stream_config);
|
||||
}
|
||||
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
|
||||
private:
|
||||
auto CheckArgument(const Argument& arg, const StreamConfig& stream_config) const
|
||||
{
|
||||
index_t K0 = GridwiseGemm::CalculateK0(arg.gemm_kernel_args_[0].K, arg.K_BATCH);
|
||||
bool all_have_kbatch_gt_one = arg.K_BATCH > 1;
|
||||
index_t K0 = arg.gemm_kernel_args_[0].karg_.K0;
|
||||
bool all_have_kbatch_gt_one = arg.gemm_kernel_args_[0].karg_.k_batch > 1;
|
||||
bool all_have_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
|
||||
|
||||
for(std::size_t i = 0; i < arg.gemm_kernel_args_.size(); ++i)
|
||||
{
|
||||
const auto& gemm_arg = arg.gemm_kernel_args_[i];
|
||||
const auto& karg = arg.gemm_kernel_args_[i].karg_;
|
||||
if(stream_config.log_level_ > 0)
|
||||
{
|
||||
gemm_arg.Print();
|
||||
karg.Print();
|
||||
}
|
||||
|
||||
// Currently all groups use same kbatch value.
|
||||
auto kbatch = arg.K_BATCH;
|
||||
K0 = GridwiseGemm::CalculateK0(arg.gemm_kernel_args_[i].K, arg.K_BATCH);
|
||||
auto kbatch = karg.k_batch;
|
||||
|
||||
if(!GridwiseGemm::CheckValidity(GridwiseGemmArg{nullptr,
|
||||
nullptr,
|
||||
nullptr,
|
||||
gemm_arg.M,
|
||||
gemm_arg.N,
|
||||
gemm_arg.K,
|
||||
gemm_arg.StrideA,
|
||||
gemm_arg.StrideB,
|
||||
gemm_arg.StrideC,
|
||||
0, // MPadded
|
||||
0, // NPadded
|
||||
0, // KPadded
|
||||
K0,
|
||||
kbatch}))
|
||||
if(!GridwiseGemm::CheckValidity(karg))
|
||||
{
|
||||
std::ostringstream err;
|
||||
err << "Group id: " << i << " has invalid GridwiseGemm settings!" << __FILE__
|
||||
@@ -534,6 +384,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
|
||||
K0 = karg.K0;
|
||||
bool not_all_have_main_k0_block_loop_same =
|
||||
all_have_main_k0_block_loop xor GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
|
||||
bool not_all_have_kbatch_value_same = all_have_kbatch_gt_one xor (kbatch > 1);
|
||||
@@ -551,75 +402,99 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
|
||||
std::ostringstream err;
|
||||
err << "Not all gemms have same kbatch value (=1 or >1)! "
|
||||
<< "group [" << i << "], kbatch: " << kbatch
|
||||
<< ", group [0], kbatch: " << arg.K_BATCH << " in " << __FILE__ << ":"
|
||||
<< __LINE__ << ", in function: " << __func__;
|
||||
<< ", group [0], kbatch: " << arg.gemm_kernel_args_[0].karg_.k_batch
|
||||
<< " in " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
}
|
||||
return std::make_tuple(all_have_kbatch_gt_one, all_have_main_k0_block_loop);
|
||||
}
|
||||
|
||||
template <InMemoryDataOperationEnum CGlobalMemoryDataOperation, bool HasMainKBlockLoop>
|
||||
float DispatchKernel(const Argument& arg,
|
||||
const void* dev_gemm_args,
|
||||
const StreamConfig& stream_config) const
|
||||
{
|
||||
const auto kernel = kernel_grouped_gemm_xdl_splitk<GridwiseGemm,
|
||||
KernelArguments,
|
||||
ADataType,
|
||||
BDataType,
|
||||
EDataType,
|
||||
HasMainKBlockLoop,
|
||||
CGlobalMemoryDataOperation>;
|
||||
return LaunchKernel(kernel, arg, dev_gemm_args, stream_config);
|
||||
}
|
||||
hip_check_error(
|
||||
hipMemcpyWithStream(arg.p_workspace_,
|
||||
arg.gemm_kernel_args_.data(),
|
||||
arg.gemm_kernel_args_.size() * sizeof(GemmTransKernelArg),
|
||||
hipMemcpyHostToDevice,
|
||||
stream_config.stream_id_));
|
||||
|
||||
template <typename KernelFunction>
|
||||
int CalculateMaxOccupancyGridSize(const KernelFunction& kernel,
|
||||
const StreamConfig& stream_config) const
|
||||
{
|
||||
// Calculate max number of workgroups that can simultaneously reside on the CU.
|
||||
int num_blocks = 0;
|
||||
size_t dyn_shared_mem_per_blk = 0;
|
||||
hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&num_blocks, kernel, BlockSize, dyn_shared_mem_per_blk));
|
||||
float ave_time = 0;
|
||||
|
||||
int cu_count = getAvailableComputeUnitCount(stream_config);
|
||||
const auto Run = [&](const auto& kernel) {
|
||||
if(all_have_kbatch_gt_one)
|
||||
{
|
||||
for(const auto& trans_arg : arg.gemm_kernel_args_)
|
||||
{
|
||||
const auto& karg = trans_arg.karg_;
|
||||
hip_check_error(hipMemsetAsync(karg.p_c_grid,
|
||||
0,
|
||||
karg.M * karg.N * sizeof(EDataType),
|
||||
stream_config.stream_id_));
|
||||
}
|
||||
}
|
||||
|
||||
if(stream_config.log_level_ > 0)
|
||||
ave_time =
|
||||
launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(arg.grid_size_),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
cast_pointer_to_constant_address_space(arg.p_workspace_),
|
||||
arg.gemm_kernel_args_.size());
|
||||
};
|
||||
|
||||
if(all_have_main_k0_block_loop)
|
||||
{
|
||||
std::cout << "MaxActiveBlocksPerCU: " << num_blocks
|
||||
<< ", available CUs count: " << cu_count << ", occup. grid size: "
|
||||
<< ck::math::min(num_blocks, CU_BLOCKS) * cu_count *
|
||||
BLOCK_SUBSCRIPTION_FACTOR
|
||||
<< std::endl;
|
||||
if(all_have_kbatch_gt_one)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_grouped_gemm_xdl_splitk<GridwiseGemm,
|
||||
GemmTransKernelArg,
|
||||
true,
|
||||
InMemoryDataOperationEnum::AtomicAdd>;
|
||||
|
||||
Run(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_grouped_gemm_xdl_splitk<GridwiseGemm,
|
||||
GemmTransKernelArg,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set>;
|
||||
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(all_have_kbatch_gt_one)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_grouped_gemm_xdl_splitk<GridwiseGemm,
|
||||
GemmTransKernelArg,
|
||||
false,
|
||||
InMemoryDataOperationEnum::AtomicAdd>;
|
||||
|
||||
Run(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_grouped_gemm_xdl_splitk<GridwiseGemm,
|
||||
GemmTransKernelArg,
|
||||
false,
|
||||
InMemoryDataOperationEnum::Set>;
|
||||
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
|
||||
return cu_count * ck::math::min(num_blocks, CU_BLOCKS) * BLOCK_SUBSCRIPTION_FACTOR;
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
template <typename KernelFunction>
|
||||
float LaunchKernel(const KernelFunction& kernel,
|
||||
const Argument& arg,
|
||||
const void* dev_gemm_args,
|
||||
const StreamConfig& stream_config) const
|
||||
// polymorphic
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
int max_occupancy_grid_size = CalculateMaxOccupancyGridSize(kernel, stream_config);
|
||||
|
||||
// We launch the smaller number of workgroups from acutally needed tiles and the
|
||||
// number of workgroups that maximize the GPU occupancy. That is because for some tile
|
||||
// configuration the first is smaller than the latter. Launching too many workgroups
|
||||
// mean some of them will have to iterate through all gemm problem descriptors just to
|
||||
// find out they have nothing to do which is of course waste of GPU cycles.
|
||||
return launch_and_time_kernel(
|
||||
stream_config,
|
||||
kernel,
|
||||
dim3(ck::math::min(arg.grid_size_, max_occupancy_grid_size)),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
cast_pointer_to_constant_address_space(dev_gemm_args),
|
||||
arg.grid_size_,
|
||||
arg.K_BATCH);
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -631,6 +506,11 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(!ck::is_xdl_supported())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if((ck::type_convert<ck::index_t>(arg.gemm_kernel_args_.size()) +
|
||||
arg.skipped_group_count_) != arg.group_count_)
|
||||
{
|
||||
@@ -645,28 +525,14 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
|
||||
bool supported = true;
|
||||
for(std::size_t i = 0; i < arg.gemm_kernel_args_.size(); ++i)
|
||||
{
|
||||
const auto& gemm_arg = arg.gemm_kernel_args_[i];
|
||||
const auto K0 = GridwiseGemm::CalculateK0(gemm_arg.K, arg.K_BATCH);
|
||||
bool group_arg_valid = GridwiseGemm::CheckValidity(GridwiseGemmArg{nullptr,
|
||||
nullptr,
|
||||
nullptr,
|
||||
gemm_arg.M,
|
||||
gemm_arg.N,
|
||||
gemm_arg.K,
|
||||
gemm_arg.StrideA,
|
||||
gemm_arg.StrideB,
|
||||
gemm_arg.StrideC,
|
||||
0, // MPadded
|
||||
0, // NPadded
|
||||
0, // KPadded
|
||||
K0,
|
||||
arg.K_BATCH});
|
||||
const auto& a = arg.gemm_kernel_args_[i].karg_;
|
||||
bool group_arg_valid = GridwiseGemm::CheckValidity(a);
|
||||
if(not group_arg_valid)
|
||||
{
|
||||
#if DEBUG_LOG
|
||||
std::cout << "[" << __func__ << "] group id: " << i
|
||||
<< " has invalid GridwiseGemm settings!" << std::endl;
|
||||
gemm_arg.Print();
|
||||
a.Print();
|
||||
#endif // DEBUG_LOG
|
||||
}
|
||||
supported = supported && group_arg_valid;
|
||||
@@ -674,6 +540,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
|
||||
return supported;
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
|
||||
@@ -693,6 +560,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(std::vector<const void*>& p_As,
|
||||
std::vector<const void*>& p_Bs,
|
||||
@@ -706,17 +574,19 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
|
||||
return std::make_unique<Argument>(p_As, p_Bs, p_Es, gemm_descs);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceGroupedGemm_XdlSplitKTileLoop"
|
||||
str << "DeviceGroupedGemm_XdlSplitK"
|
||||
<< "<"
|
||||
<< std::string(ALayout::name)[0] << ","
|
||||
<< std::string(BLayout::name)[0] << ","
|
||||
@@ -735,9 +605,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
|
||||
<< BBlockTransferSrcScalarPerVector << ", "
|
||||
<< CShuffleMXdlPerWavePerShuffle << ", "
|
||||
<< CShuffleNXdlPerWavePerShuffle << ", "
|
||||
<< ABlockTransferThreadClusterLengths_K0_M_K1{} << ", "
|
||||
<< getGemmSpecializationString(GemmSpec) << ", "
|
||||
<< PipelineVer
|
||||
<< getGemmSpecializationString(GemmSpec)
|
||||
<< ">";
|
||||
// clang-format on
|
||||
|
||||
@@ -747,24 +615,16 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
|
||||
size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override
|
||||
{
|
||||
return dynamic_cast<const Argument*>(p_arg)->gemm_kernel_args_.size() *
|
||||
sizeof(KernelArguments);
|
||||
sizeof(GemmTransKernelArg);
|
||||
}
|
||||
|
||||
static void SetKBatchSize(Argument& arg, index_t kbatch) { arg.UpdateKBatch(kbatch); }
|
||||
static void SetDeviceKernelArgs(Argument& arg, const void* p_dev_kernel_args)
|
||||
{
|
||||
arg.p_dev_gemm_args_ = p_dev_kernel_args;
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
void SetKBatchSize(BaseArgument* p_arg, index_t kbatch) const override
|
||||
{
|
||||
return SetKBatchSize(*dynamic_cast<Argument*>(p_arg), kbatch);
|
||||
}
|
||||
|
||||
void SetDeviceKernelArgs(BaseArgument* p_arg, const void* p_dev_kernel_args) const override
|
||||
{
|
||||
return SetDeviceKernelArgs(*dynamic_cast<Argument*>(p_arg), p_dev_kernel_args);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
|
||||
@@ -348,24 +348,24 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType,
|
||||
acc_elementwise_op};
|
||||
};
|
||||
|
||||
///
|
||||
/// @brief Makes a pointer to Argument class.
|
||||
///
|
||||
/// @param[in] inLengths Input tensor extent(s) from high to low dimension
|
||||
/// @param[in] inStrides Input tensor stride(s) from high to low dimension
|
||||
/// @param[in] reduceDims The dimension(s) the normalization operation is applied
|
||||
/// @param[in] alpha Typeless pointer in host memory storing the alpha scaling
|
||||
/// value as type AccDataType
|
||||
/// @param[in] beta Typeless pointer in host memory storing the beta scaling
|
||||
/// value as type AccDataType
|
||||
/// @param[in] in_dev Typeless const pointer in device memory storing the input
|
||||
/// tensor
|
||||
/// @param out_dev Typeless pointer in device memory storing the output tensor
|
||||
/// @param[in] in_elementwise_op The input elementwise operation.
|
||||
/// @param[in] acc_elementwise_op The accumulation elementwise operation.
|
||||
///
|
||||
/// @return Unique pointer to the Argument class.
|
||||
///
|
||||
//
|
||||
// @brief Makes a pointer to Argument class.
|
||||
//
|
||||
// @param[in] inLengths Input tensor extent(s) from high to low dimension
|
||||
// @param[in] inStrides Input tensor stride(s) from high to low dimension
|
||||
// @param[in] reduceDims The dimension(s) the normalization operation is applied
|
||||
// @param[in] alpha Typeless pointer in host memory storing the alpha scaling
|
||||
// value as type AccDataType
|
||||
// @param[in] beta Typeless pointer in host memory storing the beta scaling
|
||||
// value as type AccDataType
|
||||
// @param[in] in_dev Typeless const pointer in device memory storing the input
|
||||
// tensor
|
||||
// @param out_dev Typeless pointer in device memory storing the output tensor
|
||||
// @param[in] in_elementwise_op The input elementwise operation.
|
||||
// @param[in] acc_elementwise_op The accumulation elementwise operation.
|
||||
//
|
||||
// @return Unique pointer to the Argument class.
|
||||
//
|
||||
std::unique_ptr<BaseArgument> MakeArgumentPointer(const std::vector<index_t> inLengths,
|
||||
const std::vector<index_t> inStrides,
|
||||
const std::vector<int> reduceDims,
|
||||
|
||||
Reference in New Issue
Block a user