mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-21 21:39:15 +00:00
Merge commit '46f1d740f03d11bc2a78fce60a95cd0933b9dd4d' into develop
This commit is contained in:
@@ -470,9 +470,9 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
DsGridPointer p_ds_grid;
|
||||
EDataType* p_e_grid;
|
||||
|
||||
const AElementwiseOperation a_element_op;
|
||||
const BElementwiseOperation b_element_op;
|
||||
const CDEElementwiseOperation cde_element_op;
|
||||
AElementwiseOperation a_element_op;
|
||||
BElementwiseOperation b_element_op;
|
||||
CDEElementwiseOperation cde_element_op;
|
||||
|
||||
// TODO: it can be used with SplitK+reduction but currently only used with SplitK+atomicAdd
|
||||
bool is_reduce;
|
||||
@@ -555,13 +555,17 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
template <bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
|
||||
TailNumber TailNum,
|
||||
typename EpilogueArgument>
|
||||
typename Block2CTileMap,
|
||||
typename EpilogueArgument,
|
||||
int BlockMapMBlockIndex = 0,
|
||||
int BlockMapNBlockIndex = 1>
|
||||
__device__ static void Run(AsGridPointer& p_as_grid,
|
||||
BsGridPointer& p_bs_grid,
|
||||
DsGridPointer& p_ds_grid,
|
||||
EDataType* p_e_grid,
|
||||
void* p_shared,
|
||||
const Problem& problem,
|
||||
const Block2CTileMap& block_2_ctile_map,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op,
|
||||
@@ -582,9 +586,6 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
e_grid_desc_m_n, problem.MBlock, problem.NBlock);
|
||||
|
||||
// divide block work by [M, N]
|
||||
const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4};
|
||||
|
||||
const auto block_work_idx =
|
||||
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
|
||||
|
||||
@@ -596,8 +597,10 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
return;
|
||||
}
|
||||
|
||||
const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
|
||||
const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]);
|
||||
const index_t block_m_id =
|
||||
__builtin_amdgcn_readfirstlane(block_work_idx[Number<BlockMapMBlockIndex>{}]);
|
||||
const index_t block_n_id =
|
||||
__builtin_amdgcn_readfirstlane(block_work_idx[Number<BlockMapNBlockIndex>{}]);
|
||||
|
||||
// BScale struct (Empty)
|
||||
using BScale = typename BlockwiseGemmPipe::Empty;
|
||||
@@ -632,15 +635,51 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
epilogue_args);
|
||||
}
|
||||
|
||||
template <bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
|
||||
TailNumber TailNum,
|
||||
typename EpilogueArgument>
|
||||
__device__ static void Run(AsGridPointer& p_as_grid,
|
||||
BsGridPointer& p_bs_grid,
|
||||
DsGridPointer& p_ds_grid,
|
||||
EDataType* p_e_grid,
|
||||
void* p_shared,
|
||||
const Problem& problem,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op,
|
||||
EpilogueArgument& epilogue_args)
|
||||
{
|
||||
Run<HasMainKBlockLoop,
|
||||
EGlobalMemoryDataOperation,
|
||||
TailNum,
|
||||
Block2CTileMap,
|
||||
EpilogueArgument>(p_as_grid,
|
||||
p_bs_grid,
|
||||
p_ds_grid,
|
||||
p_e_grid,
|
||||
p_shared,
|
||||
problem,
|
||||
DefaultBlock2CTileMap(problem),
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op,
|
||||
epilogue_args);
|
||||
}
|
||||
|
||||
// Wrapper function to have __global__ function in common
|
||||
// between gemm_universal, b_scale, ab_scale, etc.
|
||||
template <bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
|
||||
TailNumber TailNum,
|
||||
typename EpilogueArgument>
|
||||
typename Block2CTileMap,
|
||||
typename EpilogueArgument,
|
||||
int BlockMapMBlockIndex = 0,
|
||||
int BlockMapNBlockIndex = 1>
|
||||
__device__ static void Run(void* p_shared,
|
||||
const SplitKBatchOffset& splitk_batch_offset,
|
||||
Argument& karg,
|
||||
const Block2CTileMap& block_2_ctile_map,
|
||||
EpilogueArgument& epilogue_args)
|
||||
{
|
||||
// shift A matrices pointer for splitk
|
||||
@@ -659,17 +698,47 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
splitk_batch_offset.b_k_split_offset[i];
|
||||
});
|
||||
|
||||
Run<HasMainKBlockLoop, EGlobalMemoryDataOperation, TailNum>(
|
||||
p_as_grid_splitk,
|
||||
p_bs_grid_splitk,
|
||||
karg.p_ds_grid,
|
||||
karg.p_e_grid + splitk_batch_offset.c_reduce_offset,
|
||||
p_shared,
|
||||
karg,
|
||||
karg.a_element_op,
|
||||
karg.b_element_op,
|
||||
karg.cde_element_op,
|
||||
epilogue_args);
|
||||
Run<HasMainKBlockLoop,
|
||||
EGlobalMemoryDataOperation,
|
||||
TailNum,
|
||||
Block2CTileMap,
|
||||
EpilogueArgument,
|
||||
BlockMapMBlockIndex,
|
||||
BlockMapNBlockIndex>(p_as_grid_splitk,
|
||||
p_bs_grid_splitk,
|
||||
karg.p_ds_grid,
|
||||
karg.p_e_grid + splitk_batch_offset.c_reduce_offset,
|
||||
p_shared,
|
||||
karg,
|
||||
block_2_ctile_map,
|
||||
karg.a_element_op,
|
||||
karg.b_element_op,
|
||||
karg.cde_element_op,
|
||||
epilogue_args);
|
||||
}
|
||||
|
||||
// Wrapper function to have __global__ function in common
|
||||
// between gemm_universal, b_scale, ab_scale, etc.
|
||||
template <bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
|
||||
TailNumber TailNum,
|
||||
typename EpilogueArgument>
|
||||
__device__ static void Run(void* p_shared,
|
||||
const SplitKBatchOffset& splitk_batch_offset,
|
||||
Argument& karg,
|
||||
EpilogueArgument& epilogue_args)
|
||||
{
|
||||
Run<HasMainKBlockLoop,
|
||||
EGlobalMemoryDataOperation,
|
||||
TailNum,
|
||||
Block2CTileMap,
|
||||
EpilogueArgument>(
|
||||
p_shared, splitk_batch_offset, karg, DefaultBlock2CTileMap(karg), epilogue_args);
|
||||
}
|
||||
|
||||
__device__ static auto DefaultBlock2CTileMap(const Problem& problem)
|
||||
{
|
||||
return Block2CTileMap{problem.M, problem.N, 4};
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -729,6 +729,13 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
auto KReadPadSplited = math::integer_divide_ceil(karg.K, K_t) * KReadVec;
|
||||
if((KReadPadSplited * (karg.KBatch - 1)) >= karg.K)
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Arg K value too low for combination of AK1/BK1/KBatch. AK1: "
|
||||
<< AK1Number << ", BK1: " << BK1Number << ", KBatch: " << karg.KBatch
|
||||
<< ", K: " << karg.K << " " << __FILE__ << ":" << __LINE__
|
||||
<< ", in function: " << __func__ << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user