Merge commit '87dd073887933fc2c75c234871e3885cee970a98' into develop

This commit is contained in:
assistant-librarian[bot]
2025-12-18 00:34:53 +00:00
parent 3c59d702ca
commit 334ae1c494
82 changed files with 7696 additions and 622 deletions

View File

@@ -295,7 +295,7 @@ struct ABTransferThreadTiles
BlockDescriptor& block_descriptor,
ABElementwiseOperation& ab_element_op,
const index_t block_mn_id,
const index_t)
const index_t k_id)
{
constexpr index_t NumABTensor = ABsDataType::Size();
const index_t mn_block_data_idx_on_grid =
@@ -304,7 +304,7 @@ struct ABTransferThreadTiles
if constexpr(NumABTensor > 1)
{
const auto idx_as_block_begin = generate_tuple(
[&](auto) { return make_multi_index(0, mn_block_data_idx_on_grid, 0); },
[&](auto) { return make_multi_index(k_id, mn_block_data_idx_on_grid, 0); },
Number<NumABTensor>{});
return ThreadGroupTensorSliceTransfer_v7r2<
@@ -357,7 +357,7 @@ struct ABTransferThreadTiles
ABThreadTransferSrcResetCoordinateAfterRun,
true,
GlobalBufferNum>(grid_descriptor[I0],
make_multi_index(0, mn_block_data_idx_on_grid, 0),
make_multi_index(k_id, mn_block_data_idx_on_grid, 0),
ab_element_op,
block_descriptor,
make_multi_index(0, 0, 0),

View File

@@ -333,6 +333,7 @@ struct GridwiseGemm_wmma_cshuffle_v3
struct Problem
{
__host__ Problem() = default;
__host__ Problem(index_t M_,
index_t N_,
index_t K_,
@@ -409,6 +410,7 @@ struct GridwiseGemm_wmma_cshuffle_v3
// Argument
struct Argument : public tensor_operation::device::BaseArgument, public Problem
{
__host__ Argument() = default;
__host__ Argument(std::array<const void*, NumATensor> p_as_grid_,
std::array<const void*, NumBTensor> p_bs_grid_,
std::array<const void*, NumDTensor> p_ds_grid_,
@@ -583,7 +585,8 @@ struct GridwiseGemm_wmma_cshuffle_v3
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op,
EpilogueArgument& epilogue_args,
const index_t k_id = 0)
const index_t A_k_id = 0,
const index_t B_k_id = 0)
{
const auto as_grid_desc_ak0_m_ak1 = MakeAsGridDescriptor_AK0_M_AK1(
problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideAs, problem.AK0);
@@ -651,7 +654,8 @@ struct GridwiseGemm_wmma_cshuffle_v3
a_scale_struct,
b_scale_struct,
epilogue_args,
k_id);
A_k_id,
B_k_id);
}
template <bool HasMainKBlockLoop,
@@ -700,7 +704,8 @@ struct GridwiseGemm_wmma_cshuffle_v3
Argument& karg,
const Block2CTileMap& block_2_ctile_map,
EpilogueArgument& epilogue_args,
const index_t k_id = 0)
const index_t A_k_id = 0,
const index_t B_k_id = 0)
{
// shift A matrices pointer for splitk
AsGridPointer p_as_grid_splitk;
@@ -735,7 +740,8 @@ struct GridwiseGemm_wmma_cshuffle_v3
karg.b_element_op,
karg.cde_element_op,
epilogue_args,
k_id);
A_k_id,
B_k_id);
}
// Wrapper function to have __global__ function in common
@@ -748,20 +754,146 @@ struct GridwiseGemm_wmma_cshuffle_v3
const SplitKBatchOffset& splitk_batch_offset,
Argument& karg,
EpilogueArgument& epilogue_args,
const index_t k_id = 0)
const index_t A_k_id = 0,
const index_t B_k_id = 0)
{
Run<HasMainKBlockLoop,
EGlobalMemoryDataOperation,
TailNum,
Block2CTileMap,
EpilogueArgument>(
p_shared, splitk_batch_offset, karg, DefaultBlock2CTileMap(karg), epilogue_args, k_id);
EpilogueArgument>(p_shared,
splitk_batch_offset,
karg,
DefaultBlock2CTileMap(karg),
epilogue_args,
A_k_id,
B_k_id);
}
__device__ static auto DefaultBlock2CTileMap(const Problem& problem)
{
return Block2CTileMap{problem.M, problem.N, 4};
}
// Run method for convolution (grid descriptors are passed as arguments,
// not generated internally)
template <typename AGridDesc_AK0_M_K1,
typename BGridDesc_BK0_N_K1,
typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename ComputePtrOffsetOfBatch,
index_t NumGroupsToMerge,
bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
TailNumber TailNum,
typename EpilogueArgument>
__device__ static void Run(void* p_shared,
const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1,
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
const index_t num_k_per_block,
Argument& karg,
EpilogueArgument& epilogue_args)
{
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z * NumGroupsToMerge);
const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block);
const long_index_t a_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx));
const long_index_t b_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx));
const long_index_t e_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx));
AsGridPointer p_as_grid_;
static_for<0, NumATensor, 1>{}([&](auto i) {
using ADataType_ = remove_cvref_t<tuple_element_t<i.value, AsDataType>>;
p_as_grid_(i) = static_cast<const ADataType_*>(karg.p_as_grid[i]) + a_batch_offset;
});
BsGridPointer p_bs_grid_;
static_for<0, NumBTensor, 1>{}([&](auto i) {
using BDataType_ = remove_cvref_t<tuple_element_t<i.value, BsDataType>>;
p_bs_grid_(i) = static_cast<const BDataType_*>(karg.p_bs_grid[i]) + b_batch_offset;
});
const auto ds_grid_desc_m_n =
MakeDsGridDescriptor_M_N(karg.M, karg.MPadded, karg.N, karg.NPadded, karg.StrideDs);
const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
ds_grid_desc_m_n, karg.MBlock, karg.NBlock);
const auto as_grid_desc_ak0_m_ak1 = generate_tuple(
[&](auto i) {
ignore = i;
return a_grid_desc_ak0_m_ak1;
},
Number<NumATensor>{});
const auto bs_grid_desc_bk0_n_bk1 = generate_tuple(
[&](auto i) {
ignore = i;
return b_grid_desc_bk0_n_bk1;
},
Number<NumBTensor>{});
// divide block work by [M, N]
const auto block_2_ctile_map = Block2CTileMap{karg.M, karg.N, 4};
const auto block_work_idx =
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
if(!block_2_ctile_map.ValidCTileIndex(
block_work_idx,
make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
{
return;
}
const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]);
// Scale structs (Empty)
using Scale = typename BlockwiseGemmPipe::Empty;
auto b_scale_struct = Scale{};
auto a_scale_struct = Scale{};
const index_t num_k_block_per_scale = GetKBlockPerScale();
Base::template Run<decltype(as_grid_desc_ak0_m_ak1),
decltype(bs_grid_desc_bk0_n_bk1),
decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock),
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
decltype(a_scale_struct),
decltype(b_scale_struct),
decltype(epilogue_args),
HasMainKBlockLoop,
CGlobalMemoryDataOperation,
TailNum>(p_as_grid_,
p_bs_grid_,
karg.p_ds_grid,
karg.p_e_grid + e_batch_offset,
p_shared,
as_grid_desc_ak0_m_ak1,
bs_grid_desc_bk0_n_bk1,
ds_grid_desc_mblock_mperblock_nblock_nperblock,
c_grid_desc_mblock_mperblock_nblock_nperblock,
karg.a_element_op,
karg.b_element_op,
karg.cde_element_op,
block_m_id,
block_n_id,
num_k_block_per_scale,
a_scale_struct,
b_scale_struct,
epilogue_args,
k_idx,
k_idx,
karg.KBatch);
}
};
} // namespace ck

