mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 09:16:52 +00:00
Merge pull request #5 from ROCmSoftwarePlatform/mi300
Add support for gfx940 targets.
This commit is contained in:
@@ -1,7 +1,9 @@
|
||||
add_example_executable(example_batched_gemm_gemm_xdl_fp32 batched_gemm_gemm_xdl_fp32.cpp)
|
||||
add_example_executable(example_batched_gemm_gemm_xdl_fp16 batched_gemm_gemm_xdl_fp16.cpp)
|
||||
add_example_executable(example_batched_gemm_gemm_xdl_bf16 batched_gemm_gemm_xdl_bf16.cpp)
|
||||
add_example_executable(example_batched_gemm_gemm_xdl_int8 batched_gemm_gemm_xdl_int8.cpp)
|
||||
if(NOT GPU_TARGETS MATCHES "gfx940")
|
||||
add_example_executable(example_batched_gemm_gemm_xdl_int8 batched_gemm_gemm_xdl_int8.cpp)
|
||||
endif()
|
||||
|
||||
if(USE_BITINT_EXTENSION_INT4)
|
||||
add_example_executable(example_batched_gemm_gemm_xdl_int4 batched_gemm_gemm_xdl_int4.cpp)
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
add_example_executable(example_grouped_conv_conv_fwd_xdl_fp32 grouped_conv_conv_fwd_xdl_fp32.cpp)
|
||||
add_example_executable(example_grouped_conv_conv_fwd_xdl_fp16 grouped_conv_conv_fwd_xdl_fp16.cpp)
|
||||
add_example_executable(example_grouped_conv_conv_fwd_xdl_bf16 grouped_conv_conv_fwd_xdl_bf16.cpp)
|
||||
add_example_executable(example_grouped_conv_conv_fwd_xdl_int8 grouped_conv_conv_fwd_xdl_int8.cpp)
|
||||
|
||||
if(NOT GPU_TARGETS MATCHES "gfx940")
|
||||
add_example_executable(example_grouped_conv_conv_fwd_xdl_int8 grouped_conv_conv_fwd_xdl_int8.cpp)
|
||||
endif()
|
||||
if(USE_BITINT_EXTENSION_INT4)
|
||||
add_example_executable(example_grouped_conv_conv_fwd_xdl_int4 grouped_conv_conv_fwd_xdl_int4.cpp)
|
||||
endif(USE_BITINT_EXTENSION_INT4)
|
||||
|
||||
@@ -30,7 +30,7 @@
|
||||
// check GPU target
|
||||
#ifdef __HIP_DEVICE_COMPILE__
|
||||
#if !(defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__) || defined(__gfx908__) || \
|
||||
defined(__gfx90a__) || defined(__gfx1030__) || defined(__gfx1100__))
|
||||
defined(__gfx90a__) || defined(__gfx1030__) || defined(__gfx1100__) || defined(__gfx940__))
|
||||
#error Not supported target
|
||||
#endif
|
||||
#endif
|
||||
@@ -39,7 +39,7 @@
|
||||
#ifndef __HIP_DEVICE_COMPILE__ // for host code
|
||||
#define CK_BUFFER_RESOURCE_3RD_DWORD -1
|
||||
#elif defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__) || defined(__gfx908__) || \
|
||||
defined(__gfx90a__) // for GPU code
|
||||
defined(__gfx90a__) || defined(__gfx940__) // for GPU code
|
||||
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x00020000
|
||||
#elif defined(__gfx1030__) // for GPU code
|
||||
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x31014000
|
||||
@@ -51,8 +51,8 @@
|
||||
#ifndef __HIP_DEVICE_COMPILE__ // for host code, define nothing
|
||||
#elif defined(__gfx803__) || defined(__gfx900__) // for GPU code
|
||||
#define CK_USE_AMD_V_MAC_F32
|
||||
#elif defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx1030__) // for GPU code
|
||||
#elif defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx1030__) || \
|
||||
defined(__gfx940__) // for GPU code
|
||||
#define CK_USE_AMD_V_FMAC_F32
|
||||
#define CK_USE_AMD_V_DOT2_F32_F16
|
||||
#define CK_USE_AMD_V_DOT4_I32_I8
|
||||
@@ -61,14 +61,18 @@
|
||||
// MFMA instruction
|
||||
#ifndef __HIP_DEVICE_COMPILE__ // for host code
|
||||
#define CK_USE_AMD_MFMA
|
||||
#elif defined(__gfx908__) || defined(__gfx90a__) // for GPU code
|
||||
#elif defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) // for GPU code
|
||||
#define CK_USE_AMD_MFMA
|
||||
#endif
|
||||
|
||||
#if defined(__gfx90a__)
|
||||
#if(defined(__gfx90a__) || defined(__gfx940__))
|
||||
#define CK_USE_AMD_MFMA_BF16_1K_OP
|
||||
#endif
|
||||
|
||||
#if defined(__gfx940__)
|
||||
#define CK_USE_AMD_MFMA_GFX940
|
||||
#endif
|
||||
|
||||
// WMMA instruction
|
||||
#ifndef __HIP_DEVICE_COMPILE__ // for host code
|
||||
#define CK_USE_AMD_WMMA
|
||||
@@ -88,13 +92,13 @@
|
||||
// buffer atomic add: floating point
|
||||
#ifndef __HIP_DEVICE_COMPILE__ // for host code
|
||||
#define CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1
|
||||
#elif defined(__gfx908__) || defined(__gfx90a__) // for GPU code
|
||||
#elif defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) // for GPU code
|
||||
#define CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1
|
||||
#else // for GPU code
|
||||
#define CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 0
|
||||
#endif
|
||||
|
||||
#if defined(__gfx90a__) // for GPU code
|
||||
#if(defined(__gfx90a__) || defined(__gfx940__)) // for GPU code
|
||||
#define CK_USE_AMD_BUFFER_ATOMIC_MAX_FLOAT64 1
|
||||
#else
|
||||
#define CK_USE_AMD_BUFFER_ATOMIC_MAX_FLOAT64 0
|
||||
|
||||
@@ -47,7 +47,8 @@ __global__ void
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const Block2ETileMap block_2_etile_map)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx940__))
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
|
||||
@@ -416,7 +417,8 @@ struct DeviceGemm_Xdl_WaveletModel_CShuffle : public DeviceGemm<ALayout,
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a"))
|
||||
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
|
||||
ck::get_device_name() == "gfx940"))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -43,7 +43,8 @@ __global__ void
|
||||
const B1ElementwiseOperation b1_element_op,
|
||||
const CElementwiseOperation c_element_op)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx940__))
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
const index_t block_id = get_block_1d_id();
|
||||
@@ -678,7 +679,8 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a"))
|
||||
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
|
||||
ck::get_device_name() == "gfx940"))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -56,7 +56,8 @@ __global__ void
|
||||
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
|
||||
const Block2ETileMap block_2_etile_map)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx940__))
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
const index_t num_blocks_per_batch =
|
||||
@@ -938,7 +939,8 @@ struct DeviceSplitKContractionMultipleD_Xdl_CShuffle
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a"))
|
||||
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
|
||||
ck::get_device_name() == "gfx940"))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -56,7 +56,8 @@ __global__ void
|
||||
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
|
||||
const Block2ETileMap block_2_etile_map)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx940__))
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
const index_t num_blocks_per_batch =
|
||||
@@ -839,7 +840,8 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a"))
|
||||
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
|
||||
ck::get_device_name() == "gfx940"))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -74,7 +74,8 @@ __global__ void
|
||||
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
|
||||
const Block2ETileMap block_2_etile_map)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx940__))
|
||||
const index_t num_blocks_per_batch =
|
||||
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
|
||||
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
|
||||
|
||||
@@ -60,7 +60,8 @@ __global__ void
|
||||
const index_t batch_count,
|
||||
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx940__))
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
const index_t num_blocks_per_batch =
|
||||
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
|
||||
@@ -588,7 +589,8 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a"))
|
||||
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
|
||||
ck::get_device_name() == "gfx940"))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -83,7 +83,8 @@ __global__ void
|
||||
const Block2ETileMap block_2_etile_map)
|
||||
{
|
||||
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx940__))
|
||||
const index_t num_blocks_per_batch =
|
||||
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
|
||||
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
|
||||
@@ -579,7 +580,8 @@ struct DeviceBatchedGemmMultiD_Xdl : public DeviceBatchedGemmMultiD<ALayout,
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a"))
|
||||
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
|
||||
ck::get_device_name() == "gfx940"))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -68,7 +68,8 @@ __global__ void
|
||||
const index_t batch_count,
|
||||
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx940__))
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
const index_t num_blocks_per_batch =
|
||||
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
|
||||
@@ -790,7 +791,8 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a"))
|
||||
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
|
||||
ck::get_device_name() == "gfx940"))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -59,7 +59,8 @@ __global__ void
|
||||
const ComputeBasePrtOfBatch compute_base_ptr_of_batch_,
|
||||
const Block2CTileMap block_2_ctile_map)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx940__))
|
||||
const index_t num_blocks_per_batch =
|
||||
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
|
||||
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
|
||||
|
||||
@@ -67,7 +67,8 @@ __global__ void
|
||||
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch,
|
||||
const C0MatrixMask c0_matrix_mask)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx940__))
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
const index_t num_blocks_per_batch =
|
||||
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
|
||||
@@ -714,7 +715,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
arg.Print();
|
||||
#endif
|
||||
|
||||
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a"))
|
||||
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
|
||||
ck::get_device_name() == "gfx940"))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -62,7 +62,8 @@ __global__ void
|
||||
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch,
|
||||
const C0MatrixMask c0_matrix_mask)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx940__))
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
const index_t num_blocks_per_batch =
|
||||
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
|
||||
@@ -612,7 +613,8 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a"))
|
||||
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
|
||||
ck::get_device_name() == "gfx940"))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -75,7 +75,8 @@ __global__ void
|
||||
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
|
||||
const Block2CTileMap block_2_ctile_map)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx940__))
|
||||
const index_t num_blocks_per_batch =
|
||||
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
|
||||
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
|
||||
|
||||
@@ -52,7 +52,8 @@ __global__ void
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const Block2ETileMap block_2_etile_map)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx940__))
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
|
||||
@@ -581,7 +582,8 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a"))
|
||||
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
|
||||
ck::get_device_name() == "gfx940"))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -55,7 +55,8 @@ __global__ void
|
||||
const CElementwiseOperation c_element_op,
|
||||
const Block2CTileMap block_2_ctile_map)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx940__))
|
||||
const index_t num_blocks_per_batch =
|
||||
__builtin_amdgcn_readfirstlane(get_grid_size() / num_batches);
|
||||
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
|
||||
|
||||
@@ -51,7 +51,8 @@ __global__ void
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const Block2ETileMap block_2_etile_map)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx940__))
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
|
||||
@@ -456,7 +457,8 @@ struct DeviceGemmBiasEPermute_Xdl : public DeviceGemmBiasCPermute<AElementwiseOp
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a"))
|
||||
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
|
||||
ck::get_device_name() == "gfx940"))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -63,7 +63,8 @@ __global__ void
|
||||
const Block2ETileMap block_2_etile_map,
|
||||
index_t NRaw)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx940__))
|
||||
__shared__ char p_shared[GridwiseGemmWelford::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
GridwiseGemmWelford::template Run<HasMainKBlockLoop>(
|
||||
@@ -854,7 +855,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a"))
|
||||
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
|
||||
ck::get_device_name() == "gfx940"))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -60,7 +60,8 @@ __global__ void
|
||||
const RsGridDescriptor_MBlock_MPerBlock rs_grid_desc_mblock_mperblock,
|
||||
const Block2ETileMap block_2_etile_map)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx940__))
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
|
||||
@@ -554,7 +555,8 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a"))
|
||||
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
|
||||
ck::get_device_name() == "gfx940"))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -51,7 +51,8 @@ __global__ void
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const Block2ETileMap block_2_etile_map)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx940__))
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
|
||||
@@ -490,7 +491,8 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a"))
|
||||
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
|
||||
ck::get_device_name() == "gfx940"))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -432,7 +432,7 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else if(ck::get_device_name() == "gfx90a")
|
||||
else if(ck::get_device_name() == "gfx90a" || ck::get_device_name() == "gfx940")
|
||||
{
|
||||
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, float> ||
|
||||
is_same_v<AccDataType, int32_t> || is_same_v<AccDataType, double>))
|
||||
|
||||
@@ -574,7 +574,8 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a"))
|
||||
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
|
||||
ck::get_device_name() == "gfx940"))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -648,7 +648,8 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a"))
|
||||
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
|
||||
ck::get_device_name() == "gfx940"))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -37,7 +37,8 @@ __global__ void
|
||||
const BElementwiseOperation b_element_op,
|
||||
const CDEElementwiseOperation cde_element_op)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx940__))
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
const index_t block_id = get_block_1d_id();
|
||||
@@ -703,7 +704,8 @@ struct DeviceGroupedContractionMultipleD_Xdl_CShuffle
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a"))
|
||||
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
|
||||
ck::get_device_name() == "gfx940"))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -130,7 +130,8 @@ __global__ void
|
||||
const Block2ETileMap block_2_ctile_map,
|
||||
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx940__))
|
||||
// offset base pointer for each work-group
|
||||
const index_t num_blocks_per_batch =
|
||||
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
|
||||
|
||||
@@ -78,7 +78,8 @@ __global__ void
|
||||
const Block2CTileMap block_2_ctile_map,
|
||||
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx940__))
|
||||
const index_t num_blocks_per_batch =
|
||||
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
|
||||
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
|
||||
|
||||
@@ -155,7 +155,8 @@ __global__ void
|
||||
const Block2ETileMap block_2_ctile_map,
|
||||
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx940__))
|
||||
const index_t num_blocks_per_batch =
|
||||
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
|
||||
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
|
||||
@@ -810,7 +811,7 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else if(get_device_name() == "gfx90a")
|
||||
else if(get_device_name() == "gfx90a" || get_device_name() == "gfx940")
|
||||
{
|
||||
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, float> ||
|
||||
is_same_v<AccDataType, int32_t> || is_same_v<AccDataType, double>))
|
||||
|
||||
@@ -135,7 +135,8 @@ __global__ void
|
||||
const Block2ETileMap block_2_ctile_map,
|
||||
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx940__))
|
||||
// offset base pointer for each work-group
|
||||
const index_t num_blocks_per_batch =
|
||||
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
|
||||
@@ -684,7 +685,7 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else if(get_device_name() == "gfx90a")
|
||||
else if(get_device_name() == "gfx90a" || get_device_name() == "gfx940")
|
||||
{
|
||||
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, float> ||
|
||||
is_same_v<AccDataType, int32_t> || is_same_v<AccDataType, double>))
|
||||
|
||||
@@ -38,7 +38,8 @@ __global__ void
|
||||
const BElementwiseOperation b_element_op,
|
||||
const CDEElementwiseOperation c_element_op)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx940__))
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
const index_t block_id = get_block_1d_id();
|
||||
|
||||
@@ -66,7 +66,8 @@ __global__ void
|
||||
const ReduceGridDescriptor_MBlock_MPerBlock reduce_grid_desc_mblock_mperblock,
|
||||
const Block2CTileMap block_2_ctile_map)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx940__))
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
|
||||
|
||||
@@ -54,7 +54,8 @@ __global__ void
|
||||
const ReduceGridDescriptor_MBlock_MPerBlock reduce_grid_desc_mblock_mperblock,
|
||||
const Block2CTileMap block_2_ctile_map)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx940))
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
|
||||
|
||||
@@ -44,7 +44,8 @@ __global__ void
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const Block2CTileMap block_2_ctile_map)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx940__))
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
|
||||
|
||||
@@ -57,7 +57,8 @@ __global__ void
|
||||
const C0GridDescriptor_NBlock_NPerBlock c0_grid_desc_nblock_nperblock,
|
||||
const Block2CTileMap block_2_ctile_map)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx940__))
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
// TODO ANT: separate into MMA + Epilogue
|
||||
|
||||
@@ -165,7 +165,8 @@ __global__ void
|
||||
const CElementwiseOperation c_element_op,
|
||||
const CBlockClusterAdaptor c_block_cluster_adaptor)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx940__))
|
||||
constexpr index_t shared_block_size =
|
||||
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
|
||||
|
||||
|
||||
@@ -44,7 +44,8 @@ __global__ void
|
||||
const CElementwiseOperation c_element_op,
|
||||
const Block2CTileMap block_2_ctile_map)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx940__))
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
GridwiseGemm::template Run<HasMainK0BlockLoop>(p_a_grid,
|
||||
|
||||
@@ -43,7 +43,8 @@ __global__ void
|
||||
const CElementwiseOperation c_element_op,
|
||||
const Block2CTileMap block_2_ctile_map)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx940__))
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
|
||||
|
||||
@@ -42,7 +42,8 @@ __global__ void
|
||||
const CElementwiseOperation c_element_op,
|
||||
const CBlockClusterAdaptor c_block_cluster_adaptor)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx940__))
|
||||
constexpr index_t shared_block_size =
|
||||
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
|
||||
|
||||
|
||||
@@ -44,7 +44,8 @@ __global__ void
|
||||
const CElementwiseOperation c_element_op,
|
||||
const CBlockClusterAdaptor c_block_cluster_adaptor)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx940__))
|
||||
constexpr index_t shared_block_size =
|
||||
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
|
||||
|
||||
|
||||
@@ -46,7 +46,8 @@ __global__ void
|
||||
const CElementwiseOperation c_element_op,
|
||||
const Block2CTileMap block_2_ctile_map)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx940__))
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
GridwiseGemm::template Run<HasMainK0BlockLoop>(
|
||||
|
||||
@@ -49,7 +49,8 @@ __global__ void
|
||||
const CElementwiseOperation c_element_op,
|
||||
const Block2CTileMap block_2_ctile_map)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx940__))
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop>(
|
||||
|
||||
@@ -53,7 +53,8 @@ __global__ void
|
||||
const CElementwiseOperation c_element_op,
|
||||
const Block2CTileMap block_2_ctile_map)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx940__))
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop>(
|
||||
|
||||
@@ -27,6 +27,8 @@ enum struct MfmaInstr
|
||||
mfma_f32_16x16x8bf16,
|
||||
mfma_i32_32x32x8i8,
|
||||
mfma_i32_16x16x16i8,
|
||||
mfma_i32_32x32x16i8,
|
||||
mfma_i32_16x16x32i8,
|
||||
mfma_f64_16x16x4f64
|
||||
};
|
||||
|
||||
@@ -386,6 +388,50 @@ struct mfma_type<MfmaInstr::mfma_i32_16x16x16i8>
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct mfma_type<MfmaInstr::mfma_i32_32x32x16i8>
|
||||
{
|
||||
static constexpr index_t group_size = 4;
|
||||
static constexpr index_t num_groups_per_blk = 4;
|
||||
static constexpr index_t num_regs_per_blk = 16;
|
||||
static constexpr index_t num_threads_per_blk = 32;
|
||||
static constexpr index_t wave_size = 64;
|
||||
static constexpr index_t num_input_blks = 2;
|
||||
static constexpr index_t num_output_blks = 1;
|
||||
static constexpr index_t m_per_blk = 32;
|
||||
static constexpr index_t n_per_blk = 32;
|
||||
static constexpr index_t k_per_blk = 8;
|
||||
static constexpr bool is_k_reduction = true;
|
||||
|
||||
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
|
||||
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
|
||||
{
|
||||
intrin_mfma_i32_32x32x16i8<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct mfma_type<MfmaInstr::mfma_i32_16x16x32i8>
|
||||
{
|
||||
static constexpr index_t group_size = 4;
|
||||
static constexpr index_t num_groups_per_blk = 1;
|
||||
static constexpr index_t num_regs_per_blk = 4;
|
||||
static constexpr index_t num_threads_per_blk = 16;
|
||||
static constexpr index_t wave_size = 64;
|
||||
static constexpr index_t num_input_blks = 4;
|
||||
static constexpr index_t num_output_blks = 1;
|
||||
static constexpr index_t m_per_blk = 16;
|
||||
static constexpr index_t n_per_blk = 16;
|
||||
static constexpr index_t k_per_blk = 8;
|
||||
static constexpr bool is_k_reduction = true;
|
||||
|
||||
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
|
||||
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
|
||||
{
|
||||
intrin_mfma_i32_16x16x32i8<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct mfma_type<MfmaInstr::mfma_f64_16x16x4f64>
|
||||
{
|
||||
@@ -524,17 +570,29 @@ struct MfmaSelector
|
||||
#endif
|
||||
}
|
||||
|
||||
#if defined(CK_USE_AMD_MFMA_GFX940)
|
||||
template <>
|
||||
static constexpr auto GetMfma<int8_t, 32, 32>()
|
||||
{
|
||||
return MfmaInstr::mfma_i32_32x32x16i8;
|
||||
}
|
||||
template <>
|
||||
static constexpr auto GetMfma<int8_t, 16, 16>()
|
||||
{
|
||||
return MfmaInstr::mfma_i32_16x16x32i8;
|
||||
}
|
||||
#else
|
||||
template <>
|
||||
static constexpr auto GetMfma<int8_t, 32, 32>()
|
||||
{
|
||||
return MfmaInstr::mfma_i32_32x32x8i8;
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetMfma<int8_t, 16, 16>()
|
||||
{
|
||||
return MfmaInstr::mfma_i32_16x16x16i8;
|
||||
}
|
||||
#endif
|
||||
|
||||
static constexpr auto selected_mfma = mfma_type<GetMfma<base_type, MPerXdlops, NPerXdlops>()>{};
|
||||
|
||||
|
||||
@@ -297,6 +297,44 @@ struct intrin_mfma_i32_16x16x16i8<16, 16>
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t MPerWave, index_t NPerWave>
|
||||
struct intrin_mfma_i32_32x32x16i8;
|
||||
|
||||
template <>
|
||||
struct intrin_mfma_i32_32x32x16i8<32, 32>
|
||||
{
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const int8x8_t& reg_a, const int8x8_t& reg_b, FloatC& reg_c)
|
||||
{
|
||||
reg_c.template AsType<int32x16_t>()(Number<0>{}) =
|
||||
__builtin_amdgcn_mfma_i32_32x32x16_i8(bit_cast<int64_t>(reg_a),
|
||||
bit_cast<int64_t>(reg_b),
|
||||
reg_c.template AsType<int32x16_t>()[Number<0>{}],
|
||||
0,
|
||||
0,
|
||||
0);
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t MPerWave, index_t NPerWave>
|
||||
struct intrin_mfma_i32_16x16x32i8;
|
||||
|
||||
template <>
|
||||
struct intrin_mfma_i32_16x16x32i8<16, 16>
|
||||
{
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const int8x8_t& reg_a, const int8x8_t& reg_b, FloatC& reg_c)
|
||||
{
|
||||
reg_c.template AsType<int32x4_t>()(Number<0>{}) =
|
||||
__builtin_amdgcn_mfma_i32_16x16x32i8(bit_cast<int64_t>(reg_a),
|
||||
bit_cast<int64_t>(reg_b),
|
||||
reg_c.template AsType<int32x4_t>()[Number<0>{}],
|
||||
0,
|
||||
0,
|
||||
0);
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t MPerWave, index_t NPerWave>
|
||||
struct intrin_mfma_f64_16x16x4f64;
|
||||
|
||||
@@ -306,7 +344,7 @@ struct intrin_mfma_f64_16x16x4f64<16, 16>
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const double& reg_a, const double& reg_b, FloatC& reg_c)
|
||||
{
|
||||
#ifdef __gfx90a__
|
||||
#if defined(__gfx90a__) || defined(__gfx940__)
|
||||
reg_c.template AsType<double4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f64_16x16x4f64(
|
||||
reg_a, reg_b, reg_c.template AsType<double4_t>()[Number<0>{}], 0, 0, 0);
|
||||
#else
|
||||
|
||||
@@ -898,6 +898,8 @@ struct vector_type<T, 256>
|
||||
}
|
||||
};
|
||||
|
||||
using int64_t = long;
|
||||
|
||||
// fp64
|
||||
using double2_t = typename vector_type<double, 2>::type;
|
||||
using double4_t = typename vector_type<double, 4>::type;
|
||||
|
||||
@@ -10,10 +10,9 @@ cmake
|
||||
-D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
|
||||
-D CMAKE_CXX_FLAGS="-O3 -ftemplate-backtrace-limit=0 -gline-tables-only -save-temps=$PWD" \
|
||||
-D CMAKE_BUILD_TYPE=Release \
|
||||
-D BUILD_DEV=OFF \
|
||||
-D GPU_TARGETS="gfx90a" \
|
||||
-D BUILD_DEV=ON \
|
||||
-D GPU_TARGETS="gfx908;gfx90a;gfx940" \
|
||||
-D CMAKE_VERBOSE_MAKEFILE:BOOL=ON \
|
||||
-D USE_BITINT_EXTENSION_INT4=OFF \
|
||||
${MY_PROJECT_SOURCE}
|
||||
|
||||
#-D AMDGPU_TARGETS=gfx90a;gfx908
|
||||
|
||||
@@ -11,9 +11,8 @@ cmake
|
||||
-D CMAKE_CXX_FLAGS="-O3" \
|
||||
-D CMAKE_BUILD_TYPE=Release \
|
||||
-D BUILD_DEV=OFF \
|
||||
-D GPU_TARGETS="gfx908;gfx90a" \
|
||||
-D GPU_TARGETS="gfx908;gfx90a;gfx940" \
|
||||
-D CMAKE_VERBOSE_MAKEFILE:BOOL=ON \
|
||||
-D USE_BITINT_EXTENSION_INT4=OFF \
|
||||
${MY_PROJECT_SOURCE}
|
||||
|
||||
#-D AMDGPU_TARGETS=gfx90a;gfx908
|
||||
|
||||
Reference in New Issue
Block a user