mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-05 14:11:29 +00:00
Add support for more Navi2x and Navi3x models. (#1152)
* add support for navi2x and navi3x models * fix syntax * use common macro for different mi300 architectures
This commit is contained in:
@@ -67,7 +67,7 @@ __global__ void
|
||||
const Block2CTileMap block_2_ctile_map)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
|
||||
defined(__gfx94__))
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
|
||||
|
||||
@@ -28,8 +28,7 @@ __global__ void
|
||||
#endif
|
||||
kernel_gemm_dpp(const typename GridwiseGemm::Argument karg)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1030__) || defined(__gfx1100__) || \
|
||||
defined(__gfx1101__) || defined(__gfx1102__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx103__) || defined(__gfx11__))
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
const auto a_grid_desc_ak0_m_ak1 = amd_wave_read_first_lane(
|
||||
|
||||
@@ -54,8 +54,7 @@ __global__ void
|
||||
const Block2CTileMap block_2_ctile_map,
|
||||
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__) || defined(__gfx1101__) || \
|
||||
defined(__gfx1102__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
|
||||
// offset base pointer for each work-group
|
||||
const index_t num_blocks_per_batch =
|
||||
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
|
||||
@@ -148,8 +147,7 @@ __global__ void
|
||||
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
|
||||
const Block2CTileMap block_2_etile_map)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__) || defined(__gfx1101__) || \
|
||||
defined(__gfx1102__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
|
||||
// printf("entry kernel launch");
|
||||
__shared__ char p_shared[GridwiseOp::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
@@ -244,8 +242,7 @@ __global__ void
|
||||
const CDEElementwiseOperation cde_element_op,
|
||||
const Block2CTileMap block_2_ctile_map)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__) || defined(__gfx1101__) || \
|
||||
defined(__gfx1102__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
|
||||
__shared__ char p_shared[GridwiseOp::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
GridwiseOp::template Run<HasMainKBlockLoop>(p_a_grid,
|
||||
@@ -274,7 +271,7 @@ __global__ void
|
||||
ignore = b_element_op;
|
||||
ignore = cde_element_op;
|
||||
ignore = block_2_ctile_map;
|
||||
#endif // end of if (defined(__gfx1100__ ))
|
||||
#endif // end of if (defined(__gfx11__ ))
|
||||
}
|
||||
|
||||
template < // DataType Family
|
||||
|
||||
@@ -55,8 +55,7 @@ __global__ void
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const Block2ETileMap block_2_etile_map)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx90a__) || defined(__gfx940__) || \
|
||||
defined(__gfx941__) || defined(__gfx942__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx90a__) || defined(__gfx94__))
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
|
||||
|
||||
@@ -55,7 +55,7 @@ __global__ void
|
||||
const Block2CTileMap block_2_ctile_map)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx940) || defined(__gfx941__) || defined(__gfx942__))
|
||||
defined(__gfx94__))
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
|
||||
|
||||
@@ -49,8 +49,7 @@ __global__ void
|
||||
const CElementwiseOperation c_element_op,
|
||||
const Block2CTileMap block_2_ctile_map)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__) || defined(__gfx1101__) || \
|
||||
defined(__gfx1102__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
|
||||
@@ -75,7 +74,7 @@ __global__ void
|
||||
ignore = b_element_op;
|
||||
ignore = c_element_op;
|
||||
ignore = block_2_ctile_map;
|
||||
#endif // end of if (defined(__gfx1100__))
|
||||
#endif // end of if (defined(__gfx11__))
|
||||
}
|
||||
|
||||
template <index_t BlockSize,
|
||||
|
||||
@@ -25,7 +25,7 @@ __global__ void
|
||||
kernel_gemm_xdl_cshuffle_v1(typename GridwiseGemm::Argument karg)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
|
||||
defined(__gfx94__))
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop>(
|
||||
@@ -50,7 +50,7 @@ __global__ void
|
||||
typename GridwiseGemm::Problem problem)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
|
||||
defined(__gfx94__))
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid, p_b_grid, p_c_grid, p_shared, problem);
|
||||
|
||||
@@ -26,7 +26,7 @@ __global__ void
|
||||
kernel_gemm_xdl_cshuffle_v2(typename GridwiseGemm::Argument karg)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
|
||||
defined(__gfx94__))
|
||||
// Pass two lds pointer is the key to tell compiler that ds_read/write
|
||||
// operate on different lds chunk at same time without order dependecy
|
||||
__shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
@@ -54,7 +54,7 @@ __global__ void
|
||||
typename GridwiseGemm::Problem problem)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
|
||||
defined(__gfx94__))
|
||||
__shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
__shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
|
||||
@@ -58,7 +58,7 @@ __global__ void
|
||||
const Block2CTileMap block_2_ctile_map)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
|
||||
defined(__gfx94__))
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
// TODO ANT: separate into MMA + Epilogue
|
||||
|
||||
@@ -167,7 +167,7 @@ __global__ void
|
||||
const CBlockClusterAdaptor c_block_cluster_adaptor)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
|
||||
defined(__gfx94__))
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
|
||||
|
||||
@@ -45,7 +45,7 @@ __global__ void
|
||||
const Block2CTileMap block_2_ctile_map)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
|
||||
defined(__gfx94__))
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
GridwiseGemm::template Run<HasMainK0BlockLoop>(p_a_grid,
|
||||
|
||||
@@ -38,7 +38,7 @@ __global__ void
|
||||
typename GridwiseGemm::Block2CTileMap block_mapping)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
|
||||
defined(__gfx94__))
|
||||
constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte();
|
||||
|
||||
__shared__ uint8_t p_shared[shared_size];
|
||||
|
||||
@@ -39,7 +39,7 @@ __global__ void
|
||||
const CGridDesc_M_N c_grid_desc_m_n)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
|
||||
defined(__gfx94__))
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
|
||||
@@ -70,7 +70,7 @@ __global__ void
|
||||
kernel_gemm_xdlops_v2r3(const typename GridwiseGemm::Argument karg)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
|
||||
defined(__gfx94__))
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
const auto a_grid_desc_k0_m_k1 =
|
||||
|
||||
@@ -43,7 +43,7 @@ __global__ void
|
||||
const CBlockClusterAdaptor c_block_cluster_adaptor)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
|
||||
defined(__gfx94__))
|
||||
constexpr index_t shared_block_size =
|
||||
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
|
||||
|
||||
|
||||
@@ -37,7 +37,7 @@ __global__ void
|
||||
const CElementwiseOperation c_element_op)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
|
||||
defined(__gfx94__))
|
||||
constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte();
|
||||
|
||||
__shared__ uint8_t p_shared[shared_size];
|
||||
|
||||
@@ -47,7 +47,7 @@ __global__ void
|
||||
const Block2CTileMap block_2_ctile_map)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
|
||||
defined(__gfx94__))
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
GridwiseGemm::template Run<HasMainK0BlockLoop>(
|
||||
|
||||
@@ -50,7 +50,7 @@ __global__ void
|
||||
const Block2CTileMap block_2_ctile_map)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
|
||||
defined(__gfx94__))
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop>(
|
||||
|
||||
@@ -54,7 +54,7 @@ __global__ void
|
||||
const Block2CTileMap block_2_ctile_map)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
|
||||
defined(__gfx94__))
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop>(
|
||||
|
||||
@@ -35,9 +35,8 @@ __global__ void
|
||||
const Block2ETileMap block_2_tile_map,
|
||||
const ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
|
||||
defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx1030__) || defined(__gfx1100__) || \
|
||||
defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx941__) || defined(__gfx942__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
|
||||
defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__))
|
||||
GridwiseTensorRearrangeKernel::Run(in_grid_desc,
|
||||
p_in_global,
|
||||
out_grid_desc,
|
||||
|
||||
Reference in New Issue
Block a user