Merge commit '46f1d740f03d11bc2a78fce60a95cd0933b9dd4d' into develop

This commit is contained in:
assistant-librarian[bot]
2025-12-02 00:36:50 +00:00
parent 1b8a648333
commit 94dda8df22
30 changed files with 2291 additions and 268 deletions

View File

@@ -470,9 +470,9 @@ struct GridwiseGemm_wmma_cshuffle_v3
DsGridPointer p_ds_grid;
EDataType* p_e_grid;
const AElementwiseOperation a_element_op;
const BElementwiseOperation b_element_op;
const CDEElementwiseOperation cde_element_op;
AElementwiseOperation a_element_op;
BElementwiseOperation b_element_op;
CDEElementwiseOperation cde_element_op;
// TODO: it can be used with SplitK+reduction but currently only used with SplitK+atomicAdd
bool is_reduce;
@@ -555,13 +555,17 @@ struct GridwiseGemm_wmma_cshuffle_v3
template <bool HasMainKBlockLoop,
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
TailNumber TailNum,
typename EpilogueArgument>
typename Block2CTileMap,
typename EpilogueArgument,
int BlockMapMBlockIndex = 0,
int BlockMapNBlockIndex = 1>
__device__ static void Run(AsGridPointer& p_as_grid,
BsGridPointer& p_bs_grid,
DsGridPointer& p_ds_grid,
EDataType* p_e_grid,
void* p_shared,
const Problem& problem,
const Block2CTileMap& block_2_ctile_map,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op,
@@ -582,9 +586,6 @@ struct GridwiseGemm_wmma_cshuffle_v3
MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
e_grid_desc_m_n, problem.MBlock, problem.NBlock);
// divide block work by [M, N]
const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4};
const auto block_work_idx =
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
@@ -596,8 +597,10 @@ struct GridwiseGemm_wmma_cshuffle_v3
return;
}
const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]);
const index_t block_m_id =
__builtin_amdgcn_readfirstlane(block_work_idx[Number<BlockMapMBlockIndex>{}]);
const index_t block_n_id =
__builtin_amdgcn_readfirstlane(block_work_idx[Number<BlockMapNBlockIndex>{}]);
// BScale struct (Empty)
using BScale = typename BlockwiseGemmPipe::Empty;
@@ -632,15 +635,51 @@ struct GridwiseGemm_wmma_cshuffle_v3
epilogue_args);
}
template <bool HasMainKBlockLoop,
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
TailNumber TailNum,
typename EpilogueArgument>
__device__ static void Run(AsGridPointer& p_as_grid,
BsGridPointer& p_bs_grid,
DsGridPointer& p_ds_grid,
EDataType* p_e_grid,
void* p_shared,
const Problem& problem,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op,
EpilogueArgument& epilogue_args)
{
Run<HasMainKBlockLoop,
EGlobalMemoryDataOperation,
TailNum,
Block2CTileMap,
EpilogueArgument>(p_as_grid,
p_bs_grid,
p_ds_grid,
p_e_grid,
p_shared,
problem,
DefaultBlock2CTileMap(problem),
a_element_op,
b_element_op,
cde_element_op,
epilogue_args);
}
// Wrapper function to have __global__ function in common
// between gemm_universal, b_scale, ab_scale, etc.
template <bool HasMainKBlockLoop,
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
TailNumber TailNum,
typename EpilogueArgument>
typename Block2CTileMap,
typename EpilogueArgument,
int BlockMapMBlockIndex = 0,
int BlockMapNBlockIndex = 1>
__device__ static void Run(void* p_shared,
const SplitKBatchOffset& splitk_batch_offset,
Argument& karg,
const Block2CTileMap& block_2_ctile_map,
EpilogueArgument& epilogue_args)
{
// shift A matrices pointer for splitk
@@ -659,17 +698,47 @@ struct GridwiseGemm_wmma_cshuffle_v3
splitk_batch_offset.b_k_split_offset[i];
});
Run<HasMainKBlockLoop, EGlobalMemoryDataOperation, TailNum>(
p_as_grid_splitk,
p_bs_grid_splitk,
karg.p_ds_grid,
karg.p_e_grid + splitk_batch_offset.c_reduce_offset,
p_shared,
karg,
karg.a_element_op,
karg.b_element_op,
karg.cde_element_op,
epilogue_args);
Run<HasMainKBlockLoop,
EGlobalMemoryDataOperation,
TailNum,
Block2CTileMap,
EpilogueArgument,
BlockMapMBlockIndex,
BlockMapNBlockIndex>(p_as_grid_splitk,
p_bs_grid_splitk,
karg.p_ds_grid,
karg.p_e_grid + splitk_batch_offset.c_reduce_offset,
p_shared,
karg,
block_2_ctile_map,
karg.a_element_op,
karg.b_element_op,
karg.cde_element_op,
epilogue_args);
}
// Wrapper function to have __global__ function in common
// between gemm_universal, b_scale, ab_scale, etc.
template <bool HasMainKBlockLoop,
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
TailNumber TailNum,
typename EpilogueArgument>
__device__ static void Run(void* p_shared,
const SplitKBatchOffset& splitk_batch_offset,
Argument& karg,
EpilogueArgument& epilogue_args)
{
Run<HasMainKBlockLoop,
EGlobalMemoryDataOperation,
TailNum,
Block2CTileMap,
EpilogueArgument>(
p_shared, splitk_batch_offset, karg, DefaultBlock2CTileMap(karg), epilogue_args);
}
__device__ static auto DefaultBlock2CTileMap(const Problem& problem)
{
return Block2CTileMap{problem.M, problem.N, 4};
}
};

View File

@@ -729,6 +729,13 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
auto KReadPadSplited = math::integer_divide_ceil(karg.K, K_t) * KReadVec;
if((KReadPadSplited * (karg.KBatch - 1)) >= karg.K)
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "Arg K value too low for combination of AK1/BK1/KBatch. AK1: "
<< AK1Number << ", BK1: " << BK1Number << ", KBatch: " << karg.KBatch
<< ", K: " << karg.K << " " << __FILE__ << ":" << __LINE__
<< ", in function: " << __func__ << std::endl;
}
return false;
}
}