mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 05:01:25 +00:00
CK-Tile Grouped GEMM refactor and post PR fixes (#1756)
* Grouped gemm simple code refactor * Offset invoker * Invoke generic Run, and replace name of parrtitioner variable * Tests fix type * Removed namespaces * Add template param to avoid implicit cast * Remove generic function * Constant value * underline enum to int16_t * Generalize partitioner function * Remove whitespaces * Rename function * Using support * Clang-format * Clang-format * Fn-partitioner description fn * Typo * Typo 2 * Better description * Better description * Refactor after review * Use ctr instead of set fn * Inovke ctr and typo * Comments * Remove unnecessary comment * Review, remove modulo
This commit is contained in:
@@ -101,9 +101,12 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
|
||||
|
||||
CK_TILE_DEVICE void operator()(BatchedGemmKernelArgs kargs) const
|
||||
{
|
||||
const auto [i_m, i_n] = TilePartitioner{}();
|
||||
const auto i_batch = __builtin_amdgcn_readfirstlane(blockIdx.z / kargs.KBatch);
|
||||
const auto i_k = __builtin_amdgcn_readfirstlane(blockIdx.z - i_batch * kargs.KBatch);
|
||||
const auto [iM, iN] = TilePartitioner::GetOutputTileIndex(blockIdx.x, blockIdx.y);
|
||||
const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
|
||||
const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);
|
||||
|
||||
const auto i_batch = __builtin_amdgcn_readfirstlane(blockIdx.z / kargs.KBatch);
|
||||
const auto i_k = __builtin_amdgcn_readfirstlane(blockIdx.z - i_batch * kargs.KBatch);
|
||||
|
||||
const typename Base::SplitKBatchOffset splitk_batch_offset(kargs, i_k);
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -174,7 +174,7 @@ struct GemmKernel
|
||||
|
||||
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
if(kargs.K % TilePartitioner::kK != 0 && GemmPipeline::kPadK == false)
|
||||
if(kargs.K % TilePartitioner::KPerBlock != 0 && GemmPipeline::kPadK == false)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
@@ -185,7 +185,7 @@ struct GemmKernel
|
||||
}
|
||||
else
|
||||
{
|
||||
if(kargs.M % TilePartitioner::kM != 0 && GemmPipeline::kPadM == false)
|
||||
if(kargs.M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM == false)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
@@ -197,7 +197,7 @@ struct GemmKernel
|
||||
|
||||
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
if(kargs.N % TilePartitioner::kN != 0 && GemmPipeline::kPadN == false)
|
||||
if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
@@ -208,7 +208,7 @@ struct GemmKernel
|
||||
}
|
||||
else
|
||||
{
|
||||
if(kargs.K % TilePartitioner::kK != 0 && GemmPipeline::kPadK == false)
|
||||
if(kargs.K % TilePartitioner::KPerBlock != 0 && GemmPipeline::kPadK == false)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
@@ -220,7 +220,7 @@ struct GemmKernel
|
||||
|
||||
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
if(kargs.N % TilePartitioner::kN != 0 && GemmPipeline::kPadN == false)
|
||||
if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
@@ -231,7 +231,7 @@ struct GemmKernel
|
||||
}
|
||||
else
|
||||
{
|
||||
if(kargs.M % TilePartitioner::kM != 0 && GemmPipeline::kPadM == false)
|
||||
if(kargs.M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM == false)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
@@ -323,17 +323,17 @@ struct GemmKernel
|
||||
const auto& a_tensor_view = views.at(I0);
|
||||
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return pad_tensor_view(
|
||||
a_tensor_view,
|
||||
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kK>{}),
|
||||
sequence<false, GemmPipeline::kPadK>{});
|
||||
return pad_tensor_view(a_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock>{}),
|
||||
sequence<false, GemmPipeline::kPadK>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return pad_tensor_view(
|
||||
a_tensor_view,
|
||||
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kK>{}),
|
||||
sequence<GemmPipeline::kPadM, false>{});
|
||||
return pad_tensor_view(a_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock>{}),
|
||||
sequence<GemmPipeline::kPadM, false>{});
|
||||
}
|
||||
}();
|
||||
|
||||
@@ -341,17 +341,17 @@ struct GemmKernel
|
||||
const auto& b_tensor_view = views.at(I1);
|
||||
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
|
||||
{
|
||||
return pad_tensor_view(
|
||||
b_tensor_view,
|
||||
make_tuple(number<TilePartitioner::kN>{}, number<TilePartitioner::kK>{}),
|
||||
sequence<false, GemmPipeline::kPadK>{});
|
||||
return pad_tensor_view(b_tensor_view,
|
||||
make_tuple(number<TilePartitioner::NPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock>{}),
|
||||
sequence<false, GemmPipeline::kPadK>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return pad_tensor_view(
|
||||
b_tensor_view,
|
||||
make_tuple(number<TilePartitioner::kN>{}, number<TilePartitioner::kK>{}),
|
||||
sequence<GemmPipeline::kPadN, false>{});
|
||||
return pad_tensor_view(b_tensor_view,
|
||||
make_tuple(number<TilePartitioner::NPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock>{}),
|
||||
sequence<GemmPipeline::kPadN, false>{});
|
||||
}
|
||||
}();
|
||||
|
||||
@@ -359,17 +359,17 @@ struct GemmKernel
|
||||
const auto& c_tensor_view = views.at(I2);
|
||||
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return pad_tensor_view(
|
||||
c_tensor_view,
|
||||
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kN>{}),
|
||||
sequence<false, GemmPipeline::kPadN>{});
|
||||
return pad_tensor_view(c_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
sequence<false, GemmPipeline::kPadN>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return pad_tensor_view(
|
||||
c_tensor_view,
|
||||
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kN>{}),
|
||||
sequence<GemmPipeline::kPadM, false>{});
|
||||
return pad_tensor_view(c_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
sequence<GemmPipeline::kPadM, false>{});
|
||||
}
|
||||
}();
|
||||
|
||||
@@ -383,19 +383,19 @@ struct GemmKernel
|
||||
const auto& a_pad_view = views.at(I0);
|
||||
const auto& a_block_window = make_tile_window(
|
||||
a_pad_view,
|
||||
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kK>{}),
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::KPerBlock>{}),
|
||||
{i_m, 0});
|
||||
|
||||
const auto& b_pad_view = views.at(I1);
|
||||
const auto& b_block_window = make_tile_window(
|
||||
b_pad_view,
|
||||
make_tuple(number<TilePartitioner::kN>{}, number<TilePartitioner::kK>{}),
|
||||
make_tuple(number<TilePartitioner::NPerBlock>{}, number<TilePartitioner::KPerBlock>{}),
|
||||
{i_n, 0});
|
||||
|
||||
const auto& c_pad_view = views.at(I2);
|
||||
auto c_block_window = make_tile_window(
|
||||
c_pad_view,
|
||||
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kN>{}),
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
|
||||
{i_m, i_n});
|
||||
|
||||
return make_tuple(a_block_window, b_block_window, c_block_window);
|
||||
@@ -426,7 +426,7 @@ struct GemmKernel
|
||||
// Create Gemm tensor views, pad views and tile windows
|
||||
const auto& gemm_tensor_views_tuple =
|
||||
MakeGemmTensorViews<DstInMemOp>(a_ptr, b_ptr, c_ptr, kargs, splitk_batch_offset);
|
||||
;
|
||||
|
||||
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
|
||||
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
|
||||
|
||||
@@ -456,7 +456,10 @@ struct GemmKernel
|
||||
|
||||
CK_TILE_DEVICE void operator()(GemmKernelArgs kargs) const
|
||||
{
|
||||
const auto [i_m, i_n] = TilePartitioner{}();
|
||||
const auto [iM, iN] = TilePartitioner::GetOutputTileIndex(blockIdx.x, blockIdx.y);
|
||||
const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
|
||||
const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);
|
||||
|
||||
const SplitKBatchOffset splitk_batch_offset(kargs);
|
||||
// options
|
||||
const ADataType* a_ptr =
|
||||
|
||||
@@ -1,73 +1,160 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
template <typename BlockGemmShape_>
|
||||
struct GemmTilePartitioner
|
||||
|
||||
/** @brief Struct representing 2D block index mapping into 3D output tile space. */
|
||||
template <typename BlockGemmShapeType>
|
||||
struct GemmTile2DPartitioner
|
||||
{
|
||||
using BlockGemmShape = remove_cvref_t<BlockGemmShape_>;
|
||||
|
||||
static constexpr index_t kM = BlockGemmShape::kM;
|
||||
static constexpr index_t kN = BlockGemmShape::kN;
|
||||
static constexpr index_t kK = BlockGemmShape::kK;
|
||||
|
||||
CK_TILE_HOST static constexpr auto GridSize(index_t M, index_t N, index_t batch_size)
|
||||
{
|
||||
index_t GridDimX = (M + kM - 1) / kM;
|
||||
index_t GridDimY = (N + kN - 1) / kN;
|
||||
index_t GridDimZ = batch_size;
|
||||
return dim3(GridDimX, GridDimY, GridDimZ);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetLoopNum(index_t K)
|
||||
{
|
||||
return integer_divide_ceil(K, kK);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE auto operator()()
|
||||
{
|
||||
const index_t iM = __builtin_amdgcn_readfirstlane(blockIdx.x * kM);
|
||||
const index_t iN = __builtin_amdgcn_readfirstlane(blockIdx.y * kN);
|
||||
return make_tuple(iM, iN);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename BlockGemmShape_>
|
||||
struct GemmTile1DPartitioner
|
||||
{
|
||||
using BlockGemmShape = remove_cvref_t<BlockGemmShape_>;
|
||||
using BlockGemmShape = remove_cvref_t<BlockGemmShapeType>;
|
||||
|
||||
static constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
static constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
|
||||
CK_TILE_HOST static constexpr auto GridSize(index_t M, index_t N)
|
||||
/** @brief Returns 3D grid size. */
|
||||
CK_TILE_HOST static constexpr auto GridSize(index_t M, index_t N, index_t batch_size) noexcept(
|
||||
noexcept(MPerBlock != 0 && NPerBlock != 0)) -> dim3
|
||||
{
|
||||
index_t GridDimX = (M + MPerBlock - 1) / MPerBlock;
|
||||
index_t GridDimY = (N + NPerBlock - 1) / NPerBlock;
|
||||
return dim3(GridDimX * GridDimY, 1, 1);
|
||||
const index_t GridDimX = (M + MPerBlock - 1) / MPerBlock;
|
||||
const index_t GridDimY = (N + NPerBlock - 1) / NPerBlock;
|
||||
const index_t GridDimZ = batch_size;
|
||||
return dim3(GridDimX, GridDimY, GridDimZ);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetNBlock(index_t N)
|
||||
{
|
||||
return integer_divide_ceil(N, NPerBlock);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetLoopNum(index_t K)
|
||||
/**
|
||||
* @brief Returns the number of loops.
|
||||
* @param [in] K is dimension
|
||||
*/
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetLoopNum(index_t K) noexcept -> index_t
|
||||
{
|
||||
return integer_divide_ceil(K, KPerBlock);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE auto operator()(index_t blockOffset, index_t NBlockSize)
|
||||
/**
|
||||
* @brief The function returns 2D output tile space.
|
||||
* @param [in] blockIdx is blockIdx.x
|
||||
* @param [in] blockIdy is blockIdx.y
|
||||
* @return Returns the output tile indexes.
|
||||
*/
|
||||
CK_TILE_DEVICE static constexpr auto GetOutputTileIndex(index_t blockIdx,
|
||||
index_t blockIdy) noexcept
|
||||
-> const tuple<index_t, index_t>
|
||||
{
|
||||
index_t iM = __builtin_amdgcn_readfirstlane((blockIdx.x - blockOffset) /
|
||||
GetNBlock(NBlockSize) * MPerBlock);
|
||||
index_t iN = __builtin_amdgcn_readfirstlane((blockIdx.x - blockOffset) %
|
||||
GetNBlock(NBlockSize) * NPerBlock);
|
||||
const index_t iM = __builtin_amdgcn_readfirstlane(blockIdx);
|
||||
const index_t iN = __builtin_amdgcn_readfirstlane(blockIdy);
|
||||
return make_tuple(iM, iN);
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Struct representing 1D block index mapping into 2D output tile space.
|
||||
*/
|
||||
template <typename BlockGemmShapeType>
|
||||
struct GemmTile1DPartitioner
|
||||
{
|
||||
using BlockGemmShape = remove_cvref_t<BlockGemmShapeType>;
|
||||
|
||||
static constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
static constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
|
||||
/** @brief delete default ctr with no any object */
|
||||
constexpr GemmTile1DPartitioner() noexcept = delete;
|
||||
|
||||
/** @brief constructs an object that does contain a N value. */
|
||||
constexpr GemmTile1DPartitioner(index_t N) noexcept { N_ = N; }
|
||||
|
||||
/** @brief Returns 1D grid size. */
|
||||
CK_TILE_HOST static constexpr auto
|
||||
GridSize(index_t M, index_t N) noexcept(noexcept(MPerBlock != 0 && NPerBlock != 0)) -> dim3
|
||||
{
|
||||
const index_t GridDimX = (M + MPerBlock - 1) / MPerBlock;
|
||||
const index_t GridDimY = (N + NPerBlock - 1) / NPerBlock;
|
||||
return dim3(GridDimX * GridDimY, 1, 1);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Returns the number of blocks in N.
|
||||
* @param [in] N is dimension
|
||||
*/
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetNBlock(index_t N) noexcept -> index_t
|
||||
{
|
||||
return integer_divide_ceil(N, NPerBlock);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Returns the number of loops.
|
||||
* @param [in] K is dimension
|
||||
*/
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetLoopNum(index_t K) noexcept -> index_t
|
||||
{
|
||||
return integer_divide_ceil(K, KPerBlock);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief The function returns 2D output tile space.
|
||||
* @param [in] blockIdx is blockIdx.x - block_start.
|
||||
* */
|
||||
CK_TILE_DEVICE static constexpr auto GetOutputTileIndex(index_t blockIdx) noexcept
|
||||
-> const tuple<index_t, index_t>
|
||||
{
|
||||
const index_t NBlock = GetNBlock(N_);
|
||||
|
||||
const index_t iM = __builtin_amdgcn_readfirstlane(blockIdx / NBlock);
|
||||
const index_t iN = __builtin_amdgcn_readfirstlane(blockIdx - (iM)*NBlock);
|
||||
return make_tuple(iM, iN);
|
||||
}
|
||||
|
||||
private:
|
||||
CK_TILE_DEVICE static index_t N_;
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief `GemmTile1DPartitioner::GetOutputTileIndex`'s std::false specialization,
|
||||
* checking expression validity in-place for ill-formed.
|
||||
*/
|
||||
template <typename, typename = void>
|
||||
struct HasFnOneArgImpl : std::false_type
|
||||
{
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief `GemmTile1DPartitioner::GetOutputTileIndex`'s std::true specialization,
|
||||
* checking expression validity in-place for well-formed.
|
||||
* @note: `1` - a constant value indicating the number of parameters in the function.
|
||||
*/
|
||||
template <typename T>
|
||||
struct HasFnOneArgImpl<T, std::void_t<decltype(std::declval<T>().GetOutputTileIndex(1))>>
|
||||
: std::true_type
|
||||
{
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Struct used to calculate offseted tile indexes.
|
||||
* @note: The struct supports the 1D-Partitioner mechanism,
|
||||
* enable-if `GetOutputTileIndex`-fn is std::true_type when `GetOutputTileIndex`-fn is well-formed,
|
||||
* otherwise std::false_type.
|
||||
*/
|
||||
template <typename PartitionerFn,
|
||||
typename = typename std::enable_if_t<HasFnOneArgImpl<PartitionerFn>{}>>
|
||||
struct OffsettedTile1DPartitioner
|
||||
{
|
||||
/**
|
||||
* @brief The function subtracts the block's start (offset) from 1D raw-indexes.
|
||||
* @param [in] block_start is `blockIdx.x - block_start`.
|
||||
* @return Returns a `tuple` [Im, In] shifted index, used to shift 1d-tile index.
|
||||
*/
|
||||
[[nodiscard]] CK_TILE_DEVICE static constexpr auto GetOffsetedTileIndex(index_t block_start,
|
||||
index_t N) noexcept
|
||||
-> const tuple<index_t, index_t>
|
||||
{
|
||||
const auto [iM, iN] = PartitionerFn(N).GetOutputTileIndex(blockIdx.x - block_start);
|
||||
return make_tuple(iM, iN);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1,72 +1,79 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
|
||||
#include "ck_tile/core/numeric/math.hpp"
|
||||
#include "ck_tile/core/utility/literals.hpp"
|
||||
#include "ck_tile/core/utility/amd_address_space.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
struct GroupedGemmHostArgs
|
||||
struct GroupedGemmHostArgs : public ck_tile::GemmHostArgs
|
||||
{
|
||||
const void* a_ptr;
|
||||
const void* b_ptr;
|
||||
void* c_ptr;
|
||||
index_t M;
|
||||
index_t N;
|
||||
index_t K;
|
||||
index_t stride_A;
|
||||
index_t stride_B;
|
||||
index_t stride_C;
|
||||
CK_TILE_HOST GroupedGemmHostArgs() noexcept = default;
|
||||
CK_TILE_HOST GroupedGemmHostArgs(const void* a_ptr_,
|
||||
const void* b_ptr_,
|
||||
void* c_ptr_,
|
||||
ck_tile::index_t M_,
|
||||
ck_tile::index_t N_,
|
||||
ck_tile::index_t K_,
|
||||
ck_tile::index_t stride_A_,
|
||||
ck_tile::index_t stride_B_,
|
||||
ck_tile::index_t stride_C_)
|
||||
: GemmHostArgs(a_ptr_, b_ptr_, c_ptr_, KBatch, M_, N_, K_, stride_A_, stride_B_, stride_C_)
|
||||
{
|
||||
}
|
||||
|
||||
private:
|
||||
static constexpr index_t KBatch = 1;
|
||||
};
|
||||
|
||||
template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
|
||||
struct GroupedGemmKernel
|
||||
struct GroupedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, EpiloguePipeline_>
|
||||
{
|
||||
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
|
||||
using GemmPipeline = remove_cvref_t<GemmPipeline_>;
|
||||
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
|
||||
using ALayout = remove_cvref_t<typename GemmPipeline::ALayout>;
|
||||
using BLayout = remove_cvref_t<typename GemmPipeline::BLayout>;
|
||||
using CLayout = remove_cvref_t<typename GemmPipeline::CLayout>;
|
||||
static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize;
|
||||
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
|
||||
using GemmPipeline = remove_cvref_t<GemmPipeline_>;
|
||||
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
|
||||
using ALayout = remove_cvref_t<typename GemmPipeline::ALayout>;
|
||||
using BLayout = remove_cvref_t<typename GemmPipeline::BLayout>;
|
||||
using CLayout = remove_cvref_t<typename GemmPipeline::CLayout>;
|
||||
|
||||
using ADataType = remove_cvref_t<typename GemmPipeline::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename GemmPipeline::BDataType>;
|
||||
using CDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
|
||||
|
||||
using OffsetTile1DPartitioner = OffsettedTile1DPartitioner<TilePartitioner>;
|
||||
using Base = GemmKernel<TilePartitioner_, GemmPipeline_, EpiloguePipeline_>;
|
||||
using GemmKernelArgs = typename Base::GemmKernelArgs;
|
||||
|
||||
static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize;
|
||||
static constexpr index_t KBatch = 1;
|
||||
|
||||
struct GemmTransKernelArg
|
||||
{
|
||||
GroupedGemmHostArgs group_karg;
|
||||
GemmKernelArgs group_karg;
|
||||
ck_tile::index_t block_start;
|
||||
ck_tile::index_t block_end;
|
||||
|
||||
GemmTransKernelArg() = default;
|
||||
GemmTransKernelArg(GroupedGemmHostArgs&& karg, index_t bl_start, index_t bl_end)
|
||||
GemmTransKernelArg(GemmKernelArgs&& karg, index_t bl_start, index_t bl_end)
|
||||
: group_karg{karg}, block_start{bl_start}, block_end{bl_end}
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
__host__ static size_t GetWorkSpaceSize(const std::vector<GroupedGemmHostArgs>& gemm_descs)
|
||||
__host__ static auto GetWorkSpaceSize(const std::vector<GroupedGemmHostArgs>& gemm_descs)
|
||||
-> std::size_t
|
||||
{
|
||||
return gemm_descs.size() * sizeof(GemmTransKernelArg);
|
||||
}
|
||||
|
||||
__host__ static constexpr auto BlockSize() { return dim3(KernelBlockSize); }
|
||||
__host__ static constexpr auto BlockSize() -> dim3 { return dim3(KernelBlockSize); }
|
||||
|
||||
using Hargs = GroupedGemmHostArgs;
|
||||
|
||||
__host__ static constexpr auto GridSize(const std::vector<Hargs>& gemm_descs)
|
||||
__host__ static constexpr auto GridSize(const std::vector<GroupedGemmHostArgs>& gemm_descs)
|
||||
{
|
||||
index_t grid_size = 0;
|
||||
for(const auto& it_desc : gemm_descs)
|
||||
@@ -77,7 +84,8 @@ struct GroupedGemmKernel
|
||||
return dim3(grid_size, 1, 1);
|
||||
}
|
||||
|
||||
CK_TILE_HOST static auto MakeKargs(const std::vector<Hargs>& gemm_descs)
|
||||
CK_TILE_HOST static auto MakeKargs(const std::vector<GroupedGemmHostArgs>& gemm_descs)
|
||||
-> std::vector<GemmTransKernelArg>
|
||||
{
|
||||
std::vector<GemmTransKernelArg> gemm_kernel_args_;
|
||||
index_t group_count = ck_tile::type_convert<ck_tile::index_t>(gemm_descs.size());
|
||||
@@ -100,22 +108,23 @@ struct GroupedGemmKernel
|
||||
const index_t stride_c = gemm_descs[i].stride_C;
|
||||
|
||||
const auto dim3 = TilePartitioner::GridSize(M, N);
|
||||
const index_t grid_size_grp = dim3.x * 1 * 1;
|
||||
const index_t grid_size_grp = dim3.x;
|
||||
|
||||
const index_t block_start = grid_size;
|
||||
const index_t block_end = grid_size + grid_size_grp;
|
||||
|
||||
grid_size += grid_size_grp;
|
||||
|
||||
auto karg = GroupedGemmHostArgs{type_convert<const ADataType*>(gemm_descs[i].a_ptr),
|
||||
type_convert<const BDataType*>(gemm_descs[i].b_ptr),
|
||||
type_convert<CDataType*>(gemm_descs[i].c_ptr),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_a,
|
||||
stride_b,
|
||||
stride_c};
|
||||
auto karg = GemmKernelArgs{type_convert<const ADataType*>(gemm_descs[i].a_ptr),
|
||||
type_convert<const BDataType*>(gemm_descs[i].b_ptr),
|
||||
type_convert<CDataType*>(gemm_descs[i].c_ptr),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_a,
|
||||
stride_b,
|
||||
stride_c,
|
||||
KBatch};
|
||||
|
||||
gemm_kernel_args_.emplace_back(std::move(karg), block_start, block_end);
|
||||
}
|
||||
@@ -123,162 +132,34 @@ struct GroupedGemmKernel
|
||||
return gemm_kernel_args_;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetSmemSize() -> index_t
|
||||
{
|
||||
return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void Run(const Hargs& kargs, const index_t block_start) const
|
||||
CK_TILE_DEVICE void Run(const GemmTransKernelArg& kargs) const
|
||||
{
|
||||
const auto [i_m, i_n] = TilePartitioner{}(block_start, kargs.N);
|
||||
// options
|
||||
const ADataType* a_start = static_cast<const ADataType*>(kargs.a_ptr);
|
||||
const BDataType* b_start = static_cast<const BDataType*>(kargs.b_ptr);
|
||||
// Convert pointers to tensor views
|
||||
auto a_tensor_view = [&]() {
|
||||
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
a_start,
|
||||
make_tuple(kargs.M, kargs.K),
|
||||
make_tuple(kargs.stride_A, 1),
|
||||
number<GemmPipeline::VectorSizeA>{},
|
||||
number<1>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
a_start,
|
||||
make_tuple(kargs.M, kargs.K),
|
||||
make_tuple(1, kargs.stride_A),
|
||||
number<1>{},
|
||||
number<1>{});
|
||||
}
|
||||
}();
|
||||
const auto [iM, iN] =
|
||||
OffsetTile1DPartitioner::GetOffsetedTileIndex(kargs.block_start, kargs.group_karg.N);
|
||||
|
||||
auto b_tensor_view = [&]() {
|
||||
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
b_start,
|
||||
make_tuple(kargs.N, kargs.K),
|
||||
make_tuple(1, kargs.stride_B),
|
||||
number<1>{},
|
||||
number<1>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
b_start,
|
||||
make_tuple(kargs.N, kargs.K),
|
||||
make_tuple(kargs.stride_B, 1),
|
||||
number<GemmPipeline::VectorSizeB>{},
|
||||
number<1>{});
|
||||
}
|
||||
}();
|
||||
const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
|
||||
const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);
|
||||
|
||||
auto a_pad_view = [&]() {
|
||||
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return pad_tensor_view(a_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock>{}),
|
||||
sequence<false, GemmPipeline::kPadK>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return pad_tensor_view(a_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock>{}),
|
||||
sequence<GemmPipeline::kPadM, false>{});
|
||||
}
|
||||
}();
|
||||
// clang-format on
|
||||
const typename Base::SplitKBatchOffset splitk_batch_offset(kargs.group_karg, blockIdx.z);
|
||||
|
||||
auto a_block_window = make_tile_window(
|
||||
a_pad_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::KPerBlock>{}),
|
||||
{i_m, 0});
|
||||
|
||||
auto b_pad_view = [&]() {
|
||||
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
|
||||
{
|
||||
return pad_tensor_view(b_tensor_view,
|
||||
make_tuple(number<TilePartitioner::NPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock>{}),
|
||||
sequence<false, GemmPipeline::kPadK>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return pad_tensor_view(b_tensor_view,
|
||||
make_tuple(number<TilePartitioner::NPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock>{}),
|
||||
sequence<GemmPipeline::kPadN, false>{});
|
||||
}
|
||||
}();
|
||||
|
||||
auto b_block_window = make_tile_window(
|
||||
b_pad_view,
|
||||
make_tuple(number<TilePartitioner::NPerBlock>{}, number<TilePartitioner::KPerBlock>{}),
|
||||
{i_n, 0});
|
||||
const ADataType* a_ptr = static_cast<const ADataType*>(kargs.group_karg.a_ptr);
|
||||
const BDataType* b_ptr = static_cast<const BDataType*>(kargs.group_karg.b_ptr);
|
||||
CDataType* c_ptr = static_cast<CDataType*>(kargs.group_karg.c_ptr);
|
||||
|
||||
// allocate LDS
|
||||
__shared__ char smem_ptr[GetSmemSize()];
|
||||
|
||||
const index_t num_loop = TilePartitioner::GetLoopNum(kargs.K);
|
||||
|
||||
// Run GEMM cooperatively by whole wokrgroup.
|
||||
auto c_block_tile =
|
||||
GemmPipeline{}.template operator()(a_block_window, b_block_window, num_loop, smem_ptr);
|
||||
|
||||
CDataType* c_start = static_cast<CDataType*>(kargs.c_ptr);
|
||||
auto c_tensor_view = [&]() {
|
||||
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
c_start,
|
||||
make_tuple(kargs.M, kargs.N),
|
||||
make_tuple(kargs.stride_C, 1),
|
||||
number<GemmPipeline::VectorSizeC>{},
|
||||
number<1>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
c_start,
|
||||
make_tuple(kargs.M, kargs.N),
|
||||
make_tuple(1, kargs.stride_C),
|
||||
number<1>{},
|
||||
number<1>{});
|
||||
}
|
||||
}();
|
||||
|
||||
auto c_pad_view = [&]() {
|
||||
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return pad_tensor_view(c_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
sequence<false, GemmPipeline::kPadN>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return pad_tensor_view(c_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
sequence<GemmPipeline::kPadM, false>{});
|
||||
}
|
||||
}();
|
||||
auto CBlockWindow_pad = make_tile_window(
|
||||
c_pad_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
|
||||
{i_m, i_n});
|
||||
|
||||
EpiloguePipeline{}(CBlockWindow_pad, c_block_tile);
|
||||
this->RunGemm(
|
||||
a_ptr, b_ptr, c_ptr, smem_ptr, kargs.group_karg, splitk_batch_offset, i_m, i_n);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void operator()(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
|
||||
int group_count) const
|
||||
index_t group_count) const
|
||||
{
|
||||
const index_t block_id = ck_tile::get_block_1d_id();
|
||||
const auto gemm_desc_ptr = reinterpret_cast<const GemmTransKernelArg*>(
|
||||
@@ -286,7 +167,7 @@ struct GroupedGemmKernel
|
||||
|
||||
index_t left = 0;
|
||||
index_t right = group_count;
|
||||
index_t group_id = index_t((left + right) / 2);
|
||||
index_t group_id = index_t((left + right) >> 1);
|
||||
|
||||
while((!(block_id >= gemm_desc_ptr[group_id].block_start &&
|
||||
block_id < gemm_desc_ptr[group_id].block_end)) &&
|
||||
@@ -300,10 +181,10 @@ struct GroupedGemmKernel
|
||||
{
|
||||
left = group_id;
|
||||
}
|
||||
group_id = index_t((left + right) / 2);
|
||||
group_id = index_t((left + right) >> 1);
|
||||
}
|
||||
|
||||
Run(gemm_desc_ptr[group_id].group_karg, gemm_desc_ptr[group_id].block_start);
|
||||
Run(gemm_desc_ptr[group_id]);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user