mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 00:40:09 +00:00
Tensile-style block to C tile map (#239)
* fix build
* Revert "fix build"
This reverts commit d73102384b.
* post PR #235 merge fix
* amend
* adds tensile-stype c-tile map
* make it dynamic version
* add k-split flavor tile map
* apply tensile-style tile map to all xdl gridwise gemms
* remove dead code
Co-authored-by: Chao Liu <chao.liu2@amd.com>
This commit is contained in:
@@ -8,6 +8,237 @@
|
||||
|
||||
namespace ck {
|
||||
|
||||
// Rows of column-vectors
|
||||
template <index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
typename CGridDesc_M_N,
|
||||
bool DeviceCTileIndexCheck = false>
|
||||
struct BlockToCTileMap_M00_N0_M01
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
|
||||
__host__ __device__ BlockToCTileMap_M00_N0_M01() = default;
|
||||
|
||||
__host__ __device__ BlockToCTileMap_M00_N0_M01(const CGridDesc_M_N& c_grid_desc_m_n,
|
||||
index_t M01 = 1)
|
||||
: M01_(M01), underlying_map_(GetBlockToCTileMap(c_grid_desc_m_n, M01))
|
||||
{
|
||||
}
|
||||
|
||||
__host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const
|
||||
{
|
||||
const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock);
|
||||
const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock);
|
||||
|
||||
const auto M00 = math::integer_divide_ceil(M0, M01_);
|
||||
|
||||
const index_t grid_size = M00 * M01_ * N0;
|
||||
|
||||
return grid_size;
|
||||
}
|
||||
|
||||
template <typename TopIdx>
|
||||
__host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
|
||||
{
|
||||
return underlying_map_.CalculateBottomIndex(idx_top);
|
||||
}
|
||||
|
||||
template <typename CTileIdx, typename CTileDim>
|
||||
__host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx,
|
||||
const CTileDim& c_tile_dim) const
|
||||
{
|
||||
if constexpr(DeviceCTileIndexCheck)
|
||||
return DefaultValidCTileIndex(c_tile_idx, c_tile_dim);
|
||||
else
|
||||
return true;
|
||||
}
|
||||
|
||||
__host__ bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const
|
||||
{
|
||||
if constexpr(DeviceCTileIndexCheck)
|
||||
return true; // validity check moved to kernel
|
||||
|
||||
const index_t M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock);
|
||||
if(M0 % M01_ == 0)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
__host__ __device__ static constexpr auto
|
||||
GetBlockToCTileMap(const CGridDesc_M_N& c_grid_desc_m_n, index_t M01)
|
||||
{
|
||||
const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock);
|
||||
const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock);
|
||||
|
||||
const auto M00 = math::integer_divide_ceil(M0, M01);
|
||||
|
||||
const auto m00_n0_m01_to_m0_n0_block_cluster_adaptor = make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_insert_transform(1),
|
||||
make_unmerge_transform(make_tuple(M00, M01)),
|
||||
make_pass_through_transform(make_tuple(N0))),
|
||||
make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2>{}));
|
||||
|
||||
const auto cblockid_to_m00_n0_m01_block_cluster_adaptor = make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(make_tuple(1, M00, N0, M01))),
|
||||
make_tuple(Sequence<0, 1, 2, 3>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto cblockid_to_m0_n0_block_cluster_adaptor =
|
||||
chain_tensor_adaptors(m00_n0_m01_to_m0_n0_block_cluster_adaptor,
|
||||
cblockid_to_m00_n0_m01_block_cluster_adaptor);
|
||||
|
||||
return cblockid_to_m0_n0_block_cluster_adaptor;
|
||||
}
|
||||
|
||||
index_t M01_;
|
||||
using UnderlyingMap = decltype(GetBlockToCTileMap(CGridDesc_M_N{}, 1));
|
||||
UnderlyingMap underlying_map_;
|
||||
};
|
||||
|
||||
// Rows of column-vectors
|
||||
// This C-tile map dynamically adjusts M01 when C-tile index is out of range
|
||||
template <index_t MPerBlock, index_t NPerBlock, typename CGridDesc_M_N>
|
||||
struct BlockToCTileMap_M00_N0_M01Adapt
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
|
||||
__host__ __device__ BlockToCTileMap_M00_N0_M01Adapt() = default;
|
||||
|
||||
__host__ __device__ BlockToCTileMap_M00_N0_M01Adapt(const CGridDesc_M_N& c_grid_desc_m_n,
|
||||
index_t M01 = 8)
|
||||
: M01_(M01), c_grid_desc_m_n_(c_grid_desc_m_n)
|
||||
{
|
||||
}
|
||||
|
||||
__host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const
|
||||
{
|
||||
const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock);
|
||||
const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock);
|
||||
|
||||
const index_t grid_size = M0 * N0;
|
||||
|
||||
return grid_size;
|
||||
}
|
||||
|
||||
template <typename TopIdx>
|
||||
__host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
|
||||
{
|
||||
auto block_1d_id = idx_top[I0];
|
||||
|
||||
const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n_.GetLength(I0), MPerBlock);
|
||||
const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n_.GetLength(I1), NPerBlock);
|
||||
|
||||
block_1d_id = block_1d_id % (M0 * N0); // swallow batch index
|
||||
|
||||
index_t idx_N0 = block_1d_id % N0;
|
||||
index_t idx_M0 = block_1d_id / N0;
|
||||
|
||||
const auto M01_adapt = (idx_M0 < M0 - M0 % M01_) ? M01_ : M0 % M01_;
|
||||
|
||||
index_t idx_M00 = idx_M0 / M01_;
|
||||
index_t idx_M01 = idx_M0 % M01_;
|
||||
index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0;
|
||||
|
||||
return make_tuple(idx_N0_M01_local % M01_adapt + idx_M00 * M01_,
|
||||
idx_N0_M01_local / M01_adapt);
|
||||
}
|
||||
|
||||
template <typename CTileIdx, typename CTileDim>
|
||||
__host__ __device__ bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */,
|
||||
const CTileDim& /* c_tile_dim */) const
|
||||
{
|
||||
return true; // always valid provided that user gets grid size from CalculateGridSize()
|
||||
}
|
||||
|
||||
__host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const { return true; }
|
||||
|
||||
private:
|
||||
index_t M01_;
|
||||
CGridDesc_M_N c_grid_desc_m_n_;
|
||||
};
|
||||
|
||||
// 2D slices of column-vectors in 3D space
|
||||
// This C-tile map dynamically adjusts M01 when C-tile index is out of range
|
||||
template <index_t MPerBlock, index_t NPerBlock, typename CGridDesc_M_N>
|
||||
struct BlockToCTileMap_KSplit_M00_N0_M01Adapt
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
|
||||
__host__ __device__ BlockToCTileMap_KSplit_M00_N0_M01Adapt() = default;
|
||||
|
||||
__host__ __device__ BlockToCTileMap_KSplit_M00_N0_M01Adapt(const CGridDesc_M_N& c_grid_desc_m_n,
|
||||
index_t M01 = 8,
|
||||
index_t KSplit = 1)
|
||||
: M01_(M01), KSplit_(KSplit), c_grid_desc_m_n_(c_grid_desc_m_n)
|
||||
{
|
||||
}
|
||||
|
||||
__host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const
|
||||
{
|
||||
const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock);
|
||||
const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock);
|
||||
|
||||
const index_t grid_size = M0 * N0 * KSplit_;
|
||||
|
||||
return grid_size;
|
||||
}
|
||||
|
||||
template <typename TopIdx>
|
||||
__host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
|
||||
{
|
||||
auto block_1d_id = idx_top[I0];
|
||||
|
||||
const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n_.GetLength(I0), MPerBlock);
|
||||
const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n_.GetLength(I1), NPerBlock);
|
||||
|
||||
const index_t idx_ksplit = block_1d_id / (M0 * N0);
|
||||
block_1d_id = block_1d_id % (M0 * N0);
|
||||
|
||||
index_t idx_N0 = block_1d_id % N0;
|
||||
index_t idx_M0 = block_1d_id / N0;
|
||||
|
||||
const auto M01_adapt = (idx_M0 < M0 - M0 % M01_) ? M01_ : M0 % M01_;
|
||||
|
||||
index_t idx_M00 = idx_M0 / M01_;
|
||||
index_t idx_M01 = idx_M0 % M01_;
|
||||
index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0;
|
||||
|
||||
return make_tuple(idx_ksplit,
|
||||
idx_N0_M01_local % M01_adapt + idx_M00 * M01_,
|
||||
idx_N0_M01_local / M01_adapt);
|
||||
}
|
||||
|
||||
template <typename CTileIdx, typename CTileDim>
|
||||
__host__ __device__ bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */,
|
||||
const CTileDim& /* c_tile_dim */) const
|
||||
{
|
||||
return true; // always valid provided that user gets grid size from CalculateGridSize()
|
||||
}
|
||||
|
||||
__host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const { return true; }
|
||||
|
||||
private:
|
||||
index_t M01_;
|
||||
index_t KSplit_;
|
||||
CGridDesc_M_N c_grid_desc_m_n_;
|
||||
};
|
||||
|
||||
// Blocks of row-vectors
|
||||
template <index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
|
||||
@@ -306,7 +306,7 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n)
|
||||
{
|
||||
return BlockToCTileMap_M00_N00_M01_N01<MPerBlock, NPerBlock, CGridDesc_M_N>(
|
||||
return BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, CGridDesc_M_N>(
|
||||
c_grid_desc_m_n);
|
||||
}
|
||||
|
||||
|
||||
@@ -259,7 +259,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n)
|
||||
{
|
||||
return BlockToCTileMap_M00_N00_M01_N01<MPerBlock, NPerBlock, CGridDesc_M_N>(
|
||||
return BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, CGridDesc_M_N>(
|
||||
c_grid_desc_m_n);
|
||||
}
|
||||
|
||||
|
||||
@@ -288,11 +288,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
}
|
||||
|
||||
// return block_id to C matrix tile idx (m0, n0) mapping
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n, index_t M01, index_t N01)
|
||||
__host__ __device__ static constexpr auto MakeDefaultBlock2CTileMap(
|
||||
const CGridDesc_M_N& c_grid_desc_m_n, index_t /* M01 */, index_t /* N01 */)
|
||||
{
|
||||
return BlockToCTileMap_M00_N00_M01_N01<MPerBlock, NPerBlock, CGridDesc_M_N>(
|
||||
c_grid_desc_m_n, M01, N01);
|
||||
return BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, CGridDesc_M_N>(
|
||||
c_grid_desc_m_n);
|
||||
}
|
||||
|
||||
using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 =
|
||||
|
||||
@@ -265,10 +265,10 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
|
||||
|
||||
// return block_id to C matrix tile idx (m0, n0) mapping
|
||||
__host__ __device__ static constexpr auto MakeCBlockClusterAdaptor(
|
||||
const CMNGridDesc& c_m_n_grid_desc, index_t M01, index_t N01, index_t KBatch)
|
||||
const CMNGridDesc& c_m_n_grid_desc, index_t /* M01 */, index_t /* N01 */, index_t KBatch)
|
||||
{
|
||||
return BlockToCTileMap_KSplit_M00_N00_M01_N01<MPerBlock, NPerBlock, CMNGridDesc>(
|
||||
c_m_n_grid_desc, M01, N01, KBatch);
|
||||
return BlockToCTileMap_KSplit_M00_N0_M01Adapt<MPerBlock, NPerBlock, CMNGridDesc>(
|
||||
c_m_n_grid_desc, 8, KBatch);
|
||||
}
|
||||
|
||||
using CM0N0M1N1M2M3M4N2GridDesc = decltype(MakeCM0N0M1N1M2M3M4N2GridDescriptor(CMNGridDesc{}));
|
||||
|
||||
@@ -239,10 +239,10 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
|
||||
|
||||
// return block_id to C matrix tile idx (m0, n0) mapping
|
||||
__host__ __device__ static constexpr auto MakeCBlockClusterAdaptor(
|
||||
const CMNGridDesc& c_m_n_grid_desc, index_t M01, index_t N01, index_t KBatch)
|
||||
const CMNGridDesc& c_m_n_grid_desc, index_t /* M01 */, index_t /* N01 */, index_t KBatch)
|
||||
{
|
||||
return BlockToCTileMap_KSplit_M00_N00_M01_N01<MPerBlock, NPerBlock, CMNGridDesc>(
|
||||
c_m_n_grid_desc, M01, N01, KBatch);
|
||||
return BlockToCTileMap_KSplit_M00_N0_M01Adapt<MPerBlock, NPerBlock, CMNGridDesc>(
|
||||
c_m_n_grid_desc, 8, KBatch);
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
|
||||
@@ -300,11 +300,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
|
||||
}
|
||||
|
||||
// return block_id to C matrix tile idx (m0, n0) mapping
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n, index_t M01, index_t N01)
|
||||
__host__ __device__ static constexpr auto MakeDefaultBlock2CTileMap(
|
||||
const CGridDesc_M_N& c_grid_desc_m_n, index_t /* M01 */, index_t /* N01 */)
|
||||
{
|
||||
return BlockToCTileMap_M00_N00_M01_N01<MPerBlock, NPerBlock, CGridDesc_M_N>(
|
||||
c_grid_desc_m_n, M01, N01);
|
||||
return BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, CGridDesc_M_N>(
|
||||
c_grid_desc_m_n);
|
||||
}
|
||||
using CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl =
|
||||
remove_cvref_t<decltype(
|
||||
|
||||
@@ -309,11 +309,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
|
||||
}
|
||||
|
||||
// return block_id to C matrix tile idx (m0, n0) mapping
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n, index_t M01, index_t N01)
|
||||
__host__ __device__ static constexpr auto MakeDefaultBlock2CTileMap(
|
||||
const CGridDesc_M_N& c_grid_desc_m_n, index_t /* M01 */, index_t /* N01 */)
|
||||
{
|
||||
return BlockToCTileMap_M00_N00_M01_N01<MPerBlock, NPerBlock, CGridDesc_M_N>(
|
||||
c_grid_desc_m_n, M01, N01);
|
||||
return BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, CGridDesc_M_N>(
|
||||
c_grid_desc_m_n);
|
||||
}
|
||||
|
||||
using CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl =
|
||||
|
||||
@@ -316,11 +316,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
|
||||
}
|
||||
|
||||
// return block_id to C matrix tile idx (m0, n0) mapping
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n, index_t M01, index_t N01)
|
||||
__host__ __device__ static constexpr auto MakeDefaultBlock2CTileMap(
|
||||
const CGridDesc_M_N& c_grid_desc_m_n, index_t /* M01 */, index_t /* N01 */)
|
||||
{
|
||||
return BlockToCTileMap_M00_N00_M01_N01<MPerBlock, NPerBlock, CGridDesc_M_N>(
|
||||
c_grid_desc_m_n, M01, N01);
|
||||
return BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, CGridDesc_M_N>(
|
||||
c_grid_desc_m_n);
|
||||
}
|
||||
using CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl =
|
||||
remove_cvref_t<decltype(
|
||||
|
||||
Reference in New Issue
Block a user