mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 13:11:25 +00:00
Merging the gfx12 code into public repo. (#1362)
This commit is contained in:
@@ -133,8 +133,13 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
|
||||
static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma);
|
||||
static constexpr auto WmmaK = K1 == 16 ? 32 : 16;
|
||||
|
||||
static constexpr auto AEnableLds_auto = NWaves == 1 ? false : true;
|
||||
static constexpr auto BEnableLds_auto = MWaves == 1 ? false : true;
|
||||
static constexpr auto MaxVectorLoadA = K1 * sizeof(ADataType) == 16 ? true : false;
|
||||
static constexpr auto MaxVectorLoadB = K1 * sizeof(BDataType) == 16 ? true : false;
|
||||
|
||||
static constexpr auto AEnableLds_auto =
|
||||
(NWaves == 1 && (MaxVectorLoadA || MRepeat == 1)) ? false : true;
|
||||
static constexpr auto BEnableLds_auto =
|
||||
(MWaves == 1 && (MaxVectorLoadB || NRepeat == 1)) ? false : true;
|
||||
|
||||
// If true, LDS is used unconditionally
|
||||
static constexpr auto AEnableLds_manu = false;
|
||||
@@ -829,7 +834,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(ck::is_gfx11_supported())
|
||||
if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
|
||||
{
|
||||
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
|
||||
{
|
||||
@@ -869,11 +874,15 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
|
||||
}
|
||||
else
|
||||
{
|
||||
if(!(arg.a_kz_stride_ == 1 &&
|
||||
arg.a_grid_desc_.GetLength(I2) % ABlockTransferSrcScalarPerVector == 0))
|
||||
if(!(arg.a_kz_stride_ == 1))
|
||||
{
|
||||
printf("DeviceOp: Vector Access A-k check failure\n");
|
||||
return false;
|
||||
index_t LastK =
|
||||
AEnableLds ? arg.a_grid_desc_.GetLength(I2) : arg.a_grid_desc_.GetLength(I6);
|
||||
if(LastK % ABlockTransferSrcScalarPerVector == 0)
|
||||
{
|
||||
printf("DeviceOp: Vector Access A-k check failure\n");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -70,8 +70,9 @@ __global__ void
|
||||
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
|
||||
const Block2CTileMap block_2_ctile_map)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
|
||||
defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
|
||||
defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__) || \
|
||||
defined(__gfx12__))
|
||||
|
||||
const index_t num_blocks_per_batch =
|
||||
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
|
||||
@@ -648,7 +649,7 @@ struct DeviceBatchedGemmMultipleD_Dl : public DeviceBatchedGemmMultiD<ALayout,
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(ck::get_device_name() == "gfx906" || ck::is_xdl_supported() ||
|
||||
ck::is_gfx103_supported() || ck::is_gfx11_supported())
|
||||
ck::is_gfx103_supported() || ck::is_gfx11_supported() || ck::is_gfx12_supported())
|
||||
{
|
||||
bool pass = true;
|
||||
pass = pass && arg.K_ % K1 == 0;
|
||||
|
||||
@@ -56,7 +56,7 @@ __global__ void
|
||||
bool input_permute,
|
||||
bool output_permute)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__))
|
||||
|
||||
// clang-format off
|
||||
// ***************************************************
|
||||
@@ -159,6 +159,7 @@ __global__ void
|
||||
ignore = O;
|
||||
ignore = G0;
|
||||
ignore = G1;
|
||||
ignore = alpha;
|
||||
ignore = input_permute;
|
||||
ignore = output_permute;
|
||||
#endif // end of if (defined(__gfx11__))
|
||||
@@ -187,7 +188,7 @@ __global__ void
|
||||
index_t head_size,
|
||||
float alpha)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__))
|
||||
|
||||
// clang-format off
|
||||
// ***************************************************
|
||||
@@ -321,7 +322,7 @@ __global__ void
|
||||
index_t head_size,
|
||||
float alpha)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__))
|
||||
|
||||
// clang-format off
|
||||
// ***************************************************
|
||||
@@ -858,7 +859,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
|
||||
|
||||
static bool IsSupportedArgument(const RawArg& arg)
|
||||
{
|
||||
if(ck::is_gfx11_supported())
|
||||
if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
|
||||
{
|
||||
if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>))
|
||||
{
|
||||
|
||||
@@ -592,9 +592,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
|
||||
return false;
|
||||
}
|
||||
|
||||
if(ck::get_device_name() != "gfx90a" && ck::get_device_name() != "gfx940" &&
|
||||
ck::get_device_name() != "gfx941" && ck::get_device_name() != "gfx942" &&
|
||||
std::is_same<ADataType, double>::value)
|
||||
if(!ck::is_lds_direct_load_supported() && std::is_same<ADataType, double>::value)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -1393,7 +1393,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Dl
|
||||
{
|
||||
// check device
|
||||
if(!(ck::get_device_name() == "gfx906" || ck::is_gfx103_supported() ||
|
||||
ck::is_gfx11_supported()))
|
||||
ck::is_gfx11_supported() || ck::is_gfx12_supported()))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -509,7 +509,7 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(ck::is_gfx11_supported())
|
||||
if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
|
||||
{
|
||||
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, ck::half_t> ||
|
||||
is_same_v<AccDataType, int32_t>))
|
||||
|
||||
@@ -536,7 +536,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
|
||||
}
|
||||
|
||||
if(ck::get_device_name() == "gfx906" || ck::is_gfx103_supported() ||
|
||||
ck::is_gfx11_supported())
|
||||
ck::is_gfx11_supported() || ck::is_gfx12_supported())
|
||||
{
|
||||
return GridwiseGemm::CheckValidity(
|
||||
arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_);
|
||||
|
||||
@@ -50,8 +50,9 @@ __global__ void
|
||||
const CGridDesc_M0_M10_M11_N0_N10_N11 e_grid_desc_m0_m10_m11_n0_n10_n11,
|
||||
const Block2CTileMap block_2_ctile_map)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
|
||||
defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
|
||||
defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__) || \
|
||||
defined(__gfx12__))
|
||||
|
||||
constexpr index_t shared_block_size =
|
||||
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(ABDataType);
|
||||
@@ -552,7 +553,7 @@ struct DeviceGemmMultipleD_Dl : public DeviceGemmMultipleD<ALayout,
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(ck::get_device_name() == "gfx906" || ck::is_xdl_supported() ||
|
||||
ck::is_gfx103_supported() || ck::is_gfx11_supported())
|
||||
ck::is_gfx103_supported() || ck::is_gfx11_supported() || ck::is_gfx12_supported())
|
||||
{
|
||||
return GridwiseGemm::CheckValidity(
|
||||
arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.e_grid_desc_m_n_);
|
||||
|
||||
@@ -515,7 +515,7 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(ck::is_gfx11_supported())
|
||||
if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
|
||||
{
|
||||
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
|
||||
{
|
||||
|
||||
@@ -84,14 +84,21 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
|
||||
// K1 = Max Vector Access Pixels
|
||||
static constexpr auto K1Number = Number<K1>{};
|
||||
|
||||
static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma);
|
||||
static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma);
|
||||
static constexpr auto WmmaK = K1 == 16 ? 32 : 16;
|
||||
static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma);
|
||||
static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma);
|
||||
static constexpr auto WmmaK = K1 == 16 ? 32 : 16;
|
||||
static constexpr auto MaxVectorLoadA = K1 * sizeof(ADataType) == 16 ? true : false;
|
||||
static constexpr auto MaxVectorLoadB = K1 * sizeof(BDataType) == 16 ? true : false;
|
||||
|
||||
static constexpr auto AEnableLds_auto =
|
||||
(NWaves == 1 && is_same<tensor_layout::gemm::RowMajor, ALayout>::value) ? false : true;
|
||||
static constexpr auto AEnableLds_auto = (NWaves == 1 && (MaxVectorLoadA || MRepeat == 1) &&
|
||||
is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
|
||||
? false
|
||||
: true;
|
||||
static constexpr auto BEnableLds_auto =
|
||||
(MWaves == 1 && is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value) ? false : true;
|
||||
(MWaves == 1 && (MaxVectorLoadB || NRepeat == 1) &&
|
||||
is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
|
||||
? false
|
||||
: true;
|
||||
|
||||
// If true, LDS is used unconditionally
|
||||
static constexpr auto AEnableLds_manu = false;
|
||||
@@ -443,7 +450,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(ck::is_gfx11_supported())
|
||||
if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
|
||||
{
|
||||
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, ck::half_t> ||
|
||||
is_same_v<AccDataType, int32_t>))
|
||||
|
||||
@@ -629,7 +629,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
// check device
|
||||
if(ck::is_gfx11_supported())
|
||||
if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
|
||||
{
|
||||
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
|
||||
{
|
||||
|
||||
@@ -48,8 +48,9 @@ __global__ void
|
||||
const Block2CTileMap block_2_ctile_map,
|
||||
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx103__) || \
|
||||
defined(__gfx90a__) || defined(__gfx908__) || defined(__gfx94__) || defined(__gfx11__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx103__) || \
|
||||
defined(__gfx90a__) || defined(__gfx908__) || defined(__gfx94__) || defined(__gfx11__) || \
|
||||
defined(__gfx12__))
|
||||
const index_t num_blocks_per_batch =
|
||||
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
|
||||
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
|
||||
|
||||
@@ -692,7 +692,7 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
// check device
|
||||
if(ck::is_gfx11_supported())
|
||||
if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
|
||||
{
|
||||
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
|
||||
{
|
||||
|
||||
@@ -90,8 +90,9 @@ __global__ void
|
||||
const Block2CTileMap block_2_ctile_map,
|
||||
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx103__) || \
|
||||
defined(__gfx90a__) || defined(__gfx908__) || defined(__gfx94__) || defined(__gfx11__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx103__) || \
|
||||
defined(__gfx90a__) || defined(__gfx908__) || defined(__gfx94__) || defined(__gfx11__) || \
|
||||
defined(__gfx12__))
|
||||
// offset base pointer for each work-group
|
||||
const index_t num_blocks_per_batch =
|
||||
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
|
||||
@@ -667,7 +668,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
|
||||
|
||||
// check device
|
||||
if(!(ck::get_device_name() == "gfx906" || ck::is_xdl_supported() ||
|
||||
ck::is_gfx103_supported() || ck::is_gfx11_supported()))
|
||||
ck::is_gfx103_supported() || ck::is_gfx11_supported() || ck::is_gfx12_supported()))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -107,7 +107,7 @@ __global__ void
|
||||
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx103__) || \
|
||||
defined(__gfx11__))
|
||||
defined(__gfx11__) || defined(__gfx12__))
|
||||
// offset base pointer for each work-group
|
||||
const index_t num_blocks_per_batch =
|
||||
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
|
||||
@@ -603,7 +603,7 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS
|
||||
|
||||
// check device
|
||||
if(!(ck::get_device_name() == "gfx906" || ck::is_gfx103_supported() ||
|
||||
ck::is_gfx11_supported()))
|
||||
ck::is_gfx11_supported() || ck::is_gfx12_supported()))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -582,7 +582,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
|
||||
namespace ctc = tensor_layout::convolution;
|
||||
|
||||
// check device
|
||||
if(ck::is_gfx11_supported())
|
||||
if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
|
||||
{
|
||||
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
|
||||
{
|
||||
|
||||
@@ -39,8 +39,9 @@ __global__ void
|
||||
const BElementwiseOperation b_element_op,
|
||||
const CDEElementwiseOperation cde_element_op)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
|
||||
defined(__gfx90a__) || defined(__gfx103__) || defined(__gfx11__) || defined(__gfx94__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
|
||||
defined(__gfx90a__) || defined(__gfx103__) || defined(__gfx11__) || defined(__gfx94__) || \
|
||||
defined(__gfx12__))
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
const index_t block_id = get_block_1d_id();
|
||||
@@ -673,7 +674,7 @@ struct DeviceGroupedGemmMultipleD_Dl : public DeviceGroupedGemm<ALayout,
|
||||
}
|
||||
|
||||
if(ck::get_device_name() == "gfx906" || ck::is_xdl_supported() ||
|
||||
ck::is_gfx103_supported() || ck::is_gfx11_supported())
|
||||
ck::is_gfx103_supported() || ck::is_gfx11_supported() || ck::is_gfx12_supported())
|
||||
{
|
||||
for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); i++)
|
||||
{
|
||||
|
||||
@@ -61,7 +61,7 @@ __global__ void
|
||||
bool input_permute,
|
||||
bool output_permute)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__))
|
||||
|
||||
// clang-format off
|
||||
// ***************************************************
|
||||
@@ -166,6 +166,7 @@ __global__ void
|
||||
ignore = O;
|
||||
ignore = G0;
|
||||
ignore = G1;
|
||||
ignore = alpha;
|
||||
ignore = input_permute;
|
||||
ignore = output_permute;
|
||||
#endif // end of if (defined(__gfx11__))
|
||||
@@ -596,7 +597,7 @@ struct DeviceGroupedQueryAttentionForward_Wmma
|
||||
|
||||
static bool IsSupportedArgument(const RawArg& arg)
|
||||
{
|
||||
if(ck::is_gfx11_supported())
|
||||
if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
|
||||
{
|
||||
if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>))
|
||||
{
|
||||
|
||||
@@ -60,7 +60,7 @@ __global__ void
|
||||
bool input_permute,
|
||||
bool output_permute)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__))
|
||||
|
||||
// clang-format off
|
||||
// ***************************************************
|
||||
@@ -165,6 +165,7 @@ __global__ void
|
||||
ignore = O;
|
||||
ignore = G0;
|
||||
ignore = G1;
|
||||
ignore = alpha;
|
||||
ignore = input_permute;
|
||||
ignore = output_permute;
|
||||
#endif // end of if (defined(__gfx11__))
|
||||
@@ -594,7 +595,7 @@ struct DeviceMultiQueryAttentionForward_Wmma
|
||||
|
||||
static bool IsSupportedArgument(const RawArg& arg)
|
||||
{
|
||||
if(ck::is_gfx11_supported())
|
||||
if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
|
||||
{
|
||||
if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>))
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user