Code clean-up (#1285)

* code clean-up

* remove the profiling output samples

[ROCm/composable_kernel commit: 566b6480a2]
This commit is contained in:
Illia Silin
2024-05-10 09:41:39 -07:00
committed by GitHub
parent 866c3302ad
commit a90f0099fc
38 changed files with 57 additions and 259 deletions

View File

@@ -236,7 +236,7 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING)
#ifndef CK_WORKAROUND_DENORM_FIX
#define CK_WORKAROUND_DENORM_FIX 0
#else
// enable only on MI200
// enable only for gfx90a
#define CK_WORKAROUND_DENORM_FIX = CK_WORKAROUND_DENORM_FIX && defined(__gfx90a__)
#endif // CK_WORKAROUND_DENORM_FIX

View File

@@ -65,20 +65,20 @@ inline bool is_lds_direct_load_supported()
ck::get_device_name() == "gfx941" || ck::get_device_name() == "gfx942";
}
inline bool is_navi1_supported()
inline bool is_gfx101_supported()
{
return ck::get_device_name() == "gfx1010" || ck::get_device_name() == "gfx1011" ||
ck::get_device_name() == "gfx1012";
}
inline bool is_navi2_supported()
inline bool is_gfx103_supported()
{
return ck::get_device_name() == "gfx1030" || ck::get_device_name() == "gfx1031" ||
ck::get_device_name() == "gfx1032" || ck::get_device_name() == "gfx1034" ||
ck::get_device_name() == "gfx1035" || ck::get_device_name() == "gfx1036";
}
inline bool is_navi3_supported()
inline bool is_gfx11_supported()
{
return ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" ||
ck::get_device_name() == "gfx1102" || ck::get_device_name() == "gfx1103";

View File

@@ -829,7 +829,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
static bool IsSupportedArgument(const Argument& arg)
{
if(ck::is_navi3_supported())
if(ck::is_gfx11_supported())
{
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
{

View File

@@ -648,7 +648,7 @@ struct DeviceBatchedGemmMultipleD_Dl : public DeviceBatchedGemmMultiD<ALayout,
static bool IsSupportedArgument(const Argument& arg)
{
if(ck::get_device_name() == "gfx906" || ck::is_xdl_supported() ||
ck::is_navi2_supported() || ck::is_navi3_supported())
ck::is_gfx103_supported() || ck::is_gfx11_supported())
{
bool pass = true;
pass = pass && arg.K_ % K1 == 0;

View File

@@ -858,7 +858,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
static bool IsSupportedArgument(const RawArg& arg)
{
if(ck::is_navi3_supported())
if(ck::is_gfx11_supported())
{
if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>))
{
@@ -1435,7 +1435,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
#if 0
static bool IsSupportedArgument(const Argument& arg)
{
if(ck::is_navi3_supported())
if(ck::is_gfx11_supported())
{
if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>))
{

View File

@@ -1392,8 +1392,8 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Dl
static bool IsSupportedArgument(const Argument& arg)
{
// check device
if(!(ck::get_device_name() == "gfx906" || ck::is_navi2_supported() ||
ck::is_navi3_supported()))
if(!(ck::get_device_name() == "gfx906" || ck::is_gfx103_supported() ||
ck::is_gfx11_supported()))
{
return false;
}

View File

@@ -509,7 +509,7 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
static bool IsSupportedArgument(const Argument& arg)
{
if(ck::is_navi3_supported())
if(ck::is_gfx11_supported())
{
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, ck::half_t> ||
is_same_v<AccDataType, int32_t>))

View File

@@ -535,8 +535,8 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
}
}
if(ck::get_device_name() == "gfx906" || ck::is_navi2_supported() ||
ck::is_navi3_supported())
if(ck::get_device_name() == "gfx906" || ck::is_gfx103_supported() ||
ck::is_gfx11_supported())
{
return GridwiseGemm::CheckValidity(
arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_);

View File

@@ -168,7 +168,7 @@ struct DeviceGemmDpp : public DeviceGemm<ALayout,
static bool IsSupportedArgument(const Argument& karg)
{
if(ck::is_navi2_supported() || ck::is_navi3_supported())
if(ck::is_gfx103_supported() || ck::is_gfx11_supported())
{
return GridwiseGemm::CheckValidity(karg);
}

View File

@@ -552,7 +552,7 @@ struct DeviceGemmMultipleD_Dl : public DeviceGemmMultipleD<ALayout,
static bool IsSupportedArgument(const Argument& arg)
{
if(ck::get_device_name() == "gfx906" || ck::is_xdl_supported() ||
ck::is_navi2_supported() || ck::is_navi3_supported())
ck::is_gfx103_supported() || ck::is_gfx11_supported())
{
return GridwiseGemm::CheckValidity(
arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.e_grid_desc_m_n_);

View File

@@ -515,7 +515,7 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
static bool IsSupportedArgument(const Argument& arg)
{
if(ck::is_navi3_supported())
if(ck::is_gfx11_supported())
{
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
{

View File

@@ -443,7 +443,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
static bool IsSupportedArgument(const Argument& arg)
{
if(ck::is_navi3_supported())
if(ck::is_gfx11_supported())
{
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, ck::half_t> ||
is_same_v<AccDataType, int32_t>))

View File

@@ -629,7 +629,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
static bool IsSupportedArgument(const Argument& arg)
{
// check device
if(ck::is_navi3_supported())
if(ck::is_gfx11_supported())
{
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
{

View File

@@ -692,7 +692,7 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle
static bool IsSupportedArgument(const Argument& arg)
{
// check device
if(ck::is_navi3_supported())
if(ck::is_gfx11_supported())
{
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
{

View File

@@ -666,7 +666,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
// check device
if(!(ck::get_device_name() == "gfx906" || ck::is_xdl_supported() ||
ck::is_navi2_supported() || ck::is_navi3_supported()))
ck::is_gfx103_supported() || ck::is_gfx11_supported()))
{
return false;
}

View File

@@ -601,8 +601,8 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS
namespace ctc = tensor_layout::convolution;
// check device
if(!(ck::get_device_name() == "gfx906" || ck::is_navi2_supported() ||
ck::is_navi3_supported()))
if(!(ck::get_device_name() == "gfx906" || ck::is_gfx103_supported() ||
ck::is_gfx11_supported()))
{
return false;
}

View File

@@ -581,7 +581,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
namespace ctc = tensor_layout::convolution;
// check device
if(ck::is_navi3_supported())
if(ck::is_gfx11_supported())
{
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
{

View File

@@ -673,7 +673,7 @@ struct DeviceGroupedGemmMultipleD_Dl : public DeviceGroupedGemm<ALayout,
}
if(ck::get_device_name() == "gfx906" || ck::is_xdl_supported() ||
ck::is_navi2_supported() || ck::is_navi3_supported())
ck::is_gfx103_supported() || ck::is_gfx11_supported())
{
for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); i++)
{

View File

@@ -596,7 +596,7 @@ struct DeviceGroupedQueryAttentionForward_Wmma
static bool IsSupportedArgument(const RawArg& arg)
{
if(ck::is_navi3_supported())
if(ck::is_gfx11_supported())
{
if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>))
{
@@ -958,7 +958,7 @@ struct DeviceGroupedQueryAttentionForward_Wmma
#if 0
static bool IsSupportedArgument(const Argument& arg)
{
if(ck::is_navi3_supported())
if(ck::is_gfx11_supported())
{
if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>))
{

View File

@@ -594,7 +594,7 @@ struct DeviceMultiQueryAttentionForward_Wmma
static bool IsSupportedArgument(const RawArg& arg)
{
if(ck::is_navi3_supported())
if(ck::is_gfx11_supported())
{
if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>))
{
@@ -950,7 +950,7 @@ struct DeviceMultiQueryAttentionForward_Wmma
#if 0
static bool IsSupportedArgument(const Argument& arg)
{
if(ck::is_navi3_supported())
if(ck::is_gfx11_supported())
{
if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>))
{

View File

@@ -260,7 +260,7 @@ struct BlockToCTileMap_M00_N0_M01Adapt : BlockToCTileMap_M00_N0_M01Adapt<MPerBlo
};
// Grouped Rows of column-vectors WGP mapping
// Optimized for MI300-like multipe-die chip
// Optimized for gfx94x-like multipe-die chip
template <index_t GroupNum, index_t MPerBlock, index_t NPerBlock>
struct BlockToCTileMap_Grouped_M00_N0_M01Adapt

View File

@@ -95,7 +95,7 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_f16,
// Wave mode dependent propety
static constexpr index_t wave_size = Number<WaveSize>{};
// * Fixed in Navi3x, Will be wave mode dependent on Navi4x
// * Fixed on gfx11, Will be wave mode dependent for future architectures
static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4;
static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4;
// * num_acc_vgprs_per_wave alone M direction

View File

@@ -4,7 +4,7 @@
#pragma once
namespace ck {
// Define the common macro for MI300 models
// Define the common macro for gfx94x models
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
#define __gfx94__
#endif

View File

@@ -8,7 +8,7 @@
#include "ck/utility/random_gen.hpp"
namespace ck {
// Define the common macro for MI300 models
// Define the common macro for gfx94x models
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
#define __gfx94__
#endif