View File

@@ -723,7 +723,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_ab_scale
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op,
EpilogueArgument& epilogue_args,
const index_t k_id = 0)
const index_t A_k_id = 0,
const index_t B_k_id = 0)
{
const auto as_grid_desc_ak0_m_ak1 = MakeAsGridDescriptor_AK0_M_AK1(
problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideAs, problem.AK0);
@@ -793,7 +794,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_ab_scale
a_scale_struct,
b_scale_struct,
epilogue_args,
k_id);
A_k_id,
B_k_id);
}
// NOTE: Wrapper function to have __global__ function in common
@@ -806,7 +808,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_ab_scale
const SplitKBatchOffset& splitk_batch_offset,
Argument& karg,
EpilogueArgument& epilogue_args,
const index_t k_id = 0)
const index_t A_k_id = 0,
const index_t B_k_id = 0)
{
// shift A matrices pointer for splitk
AsGridPointer p_as_grid_splitk;
@@ -857,7 +860,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_ab_scale
karg.b_element_op,
karg.cde_element_op,
epilogue_args,
k_id);
A_k_id,
B_k_id);
}
};

View File

@@ -101,7 +101,12 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{};
GridwiseGemm::template Run<HasMainKBlockLoop, EGlobalMemoryDataOperation, TailNum>(
p_shared, splitk_batch_offset, karg, epilogue_args, k_id);
p_shared,
splitk_batch_offset,
karg,
epilogue_args,
0, /* A_k_id == 0 (we shift the pointer for splitk) */
k_id);
#if defined(__gfx11__)
}
@@ -344,11 +349,20 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
// return block_id to C matrix tile idx (m0, n0) mapping
using Block2CTileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>;
// Calculate grid size taking into account splitk (KBatch)
// 2D grid (x,z)
__host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch)
{
return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, KBatch);
}
// Calculate grid size taking into account splitk (KBatch) and multiple groups (Batch)
// 3D grid (x,y,z)
__host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch, index_t Batch)
{
return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), KBatch, Batch);
}
__host__ static auto CalculateMPadded(index_t M)
{
return math::integer_least_multiple(M, MPerBlock);
@@ -706,8 +720,10 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
ReduceTrait>;
template <typename DEGridDesc>
__device__ static constexpr auto MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
const DEGridDesc& de_grid_desc_m_n, index_t MBlock, index_t NBlock)
__host__ __device__ static constexpr auto
MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DEGridDesc& de_grid_desc_m_n,
index_t MBlock,
index_t NBlock)
{
const auto de_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
de_grid_desc_m_n,
@@ -1004,6 +1020,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
}
}
// Note: arguments k_batch and k_id should be set if splitk is used
// with implicit gemm (no pointer shift but shift using tensor descriptors)
template <typename AGridDesc_AK0_M_K1,
typename BGridDesc_BK0_N_K1,
typename DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
@@ -1034,7 +1052,9 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
AScaleStruct& a_scale_struct,
BScaleStruct& b_scale_struct,
EpilogueArgument& epilogue_args,
const index_t k_id = 0)
const index_t A_k_id = 0,
const index_t B_k_id = 0,
const index_t k_batch = 1)
{
const auto as_grid_buf = generate_tuple(
[&](auto i) {
@@ -1066,7 +1086,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
AsDataType,
AElementwiseOperation,
BlockwiseGemmPipe::GlobalBufferNum>(
as_grid_desc_ak0_m_ak1, a_block_desc_ak0_m_ak1, a_element_op, block_m_id, k_id);
as_grid_desc_ak0_m_ak1, a_block_desc_ak0_m_ak1, a_element_op, block_m_id, A_k_id);
// B matrix blockwise copy
auto b_blockwise_copy =
@@ -1075,7 +1095,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
BsDataType,
BElementwiseOperation,
BlockwiseGemmPipe::GlobalBufferNum>(
bs_grid_desc_bk0_n_bk1, b_block_desc_bk0_n_bk1, b_element_op, block_n_id, k_id);
bs_grid_desc_bk0_n_bk1, b_block_desc_bk0_n_bk1, b_element_op, block_n_id, B_k_id);
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
@@ -1100,7 +1120,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
ATransfer::GetKDimension(as_grid_desc_ak0_m_ak1[I0]) / KPerBlock);
ATransfer::GetKDimension(as_grid_desc_ak0_m_ak1[I0]) / (KPerBlock * k_batch));
blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
get_first_element_workaround<NumATensor>(as_grid_desc_ak0_m_ak1),