mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 09:16:52 +00:00
[rocm-libraries] ROCm/rocm-libraries#4961 (commit 6c3969a)
[CK] Remove code duplications in grouped gemm fixed nk implementations (#4961) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Motivation Different flavours of grouped gemm fixed nk implemenations share the same block to tile mapping logic. Despite that the code responsible for it is duplicated in each device struct implementation. - Move `BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops` and `OffsettedBlockToCTileMapMLoops` from the device struct implementations to a common header file. - Use the generic Kernel Argument structures in xdl versions of the fixed nk. ## Technical Details <!-- Explain the changes along with any relevant GitHub links. --> ## Test Plan CI in general. Relevant test and examples are all fixed_nk versions of grouped gemm multiple D and ABD. ## Test Result <!-- Briefly summarize test outcomes. --> ## Submission Checklist - [ ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
e6d1781f20
commit
d4236de1ba
@@ -0,0 +1,167 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
struct DeviceGroupedGemm_Fixed_NK_Common
|
||||
{
|
||||
template <typename UnderlyingBlockToCTileMap, bool HasSplitKSupport = true>
|
||||
struct OffsettedBlockToCTileMapMLoops
|
||||
{
|
||||
using underlying_type = UnderlyingBlockToCTileMap;
|
||||
|
||||
__host__ __device__ OffsettedBlockToCTileMapMLoops(
|
||||
UnderlyingBlockToCTileMap block_to_ctile_map, index_t block_start, index_t id_off = 0)
|
||||
{
|
||||
block_to_ctile_map_ = block_to_ctile_map;
|
||||
block_start_ = block_start;
|
||||
id_off_ = id_off;
|
||||
}
|
||||
|
||||
template <typename TopIdx>
|
||||
__host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
|
||||
{
|
||||
auto idx_bot = block_to_ctile_map_.CalculateBottomIndex(
|
||||
make_multi_index(idx_top[Number<0>{}] - block_start_ + id_off_));
|
||||
|
||||
// Workarounds the fact that gridwise gemm implementations not supporting splitk require
|
||||
// different index mapping.
|
||||
if constexpr(HasSplitKSupport)
|
||||
{
|
||||
return make_tuple(idx_bot[Number<0>{}], idx_bot[Number<1>{}], idx_bot[Number<2>{}]);
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_tuple(idx_bot[Number<1>{}], idx_bot[Number<2>{}]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename CTileIdx, typename CTileDim>
|
||||
__host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx,
|
||||
const CTileDim& c_tile_dim) const
|
||||
{
|
||||
return block_to_ctile_map_.ValidCTileIndex(c_tile_idx, c_tile_dim);
|
||||
}
|
||||
|
||||
template <typename CGridDesc_M_N>
|
||||
__host__ bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const
|
||||
{
|
||||
return block_to_ctile_map_.CheckValidity(c_grid_desc_m_n);
|
||||
}
|
||||
|
||||
template <typename CGridDesc_M_N>
|
||||
__host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const
|
||||
{
|
||||
return block_to_ctile_map_.CalculateGridSize(c_grid_desc_m_n);
|
||||
}
|
||||
|
||||
UnderlyingBlockToCTileMap block_to_ctile_map_;
|
||||
index_t block_start_;
|
||||
index_t id_off_;
|
||||
};
|
||||
|
||||
template <index_t MPerBlock, index_t NPerBlock>
|
||||
struct BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
|
||||
__host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops() = default;
|
||||
|
||||
__host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops(
|
||||
const BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&) = default;
|
||||
__host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops(
|
||||
BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&&) = default;
|
||||
__host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&
|
||||
operator=(const BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&) = default;
|
||||
__host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&
|
||||
operator=(BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&&) = default;
|
||||
|
||||
__host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops(index_t M,
|
||||
index_t N,
|
||||
index_t KBatch,
|
||||
index_t M01 = 8)
|
||||
: M_(M), N_(N), KBatch_(KBatch), M01_(M01)
|
||||
{
|
||||
}
|
||||
|
||||
template <typename CGridDesc_M_N>
|
||||
__host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops(
|
||||
const CGridDesc_M_N& c_grid_desc_m_n, index_t KBatch, index_t M01 = 8)
|
||||
: BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops(
|
||||
c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1), KBatch, M01)
|
||||
{
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr index_t CalculateGridSize(index_t M, index_t N) const
|
||||
{
|
||||
const auto M0 = math::integer_divide_ceil(M, MPerBlock);
|
||||
const auto N0 = math::integer_divide_ceil(N, NPerBlock);
|
||||
|
||||
return M0 * N0 * KBatch_;
|
||||
}
|
||||
|
||||
template <typename CGridDesc_M_N>
|
||||
__host__ __device__ constexpr index_t
|
||||
CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const
|
||||
{
|
||||
return CalculateGridSize(c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1));
|
||||
}
|
||||
|
||||
template <typename CGridDesc_M_N>
|
||||
__host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename TopIdx>
|
||||
__host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
|
||||
{
|
||||
auto block_1d_id = idx_top[I0];
|
||||
|
||||
const auto M0 = math::integer_divide_ceil(M_, MPerBlock);
|
||||
const auto N0 = math::integer_divide_ceil(N_, NPerBlock);
|
||||
|
||||
block_1d_id = block_1d_id % (M0 * N0 * KBatch_); // hide groups
|
||||
|
||||
const index_t idx_ksplit = block_1d_id / (M0 * N0);
|
||||
block_1d_id = block_1d_id % (M0 * N0);
|
||||
|
||||
index_t idx_N0 = block_1d_id % N0;
|
||||
index_t idx_M0 = block_1d_id / N0;
|
||||
|
||||
const auto M01_adapt = (idx_M0 < M0 - M0 % M01_) ? M01_ : M0 % M01_;
|
||||
|
||||
index_t idx_M00 = idx_M0 / M01_;
|
||||
index_t idx_M01 = idx_M0 % M01_;
|
||||
index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0;
|
||||
|
||||
return make_tuple(idx_ksplit,
|
||||
idx_N0_M01_local % M01_adapt + idx_M00 * M01_,
|
||||
idx_N0_M01_local / M01_adapt);
|
||||
}
|
||||
|
||||
template <typename CTileIdx, typename CTileDim>
|
||||
__host__ __device__ bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */,
|
||||
const CTileDim& /* c_tile_dim */) const
|
||||
{
|
||||
return true; // always valid provided that user gets grid size from CalculateGridSize()
|
||||
}
|
||||
|
||||
private:
|
||||
index_t M_;
|
||||
index_t N_;
|
||||
index_t KBatch_;
|
||||
index_t M01_;
|
||||
};
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -21,6 +21,7 @@
|
||||
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd_fixed_nk.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_fixed_nk_common.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
|
||||
@@ -302,149 +303,11 @@ struct DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK
|
||||
false,
|
||||
false>;
|
||||
|
||||
// TODO: Block to tile mappings could potentially moved out to avoid code duplications between
|
||||
// different device implementations.
|
||||
|
||||
template <typename UnderlyingBlockToCTileMap>
|
||||
struct OffsettedBlockToCTileMapMLoops
|
||||
{
|
||||
using underlying_type = UnderlyingBlockToCTileMap;
|
||||
|
||||
__host__ __device__ OffsettedBlockToCTileMapMLoops(
|
||||
UnderlyingBlockToCTileMap block_to_ctile_map, index_t block_start, index_t id_off = 0)
|
||||
{
|
||||
block_to_ctile_map_ = block_to_ctile_map;
|
||||
block_start_ = block_start;
|
||||
id_off_ = id_off;
|
||||
}
|
||||
|
||||
template <typename TopIdx>
|
||||
__host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
|
||||
{
|
||||
auto idx_bot = block_to_ctile_map_.CalculateBottomIndex(
|
||||
make_multi_index(idx_top[Number<0>{}] - block_start_ + id_off_));
|
||||
|
||||
return make_tuple(idx_bot[Number<0>{}], idx_bot[Number<1>{}], idx_bot[Number<2>{}]);
|
||||
}
|
||||
|
||||
template <typename CTileIdx, typename CTileDim>
|
||||
__host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx,
|
||||
const CTileDim& c_tile_dim) const
|
||||
{
|
||||
return block_to_ctile_map_.ValidCTileIndex(c_tile_idx, c_tile_dim);
|
||||
}
|
||||
|
||||
template <typename CGridDesc_M_N>
|
||||
__host__ bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const
|
||||
{
|
||||
return block_to_ctile_map_.CheckValidity(c_grid_desc_m_n);
|
||||
}
|
||||
|
||||
template <typename CGridDesc_M_N>
|
||||
__host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const
|
||||
{
|
||||
return block_to_ctile_map_.CalculateGridSize(c_grid_desc_m_n);
|
||||
}
|
||||
|
||||
UnderlyingBlockToCTileMap block_to_ctile_map_;
|
||||
index_t block_start_;
|
||||
index_t id_off_;
|
||||
};
|
||||
|
||||
template <index_t MPerBlock_, index_t NPerBlock_>
|
||||
struct BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops
|
||||
{
|
||||
__host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops() = default;
|
||||
|
||||
__host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops(
|
||||
const BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&) = default;
|
||||
__host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops(
|
||||
BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&&) = default;
|
||||
__host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&
|
||||
operator=(const BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&) = default;
|
||||
__host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&
|
||||
operator=(BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&&) = default;
|
||||
|
||||
__host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops(index_t M,
|
||||
index_t N,
|
||||
index_t KBatch,
|
||||
index_t M01 = 8)
|
||||
: M_(M), N_(N), KBatch_(KBatch), M01_(M01)
|
||||
{
|
||||
}
|
||||
|
||||
template <typename CGridDesc_M_N>
|
||||
__host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops(
|
||||
const CGridDesc_M_N& c_grid_desc_m_n, index_t KBatch, index_t M01 = 8)
|
||||
: BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops(
|
||||
c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1), KBatch, M01)
|
||||
{
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr index_t CalculateGridSize(index_t M, index_t N) const
|
||||
{
|
||||
const auto M0 = math::integer_divide_ceil(M, MPerBlock);
|
||||
const auto N0 = math::integer_divide_ceil(N, NPerBlock);
|
||||
|
||||
return M0 * N0 * KBatch_;
|
||||
}
|
||||
|
||||
template <typename CGridDesc_M_N>
|
||||
__host__ __device__ constexpr index_t
|
||||
CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const
|
||||
{
|
||||
return CalculateGridSize(c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1));
|
||||
}
|
||||
|
||||
template <typename CGridDesc_M_N>
|
||||
__host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename TopIdx>
|
||||
__host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
|
||||
{
|
||||
auto block_1d_id = idx_top[I0];
|
||||
|
||||
const auto M0 = math::integer_divide_ceil(M_, MPerBlock_);
|
||||
const auto N0 = math::integer_divide_ceil(N_, NPerBlock_);
|
||||
|
||||
block_1d_id = block_1d_id % (M0 * N0 * KBatch_); // hide groups
|
||||
|
||||
const index_t idx_ksplit = block_1d_id / (M0 * N0);
|
||||
block_1d_id = block_1d_id % (M0 * N0);
|
||||
|
||||
index_t idx_N0 = block_1d_id % N0;
|
||||
index_t idx_M0 = block_1d_id / N0;
|
||||
|
||||
const auto M01_adapt = (idx_M0 < M0 - M0 % M01_) ? M01_ : M0 % M01_;
|
||||
|
||||
index_t idx_M00 = idx_M0 / M01_;
|
||||
index_t idx_M01 = idx_M0 % M01_;
|
||||
index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0;
|
||||
|
||||
return make_tuple(idx_ksplit,
|
||||
idx_N0_M01_local % M01_adapt + idx_M00 * M01_,
|
||||
idx_N0_M01_local / M01_adapt);
|
||||
}
|
||||
|
||||
template <typename CTileIdx, typename CTileDim>
|
||||
__host__ __device__ bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */,
|
||||
const CTileDim& /* c_tile_dim */) const
|
||||
{
|
||||
return true; // always valid provided that user gets grid size from CalculateGridSize()
|
||||
}
|
||||
|
||||
private:
|
||||
index_t M_;
|
||||
index_t N_;
|
||||
index_t KBatch_;
|
||||
index_t M01_;
|
||||
};
|
||||
|
||||
using Block2ETileMap = BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops<MPerBlock, NPerBlock>;
|
||||
using GroupedGemmBlock2ETileMap = OffsettedBlockToCTileMapMLoops<Block2ETileMap>;
|
||||
using Block2ETileMap =
|
||||
DeviceGroupedGemm_Fixed_NK_Common::BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops<MPerBlock,
|
||||
NPerBlock>;
|
||||
using GroupedGemmBlock2ETileMap =
|
||||
DeviceGroupedGemm_Fixed_NK_Common::OffsettedBlockToCTileMapMLoops<Block2ETileMap>;
|
||||
|
||||
static constexpr index_t DefaultKBatch = 1; // implementation only supports KBatch == 1
|
||||
using KernelArgument = typename GridwiseGemm::Argument;
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd_fixed_nk.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_fixed_nk_common.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
@@ -268,167 +269,14 @@ struct DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK
|
||||
LoopSched>;
|
||||
using GridwiseGemm64 = GridwiseGemmBase<math::max(NXdlPerWave64, 1)>;
|
||||
using GridwiseGemm32 = GridwiseGemmBase<NXdlPerWave32>;
|
||||
template <typename UnderlyingBlockToCTileMap>
|
||||
struct OffsettedBlockToCTileMapMLoops
|
||||
{
|
||||
using underlying_type = UnderlyingBlockToCTileMap;
|
||||
|
||||
__host__ __device__ OffsettedBlockToCTileMapMLoops(
|
||||
UnderlyingBlockToCTileMap block_to_ctile_map, index_t block_start, index_t id_off = 0)
|
||||
{
|
||||
block_to_ctile_map_ = block_to_ctile_map;
|
||||
block_start_ = block_start;
|
||||
id_off_ = id_off;
|
||||
}
|
||||
using Block2ETileMap =
|
||||
DeviceGroupedGemm_Fixed_NK_Common::BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops<MPerBlock,
|
||||
NPerBlock>;
|
||||
using GroupedGemmBlock2ETileMap =
|
||||
DeviceGroupedGemm_Fixed_NK_Common::OffsettedBlockToCTileMapMLoops<Block2ETileMap, false>;
|
||||
|
||||
template <typename TopIdx>
|
||||
__host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
|
||||
{
|
||||
auto idx_bot = block_to_ctile_map_.CalculateBottomIndex(
|
||||
make_multi_index(idx_top[Number<0>{}] - block_start_ + id_off_));
|
||||
|
||||
return make_tuple(
|
||||
// idx_bot[Number<0>{}],
|
||||
idx_bot[Number<1>{}],
|
||||
idx_bot[Number<2>{}]);
|
||||
}
|
||||
|
||||
template <typename CTileIdx, typename CTileDim>
|
||||
__host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx,
|
||||
const CTileDim& c_tile_dim) const
|
||||
{
|
||||
return block_to_ctile_map_.ValidCTileIndex(c_tile_idx, c_tile_dim);
|
||||
}
|
||||
|
||||
template <typename CGridDesc_M_N>
|
||||
__host__ bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const
|
||||
{
|
||||
return block_to_ctile_map_.CheckValidity(c_grid_desc_m_n);
|
||||
}
|
||||
|
||||
template <typename CGridDesc_M_N>
|
||||
__host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const
|
||||
{
|
||||
return block_to_ctile_map_.CalculateGridSize(c_grid_desc_m_n);
|
||||
}
|
||||
|
||||
UnderlyingBlockToCTileMap block_to_ctile_map_;
|
||||
index_t block_start_;
|
||||
index_t id_off_;
|
||||
};
|
||||
|
||||
template <index_t MPerBlock_, index_t NPerBlock_>
|
||||
struct BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
|
||||
__host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops() = default;
|
||||
|
||||
__host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops(
|
||||
const BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&) = default;
|
||||
__host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops(
|
||||
BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&&) = default;
|
||||
__host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&
|
||||
operator=(const BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&) = default;
|
||||
__host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&
|
||||
operator=(BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&&) = default;
|
||||
|
||||
__host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops(index_t M,
|
||||
index_t N,
|
||||
index_t KBatch,
|
||||
index_t M01 = 8)
|
||||
: M_(M), N_(N), KBatch_(KBatch), M01_(M01)
|
||||
{
|
||||
}
|
||||
|
||||
template <typename CGridDesc_M_N>
|
||||
__host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops(
|
||||
const CGridDesc_M_N& c_grid_desc_m_n, index_t KBatch, index_t M01 = 8)
|
||||
: BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops(
|
||||
c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1), KBatch, M01)
|
||||
{
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr index_t CalculateGridSize(index_t M, index_t N) const
|
||||
{
|
||||
const auto M0 = math::integer_divide_ceil(M, MPerBlock);
|
||||
const auto N0 = math::integer_divide_ceil(N, NPerBlock);
|
||||
|
||||
return M0 * N0 * KBatch_;
|
||||
}
|
||||
|
||||
template <typename CGridDesc_M_N>
|
||||
__host__ __device__ constexpr index_t
|
||||
CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const
|
||||
{
|
||||
return CalculateGridSize(c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1));
|
||||
}
|
||||
|
||||
template <typename CGridDesc_M_N>
|
||||
__host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename TopIdx>
|
||||
__host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
|
||||
{
|
||||
auto block_1d_id = idx_top[I0];
|
||||
|
||||
const auto M0 = math::integer_divide_ceil(M_, MPerBlock_);
|
||||
const auto N0 = math::integer_divide_ceil(N_, NPerBlock_);
|
||||
|
||||
block_1d_id = block_1d_id % (M0 * N0 * KBatch_); // hide groups
|
||||
|
||||
const index_t idx_ksplit = block_1d_id / (M0 * N0);
|
||||
block_1d_id = block_1d_id % (M0 * N0);
|
||||
|
||||
index_t idx_N0 = block_1d_id % N0;
|
||||
index_t idx_M0 = block_1d_id / N0;
|
||||
|
||||
const auto M01_adapt = (idx_M0 < M0 - M0 % M01_) ? M01_ : M0 % M01_;
|
||||
|
||||
index_t idx_M00 = idx_M0 / M01_;
|
||||
index_t idx_M01 = idx_M0 % M01_;
|
||||
index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0;
|
||||
|
||||
return make_tuple(idx_ksplit,
|
||||
idx_N0_M01_local % M01_adapt + idx_M00 * M01_,
|
||||
idx_N0_M01_local / M01_adapt);
|
||||
}
|
||||
|
||||
template <typename CTileIdx, typename CTileDim>
|
||||
__host__ __device__ bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */,
|
||||
const CTileDim& /* c_tile_dim */) const
|
||||
{
|
||||
return true; // always valid provided that user gets grid size from CalculateGridSize()
|
||||
}
|
||||
|
||||
private:
|
||||
index_t M_;
|
||||
index_t N_;
|
||||
index_t KBatch_;
|
||||
index_t M01_;
|
||||
};
|
||||
|
||||
using Block2ETileMap = BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops<MPerBlock, NPerBlock>;
|
||||
using GroupedGemmBlock2ETileMap = OffsettedBlockToCTileMapMLoops<Block2ETileMap>;
|
||||
|
||||
struct GemmBiasTransKernelArg
|
||||
{
|
||||
// pointers
|
||||
std::array<const void*, NumATensor> as_ptr_;
|
||||
std::array<const void*, NumBTensor> bs_ptr_;
|
||||
std::array<const void*, NumDTensor> ds_ptr_;
|
||||
void* e_ptr_;
|
||||
|
||||
index_t M_, N_, K_;
|
||||
std::array<index_t, NumATensor> StrideAs_;
|
||||
std::array<index_t, NumBTensor> StrideBs_;
|
||||
std::array<index_t, NumDTensor> StrideDs_;
|
||||
index_t StrideE_;
|
||||
};
|
||||
using KernelArgument = GroupedGemmMultiABDKernelArgument<NumATensor, NumBTensor, NumDTensor>;
|
||||
|
||||
// Argument
|
||||
struct Argument : public BaseArgument
|
||||
@@ -537,7 +385,7 @@ struct DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK
|
||||
throw std::runtime_error("wrong! block_2_etile_map validation failed");
|
||||
}
|
||||
|
||||
gemm_desc_kernel_arg_.push_back(GemmBiasTransKernelArg{
|
||||
gemm_desc_kernel_arg_.push_back(KernelArgument{
|
||||
p_as_grid,
|
||||
p_bs_grid,
|
||||
p_ds_grid,
|
||||
@@ -556,7 +404,7 @@ struct DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK
|
||||
|
||||
const auto e_grid_desc_sum_m_n =
|
||||
GridwiseGemm64::template MakeEGridDescriptor_M_N<ELayout, GemmSpec>(
|
||||
sum_of_m, gemm_desc_kernel_arg_[0].N_, gemm_desc_kernel_arg_[0].StrideE_);
|
||||
sum_of_m, gemm_desc_kernel_arg_[0].N, gemm_desc_kernel_arg_[0].StrideE);
|
||||
|
||||
const auto local_b2c_tile_map = Block2ETileMap{e_grid_desc_sum_m_n, 1};
|
||||
|
||||
@@ -570,7 +418,7 @@ struct DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK
|
||||
BElementwiseOperation b_element_op_;
|
||||
CDEElementwiseOperation c_element_op_;
|
||||
|
||||
std::vector<GemmBiasTransKernelArg> gemm_desc_kernel_arg_;
|
||||
std::vector<KernelArgument> gemm_desc_kernel_arg_;
|
||||
std::vector<Tuple<index_t, index_t>> a_mtx_mraw_kraw_;
|
||||
std::vector<Tuple<index_t, index_t>> b_mtx_nraw_kraw_;
|
||||
|
||||
@@ -596,7 +444,7 @@ struct DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK
|
||||
|
||||
for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); i++)
|
||||
{
|
||||
if(GridwiseGemm::CalculateHasMainKBlockLoop(arg.gemm_desc_kernel_arg_[i].K_) !=
|
||||
if(GridwiseGemm::CalculateHasMainKBlockLoop(arg.gemm_desc_kernel_arg_[i].K) !=
|
||||
has_main_k_block_loop)
|
||||
{
|
||||
throw std::runtime_error("wrong! not all gemm has_main_k_block_loop");
|
||||
@@ -729,7 +577,7 @@ struct DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK
|
||||
{
|
||||
if(get_warp_size() == 64)
|
||||
{
|
||||
if(GridwiseGemm64::CalculateHasMainKBlockLoop(arg.gemm_desc_kernel_arg_[i].K_) !=
|
||||
if(GridwiseGemm64::CalculateHasMainKBlockLoop(arg.gemm_desc_kernel_arg_[i].K) !=
|
||||
true)
|
||||
{
|
||||
supported = false;
|
||||
@@ -737,7 +585,7 @@ struct DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK
|
||||
}
|
||||
else
|
||||
{
|
||||
if(GridwiseGemm32::CalculateHasMainKBlockLoop(arg.gemm_desc_kernel_arg_[i].K_) !=
|
||||
if(GridwiseGemm32::CalculateHasMainKBlockLoop(arg.gemm_desc_kernel_arg_[i].K) !=
|
||||
true)
|
||||
{
|
||||
supported = false;
|
||||
|
||||
@@ -20,6 +20,7 @@
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_fixed_nk.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_fixed_nk_common.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp"
|
||||
|
||||
namespace ck {
|
||||
@@ -328,152 +329,11 @@ struct DeviceGroupedGemm_Wmma_Fixed_Nk : public DeviceGroupedGemmFixedNK<ALayout
|
||||
remove_cvref_t<decltype(GridwiseGemm::template MakeDEGridDescriptor_M_N<ELayout>(
|
||||
1, 1, 1, 1, 1))>;
|
||||
|
||||
template <typename UnderlyingBlockToCTileMap>
|
||||
struct OffsettedBlockToCTileMapMLoops
|
||||
{
|
||||
using underlying_type = UnderlyingBlockToCTileMap;
|
||||
|
||||
__host__ __device__ OffsettedBlockToCTileMapMLoops(
|
||||
UnderlyingBlockToCTileMap block_to_ctile_map, index_t block_start, index_t id_off = 0)
|
||||
{
|
||||
block_to_ctile_map_ = block_to_ctile_map;
|
||||
block_start_ = block_start;
|
||||
id_off_ = id_off;
|
||||
}
|
||||
|
||||
template <typename TopIdx>
|
||||
__host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
|
||||
{
|
||||
auto idx_bot = block_to_ctile_map_.CalculateBottomIndex(
|
||||
make_multi_index(idx_top[Number<0>{}] - block_start_ + id_off_));
|
||||
|
||||
return make_tuple(idx_bot[Number<0>{}], idx_bot[Number<1>{}], idx_bot[Number<2>{}]);
|
||||
}
|
||||
|
||||
template <typename CTileIdx, typename CTileDim>
|
||||
__host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx,
|
||||
const CTileDim& c_tile_dim) const
|
||||
{
|
||||
return block_to_ctile_map_.ValidCTileIndex(c_tile_idx, c_tile_dim);
|
||||
}
|
||||
|
||||
template <typename CGridDesc_M_N>
|
||||
__host__ bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const
|
||||
{
|
||||
return block_to_ctile_map_.CheckValidity(c_grid_desc_m_n);
|
||||
}
|
||||
|
||||
template <typename CGridDesc_M_N>
|
||||
__host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const
|
||||
{
|
||||
return block_to_ctile_map_.CalculateGridSize(c_grid_desc_m_n);
|
||||
}
|
||||
|
||||
UnderlyingBlockToCTileMap block_to_ctile_map_;
|
||||
index_t block_start_;
|
||||
index_t id_off_;
|
||||
};
|
||||
|
||||
template <index_t MPerBlock_, index_t NPerBlock_>
|
||||
struct BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
|
||||
__host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops() = default;
|
||||
|
||||
__host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops(
|
||||
const BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&) = default;
|
||||
__host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops(
|
||||
BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&&) = default;
|
||||
__host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&
|
||||
operator=(const BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&) = default;
|
||||
__host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&
|
||||
operator=(BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&&) = default;
|
||||
|
||||
__host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops(index_t M,
|
||||
index_t N,
|
||||
index_t KBatch,
|
||||
index_t M01 = 8)
|
||||
: M_(M), N_(N), KBatch_(KBatch), M01_(M01)
|
||||
{
|
||||
}
|
||||
|
||||
template <typename CGridDesc_M_N>
|
||||
__host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops(
|
||||
const CGridDesc_M_N& c_grid_desc_m_n, index_t KBatch, index_t M01 = 8)
|
||||
: BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops(
|
||||
c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1), KBatch, M01)
|
||||
{
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr index_t CalculateGridSize(index_t M, index_t N) const
|
||||
{
|
||||
const auto M0 = math::integer_divide_ceil(M, MPerBlock);
|
||||
const auto N0 = math::integer_divide_ceil(N, NPerBlock);
|
||||
|
||||
return M0 * N0 * KBatch_;
|
||||
}
|
||||
|
||||
template <typename CGridDesc_M_N>
|
||||
__host__ __device__ constexpr index_t
|
||||
CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const
|
||||
{
|
||||
return CalculateGridSize(c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1));
|
||||
}
|
||||
|
||||
template <typename CGridDesc_M_N>
|
||||
__host__ bool CheckValidity(const CGridDesc_M_N&) const
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename TopIdx>
|
||||
__host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
|
||||
{
|
||||
auto block_1d_id = idx_top[I0];
|
||||
|
||||
const auto M0 = math::integer_divide_ceil(M_, MPerBlock_);
|
||||
const auto N0 = math::integer_divide_ceil(N_, NPerBlock_);
|
||||
|
||||
const auto total_tiles_per_group = M0 * N0 * KBatch_;
|
||||
|
||||
// wrap block id into this group
|
||||
block_1d_id = block_1d_id % total_tiles_per_group;
|
||||
|
||||
const index_t idx_ksplit = block_1d_id / (M0 * N0);
|
||||
block_1d_id = block_1d_id % (M0 * N0);
|
||||
|
||||
index_t idx_N0 = block_1d_id % N0;
|
||||
index_t idx_M0 = block_1d_id / N0;
|
||||
|
||||
const auto M01_adapt = (idx_M0 < M0 - M0 % M01_) ? M01_ : M0 % M01_;
|
||||
|
||||
index_t idx_M00 = idx_M0 / M01_;
|
||||
index_t idx_M01 = idx_M0 % M01_;
|
||||
index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0;
|
||||
|
||||
return make_tuple(idx_ksplit,
|
||||
idx_N0_M01_local % M01_adapt + idx_M00 * M01_,
|
||||
idx_N0_M01_local / M01_adapt);
|
||||
}
|
||||
|
||||
template <typename CTileIdx, typename CTileDim>
|
||||
__host__ __device__ bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */,
|
||||
const CTileDim& /* c_tile_dim */) const
|
||||
{
|
||||
return true; // always valid provided that user gets grid size from CalculateGridSize()
|
||||
}
|
||||
|
||||
private:
|
||||
index_t M_;
|
||||
index_t N_;
|
||||
index_t KBatch_;
|
||||
index_t M01_;
|
||||
};
|
||||
|
||||
using Block2ETileMap = BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops<MPerBlock, NPerBlock>;
|
||||
using GroupedGemmBlock2ETileMap = OffsettedBlockToCTileMapMLoops<Block2ETileMap>;
|
||||
using Block2ETileMap =
|
||||
DeviceGroupedGemm_Fixed_NK_Common::BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops<MPerBlock,
|
||||
NPerBlock>;
|
||||
using GroupedGemmBlock2ETileMap =
|
||||
DeviceGroupedGemm_Fixed_NK_Common::OffsettedBlockToCTileMapMLoops<Block2ETileMap>;
|
||||
|
||||
static constexpr index_t DefaultKBatch = 1;
|
||||
using KernelArgument = typename GridwiseGemm::Argument;
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_fixed_nk.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_fixed_nk_common.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
@@ -309,164 +310,13 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
|
||||
using GridwiseGemm64 = GridwiseGemmBase<math::max(NXdlPerWave64, 1)>;
|
||||
using GridwiseGemm32 = GridwiseGemmBase<NXdlPerWave32>;
|
||||
|
||||
template <typename UnderlyingBlockToCTileMap>
|
||||
struct OffsettedBlockToCTileMapMLoops
|
||||
{
|
||||
using underlying_type = UnderlyingBlockToCTileMap;
|
||||
using Block2ETileMap =
|
||||
DeviceGroupedGemm_Fixed_NK_Common::BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops<MPerBlock,
|
||||
NPerBlock>;
|
||||
using GroupedGemmBlock2ETileMap =
|
||||
DeviceGroupedGemm_Fixed_NK_Common::OffsettedBlockToCTileMapMLoops<Block2ETileMap>;
|
||||
|
||||
__host__ __device__ OffsettedBlockToCTileMapMLoops(
|
||||
UnderlyingBlockToCTileMap block_to_ctile_map, index_t block_start, index_t id_off = 0)
|
||||
{
|
||||
block_to_ctile_map_ = block_to_ctile_map;
|
||||
block_start_ = block_start;
|
||||
id_off_ = id_off;
|
||||
}
|
||||
|
||||
template <typename TopIdx>
|
||||
__host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
|
||||
{
|
||||
auto idx_bot = block_to_ctile_map_.CalculateBottomIndex(
|
||||
make_multi_index(idx_top[Number<0>{}] - block_start_ + id_off_));
|
||||
|
||||
return make_tuple(idx_bot[Number<0>{}], idx_bot[Number<1>{}], idx_bot[Number<2>{}]);
|
||||
}
|
||||
|
||||
template <typename CTileIdx, typename CTileDim>
|
||||
__host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx,
|
||||
const CTileDim& c_tile_dim) const
|
||||
{
|
||||
return block_to_ctile_map_.ValidCTileIndex(c_tile_idx, c_tile_dim);
|
||||
}
|
||||
|
||||
template <typename CGridDesc_M_N>
|
||||
__host__ bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const
|
||||
{
|
||||
return block_to_ctile_map_.CheckValidity(c_grid_desc_m_n);
|
||||
}
|
||||
|
||||
template <typename CGridDesc_M_N>
|
||||
__host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const
|
||||
{
|
||||
return block_to_ctile_map_.CalculateGridSize(c_grid_desc_m_n);
|
||||
}
|
||||
|
||||
UnderlyingBlockToCTileMap block_to_ctile_map_;
|
||||
index_t block_start_;
|
||||
index_t id_off_;
|
||||
};
|
||||
|
||||
template <index_t MPerBlock_, index_t NPerBlock_>
|
||||
struct BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
|
||||
__host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops() = default;
|
||||
|
||||
__host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops(
|
||||
const BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&) = default;
|
||||
__host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops(
|
||||
BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&&) = default;
|
||||
__host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&
|
||||
operator=(const BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&) = default;
|
||||
__host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&
|
||||
operator=(BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&&) = default;
|
||||
|
||||
__host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops(index_t M,
|
||||
index_t N,
|
||||
index_t KBatch,
|
||||
index_t M01 = 8)
|
||||
: M_(M), N_(N), KBatch_(KBatch), M01_(M01)
|
||||
{
|
||||
}
|
||||
|
||||
template <typename CGridDesc_M_N>
|
||||
__host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops(
|
||||
const CGridDesc_M_N& c_grid_desc_m_n, index_t KBatch, index_t M01 = 8)
|
||||
: BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops(
|
||||
c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1), KBatch, M01)
|
||||
{
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr index_t CalculateGridSize(index_t M, index_t N) const
|
||||
{
|
||||
const auto M0 = math::integer_divide_ceil(M, MPerBlock);
|
||||
const auto N0 = math::integer_divide_ceil(N, NPerBlock);
|
||||
|
||||
return M0 * N0 * KBatch_;
|
||||
}
|
||||
|
||||
template <typename CGridDesc_M_N>
|
||||
__host__ __device__ constexpr index_t
|
||||
CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const
|
||||
{
|
||||
return CalculateGridSize(c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1));
|
||||
}
|
||||
|
||||
template <typename CGridDesc_M_N>
|
||||
__host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename TopIdx>
|
||||
__host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
|
||||
{
|
||||
auto block_1d_id = idx_top[I0];
|
||||
|
||||
const auto M0 = math::integer_divide_ceil(M_, MPerBlock_);
|
||||
const auto N0 = math::integer_divide_ceil(N_, NPerBlock_);
|
||||
|
||||
block_1d_id = block_1d_id % (M0 * N0 * KBatch_); // hide groups
|
||||
|
||||
const index_t idx_ksplit = block_1d_id / (M0 * N0);
|
||||
block_1d_id = block_1d_id % (M0 * N0);
|
||||
|
||||
index_t idx_N0 = block_1d_id % N0;
|
||||
index_t idx_M0 = block_1d_id / N0;
|
||||
|
||||
const auto M01_adapt = (idx_M0 < M0 - M0 % M01_) ? M01_ : M0 % M01_;
|
||||
|
||||
index_t idx_M00 = idx_M0 / M01_;
|
||||
index_t idx_M01 = idx_M0 % M01_;
|
||||
index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0;
|
||||
|
||||
return make_tuple(idx_ksplit,
|
||||
idx_N0_M01_local % M01_adapt + idx_M00 * M01_,
|
||||
idx_N0_M01_local / M01_adapt);
|
||||
}
|
||||
|
||||
template <typename CTileIdx, typename CTileDim>
|
||||
__host__ __device__ bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */,
|
||||
const CTileDim& /* c_tile_dim */) const
|
||||
{
|
||||
return true; // always valid provided that user gets grid size from CalculateGridSize()
|
||||
}
|
||||
|
||||
private:
|
||||
index_t M_;
|
||||
index_t N_;
|
||||
index_t KBatch_;
|
||||
index_t M01_;
|
||||
};
|
||||
|
||||
using Block2ETileMap = BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops<MPerBlock, NPerBlock>;
|
||||
using GroupedGemmBlock2ETileMap = OffsettedBlockToCTileMapMLoops<Block2ETileMap>;
|
||||
|
||||
// TODO: replace with GroupedGemmKernelArgument
|
||||
struct GemmBiasTransKernelArg
|
||||
{
|
||||
// pointers
|
||||
const void* a_ptr_;
|
||||
const void* b_ptr_;
|
||||
std::array<const void*, NumDTensor> ds_ptr_;
|
||||
void* e_ptr_;
|
||||
|
||||
index_t M_, N_, K_;
|
||||
index_t StrideA_, StrideB_;
|
||||
std::array<index_t, NumDTensor> StrideDs_;
|
||||
index_t StrideE_;
|
||||
};
|
||||
using KernelArgument = GroupedGemmKernelArgument<NumDTensor>;
|
||||
|
||||
// Argument
|
||||
struct Argument : public BaseArgument
|
||||
@@ -484,8 +334,8 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
|
||||
|
||||
const index_t AverM = math::integer_divide_ceil(sum_of_m, group_count_);
|
||||
|
||||
const index_t StrideE = gemm_desc_kernel_arg_[0].StrideE_;
|
||||
const index_t N = gemm_desc_kernel_arg_[0].N_;
|
||||
const index_t StrideE = gemm_desc_kernel_arg_[0].StrideE;
|
||||
const index_t N = gemm_desc_kernel_arg_[0].N;
|
||||
|
||||
const auto e_grid_desc_m_n =
|
||||
GridwiseGemm64::template MakeEGridDescriptor_M_N<ELayout, GemmSpec>(
|
||||
@@ -626,7 +476,7 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
|
||||
}
|
||||
}
|
||||
|
||||
gemm_desc_kernel_arg_.push_back(GemmBiasTransKernelArg{
|
||||
gemm_desc_kernel_arg_.push_back(KernelArgument{
|
||||
nullptr,
|
||||
nullptr,
|
||||
p_ds_grid,
|
||||
@@ -645,7 +495,7 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
|
||||
|
||||
const auto e_grid_desc_sum_m_n =
|
||||
GridwiseGemm64::template MakeEGridDescriptor_M_N<ELayout, GemmSpec>(
|
||||
sum_of_m, gemm_desc_kernel_arg_[0].N_, gemm_desc_kernel_arg_[0].StrideE_);
|
||||
sum_of_m, gemm_desc_kernel_arg_[0].N, gemm_desc_kernel_arg_[0].StrideE);
|
||||
|
||||
const auto local_b2c_tile_map = Block2ETileMap{e_grid_desc_sum_m_n, 1};
|
||||
|
||||
@@ -659,7 +509,7 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
|
||||
BElementwiseOperation b_element_op_;
|
||||
CDEElementwiseOperation c_element_op_;
|
||||
|
||||
std::vector<GemmBiasTransKernelArg> gemm_desc_kernel_arg_;
|
||||
std::vector<KernelArgument> gemm_desc_kernel_arg_;
|
||||
std::vector<Tuple<index_t, index_t>> a_mtx_mraw_kraw_;
|
||||
std::vector<Tuple<index_t, index_t>> b_mtx_nraw_kraw_;
|
||||
|
||||
@@ -686,7 +536,7 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
|
||||
for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); i++)
|
||||
{
|
||||
const auto KPad =
|
||||
GridwiseGemm::CalculateKPadded(arg.gemm_desc_kernel_arg_[i].K_, arg.k_batch_);
|
||||
GridwiseGemm::CalculateKPadded(arg.gemm_desc_kernel_arg_[i].K, arg.k_batch_);
|
||||
|
||||
if(GridwiseGemm::CalculateHasMainKBlockLoop(KPad) != has_main_k_block_loop)
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user