Grouped GEMM Multiple D tile loop. (#1247)

* Overload output stream operator for LoopScheduler and PiplineVersion

* Add Run overload accepting grid descriptors MK.

* Add __device__ keyword for CalculateGridSize

* Create device op GroupedGemmMultipleD

* Add GroupedGemm MultipleD Tile Loop implementation.

* Add an example for GroupedGemm MultipleD tile loop.

* Device Op GroupedGEMMTileLoop.

* Bunch of small changes in exmaple.

* CkProfiler

* Remove unused tparam.

* Fix include statement.

* Fix output stream overloads.

* Do not make descriptors and check validity untill we find group.

* Fix gemm desc initialization.

* Revert device op

* Fix compilation for DTYPES=FP16

* Validate tensor transfers paramters.

* Validate on host only NK dims if M is not known.

* Fix bug.

* A convenient debug func for selecting threads.

* Fix has main k block loop bug.

* Make sure that b2c has up to date tile offset.

* Output stream operator for Sequence type.

* Cmake file formatting.
This commit is contained in:
Adam Osewski
2024-04-25 22:12:53 +02:00
committed by GitHub
parent f448d179b7
commit b4032629e5
20 changed files with 2264 additions and 22 deletions

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -151,7 +151,7 @@ struct BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, void>
{
}
__host__ static constexpr index_t CalculateGridSize(index_t M, index_t N)
__host__ __device__ static constexpr index_t CalculateGridSize(index_t M, index_t N)
{
const auto M0 = math::integer_divide_ceil(M, MPerBlock);
const auto N0 = math::integer_divide_ceil(N, NPerBlock);
@@ -275,7 +275,7 @@ struct BlockToCTileMap_Grouped_M00_N0_M01Adapt
{
}
__host__ static constexpr index_t CalculateGridSize(index_t M, index_t N)
__host__ __device__ static constexpr index_t CalculateGridSize(index_t M, index_t N)
{
const auto M0 = math::integer_divide_ceil(M, MPerBlock);
const auto N0 = math::integer_divide_ceil(N, NPerBlock);
@@ -428,7 +428,7 @@ struct BlockToCTileMap_N00_M0_N01Adapt<MPerBlock, NPerBlock, void>
{
}
__host__ static constexpr index_t CalculateGridSize(index_t M, index_t N)
__host__ __device__ static constexpr index_t CalculateGridSize(index_t M, index_t N)
{
const auto M0 = math::integer_divide_ceil(M, MPerBlock);
const auto N0 = math::integer_divide_ceil(N, NPerBlock);
@@ -900,6 +900,11 @@ struct OffsettedBlockToCTileMap
return block_to_ctile_map_.CalculateGridSize(c_grid_desc_m_n);
}
__host__ __device__ constexpr index_t CalculateGridSize(index_t M, index_t N) const
{
return block_to_ctile_map_.CalculateGridSize(M, N);
}
UnderlyingBlockToCTileMap block_to_ctile_map_;
index_t block_start_;
};

View File

@@ -257,7 +257,70 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
e_grid_desc_m_n);
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template <typename ALayout, typename BLayout, typename ELayout>
__host__ __device__ static bool
CheckTensorTransfersValidity(index_t MRaw, index_t NRaw, index_t KRaw)
{
// Check if the vector dim is K1 or M|N
const auto A_vector_dim_size = ABlockTransferSrcVectorDim == 2 ? KRaw : MRaw;
const auto B_vector_dim_size = BBlockTransferSrcVectorDim == 2 ? KRaw : NRaw;
const auto E_vector_dim_size = NRaw;
// check vector load for A tensor
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
{
if(!(A_vector_dim_size == KRaw &&
A_vector_dim_size % ABlockTransferSrcScalarPerVector == 0))
return false;
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
{
if(!(A_vector_dim_size == MRaw &&
A_vector_dim_size % ABlockTransferSrcScalarPerVector == 0))
return false;
}
else
{
return false;
}
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
{
if(!(B_vector_dim_size == NRaw &&
B_vector_dim_size % BBlockTransferSrcScalarPerVector == 0))
return false;
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
{
if(!(B_vector_dim_size == KRaw &&
B_vector_dim_size % BBlockTransferSrcScalarPerVector == 0))
return false;
}
else
{
return false;
}
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ELayout>)
{
if(!(E_vector_dim_size == NRaw &&
E_vector_dim_size % CDEShuffleBlockTransferScalarPerVector_NPerBlock == 0))
return false;
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ELayout>)
{
if(!(E_vector_dim_size == NRaw &&
CDEShuffleBlockTransferScalarPerVector_NPerBlock == 1))
return false;
}
else
{
return false;
}
return true;
}
template <typename AGridDesc_M_K,
typename BGridDesc_N_K,
typename DsGridDesc_M_N,
@@ -267,7 +330,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
const BGridDesc_N_K& b_grid_desc_n_k,
const DsGridDesc_M_N& ds_grid_desc_m_n,
const EGridDesc_M_N& e_grid_desc_m_n,
const Block2ETileMap&)
[[maybe_unused]] const Block2ETileMap&)
{
static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
(NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
@@ -285,7 +348,6 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
{
return false;
}
bool valid = true;
static_for<0, NumDTensor, 1>{}([&](auto i) {
@@ -306,7 +368,6 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
// check gridwise gemm pipeline
const auto num_k_loop = AK / KPerBlock;
if(!GridwiseGemmPipe::IsSupported(num_k_loop))
{
return false;
@@ -938,6 +999,63 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
e_grid_desc_mblock_mperblock_nblock_nperblock,
block_2_etile_map);
}
template <bool HasMainKBlockLoop,
typename AGridDesc_MK,
typename BGridDesc_NK,
typename DsGridDesc_MN,
typename EGridDesc_MN,
typename Block2ETileMap>
__device__ static void Run(const void* __restrict__ p_a_grid_,
const void* __restrict__ p_b_grid_,
DsGridPointer p_ds_grid,
void* __restrict__ p_e_grid_,
void* __restrict__ p_shared,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op,
const AGridDesc_MK& a_grid_desc_m_k,
const BGridDesc_NK& b_grid_desc_n_k,
const DsGridDesc_MN& ds_grid_desc_m_n,
const EGridDesc_MN& e_grid_desc_m_n,
const Block2ETileMap& block_2_etile_map)
{
const auto p_a_grid = reinterpret_cast<const ADataType*>(p_a_grid_);
const auto p_b_grid = reinterpret_cast<const BDataType*>(p_b_grid_);
const auto p_e_grid = reinterpret_cast<EDataType*>(p_e_grid_);
// tensor descriptors for block/thread-wise copy
const auto a_grid_desc_ak0_m_ak1 = MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k);
const auto b_grid_desc_bk0_n_bk1 = MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k);
using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
remove_cvref_t<decltype(MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
DsGridDesc_MN{}))>;
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock;
static_for<0, NumDTensor, 1>{}([&](auto j) {
ds_grid_desc_mblock_mperblock_nblock_nperblock(j) =
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n[j]);
});
const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n);
Run<HasMainKBlockLoop>(p_a_grid,
p_b_grid,
p_ds_grid,
p_e_grid,
p_shared,
a_element_op,
b_element_op,
cde_element_op,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
ds_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_desc_mblock_mperblock_nblock_nperblock,
block_2_etile_map);
}
};
} // namespace ck

View File

@@ -1,9 +1,10 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <ostream>
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp"
@@ -57,3 +58,16 @@ constexpr auto GridwiseGemmPipeline_Selector()
}
} // namespace ck
inline std::ostream& operator<<(std::ostream& os, const ck::PipelineVersion& p)
{
switch(p)
{
case ck::PipelineVersion::v1: os << "PipelineVersion::v1"; break;
case ck::PipelineVersion::v2: os << "PipelineVersion::v2"; break;
case ck::PipelineVersion::v4: os << "PipelineVersion::v4"; break;
case ck::PipelineVersion::weight_only: os << "PipelineVersion::weight_only"; break;
default: os << "";
}
return os;
}