mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-01 04:07:56 +00:00
gridwise update
This commit is contained in:
@@ -4,7 +4,6 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/utility/env.hpp"
|
||||
#include "ck/tensor_description/multi_index_transform_helper.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
@@ -140,24 +139,9 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
static constexpr auto AK1Number = Number<AK1Value>{};
|
||||
static constexpr auto BK1Number = Number<BK1Value>{};
|
||||
|
||||
static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number);
|
||||
static constexpr bool is_single_rate_mfma =
|
||||
(((is_same<ComputeTypeA, half_t>::value || is_same<ComputeTypeA, bhalf_t>::value) &&
|
||||
lcm_AK1_BK1 <= 4) ||
|
||||
(is_same<ComputeTypeA, int8_t>::value && lcm_AK1_BK1 <= 8) ||
|
||||
((is_same<ComputeTypeA, f8_t>::value || is_same<ComputeTypeA, bf8_t>::value) &&
|
||||
lcm_AK1_BK1 < 32))
|
||||
? true
|
||||
: false;
|
||||
static constexpr auto is_scale_mfma = false;
|
||||
static constexpr index_t KPack =
|
||||
math::max(lcm_AK1_BK1,
|
||||
MfmaSelector<ComputeTypeA,
|
||||
MPerXdl,
|
||||
NPerXdl,
|
||||
ComputeTypeA,
|
||||
is_single_rate_mfma,
|
||||
is_scale_mfma>::selected_mfma.k_per_blk);
|
||||
math::max(math::lcm(AK1Number, BK1Number),
|
||||
MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
|
||||
|
||||
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
|
||||
__host__ static auto CalculateMPadded(index_t M)
|
||||
@@ -239,23 +223,6 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
}
|
||||
}();
|
||||
|
||||
// Pad both M and K to be multiples of the block sizes
|
||||
const auto a_grid_desc_m_k =
|
||||
transform_tensor_descriptor(a_grid_desc_mraw_kraw,
|
||||
make_tuple(make_right_pad_transform(M, MPad - M),
|
||||
make_right_pad_transform(K, KPad - K)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
|
||||
a_grid_desc_m_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)),
|
||||
make_pass_through_transform(MPad)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
return a_grid_desc_ak0_m_ak1;
|
||||
#if 0
|
||||
using GemmSpecialization = tensor_operation::device::GemmSpecialization;
|
||||
|
||||
if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
|
||||
@@ -322,7 +289,6 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
|
||||
return a_grid_desc_ak0_m_ak1;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
__device__ static auto MakeBGridDescriptor_BK0_N_BK1(
|
||||
@@ -339,23 +305,6 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
}
|
||||
}();
|
||||
|
||||
// Pad both N and K to be multiples of the block sizes
|
||||
const auto b_grid_desc_n_k =
|
||||
transform_tensor_descriptor(b_grid_desc_nraw_kraw,
|
||||
make_tuple(make_right_pad_transform(N, NPad - N),
|
||||
make_right_pad_transform(K, KPad - K)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
|
||||
b_grid_desc_n_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)),
|
||||
make_pass_through_transform(NPad)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
return b_grid_desc_bk0_n_bk1;
|
||||
#if 0
|
||||
using GemmSpecialization = tensor_operation::device::GemmSpecialization;
|
||||
|
||||
if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
|
||||
@@ -422,7 +371,6 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
|
||||
return b_grid_desc_bk0_n_bk1;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename ABlockDesc_AK0_M_AK1>
|
||||
@@ -457,13 +405,6 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
}
|
||||
}();
|
||||
|
||||
// Pad both M and N to be multiples of the block sizes
|
||||
return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
|
||||
make_tuple(make_right_pad_transform(M, MPad - M),
|
||||
make_right_pad_transform(N, NPad - N)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
#if 0
|
||||
using GemmSpecialization = tensor_operation::device::GemmSpecialization;
|
||||
|
||||
if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
|
||||
@@ -501,7 +442,6 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
// not pad M or N
|
||||
return c_grid_desc_mraw_nraw;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
struct Problem
|
||||
@@ -513,8 +453,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
index_t StrideB_,
|
||||
index_t StrideC_,
|
||||
index_t Streamk_sel_,
|
||||
index_t Grid_size_,
|
||||
StreamKReductionStrategy reduction_strategy_)
|
||||
index_t Grid_size_)
|
||||
: M{M_},
|
||||
N{N_},
|
||||
K{K_},
|
||||
@@ -523,7 +462,6 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
StrideC{StrideC_},
|
||||
Streamk_sel{Streamk_sel_},
|
||||
Grid_size{Grid_size_},
|
||||
reduction_strategy{reduction_strategy_}, // Initialize the member variable
|
||||
MPadded{CalculateMPadded(M_)},
|
||||
NPadded{CalculateNPadded(N_)},
|
||||
KRead{CalculateKRead(K_, 1)},
|
||||
@@ -552,13 +490,8 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
<< "AK0:" << AK0 << ", "
|
||||
<< "BK0:" << BK0 << ", "
|
||||
<< "MBlock: " << MBlock << ", "
|
||||
<< "NBlock: " << NBlock << ", "
|
||||
<< "Stream-K Selection:" << Streamk_sel << ", "
|
||||
<< "Grid size:" << Grid_size << ", "
|
||||
<< "Reduction Strategy:"
|
||||
<< (reduction_strategy == StreamKReductionStrategy::Atomic ? "Atomic"
|
||||
: "Reduction")
|
||||
<< "}" << std::endl;
|
||||
<< "NBlock: " << NBlock << ", Stream-K Selection:" << Streamk_sel
|
||||
<< ", Grid size:" << Grid_size << "}" << std::endl;
|
||||
}
|
||||
|
||||
index_t M;
|
||||
@@ -569,7 +502,6 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
index_t StrideC;
|
||||
index_t Streamk_sel;
|
||||
mutable index_t Grid_size;
|
||||
StreamKReductionStrategy reduction_strategy;
|
||||
index_t MPadded;
|
||||
index_t NPadded;
|
||||
index_t KRead;
|
||||
@@ -593,26 +525,14 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
index_t StrideB_,
|
||||
index_t StrideC_,
|
||||
index_t Streamk_sel_,
|
||||
index_t Grid_size_,
|
||||
StreamKReductionStrategy reduction_strategy_)
|
||||
: Problem{M_,
|
||||
N_,
|
||||
K_,
|
||||
StrideA_,
|
||||
StrideB_,
|
||||
StrideC_,
|
||||
Streamk_sel_,
|
||||
Grid_size_,
|
||||
reduction_strategy_},
|
||||
index_t Grid_size_)
|
||||
: Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_, Streamk_sel_, Grid_size_},
|
||||
p_a_grid{p_a_grid_},
|
||||
p_b_grid{p_b_grid_},
|
||||
p_c_grid{p_c_grid_},
|
||||
block_2_ctile_map_streamk(M_,
|
||||
N_,
|
||||
AK0Number * CalculateKPadded(K_, 1),
|
||||
Grid_size_,
|
||||
Streamk_sel_,
|
||||
reduction_strategy_)
|
||||
block_2_ctile_map_streamk(
|
||||
M_, N_, AK0Number * CalculateKPadded(K_, 1), Grid_size_, Streamk_sel_),
|
||||
launch_grid_dims_{0, 0, 0} // Initialize grid dims to zero
|
||||
|
||||
{
|
||||
}
|
||||
@@ -627,6 +547,18 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
8,
|
||||
4>
|
||||
block_2_ctile_map_streamk;
|
||||
|
||||
mutable dim3 launch_grid_dims_;
|
||||
|
||||
void SetLaunchGridDims(dim3 dims) const
|
||||
{
|
||||
launch_grid_dims_ = dims;
|
||||
}
|
||||
|
||||
dim3 GetLaunchGridDims() const
|
||||
{
|
||||
return launch_grid_dims_;
|
||||
}
|
||||
};
|
||||
|
||||
struct SplitKBatchOffset
|
||||
@@ -1027,8 +959,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
|
||||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
|
||||
GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding ||
|
||||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) &&
|
||||
!(is_same<tensor_layout::gemm::RowMajor, ALayout>::value))
|
||||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding))
|
||||
{
|
||||
if(!(karg.M % MPerBlock == 0))
|
||||
{
|
||||
@@ -1045,8 +976,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
|
||||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
|
||||
GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding ||
|
||||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) &&
|
||||
(is_same<tensor_layout::gemm::RowMajor, BLayout>::value))
|
||||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding))
|
||||
{
|
||||
if(!(karg.N % NPerBlock == 0))
|
||||
{
|
||||
@@ -1112,7 +1042,6 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
<< ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
|
||||
<< __LINE__ << ", in function: " << __func__ << std::endl;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@@ -1128,10 +1057,6 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
<< BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
|
||||
<< __LINE__ << ", in function: " << __func__ << std::endl;
|
||||
}
|
||||
std::cout << "Arg N (" << karg.N
|
||||
<< ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
|
||||
<< BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
|
||||
<< __LINE__ << ", in function: " << __func__ << std::endl;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@@ -1146,7 +1071,6 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
<< BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
|
||||
<< __LINE__ << ", in function: " << __func__ << std::endl;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@@ -1164,7 +1088,6 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@@ -1181,11 +1104,20 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(is_same<remove_cvref_t<CDataType>, bhalf_t>::value)
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << " Grid size: " << karg.Grid_size << " > 1 is not support yet"
|
||||
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
// check gridwise gemm pipeline
|
||||
const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value);
|
||||
|
||||
@@ -1288,13 +1220,11 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
|
||||
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
|
||||
|
||||
Block2CTileMap_streamk block_2_ctile_map_streamk(problem.M,
|
||||
problem.N,
|
||||
AK0Number * problem.KPadded,
|
||||
problem.Grid_size,
|
||||
problem.Streamk_sel,
|
||||
problem.reduction_strategy);
|
||||
problem.Streamk_sel);
|
||||
uint32_t iter_start, iter_end;
|
||||
bool is_sk_block, is_dp_block, is_reduction_block;
|
||||
index_t num_k_block_main_loop;
|
||||
@@ -1309,7 +1239,6 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
uint32_t* p_semaphore = reinterpret_cast<uint32_t*>(
|
||||
reinterpret_cast<char*>(p_workspace) +
|
||||
block_2_ctile_map_streamk.get_workspace_size_for_acc(sizeof(AccDataType)));
|
||||
|
||||
for(auto block_idx = get_block_1d_id();
|
||||
block_idx < block_2_ctile_map_streamk.get_grid_dims();
|
||||
block_idx += gridDim.x)
|
||||
@@ -1325,7 +1254,8 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
block_2_ctile_map_streamk.get_block_itr(block_idx, iter_start, iter_end);
|
||||
num_k_block_main_loop = iter_end - iter_start;
|
||||
|
||||
if(problem.reduction_strategy == StreamKReductionStrategy::Reduction)
|
||||
if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
|
||||
StreamKReductionStrategy::Reduction)
|
||||
{
|
||||
is_reduction_block = static_cast<uint32_t>(block_idx) >=
|
||||
block_2_ctile_map_streamk.reduction_start_block_idx;
|
||||
@@ -1913,7 +1843,8 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
}
|
||||
else if(is_sk_block)
|
||||
{
|
||||
if(problem.reduction_strategy == StreamKReductionStrategy::Atomic)
|
||||
if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
|
||||
StreamKReductionStrategy::Atomic)
|
||||
{
|
||||
// each block copy its data from LDS to global
|
||||
c_shuffle_block_copy_lds_to_global
|
||||
@@ -1925,8 +1856,8 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
c_grid_buf);
|
||||
}
|
||||
else if(problem.reduction_strategy ==
|
||||
StreamKReductionStrategy::Reduction)
|
||||
else if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
|
||||
StreamKReductionStrategy::Reduction)
|
||||
{
|
||||
// constexpr offset
|
||||
c_block_copy_lds_to_partial_acc.SetSrcSliceOrigin(
|
||||
@@ -1958,7 +1889,8 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
}
|
||||
});
|
||||
|
||||
if(problem.reduction_strategy == StreamKReductionStrategy::Reduction)
|
||||
if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
|
||||
StreamKReductionStrategy::Reduction)
|
||||
{
|
||||
if(is_sk_block)
|
||||
{
|
||||
@@ -1973,7 +1905,8 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
iter_end -= current_iter_length;
|
||||
if(iter_end <= iter_start)
|
||||
break;
|
||||
if(problem.reduction_strategy == StreamKReductionStrategy::Reduction)
|
||||
if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
|
||||
StreamKReductionStrategy::Reduction)
|
||||
{
|
||||
block_acc_offset -= MPerBlock * NPerBlock;
|
||||
}
|
||||
@@ -2028,8 +1961,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
problem.N,
|
||||
AK0Number * problem.KPadded,
|
||||
problem.Grid_size,
|
||||
problem.Streamk_sel,
|
||||
problem.reduction_strategy);
|
||||
problem.Streamk_sel);
|
||||
for(auto block_idx = get_block_1d_id();
|
||||
block_idx < block_2_ctile_map_streamk.get_grid_dims();
|
||||
block_idx += gridDim.x)
|
||||
@@ -2048,7 +1980,8 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
reinterpret_cast<char*>(p_workspace) +
|
||||
block_2_ctile_map_streamk.get_workspace_size_for_acc(sizeof(AccDataType)));
|
||||
|
||||
if(problem.reduction_strategy == StreamKReductionStrategy::Reduction)
|
||||
if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
|
||||
StreamKReductionStrategy::Reduction)
|
||||
{
|
||||
is_reduction_block = static_cast<uint32_t>(block_idx) >=
|
||||
block_2_ctile_map_streamk.reduction_start_block_idx;
|
||||
@@ -2664,7 +2597,8 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
}
|
||||
else if(is_sk_block)
|
||||
{
|
||||
if(problem.reduction_strategy == StreamKReductionStrategy::Atomic)
|
||||
if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
|
||||
StreamKReductionStrategy::Atomic)
|
||||
{
|
||||
// each block copy its data from LDS to global
|
||||
c_shuffle_block_copy_lds_to_global
|
||||
@@ -2676,8 +2610,8 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
c_grid_buf);
|
||||
}
|
||||
else if(problem.reduction_strategy ==
|
||||
StreamKReductionStrategy::Reduction)
|
||||
else if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
|
||||
StreamKReductionStrategy::Reduction)
|
||||
{
|
||||
// constexpr offset
|
||||
c_block_copy_lds_to_partial_acc.SetSrcSliceOrigin(
|
||||
@@ -2712,14 +2646,16 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
iter_end -= current_iter_length;
|
||||
if(iter_end <= iter_start)
|
||||
break;
|
||||
if(problem.reduction_strategy == StreamKReductionStrategy::Reduction)
|
||||
if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
|
||||
StreamKReductionStrategy::Reduction)
|
||||
{
|
||||
block_acc_offset -= MPerBlock * NPerBlock;
|
||||
}
|
||||
// make sure next loop LDS is ready for use
|
||||
block_sync_lds();
|
||||
}
|
||||
if(problem.reduction_strategy == StreamKReductionStrategy::Reduction)
|
||||
if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
|
||||
StreamKReductionStrategy::Reduction)
|
||||
{
|
||||
if(is_sk_block)
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user