gridwise update

This commit is contained in:
ozturkosu
2025-06-13 04:54:51 -04:00
parent 2a88a04999
commit d7e80d2ae0

View File

@@ -4,7 +4,6 @@
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/utility/env.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
@@ -140,24 +139,9 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
static constexpr auto AK1Number = Number<AK1Value>{};
static constexpr auto BK1Number = Number<BK1Value>{};
static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number);
static constexpr bool is_single_rate_mfma =
(((is_same<ComputeTypeA, half_t>::value || is_same<ComputeTypeA, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4) ||
(is_same<ComputeTypeA, int8_t>::value && lcm_AK1_BK1 <= 8) ||
((is_same<ComputeTypeA, f8_t>::value || is_same<ComputeTypeA, bf8_t>::value) &&
lcm_AK1_BK1 < 32))
? true
: false;
static constexpr auto is_scale_mfma = false;
static constexpr index_t KPack =
math::max(lcm_AK1_BK1,
MfmaSelector<ComputeTypeA,
MPerXdl,
NPerXdl,
ComputeTypeA,
is_single_rate_mfma,
is_scale_mfma>::selected_mfma.k_per_blk);
math::max(math::lcm(AK1Number, BK1Number),
MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
__host__ static auto CalculateMPadded(index_t M)
@@ -239,23 +223,6 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
}
}();
// Pad both M and K to be multiples of the block sizes
const auto a_grid_desc_m_k =
transform_tensor_descriptor(a_grid_desc_mraw_kraw,
make_tuple(make_right_pad_transform(M, MPad - M),
make_right_pad_transform(K, KPad - K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)),
make_pass_through_transform(MPad)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
#if 0
using GemmSpecialization = tensor_operation::device::GemmSpecialization;
if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
@@ -322,7 +289,6 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
return a_grid_desc_ak0_m_ak1;
}
#endif
}
__device__ static auto MakeBGridDescriptor_BK0_N_BK1(
@@ -339,23 +305,6 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
}
}();
// Pad both N and K to be multiples of the block sizes
const auto b_grid_desc_n_k =
transform_tensor_descriptor(b_grid_desc_nraw_kraw,
make_tuple(make_right_pad_transform(N, NPad - N),
make_right_pad_transform(K, KPad - K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)),
make_pass_through_transform(NPad)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
#if 0
using GemmSpecialization = tensor_operation::device::GemmSpecialization;
if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
@@ -422,7 +371,6 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
return b_grid_desc_bk0_n_bk1;
}
#endif
}
template <typename ABlockDesc_AK0_M_AK1>
@@ -457,13 +405,6 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
}
}();
// Pad both M and N to be multiples of the block sizes
return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
make_tuple(make_right_pad_transform(M, MPad - M),
make_right_pad_transform(N, NPad - N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
#if 0
using GemmSpecialization = tensor_operation::device::GemmSpecialization;
if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
@@ -501,7 +442,6 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
// not pad M or N
return c_grid_desc_mraw_nraw;
}
#endif
}
struct Problem
@@ -513,8 +453,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
index_t StrideB_,
index_t StrideC_,
index_t Streamk_sel_,
index_t Grid_size_,
StreamKReductionStrategy reduction_strategy_)
index_t Grid_size_)
: M{M_},
N{N_},
K{K_},
@@ -523,7 +462,6 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
StrideC{StrideC_},
Streamk_sel{Streamk_sel_},
Grid_size{Grid_size_},
reduction_strategy{reduction_strategy_}, // Initialize the member variable
MPadded{CalculateMPadded(M_)},
NPadded{CalculateNPadded(N_)},
KRead{CalculateKRead(K_, 1)},
@@ -552,13 +490,8 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
<< "AK0:" << AK0 << ", "
<< "BK0:" << BK0 << ", "
<< "MBlock: " << MBlock << ", "
<< "NBlock: " << NBlock << ", "
<< "Stream-K Selection:" << Streamk_sel << ", "
<< "Grid size:" << Grid_size << ", "
<< "Reduction Strategy:"
<< (reduction_strategy == StreamKReductionStrategy::Atomic ? "Atomic"
: "Reduction")
<< "}" << std::endl;
<< "NBlock: " << NBlock << ", Stream-K Selection:" << Streamk_sel
<< ", Grid size:" << Grid_size << "}" << std::endl;
}
index_t M;
@@ -569,7 +502,6 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
index_t StrideC;
index_t Streamk_sel;
mutable index_t Grid_size;
StreamKReductionStrategy reduction_strategy;
index_t MPadded;
index_t NPadded;
index_t KRead;
@@ -593,26 +525,14 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
index_t StrideB_,
index_t StrideC_,
index_t Streamk_sel_,
index_t Grid_size_,
StreamKReductionStrategy reduction_strategy_)
: Problem{M_,
N_,
K_,
StrideA_,
StrideB_,
StrideC_,
Streamk_sel_,
Grid_size_,
reduction_strategy_},
index_t Grid_size_)
: Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_, Streamk_sel_, Grid_size_},
p_a_grid{p_a_grid_},
p_b_grid{p_b_grid_},
p_c_grid{p_c_grid_},
block_2_ctile_map_streamk(M_,
N_,
AK0Number * CalculateKPadded(K_, 1),
Grid_size_,
Streamk_sel_,
reduction_strategy_)
block_2_ctile_map_streamk(
M_, N_, AK0Number * CalculateKPadded(K_, 1), Grid_size_, Streamk_sel_),
launch_grid_dims_{0, 0, 0} // Initialize grid dims to zero
{
}
@@ -627,6 +547,18 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
8,
4>
block_2_ctile_map_streamk;
mutable dim3 launch_grid_dims_;
void SetLaunchGridDims(dim3 dims) const
{
launch_grid_dims_ = dims;
}
dim3 GetLaunchGridDims() const
{
return launch_grid_dims_;
}
};
struct SplitKBatchOffset
@@ -1027,8 +959,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) &&
!(is_same<tensor_layout::gemm::RowMajor, ALayout>::value))
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding))
{
if(!(karg.M % MPerBlock == 0))
{
@@ -1045,8 +976,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) &&
(is_same<tensor_layout::gemm::RowMajor, BLayout>::value))
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding))
{
if(!(karg.N % NPerBlock == 0))
{
@@ -1112,7 +1042,6 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
<< ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
}
return false;
}
}
@@ -1128,10 +1057,6 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
<< BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
}
std::cout << "Arg N (" << karg.N
<< ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
<< BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
return false;
}
}
@@ -1146,7 +1071,6 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
<< BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
}
return false;
}
}
@@ -1164,7 +1088,6 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
<< std::endl;
}
return false;
}
}
@@ -1181,11 +1104,20 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
<< std::endl;
}
return false;
}
}
if constexpr(is_same<remove_cvref_t<CDataType>, bhalf_t>::value)
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << " Grid size: " << karg.Grid_size << " > 1 is not support yet"
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
<< std::endl;
}
}
// check gridwise gemm pipeline
const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value);
@@ -1288,13 +1220,11 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
Block2CTileMap_streamk block_2_ctile_map_streamk(problem.M,
problem.N,
AK0Number * problem.KPadded,
problem.Grid_size,
problem.Streamk_sel,
problem.reduction_strategy);
problem.Streamk_sel);
uint32_t iter_start, iter_end;
bool is_sk_block, is_dp_block, is_reduction_block;
index_t num_k_block_main_loop;
@@ -1309,7 +1239,6 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
uint32_t* p_semaphore = reinterpret_cast<uint32_t*>(
reinterpret_cast<char*>(p_workspace) +
block_2_ctile_map_streamk.get_workspace_size_for_acc(sizeof(AccDataType)));
for(auto block_idx = get_block_1d_id();
block_idx < block_2_ctile_map_streamk.get_grid_dims();
block_idx += gridDim.x)
@@ -1325,7 +1254,8 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
block_2_ctile_map_streamk.get_block_itr(block_idx, iter_start, iter_end);
num_k_block_main_loop = iter_end - iter_start;
if(problem.reduction_strategy == StreamKReductionStrategy::Reduction)
if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
StreamKReductionStrategy::Reduction)
{
is_reduction_block = static_cast<uint32_t>(block_idx) >=
block_2_ctile_map_streamk.reduction_start_block_idx;
@@ -1913,7 +1843,8 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
}
else if(is_sk_block)
{
if(problem.reduction_strategy == StreamKReductionStrategy::Atomic)
if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
StreamKReductionStrategy::Atomic)
{
// each block copy its data from LDS to global
c_shuffle_block_copy_lds_to_global
@@ -1925,8 +1856,8 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
c_grid_desc_mblock_mperblock_nblock_nperblock,
c_grid_buf);
}
else if(problem.reduction_strategy ==
StreamKReductionStrategy::Reduction)
else if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
StreamKReductionStrategy::Reduction)
{
// constexpr offset
c_block_copy_lds_to_partial_acc.SetSrcSliceOrigin(
@@ -1958,7 +1889,8 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
}
});
if(problem.reduction_strategy == StreamKReductionStrategy::Reduction)
if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
StreamKReductionStrategy::Reduction)
{
if(is_sk_block)
{
@@ -1973,7 +1905,8 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
iter_end -= current_iter_length;
if(iter_end <= iter_start)
break;
if(problem.reduction_strategy == StreamKReductionStrategy::Reduction)
if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
StreamKReductionStrategy::Reduction)
{
block_acc_offset -= MPerBlock * NPerBlock;
}
@@ -2028,8 +1961,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
problem.N,
AK0Number * problem.KPadded,
problem.Grid_size,
problem.Streamk_sel,
problem.reduction_strategy);
problem.Streamk_sel);
for(auto block_idx = get_block_1d_id();
block_idx < block_2_ctile_map_streamk.get_grid_dims();
block_idx += gridDim.x)
@@ -2048,7 +1980,8 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
reinterpret_cast<char*>(p_workspace) +
block_2_ctile_map_streamk.get_workspace_size_for_acc(sizeof(AccDataType)));
if(problem.reduction_strategy == StreamKReductionStrategy::Reduction)
if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
StreamKReductionStrategy::Reduction)
{
is_reduction_block = static_cast<uint32_t>(block_idx) >=
block_2_ctile_map_streamk.reduction_start_block_idx;
@@ -2664,7 +2597,8 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
}
else if(is_sk_block)
{
if(problem.reduction_strategy == StreamKReductionStrategy::Atomic)
if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
StreamKReductionStrategy::Atomic)
{
// each block copy its data from LDS to global
c_shuffle_block_copy_lds_to_global
@@ -2676,8 +2610,8 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
c_grid_desc_mblock_mperblock_nblock_nperblock,
c_grid_buf);
}
else if(problem.reduction_strategy ==
StreamKReductionStrategy::Reduction)
else if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
StreamKReductionStrategy::Reduction)
{
// constexpr offset
c_block_copy_lds_to_partial_acc.SetSrcSliceOrigin(
@@ -2712,14 +2646,16 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
iter_end -= current_iter_length;
if(iter_end <= iter_start)
break;
if(problem.reduction_strategy == StreamKReductionStrategy::Reduction)
if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
StreamKReductionStrategy::Reduction)
{
block_acc_offset -= MPerBlock * NPerBlock;
}
// make sure next loop LDS is ready for use
block_sync_lds();
}
if(problem.reduction_strategy == StreamKReductionStrategy::Reduction)
if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
StreamKReductionStrategy::Reduction)
{
if(is_sk_block)
{