mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-05 14:11:29 +00:00
Simplify kernel argument of device operator DeviceGemm_Xdl_CShuffle<> (#696)
* Remove M/N/KPad local variables * Use M/N/KPad to name padded lengths * Replace duplicated local variable by parameters * Rename variables M/N/KRaw to M/N/K * Move AK0/BK0 compute logic into GridwiseGemm * Use macro to shorten code * Move CalculateGridSize() logic into GridwiseGemm * Add comment to credit the implementation source * Reuse the existing implementation * Remove no-longer used data members * Remove elementwise-op objects from interfaces * Reserve kernel arg as whole object in interfaces * Remove redundant data member * Make 3rd type parameter optional * Remove unnesscary type parameters * Remove no-longer used descriptor-creation methods * Move kernel arg type definition into GridwiseGemm * Add macro to switch between code sections * Move argument field computing logic into device op side * Make utility method 'static' * Declare special methods * Unify MakeArgument() usage * Adapt the new GridwiseGemm interface * Push-down class 'GridwiseGemm::Argument' fields * Remove no-longer used methods * Add unused parameters * Force copying parameters in 'Embed' ctor * Remove no-longer used descriptors * Fallback change on BaseArgument * Remove macro 'INTEGER_DIVIDE_CEIL' * Make variable naming more consistent * Make sure methods are only invoked on right place * Remove tailing underscore in public attribute name * Remove necessary methods * Hide computing logic of derived attributes * Make new 'Embed' ctor only available for device code * Make sure 'Embed' type args are not references * Move check for karg.K into CheckValidity() * Remove more integer division logic form device code * Undo changes on Embed * Separate 'Problem' concept out from 'Argument' * Share same name for kernel interfaces * Reject unsupported argument --------- Co-authored-by: zjing14 <zhangjing14@gmail.com>
This commit is contained in:
@@ -109,30 +109,37 @@ struct BlockToCTileMap_M00_N0_M01
|
||||
|
||||
// Rows of column-vectors
|
||||
// This C-tile map dynamically adjusts M01 when C-tile index is out of range
|
||||
template <index_t MPerBlock, index_t NPerBlock, typename CGridDesc_M_N>
|
||||
struct BlockToCTileMap_M00_N0_M01Adapt
|
||||
template <index_t MPerBlock, index_t NPerBlock, typename CGridDesc_M_N = void>
|
||||
struct BlockToCTileMap_M00_N0_M01Adapt;
|
||||
|
||||
template <index_t MPerBlock, index_t NPerBlock>
|
||||
struct BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, void>
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
|
||||
__host__ __device__ BlockToCTileMap_M00_N0_M01Adapt() = default;
|
||||
|
||||
__host__ __device__ BlockToCTileMap_M00_N0_M01Adapt(const CGridDesc_M_N& c_grid_desc_m_n,
|
||||
index_t M01 = 8)
|
||||
: M01_(M01), c_grid_desc_m_n_(c_grid_desc_m_n)
|
||||
__host__ __device__ BlockToCTileMap_M00_N0_M01Adapt(const BlockToCTileMap_M00_N0_M01Adapt&) =
|
||||
default;
|
||||
__host__ __device__ BlockToCTileMap_M00_N0_M01Adapt(BlockToCTileMap_M00_N0_M01Adapt&&) =
|
||||
default;
|
||||
__host__ __device__ BlockToCTileMap_M00_N0_M01Adapt&
|
||||
operator=(const BlockToCTileMap_M00_N0_M01Adapt&) = default;
|
||||
__host__ __device__ BlockToCTileMap_M00_N0_M01Adapt&
|
||||
operator=(BlockToCTileMap_M00_N0_M01Adapt&&) = default;
|
||||
|
||||
__host__ __device__ BlockToCTileMap_M00_N0_M01Adapt(index_t M, index_t N, index_t M01 = 8)
|
||||
: M_(M), N_(N), M01_(M01)
|
||||
{
|
||||
}
|
||||
|
||||
__host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const
|
||||
__host__ static constexpr index_t CalculateGridSize(index_t M, index_t N)
|
||||
{
|
||||
const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock);
|
||||
const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock);
|
||||
const auto M0 = math::integer_divide_ceil(M, MPerBlock);
|
||||
const auto N0 = math::integer_divide_ceil(N, NPerBlock);
|
||||
|
||||
const index_t grid_size = M0 * N0;
|
||||
|
||||
return grid_size;
|
||||
return M0 * N0;
|
||||
}
|
||||
|
||||
template <typename TopIdx>
|
||||
@@ -140,8 +147,8 @@ struct BlockToCTileMap_M00_N0_M01Adapt
|
||||
{
|
||||
auto block_1d_id = idx_top[I0];
|
||||
|
||||
const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n_.GetLength(I0), MPerBlock);
|
||||
const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n_.GetLength(I1), NPerBlock);
|
||||
const auto M0 = math::integer_divide_ceil(M_, MPerBlock);
|
||||
const auto N0 = math::integer_divide_ceil(N_, NPerBlock);
|
||||
|
||||
block_1d_id = block_1d_id % (M0 * N0); // swallow batch index
|
||||
|
||||
@@ -209,11 +216,36 @@ struct BlockToCTileMap_M00_N0_M01Adapt
|
||||
return true; // always valid provided that user gets grid size from CalculateGridSize()
|
||||
}
|
||||
|
||||
__host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const { return true; }
|
||||
|
||||
private:
|
||||
index_t M_;
|
||||
index_t N_;
|
||||
index_t M01_;
|
||||
CGridDesc_M_N c_grid_desc_m_n_;
|
||||
};
|
||||
|
||||
template <index_t MPerBlock, index_t NPerBlock, typename CGridDesc_M_N>
|
||||
struct BlockToCTileMap_M00_N0_M01Adapt : BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, void>
|
||||
{
|
||||
using Parent = BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, void>;
|
||||
|
||||
using Parent::I0;
|
||||
using Parent::I1;
|
||||
|
||||
using Parent::Parent;
|
||||
using Parent::operator=;
|
||||
|
||||
__host__ __device__ BlockToCTileMap_M00_N0_M01Adapt(const CGridDesc_M_N& c_grid_desc_m_n,
|
||||
index_t M01 = 8)
|
||||
: Parent(c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1), M01)
|
||||
{
|
||||
}
|
||||
|
||||
__host__ static constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n)
|
||||
{
|
||||
return Parent::CalculateGridSize(c_grid_desc_m_n.GetLength(I0),
|
||||
c_grid_desc_m_n.GetLength(I1));
|
||||
}
|
||||
|
||||
__host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const { return true; }
|
||||
};
|
||||
|
||||
// 2D slices of column-vectors in 3D space
|
||||
|
||||
@@ -17,17 +17,25 @@
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename GridwiseGemm,
|
||||
typename FloatAB,
|
||||
typename FloatC,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
typename AGridDesc_AK0_M_AK1,
|
||||
typename BGridDesc_BK0_N_BK1,
|
||||
typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename Block2CTileMap,
|
||||
bool HasMainKBlockLoop>
|
||||
template <typename GridwiseGemm, bool HasMainKBlockLoop>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_gemm_xdl_cshuffle_v1(typename GridwiseGemm::Argument karg)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx940__))
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop>(
|
||||
karg.p_a_grid, karg.p_b_grid, karg.p_c_grid, p_shared, karg);
|
||||
#else
|
||||
ignore = karg;
|
||||
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
|
||||
}
|
||||
|
||||
template <typename GridwiseGemm, typename FloatAB, typename FloatC, bool HasMainKBlockLoop>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
@@ -35,55 +43,33 @@ __global__ void
|
||||
kernel_gemm_xdl_cshuffle_v1(const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
const AElementwiseOperation a_element_op,
|
||||
const BElementwiseOperation b_element_op,
|
||||
const CElementwiseOperation c_element_op,
|
||||
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
|
||||
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
|
||||
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const Block2CTileMap block_2_ctile_map)
|
||||
typename GridwiseGemm::Problem problem)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx940__))
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
p_shared,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
block_2_ctile_map);
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid, p_b_grid, p_c_grid, p_shared, problem);
|
||||
#else
|
||||
ignore = p_a_grid;
|
||||
ignore = p_b_grid;
|
||||
ignore = p_c_grid;
|
||||
ignore = a_element_op;
|
||||
ignore = b_element_op;
|
||||
ignore = c_element_op;
|
||||
ignore = a_grid_desc_ak0_m_ak1;
|
||||
ignore = b_grid_desc_bk0_n_bk1;
|
||||
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
ignore = block_2_ctile_map;
|
||||
ignore = problem;
|
||||
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
|
||||
}
|
||||
|
||||
template <typename FloatAB,
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
typename FloatAB,
|
||||
typename FloatGemmAcc,
|
||||
typename FloatCShuffle,
|
||||
typename FloatC,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
tensor_operation::device::GemmSpecialization GemmSpec,
|
||||
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
|
||||
typename AGridDesc_AK0_M_AK1,
|
||||
typename BGridDesc_BK0_N_BK1,
|
||||
typename CGridDesc_M_N,
|
||||
index_t NumGemmKPrefetchStage,
|
||||
index_t BlockSize,
|
||||
index_t MPerBlock,
|
||||
@@ -129,35 +115,396 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
|
||||
static constexpr auto I7 = Number<7>{};
|
||||
|
||||
// K1 should be Number<...>
|
||||
static constexpr auto AK0 = Number<KPerBlock / AK1Value>{};
|
||||
static constexpr auto BK0 = Number<KPerBlock / BK1Value>{};
|
||||
static constexpr auto AK1 = Number<AK1Value>{};
|
||||
static constexpr auto BK1 = Number<BK1Value>{};
|
||||
static constexpr auto AK0Number = Number<KPerBlock / AK1Value>{};
|
||||
static constexpr auto BK0Number = Number<KPerBlock / BK1Value>{};
|
||||
static constexpr auto AK1Number = Number<AK1Value>{};
|
||||
static constexpr auto BK1Number = Number<BK1Value>{};
|
||||
|
||||
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
|
||||
|
||||
__host__ static auto CalculateGridSize(index_t M, index_t N)
|
||||
{
|
||||
return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, 1);
|
||||
}
|
||||
|
||||
__host__ static auto CalculateMPadded(index_t M)
|
||||
{
|
||||
return math::integer_divide_ceil(M, MPerBlock) * MPerBlock;
|
||||
}
|
||||
|
||||
__host__ static auto CalculateNPadded(index_t N)
|
||||
{
|
||||
return math::integer_divide_ceil(N, NPerBlock) * NPerBlock;
|
||||
}
|
||||
|
||||
__host__ static auto CalculateKPadded(index_t K)
|
||||
{
|
||||
return math::integer_divide_ceil(K, KPerBlock) * KPerBlock;
|
||||
}
|
||||
|
||||
__host__ static auto CalculateAK0(index_t K)
|
||||
{
|
||||
using GemmSpecialization = tensor_operation::device::GemmSpecialization;
|
||||
|
||||
if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
|
||||
GemmSpec == GemmSpecialization::MNKPadding ||
|
||||
GemmSpec == GemmSpecialization::KPadding ||
|
||||
GemmSpec == GemmSpecialization::NKPadding)
|
||||
{
|
||||
return CalculateKPadded(K) / AK1Value;
|
||||
}
|
||||
else
|
||||
{
|
||||
return K / AK1Value;
|
||||
}
|
||||
}
|
||||
|
||||
__host__ static auto CalculateBK0(index_t K)
|
||||
{
|
||||
using GemmSpecialization = tensor_operation::device::GemmSpecialization;
|
||||
|
||||
if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
|
||||
GemmSpec == GemmSpecialization::MNKPadding ||
|
||||
GemmSpec == GemmSpecialization::KPadding ||
|
||||
GemmSpec == GemmSpecialization::MKPadding)
|
||||
{
|
||||
return CalculateKPadded(K) / BK1Value;
|
||||
}
|
||||
else
|
||||
{
|
||||
return K / BK1Value;
|
||||
}
|
||||
}
|
||||
|
||||
__host__ static auto CalculateMBlock(index_t M)
|
||||
{
|
||||
return math::integer_divide_floor(M, MPerBlock);
|
||||
}
|
||||
|
||||
__host__ static auto CalculateNBlock(index_t N)
|
||||
{
|
||||
return math::integer_divide_floor(N, NPerBlock);
|
||||
}
|
||||
|
||||
__device__ static auto MakeAGridDescriptor_AK0_M_AK1(
|
||||
index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0)
|
||||
{
|
||||
const auto a_grid_desc_mraw_kraw = [&]() {
|
||||
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
|
||||
}
|
||||
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
|
||||
}
|
||||
}();
|
||||
|
||||
using GemmSpecialization = tensor_operation::device::GemmSpecialization;
|
||||
|
||||
if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
|
||||
GemmSpec == GemmSpecialization::MNKPadding)
|
||||
{
|
||||
// pad both M and K
|
||||
const auto a_grid_desc_m_k =
|
||||
transform_tensor_descriptor(a_grid_desc_mraw_kraw,
|
||||
make_tuple(make_right_pad_transform(M, MPad - M),
|
||||
make_right_pad_transform(K, KPad - K)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
|
||||
a_grid_desc_m_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)),
|
||||
make_pass_through_transform(MPad)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
return a_grid_desc_ak0_m_ak1;
|
||||
}
|
||||
else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
|
||||
GemmSpec == GemmSpecialization::MNPadding)
|
||||
{
|
||||
// pad M, but not K
|
||||
const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
|
||||
a_grid_desc_mraw_kraw,
|
||||
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)),
|
||||
make_right_pad_transform(M, MPad - M)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
return a_grid_desc_ak0_m_ak1;
|
||||
}
|
||||
else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
|
||||
GemmSpec == GemmSpecialization::NKPadding)
|
||||
{
|
||||
// pad K, but not M
|
||||
const auto a_grid_desc_m_k = transform_tensor_descriptor(
|
||||
a_grid_desc_mraw_kraw,
|
||||
make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
|
||||
a_grid_desc_m_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)),
|
||||
make_pass_through_transform(M)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
return a_grid_desc_ak0_m_ak1;
|
||||
}
|
||||
else
|
||||
{
|
||||
// not pad M or K
|
||||
const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
|
||||
a_grid_desc_mraw_kraw,
|
||||
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)),
|
||||
make_pass_through_transform(M)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
return a_grid_desc_ak0_m_ak1;
|
||||
}
|
||||
}
|
||||
|
||||
__device__ static auto MakeBGridDescriptor_BK0_N_BK1(
|
||||
index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
|
||||
{
|
||||
const auto b_grid_desc_nraw_kraw = [&]() {
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB));
|
||||
}
|
||||
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1));
|
||||
}
|
||||
}();
|
||||
|
||||
using GemmSpecialization = tensor_operation::device::GemmSpecialization;
|
||||
|
||||
if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
|
||||
GemmSpec == GemmSpecialization::MNKPadding)
|
||||
{
|
||||
// pad both N and K
|
||||
const auto b_grid_desc_n_k =
|
||||
transform_tensor_descriptor(b_grid_desc_nraw_kraw,
|
||||
make_tuple(make_right_pad_transform(N, NPad - N),
|
||||
make_right_pad_transform(K, KPad - K)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
|
||||
b_grid_desc_n_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)),
|
||||
make_pass_through_transform(NPad)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
return b_grid_desc_bk0_n_bk1;
|
||||
}
|
||||
else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
|
||||
GemmSpec == GemmSpecialization::MNPadding)
|
||||
{
|
||||
// pad N, but not K
|
||||
const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
|
||||
b_grid_desc_nraw_kraw,
|
||||
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)),
|
||||
make_right_pad_transform(N, NPad - N)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
return b_grid_desc_bk0_n_bk1;
|
||||
}
|
||||
else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
|
||||
GemmSpec == GemmSpecialization::MKPadding)
|
||||
{
|
||||
// pad K, but not N
|
||||
const auto b_grid_desc_n_k = transform_tensor_descriptor(
|
||||
b_grid_desc_nraw_kraw,
|
||||
make_tuple(make_pass_through_transform(N), make_right_pad_transform(K, KPad - K)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
|
||||
b_grid_desc_n_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)),
|
||||
make_pass_through_transform(N)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
return b_grid_desc_bk0_n_bk1;
|
||||
}
|
||||
else
|
||||
{
|
||||
// not pad N or K
|
||||
const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
|
||||
b_grid_desc_nraw_kraw,
|
||||
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)),
|
||||
make_pass_through_transform(N)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
return b_grid_desc_bk0_n_bk1;
|
||||
}
|
||||
}
|
||||
|
||||
__host__ __device__ static auto
|
||||
MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
|
||||
{
|
||||
const auto c_grid_desc_mraw_nraw = [&]() {
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
|
||||
}
|
||||
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, CLayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
|
||||
}
|
||||
}();
|
||||
|
||||
using GemmSpecialization = tensor_operation::device::GemmSpecialization;
|
||||
|
||||
if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
|
||||
GemmSpec == GemmSpecialization::MNKPadding)
|
||||
{
|
||||
// pad M and N
|
||||
return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
|
||||
make_tuple(make_right_pad_transform(M, MPad - M),
|
||||
make_right_pad_transform(N, NPad - N)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
|
||||
GemmSpec == GemmSpecialization::MKPadding)
|
||||
{
|
||||
// pad M, but not N
|
||||
return transform_tensor_descriptor(
|
||||
c_grid_desc_mraw_nraw,
|
||||
make_tuple(make_right_pad_transform(M, MPad - M), make_pass_through_transform(N)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
|
||||
GemmSpec == GemmSpecialization::NKPadding)
|
||||
{
|
||||
// pad N, but not M
|
||||
return transform_tensor_descriptor(
|
||||
c_grid_desc_mraw_nraw,
|
||||
make_tuple(make_pass_through_transform(M), make_right_pad_transform(N, NPad - N)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
// not pad M or N
|
||||
return c_grid_desc_mraw_nraw;
|
||||
}
|
||||
}
|
||||
|
||||
struct Problem
|
||||
{
|
||||
__host__ Problem(index_t M_,
|
||||
index_t N_,
|
||||
index_t K_,
|
||||
index_t StrideA_,
|
||||
index_t StrideB_,
|
||||
index_t StrideC_)
|
||||
: M{M_},
|
||||
N{N_},
|
||||
K{K_},
|
||||
StrideA{StrideA_},
|
||||
StrideB{StrideB_},
|
||||
StrideC{StrideC_},
|
||||
MPadded{CalculateMPadded(M_)},
|
||||
NPadded{CalculateNPadded(N_)},
|
||||
KPadded{CalculateKPadded(K_)},
|
||||
AK0{CalculateAK0(K_)},
|
||||
BK0{CalculateBK0(K_)},
|
||||
MBlock{CalculateMBlock(M_)},
|
||||
NBlock{CalculateNBlock(N_)}
|
||||
{
|
||||
}
|
||||
|
||||
__host__ void Print() const
|
||||
{
|
||||
std::cout << "problem {"
|
||||
<< "M:" << M << ", "
|
||||
<< "N:" << N << ", "
|
||||
<< "K:" << K << ", "
|
||||
<< "SA:" << StrideA << ", "
|
||||
<< "SB:" << StrideB << ", "
|
||||
<< "SC:" << StrideC << ", "
|
||||
<< "MP:" << MPadded << ", "
|
||||
<< "NP:" << NPadded << ", "
|
||||
<< "KP:" << KPadded << ", "
|
||||
<< "AK0:" << AK0 << ", "
|
||||
<< "BK0:" << BK0 << ", "
|
||||
<< "MBlock: " << MBlock << ", "
|
||||
<< "NBlock: " << NBlock << "}" << std::endl;
|
||||
}
|
||||
|
||||
index_t M;
|
||||
index_t N;
|
||||
index_t K;
|
||||
index_t StrideA;
|
||||
index_t StrideB;
|
||||
index_t StrideC;
|
||||
index_t MPadded;
|
||||
index_t NPadded;
|
||||
index_t KPadded;
|
||||
index_t AK0;
|
||||
index_t BK0;
|
||||
index_t MBlock;
|
||||
index_t NBlock;
|
||||
};
|
||||
|
||||
// Argument
|
||||
struct Argument : public tensor_operation::device::BaseArgument, public Problem
|
||||
{
|
||||
__host__ Argument(const FloatAB* p_a_grid_,
|
||||
const FloatAB* p_b_grid_,
|
||||
FloatC* p_c_grid_,
|
||||
index_t M_,
|
||||
index_t N_,
|
||||
index_t K_,
|
||||
index_t StrideA_,
|
||||
index_t StrideB_,
|
||||
index_t StrideC_)
|
||||
: Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_},
|
||||
p_a_grid{p_a_grid_},
|
||||
p_b_grid{p_b_grid_},
|
||||
p_c_grid{p_c_grid_}
|
||||
{
|
||||
}
|
||||
|
||||
const FloatAB* p_a_grid;
|
||||
const FloatAB* p_b_grid;
|
||||
FloatC* p_c_grid;
|
||||
};
|
||||
|
||||
// FIXME: pass GridwiseGemmPipe as a template arguement into GridwiseGemm
|
||||
using GridwiseGemmPipe = remove_cvref_t<decltype(
|
||||
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
|
||||
|
||||
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
|
||||
__device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
|
||||
{
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(AK0, Number<MPerBlock>{}, AK1),
|
||||
make_tuple(Number<MPerBlock + ABlockLdsExtraM>{} * AK1, AK1, I1));
|
||||
make_tuple(AK0Number, Number<MPerBlock>{}, AK1Number),
|
||||
make_tuple(Number<MPerBlock + ABlockLdsExtraM>{} * AK1Number, AK1Number, I1));
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
|
||||
__device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
|
||||
{
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(BK0, Number<NPerBlock>{}, BK1),
|
||||
make_tuple(Number<NPerBlock + BBlockLdsExtraN>{} * BK1, BK1, I1));
|
||||
make_tuple(BK0Number, Number<NPerBlock>{}, BK1Number),
|
||||
make_tuple(Number<NPerBlock + BBlockLdsExtraN>{} * BK1Number, BK1Number, I1));
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
|
||||
__device__ static constexpr auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
|
||||
{
|
||||
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
|
||||
constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
|
||||
@@ -172,14 +519,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
|
||||
return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
|
||||
__device__ static constexpr index_t GetSharedMemoryNumberOfByte()
|
||||
{
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
|
||||
constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
|
||||
|
||||
// lds max alignment
|
||||
constexpr auto max_lds_align = math::lcm(AK1, BK1);
|
||||
constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
|
||||
|
||||
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
|
||||
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
|
||||
@@ -200,36 +547,102 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
|
||||
}
|
||||
|
||||
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
|
||||
template <typename Block2CTileMap>
|
||||
__host__ __device__ static constexpr bool
|
||||
CheckValidity(const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
|
||||
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
|
||||
const CGridDesc_M_N& c_grid_desc_m_n,
|
||||
const Block2CTileMap& block_2_ctile_map)
|
||||
__host__ static constexpr bool CheckValidity(const Problem& problem)
|
||||
{
|
||||
static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
|
||||
(NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
|
||||
"Invalid tuning param!");
|
||||
|
||||
const auto M = a_grid_desc_ak0_m_ak1.GetLength(I1);
|
||||
const auto N = b_grid_desc_bk0_n_bk1.GetLength(I1);
|
||||
const auto K = a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2);
|
||||
|
||||
if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1)))
|
||||
return false;
|
||||
|
||||
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0))
|
||||
return false;
|
||||
|
||||
// check gridwise gemm pipeline
|
||||
const auto num_k_loop = K / KPerBlock;
|
||||
|
||||
if(!GridwiseGemmPipe::IsSupported(num_k_loop))
|
||||
if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
|
||||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
|
||||
GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding ||
|
||||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding))
|
||||
{
|
||||
return false;
|
||||
if(!(problem.M % MPerBlock == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n))
|
||||
if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
|
||||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
|
||||
GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding ||
|
||||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding))
|
||||
{
|
||||
if(!(problem.N % NPerBlock == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding ||
|
||||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding ||
|
||||
GemmSpec == tensor_operation::device::GemmSpecialization::KPadding ||
|
||||
GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding)
|
||||
{
|
||||
if(!(CalculateKPadded(problem.K) % AK1Value == 0) ||
|
||||
!(CalculateKPadded(problem.K) % BK1Value == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(!(problem.K % AK1Value == 0) || !(problem.K % BK1Value == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
|
||||
{
|
||||
if(problem.K % ABlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(problem.M % ABlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
|
||||
{
|
||||
if(problem.N % BBlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(problem.K % BBlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
|
||||
{
|
||||
if(problem.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(problem.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// check gridwise gemm pipeline
|
||||
const auto num_k_loop = (CalculateAK0(problem.K) * AK1Value) / KPerBlock;
|
||||
|
||||
if(!GridwiseGemmPipe::IsSupported(num_k_loop))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
@@ -238,22 +651,17 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
|
||||
return true;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
|
||||
__host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
|
||||
{
|
||||
const index_t num_loop = K / KPerBlock;
|
||||
|
||||
return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc_M_N& c_grid_desc_m_n)
|
||||
template <typename CGridDesc>
|
||||
__device__ static constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock)
|
||||
{
|
||||
const auto M = c_grid_desc_m_n.GetLength(I0);
|
||||
const auto N = c_grid_desc_m_n.GetLength(I1);
|
||||
|
||||
const auto MBlock = M / MPerBlock;
|
||||
const auto NBlock = N / NPerBlock;
|
||||
|
||||
const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
|
||||
c_grid_desc_m_n,
|
||||
make_tuple(make_unmerge_transform(make_tuple(MBlock, Number<MPerBlock>{})),
|
||||
@@ -265,33 +673,26 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
|
||||
}
|
||||
|
||||
// return block_id to C matrix tile idx (m0, n0) mapping
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n)
|
||||
{
|
||||
return BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, CGridDesc_M_N>(
|
||||
c_grid_desc_m_n);
|
||||
}
|
||||
using Block2CTileMap = BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock>;
|
||||
|
||||
using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
|
||||
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}))>;
|
||||
|
||||
using DefaultBlock2CTileMap =
|
||||
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}))>;
|
||||
|
||||
template <bool HasMainKBlockLoop, typename Block2CTileMap>
|
||||
template <bool HasMainKBlockLoop>
|
||||
__device__ static void Run(const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
void* __restrict__ p_shared,
|
||||
const AElementwiseOperation& a_element_op,
|
||||
const BElementwiseOperation& b_element_op,
|
||||
const CElementwiseOperation& c_element_op,
|
||||
const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
|
||||
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
|
||||
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const Block2CTileMap& block_2_ctile_map)
|
||||
const Problem& problem)
|
||||
{
|
||||
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
|
||||
problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
|
||||
const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(
|
||||
problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0);
|
||||
const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(
|
||||
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
|
||||
|
||||
const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
|
||||
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
c_grid_desc_m_n, problem.MBlock, problem.NBlock);
|
||||
|
||||
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
|
||||
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
@@ -299,7 +700,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
|
||||
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
|
||||
|
||||
const AElementwiseOperation a_element_op{};
|
||||
const BElementwiseOperation b_element_op{};
|
||||
const CElementwiseOperation c_element_op{};
|
||||
|
||||
// divide block work by [M, N]
|
||||
const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N};
|
||||
|
||||
const auto block_work_idx =
|
||||
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
|
||||
|
||||
@@ -319,7 +726,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
|
||||
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
|
||||
|
||||
// lds max alignment
|
||||
constexpr auto max_lds_align = math::lcm(AK1, BK1);
|
||||
constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
|
||||
@@ -333,7 +740,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
|
||||
AElementwiseOperation,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
Sequence<AK0, MPerBlock, AK1>,
|
||||
Sequence<AK0Number, MPerBlock, AK1Number>,
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
FloatAB,
|
||||
@@ -364,7 +771,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
|
||||
BElementwiseOperation,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
Sequence<BK0, NPerBlock, BK1>,
|
||||
Sequence<BK0Number, NPerBlock, BK1Number>,
|
||||
BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
FloatAB,
|
||||
@@ -396,8 +803,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
|
||||
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
|
||||
// register
|
||||
// sanity check
|
||||
constexpr index_t KPack = math::max(
|
||||
math::lcm(AK1, BK1), MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
|
||||
constexpr index_t KPack =
|
||||
math::max(math::lcm(AK1Number, BK1Number),
|
||||
MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
|
||||
|
||||
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
|
||||
BlockSize,
|
||||
@@ -425,8 +833,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
|
||||
static_cast<FloatAB*>(p_shared) + a_block_space_size_aligned,
|
||||
b_block_desc_bk0_n_bk1.GetElementSpaceSize());
|
||||
|
||||
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0);
|
||||
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0);
|
||||
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
|
||||
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0);
|
||||
|
||||
// gridwise GEMM pipeline
|
||||
static_assert(std::is_default_constructible_v<GridwiseGemmPipe>);
|
||||
|
||||
Reference in New Issue
Block a user