mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
Add s_nops after v_dot to avoid hazard (#808)
* Add s_nops after v_dot to avoid hazard
* Fix builtin for inner_produxt fp16
* Skip inline version to builtin
* Add comments regarding isa
* Fix comment regarding s_nop
[ROCm/composable_kernel commit: 7761e5232c]
This commit is contained in:
@@ -118,8 +118,12 @@
|
||||
// inline asm
|
||||
#define CK_USE_AMD_INLINE_ASM 1
|
||||
|
||||
// inner product (DLOP)
|
||||
#define CK_USE_AMD_INNER_PRODUCT_INLINE_ASM 1
|
||||
// inner product (V_MAC/V_FMAC)
|
||||
#define CK_USE_AMD_V_MAC_INLINE_ASM 1
|
||||
|
||||
// V_DOT inline instructions, less efficient since they require adding
|
||||
// `s_nop`s to avoid hazard
|
||||
#define CK_USE_AMD_V_DOT_INLINE_ASM 0
|
||||
|
||||
// block synchronization only s_wait lgkmcnt(0), not vmcnt(0)
|
||||
#define CK_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM 1
|
||||
|
||||
@@ -70,10 +70,9 @@ __global__ void
|
||||
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
|
||||
const Block2CTileMap block_2_ctile_map)
|
||||
{
|
||||
// TODO: Enable for gfx90a after complier fix
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
|
||||
defined(__gfx940__) || defined(__gfx1030__) || defined(__gfx1100__) || defined(__gfx1101__) || \
|
||||
defined(__gfx1102__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
|
||||
defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx1030__) || defined(__gfx1100__) || \
|
||||
defined(__gfx1101__) || defined(__gfx1102__))
|
||||
|
||||
const index_t num_blocks_per_batch =
|
||||
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
|
||||
@@ -650,10 +649,10 @@ struct DeviceBatchedGemmMultipleD_Dl : public DeviceBatchedGemmMultiD<ALayout,
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
// TODO: Enable for gfx90a after complier fix
|
||||
if(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx908" ||
|
||||
ck::get_device_name() == "gfx1030" || ck::get_device_name() == "gfx940" ||
|
||||
ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" ||
|
||||
ck::get_device_name() == "gfx1102")
|
||||
if(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx90a" ||
|
||||
ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx1030" ||
|
||||
ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx1100" ||
|
||||
ck::get_device_name() == "gfx1101" || ck::get_device_name() == "gfx1102")
|
||||
{
|
||||
bool pass = true;
|
||||
pass = pass && arg.K_ % K1 == 0;
|
||||
|
||||
@@ -13,13 +13,13 @@ __device__ void inner_product(const TA& a, const TB& b, TC& c);
|
||||
template <>
|
||||
__device__ void inner_product<float, float, float>(const float& a, const float& b, float& c)
|
||||
{
|
||||
#if CK_USE_AMD_INNER_PRODUCT_INLINE_ASM && defined(CK_USE_AMD_V_MAC_F32)
|
||||
#if CK_USE_AMD_V_MAC_INLINE_ASM && defined(CK_USE_AMD_V_MAC_F32)
|
||||
asm volatile("\n \
|
||||
v_mac_f32 %0, %1, %2 \n \
|
||||
"
|
||||
: "=v"(c)
|
||||
: "v"(a), "v"(b), "0"(c));
|
||||
#elif CK_USE_AMD_INNER_PRODUCT_INLINE_ASM && defined(CK_USE_AMD_V_FMAC_F32)
|
||||
#elif CK_USE_AMD_V_MAC_INLINE_ASM && defined(CK_USE_AMD_V_FMAC_F32)
|
||||
asm volatile("\n \
|
||||
v_fmac_f32 %0, %1, %2 \n \
|
||||
"
|
||||
@@ -76,14 +76,18 @@ template <>
|
||||
__device__ void inner_product<half2_t, half2_t, float>(const half2_t& a, const half2_t& b, float& c)
|
||||
{
|
||||
#if defined(CK_USE_AMD_V_DOT2_F32_F16)
|
||||
#if CK_USE_AMD_INNER_PRODUCT_INLINE_ASM
|
||||
#if CK_USE_AMD_V_DOT_INLINE_ASM
|
||||
// Use 3 x s_nop to avoid hazard (mi200 cdna2 isa page 47
|
||||
// https://www.amd.com/system/files/TechDocs/instinct-mi200-cdna2-instruction-set-architecture.pdf
|
||||
// ) s_nop with parameter 2 is equal to 3 x s_nop
|
||||
asm volatile("\n \
|
||||
v_dot2_f32_f16 %0, %1, %2, %0\n \
|
||||
s_nop 2 \n \
|
||||
"
|
||||
: "=v"(c)
|
||||
: "v"(a), "v"(b), "0"(c));
|
||||
#else
|
||||
c = __builtin_amdgcn_sdot2(a, b, c, false);
|
||||
c = __builtin_amdgcn_fdot2(a, b, c, false);
|
||||
#endif
|
||||
#else
|
||||
const vector_type<half_t, 2> a_vector{a};
|
||||
@@ -163,9 +167,13 @@ __device__ void
|
||||
inner_product<int8x4_t, int8x4_t, int32_t>(const int8x4_t& a, const int8x4_t& b, int32_t& c)
|
||||
{
|
||||
#if defined(CK_USE_AMD_V_DOT4_I32_I8)
|
||||
#if CK_USE_AMD_INNER_PRODUCT_INLINE_ASM
|
||||
#if CK_USE_AMD_V_DOT_INLINE_ASM
|
||||
// Use 3 x s_nop to avoid hazard (mi200 cdna2 isa page 47
|
||||
// https://www.amd.com/system/files/TechDocs/instinct-mi200-cdna2-instruction-set-architecture.pdf
|
||||
// ) s_nop with parameter 2 is equal to 3 x s_nop
|
||||
asm volatile("\n \
|
||||
v_dot4_i32_i8 %0, %1, %2, %0\n \
|
||||
s_nop 2 \n \
|
||||
"
|
||||
: "=v"(c)
|
||||
: "v"(bit_cast<int32_t>(a)), "v"(bit_cast<int32_t>(b)), "0"(c));
|
||||
|
||||
@@ -3,13 +3,6 @@
|
||||
## GPU visibility
|
||||
export HIP_VISIBLE_DEVICES=0
|
||||
DRIVER="../build/bin/ckProfiler"
|
||||
OP=$1
|
||||
DATATYPE=$2
|
||||
LAYOUT=$3
|
||||
VERIFY=$4
|
||||
INIT=$5
|
||||
LOG=$6
|
||||
TIME=$7
|
||||
|
||||
OP=$1
|
||||
DATATYPE=$2
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
# TODO: Enable for gfx90a after complier fix
|
||||
if(DL_KERNELS)
|
||||
if(NOT GPU_TARGETS MATCHES "gfx90a")
|
||||
add_gtest_executable(test_batched_gemm_multi_d test_batched_gemm_multi_d.cpp)
|
||||
target_link_libraries(test_batched_gemm_multi_d PRIVATE utility device_batched_gemm_multi_d_instance)
|
||||
endif()
|
||||
add_gtest_executable(test_batched_gemm_multi_d test_batched_gemm_multi_d.cpp)
|
||||
target_link_libraries(test_batched_gemm_multi_d PRIVATE utility device_batched_gemm_multi_d_instance)
|
||||
endif()
|
||||
|
||||
Reference in New Issue
Block a user