mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 13:11:25 +00:00
Add support for more Navi2x and Navi3x models. (#1152)
* add support for navi2x and navi3x models * fix syntax * use common macro for different mi300 architectures
This commit is contained in:
@@ -770,8 +770,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" ||
|
||||
ck::get_device_name() == "gfx1102")
|
||||
if(ck::is_navi3_supported())
|
||||
{
|
||||
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
|
||||
{
|
||||
|
||||
@@ -57,7 +57,7 @@ __global__ void
|
||||
const Block2ETileMap block_2_etile_map)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
|
||||
defined(__gfx94__))
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
const index_t num_blocks_per_batch =
|
||||
|
||||
@@ -75,7 +75,7 @@ __global__ void
|
||||
const Block2ETileMap block_2_etile_map)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
|
||||
defined(__gfx94__))
|
||||
const index_t num_blocks_per_batch =
|
||||
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
|
||||
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
|
||||
|
||||
@@ -61,7 +61,7 @@ __global__ void
|
||||
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
|
||||
defined(__gfx94__))
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
const index_t num_blocks_per_batch =
|
||||
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
|
||||
|
||||
@@ -84,7 +84,7 @@ __global__ void
|
||||
{
|
||||
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
|
||||
defined(__gfx94__))
|
||||
const index_t num_blocks_per_batch =
|
||||
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
|
||||
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
|
||||
|
||||
@@ -70,9 +70,8 @@ __global__ void
|
||||
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
|
||||
const Block2CTileMap block_2_ctile_map)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
|
||||
defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx1030__) || defined(__gfx1100__) || \
|
||||
defined(__gfx1101__) || defined(__gfx1102__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
|
||||
defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__))
|
||||
|
||||
const index_t num_blocks_per_batch =
|
||||
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
|
||||
@@ -648,11 +647,8 @@ struct DeviceBatchedGemmMultipleD_Dl : public DeviceBatchedGemmMultiD<ALayout,
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
// TODO: Enable for gfx90a after complier fix
|
||||
if(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx90a" ||
|
||||
ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx1030" ||
|
||||
ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx1100" ||
|
||||
ck::get_device_name() == "gfx1101" || ck::get_device_name() == "gfx1102")
|
||||
if(ck::get_device_name() == "gfx906" || ck::is_xdl_supported() ||
|
||||
ck::is_navi2_supported() || ck::is_navi3_supported())
|
||||
{
|
||||
bool pass = true;
|
||||
pass = pass && arg.K_ % K1 == 0;
|
||||
|
||||
@@ -69,7 +69,7 @@ __global__ void
|
||||
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
|
||||
defined(__gfx94__))
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
const index_t num_blocks_per_batch =
|
||||
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
|
||||
|
||||
@@ -60,7 +60,7 @@ __global__ void
|
||||
const Block2CTileMap block_2_ctile_map)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
|
||||
defined(__gfx94__))
|
||||
const index_t num_blocks_per_batch =
|
||||
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
|
||||
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
|
||||
|
||||
@@ -68,7 +68,7 @@ __global__ void
|
||||
const C0MatrixMask c0_matrix_mask)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
|
||||
defined(__gfx94__))
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
const index_t num_blocks_per_batch =
|
||||
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
|
||||
|
||||
@@ -63,7 +63,7 @@ __global__ void
|
||||
const C0MatrixMask c0_matrix_mask)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
|
||||
defined(__gfx94__))
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
const index_t num_blocks_per_batch =
|
||||
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
|
||||
|
||||
@@ -53,7 +53,7 @@ __global__ void
|
||||
kernel_batched_gemm_xdlops_v2r3(const typename DeviceOp::Argument karg)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
|
||||
defined(__gfx94__))
|
||||
const index_t num_blocks_per_batch =
|
||||
__builtin_amdgcn_readfirstlane(get_grid_size() / karg.Batch);
|
||||
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
|
||||
|
||||
@@ -56,7 +56,7 @@ __global__ void
|
||||
const Block2ETileMap block_2_etile_map)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
|
||||
defined(__gfx94__))
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop>(p_as_grid,
|
||||
|
||||
@@ -54,7 +54,7 @@ __global__ void
|
||||
const Block2ETileMap block_2_etile_map)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
|
||||
defined(__gfx94__))
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
|
||||
|
||||
@@ -56,7 +56,7 @@ __global__ void
|
||||
const Block2CTileMap block_2_ctile_map)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
|
||||
defined(__gfx94__))
|
||||
const index_t num_blocks_per_batch =
|
||||
__builtin_amdgcn_readfirstlane(get_grid_size() / num_batches);
|
||||
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
|
||||
|
||||
@@ -1393,9 +1393,8 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Dl
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
// check device
|
||||
if(!(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx1030" ||
|
||||
ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" ||
|
||||
ck::get_device_name() == "gfx1102"))
|
||||
if(!(ck::get_device_name() == "gfx906" || ck::is_navi2_supported() ||
|
||||
ck::is_navi3_supported()))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -536,9 +536,8 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
|
||||
}
|
||||
}
|
||||
|
||||
if(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx1030" ||
|
||||
ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" ||
|
||||
ck::get_device_name() == "gfx1102")
|
||||
if(ck::get_device_name() == "gfx906" || ck::is_navi2_supported() ||
|
||||
ck::is_navi3_supported())
|
||||
{
|
||||
return GridwiseGemm::CheckValidity(
|
||||
arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_);
|
||||
|
||||
@@ -168,8 +168,7 @@ struct DeviceGemmDpp : public DeviceGemm<ALayout,
|
||||
|
||||
static bool IsSupportedArgument(const Argument& karg)
|
||||
{
|
||||
if(ck::get_device_name() == "gfx1030" || ck::get_device_name() == "gfx1100" ||
|
||||
ck::get_device_name() == "gfx1101" || ck::get_device_name() == "gfx1102")
|
||||
if(ck::is_navi2_supported() || ck::is_navi3_supported())
|
||||
{
|
||||
return GridwiseGemm::CheckValidity(karg);
|
||||
}
|
||||
|
||||
@@ -54,7 +54,7 @@ __global__ void
|
||||
const Block2ETileMap block_2_etile_map)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
|
||||
defined(__gfx94__))
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop>(p_as_grid,
|
||||
|
||||
@@ -50,9 +50,8 @@ __global__ void
|
||||
const CGridDesc_M0_M10_M11_N0_N10_N11 e_grid_desc_m0_m10_m11_n0_n10_n11,
|
||||
const Block2CTileMap block_2_ctile_map)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
|
||||
defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx1030__) || defined(__gfx1100__) || \
|
||||
defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx941__) || defined(__gfx942__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
|
||||
defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__))
|
||||
|
||||
constexpr index_t shared_block_size =
|
||||
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(ABDataType);
|
||||
@@ -552,11 +551,8 @@ struct DeviceGemmMultipleD_Dl : public DeviceGemmMultipleD<ALayout,
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx908" ||
|
||||
ck::get_device_name() == "gfx90a" || ck::get_device_name() == "gfx1030" ||
|
||||
ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx1100" ||
|
||||
ck::get_device_name() == "gfx1101" || ck::get_device_name() == "gfx1102" ||
|
||||
ck::get_device_name() == "gfx941" || ck::get_device_name() == "gfx942")
|
||||
if(ck::get_device_name() == "gfx906" || ck::is_xdl_supported() ||
|
||||
ck::is_navi2_supported() || ck::is_navi3_supported())
|
||||
{
|
||||
return GridwiseGemm::CheckValidity(
|
||||
arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.e_grid_desc_m_n_);
|
||||
|
||||
@@ -64,7 +64,7 @@ __global__ void
|
||||
index_t NRaw)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
|
||||
defined(__gfx94__))
|
||||
__shared__ char p_shared[GridwiseGemmWelford::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
GridwiseGemmWelford::template Run<HasMainKBlockLoop>(
|
||||
|
||||
@@ -61,7 +61,7 @@ __global__ void
|
||||
const Block2ETileMap block_2_etile_map)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
|
||||
defined(__gfx94__))
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
|
||||
|
||||
@@ -484,8 +484,7 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" ||
|
||||
ck::get_device_name() == "gfx1102")
|
||||
if(ck::is_navi3_supported())
|
||||
{
|
||||
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
|
||||
{
|
||||
|
||||
@@ -53,7 +53,7 @@ __global__ void
|
||||
const Block2ETileMap block_2_etile_map)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
|
||||
defined(__gfx94__))
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
|
||||
|
||||
@@ -411,8 +411,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" ||
|
||||
ck::get_device_name() == "gfx1102")
|
||||
if(ck::is_navi3_supported())
|
||||
{
|
||||
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
|
||||
{
|
||||
|
||||
@@ -184,8 +184,7 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else if(ck::get_device_name() == "gfx90a" || ck::get_device_name() == "gfx940" ||
|
||||
ck::get_device_name() == "gfx941" || ck::get_device_name() == "gfx942")
|
||||
else if(ck::is_lds_direct_load_supported())
|
||||
{
|
||||
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, float> ||
|
||||
is_same_v<AccDataType, int32_t> || is_same_v<AccDataType, double>))
|
||||
|
||||
@@ -243,9 +243,7 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout,
|
||||
|
||||
static bool IsSupportedArgument(const Argument& karg)
|
||||
{
|
||||
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
|
||||
ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" ||
|
||||
ck::get_device_name() == "gfx942"))
|
||||
if(!(ck::is_xdl_supported()))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -48,7 +48,7 @@ __global__ void
|
||||
const Block2ETileMap block_2_etile_map)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
|
||||
defined(__gfx94__))
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
|
||||
|
||||
@@ -38,7 +38,7 @@ __global__ void
|
||||
const CDEElementwiseOperation cde_element_op)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
|
||||
defined(__gfx94__))
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
const index_t block_id = get_block_1d_id();
|
||||
|
||||
@@ -627,8 +627,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
// check device
|
||||
if(get_device_name() == "gfx1100" || get_device_name() == "gfx1101" ||
|
||||
ck::get_device_name() == "gfx1102")
|
||||
if(ck::is_navi3_supported())
|
||||
{
|
||||
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
|
||||
{
|
||||
|
||||
@@ -87,7 +87,7 @@ __global__ void
|
||||
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
|
||||
defined(__gfx94__))
|
||||
// offset base pointer for each work-group
|
||||
const index_t num_blocks_per_batch =
|
||||
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
|
||||
|
||||
@@ -48,9 +48,8 @@ __global__ void
|
||||
const Block2CTileMap block_2_ctile_map,
|
||||
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx1030__) || \
|
||||
defined(__gfx90a__) || defined(__gfx908__) || defined(__gfx940__) || defined(__gfx1100__) || \
|
||||
defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx941__) || defined(__gfx942__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx103__) || \
|
||||
defined(__gfx90a__) || defined(__gfx908__) || defined(__gfx94__) || defined(__gfx11__))
|
||||
const index_t num_blocks_per_batch =
|
||||
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
|
||||
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
|
||||
|
||||
@@ -698,8 +698,7 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
// check device
|
||||
if(get_device_name() == "gfx1100" || get_device_name() == "gfx1101" ||
|
||||
get_device_name() == "gfx1102")
|
||||
if(ck::is_navi3_supported())
|
||||
{
|
||||
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
|
||||
{
|
||||
|
||||
@@ -55,7 +55,7 @@ __global__ void
|
||||
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
|
||||
defined(__gfx94__))
|
||||
const index_t num_blocks_per_batch =
|
||||
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
|
||||
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
|
||||
|
||||
@@ -90,9 +90,8 @@ __global__ void
|
||||
const Block2CTileMap block_2_ctile_map,
|
||||
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx1030__) || \
|
||||
defined(__gfx90a__) || defined(__gfx908__) || defined(__gfx940__) || defined(__gfx1100__) || \
|
||||
defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx941__) || defined(__gfx942__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx103__) || \
|
||||
defined(__gfx90a__) || defined(__gfx908__) || defined(__gfx94__) || defined(__gfx11__))
|
||||
// offset base pointer for each work-group
|
||||
const index_t num_blocks_per_batch =
|
||||
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
|
||||
@@ -666,11 +665,8 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
|
||||
namespace ctc = tensor_layout::convolution;
|
||||
|
||||
// check device
|
||||
if(!(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx1030" ||
|
||||
ck::get_device_name() == "gfx90a" || ck::get_device_name() == "gfx908" ||
|
||||
ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx1100" ||
|
||||
ck::get_device_name() == "gfx1101" || ck::get_device_name() == "gfx1102" ||
|
||||
ck::get_device_name() == "gfx941" || ck::get_device_name() == "gfx942"))
|
||||
if(!(ck::get_device_name() == "gfx906" || ck::is_xdl_supported() ||
|
||||
ck::is_navi2_supported() || ck::is_navi3_supported()))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -106,8 +106,8 @@ __global__ void
|
||||
const Block2CTileMap block_2_ctile_map,
|
||||
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx1030__) || \
|
||||
defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx103__) || \
|
||||
defined(__gfx11__))
|
||||
// offset base pointer for each work-group
|
||||
const index_t num_blocks_per_batch =
|
||||
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
|
||||
@@ -601,9 +601,8 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS
|
||||
namespace ctc = tensor_layout::convolution;
|
||||
|
||||
// check device
|
||||
if(!(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx1030" ||
|
||||
ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" ||
|
||||
ck::get_device_name() == "gfx1102"))
|
||||
if(!(ck::get_device_name() == "gfx906" || ck::is_navi2_supported() ||
|
||||
ck::is_navi3_supported()))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -96,7 +96,7 @@ __global__ void
|
||||
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
|
||||
defined(__gfx94__))
|
||||
// offset base pointer for each work-group
|
||||
const index_t num_blocks_per_batch =
|
||||
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
|
||||
@@ -817,8 +817,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else if(get_device_name() == "gfx90a" || get_device_name() == "gfx940" ||
|
||||
get_device_name() == "gfx941" || get_device_name() == "gfx942")
|
||||
else if(ck::is_lds_direct_load_supported())
|
||||
{
|
||||
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, float> ||
|
||||
is_same_v<AccDataType, int32_t> || is_same_v<AccDataType, double>))
|
||||
|
||||
@@ -156,7 +156,7 @@ __global__ void
|
||||
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
|
||||
defined(__gfx94__))
|
||||
const index_t num_blocks_per_batch =
|
||||
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
|
||||
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
|
||||
@@ -813,8 +813,7 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else if(get_device_name() == "gfx90a" || get_device_name() == "gfx940" ||
|
||||
get_device_name() == "gfx941" || get_device_name() == "gfx942")
|
||||
else if(ck::is_lds_direct_load_supported())
|
||||
{
|
||||
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, float> ||
|
||||
is_same_v<AccDataType, int32_t> || is_same_v<AccDataType, double>))
|
||||
|
||||
@@ -531,8 +531,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
|
||||
namespace ctc = tensor_layout::convolution;
|
||||
|
||||
// check device
|
||||
if(get_device_name() == "gfx1100" || get_device_name() == "gfx1101" ||
|
||||
ck::get_device_name() == "gfx1102")
|
||||
if(ck::is_navi3_supported())
|
||||
{
|
||||
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
|
||||
{
|
||||
|
||||
@@ -39,9 +39,8 @@ __global__ void
|
||||
const BElementwiseOperation b_element_op,
|
||||
const CDEElementwiseOperation cde_element_op)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
|
||||
defined(__gfx90a__) || defined(__gfx1030__) || defined(__gfx1100__) || defined(__gfx1101__) || \
|
||||
defined(__gfx1102__) || defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
|
||||
defined(__gfx90a__) || defined(__gfx103__) || defined(__gfx11__) || defined(__gfx94__))
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
const index_t block_id = get_block_1d_id();
|
||||
@@ -668,26 +667,24 @@ struct DeviceGroupedGemmMultipleD_Dl : public DeviceGroupedGemm<ALayout,
|
||||
return false;
|
||||
}
|
||||
|
||||
const std::string device_name = ck::get_device_name();
|
||||
|
||||
// TODO add newer Navi arch
|
||||
if(device_name != "gfx906" and device_name != "gfx908" and device_name != "gfx90a" and
|
||||
device_name != "gfx1030")
|
||||
if(ck::get_device_name() == "gfx906" || ck::is_xdl_supported() ||
|
||||
ck::is_navi2_supported() || ck::is_navi3_supported())
|
||||
{
|
||||
for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); i++)
|
||||
{
|
||||
if(!GridwiseGemm::CheckValidity(arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_,
|
||||
arg.gemm_desc_kernel_arg_[i].b_grid_desc_k0_n_k1_,
|
||||
arg.gemm_desc_kernel_arg_[i].e_grid_desc_m_n_))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); i++)
|
||||
{
|
||||
if(!GridwiseGemm::CheckValidity(arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_,
|
||||
arg.gemm_desc_kernel_arg_[i].b_grid_desc_k0_n_k1_,
|
||||
arg.gemm_desc_kernel_arg_[i].e_grid_desc_m_n_))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
|
||||
@@ -44,7 +44,7 @@ __global__ void
|
||||
const CElementwiseOperation c_element_op)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
|
||||
defined(__gfx94__))
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
const index_t block_id = get_block_1d_id();
|
||||
|
||||
@@ -39,7 +39,7 @@ __global__ void
|
||||
const CDEElementwiseOperation c_element_op)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
|
||||
defined(__gfx94__))
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
const index_t block_id = get_block_1d_id();
|
||||
|
||||
@@ -50,7 +50,7 @@ __global__ void
|
||||
const CDEElementwiseOperation c_element_op)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
|
||||
defined(__gfx94__))
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
const index_t block_id = get_block_1d_id();
|
||||
|
||||
@@ -35,7 +35,7 @@ __global__ void
|
||||
const index_t group_count)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
|
||||
defined(__gfx94__))
|
||||
constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte();
|
||||
__shared__ uint8_t p_shared[shared_size];
|
||||
|
||||
|
||||
@@ -57,7 +57,7 @@ __global__ void
|
||||
const Block2ETileMap block_2_etile_map)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
|
||||
defined(__gfx94__))
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
const index_t num_blocks_per_batch =
|
||||
|
||||
Reference in New Issue
Block a user