mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 13:11:25 +00:00
added ab_elementwise_op support into splitK Gemm (#956)
* add ab_elementwise * fixed ci * fixed a merge issue * fixed pr comments * fixed a conflict * remove 61_example --------- Co-authored-by: Jing Zhang <jizha@amd.com>
This commit is contained in:
@@ -22,13 +22,19 @@ namespace ck {
|
||||
template <typename GridwiseGemm,
|
||||
bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
|
||||
typename Block2CTileMap>
|
||||
typename Block2CTileMap,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_gemm_xdlops_v2r4r2_simplified(typename GridwiseGemm::Argument karg,
|
||||
const Block2CTileMap& b2c_map)
|
||||
const Block2CTileMap& b2c_map,
|
||||
const AElementwiseOperation a_element_op,
|
||||
const BElementwiseOperation b_element_op,
|
||||
const CElementwiseOperation c_element_op)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
|
||||
@@ -37,10 +43,13 @@ __global__ void
|
||||
__shared__ uint8_t p_shared[shared_size];
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation>(
|
||||
karg, static_cast<void*>(p_shared), b2c_map);
|
||||
karg, static_cast<void*>(p_shared), b2c_map, a_element_op, b_element_op, c_element_op);
|
||||
#else
|
||||
ignore = karg;
|
||||
ignore = b2c_map;
|
||||
ignore = a_element_op;
|
||||
ignore = b_element_op;
|
||||
ignore = c_element_op;
|
||||
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
|
||||
}
|
||||
|
||||
@@ -577,7 +586,10 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
|
||||
typename Block2CTileMap>
|
||||
__device__ static void Run(const Argument& karg,
|
||||
void* __restrict__ p_shared_block,
|
||||
const Block2CTileMap& block_2_ctile_map)
|
||||
const Block2CTileMap& block_2_ctile_map,
|
||||
const AElementwiseOperation a_element_op = AElementwiseOperation{},
|
||||
const BElementwiseOperation b_element_op = BElementwiseOperation{},
|
||||
const CElementwiseOperation c_element_op = CElementwiseOperation{})
|
||||
{
|
||||
const FloatA* p_a_grid = karg.p_a_grid;
|
||||
const FloatB* p_b_grid = karg.p_b_grid;
|
||||
@@ -590,9 +602,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
|
||||
|
||||
const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
|
||||
MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n);
|
||||
const AElementwiseOperation a_element_op = AElementwiseOperation{};
|
||||
const BElementwiseOperation b_element_op = BElementwiseOperation{};
|
||||
const CElementwiseOperation c_element_op = CElementwiseOperation{};
|
||||
|
||||
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_a_grid, a_b_k0_m_k1_grid_desc.GetElementSpaceSize());
|
||||
@@ -761,8 +770,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
|
||||
|
||||
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
|
||||
BlockSize,
|
||||
ComputeType,
|
||||
ComputeType,
|
||||
ComputeType, // ComputeType A
|
||||
ComputeType, // ComputeType B
|
||||
FloatAcc,
|
||||
decltype(a_k0_m_k1_block_desc),
|
||||
decltype(b_k0_n_k1_block_desc),
|
||||
|
||||
Reference in New Issue
Block a user