Fix issue with multiple targets and remove smfmac tests from unsupported test targets (#1372)

[ROCm/composable_kernel commit: 959073842c]
This commit is contained in:
Jun Liu
2024-07-03 23:34:38 -07:00
committed by GitHub
parent 35bbee7130
commit fa73739812
24 changed files with 243 additions and 50 deletions

4
Jenkinsfile vendored
View File

@@ -886,10 +886,10 @@ pipeline {
}
agent{ label rocmnode("gfx90a") }
environment{
setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx908;gfx90a" -DCMAKE_CXX_FLAGS=" -O3 " """
setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx1100;gfx90a" -DCMAKE_CXX_FLAGS=" -O3 " """
execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && \
cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \
-DGPU_TARGETS="gfx908;gfx90a" \
-DGPU_TARGETS="gfx1100;gfx90a" \
-DCMAKE_CXX_COMPILER="${build_compiler()}" \
-DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """
}

View File

@@ -7,19 +7,23 @@
#include <initializer_list>
#include <vector>
#include "ck/utility/common_header.hpp"
// __gfx9__ defined in the above header via ck.hpp
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/library/utility/fill.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/wrapper/layout.hpp"
#include "ck/wrapper/tensor.hpp"
#include "ck/wrapper/operations/copy.hpp"
#include "ck/wrapper/operations/gemm.hpp"
#include "ck/wrapper/utils/kernel_utils.hpp"
#include "ck/host_utility/device_prop.hpp"
struct SimpleDeviceMem
{
@@ -204,6 +208,14 @@ void PerformGemm(const ck::index_t M,
int main(int argc, char* argv[])
{
bool is_supported = ck::is_xdl_supported();
if(!is_supported)
{
std::cout << "WARNING: xdl example not supported on the platform " << ck::get_device_name()
<< std::endl;
return 0;
}
using DataType = ck::half_t;
const auto thread_layout =
ck::wrapper::make_layout(ck::make_tuple(ck::Number<64>{}, ck::Number<4>{}),
@@ -213,3 +225,4 @@ int main(int argc, char* argv[])
3840, 4096, 4096, tile_shape, thread_layout);
return 0;
}
#endif

View File

@@ -7,18 +7,21 @@
#include <initializer_list>
#include <vector>
#include "ck/library/utility/host_tensor.hpp"
#include "ck/utility/common_header.hpp"
// __gfx9__ defined in the above header via ck.hpp
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/library/utility/fill.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/wrapper/layout.hpp"
#include "ck/wrapper/tensor.hpp"
#include "ck/wrapper/operations/copy.hpp"
#include "ck/wrapper/operations/gemm.hpp"
#include "ck/wrapper/utils/kernel_utils.hpp"
#include "ck/host_utility/device_prop.hpp"
struct SimpleDeviceMem
{
@@ -296,6 +299,14 @@ void PerformGemm(const ck::index_t M,
int main(int argc, char* argv[])
{
bool is_supported = ck::is_xdl_supported();
if(!is_supported)
{
std::cout << "WARNING: xdl example not supported on the platform " << ck::get_device_name()
<< std::endl;
return 0;
}
using DataType = ck::half_t;
const auto thread_layout =
ck::wrapper::make_layout(ck::make_tuple(ck::Number<4>{}, ck::Number<64>{}, ck::Number<1>{}),
@@ -305,3 +316,4 @@ int main(int argc, char* argv[])
3840, 4096, 4096, tile_shape, thread_layout);
return 0;
}
#endif

View File

@@ -17,6 +17,7 @@
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/host_utility/device_prop.hpp"
struct AlphaBetaAdd
{
@@ -175,6 +176,14 @@ int main(int argc, char* argv[])
exit(0);
}
bool is_supported = ck::is_gfx11_supported();
if(!is_supported)
{
std::cout << "WARNING: wmma example not supported on the platform " << ck::get_device_name()
<< std::endl;
return 0;
}
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
using namespace ck::literals;

View File

@@ -17,6 +17,7 @@
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/host_utility/device_prop.hpp"
struct AlphaBetaAdd
{
@@ -175,6 +176,14 @@ int main(int argc, char* argv[])
exit(0);
}
bool is_supported = ck::is_gfx11_supported();
if(!is_supported)
{
std::cout << "WARNING: wmma example not supported on the platform " << ck::get_device_name()
<< std::endl;
return 0;
}
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
using namespace ck::literals;

View File

@@ -2,6 +2,7 @@
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "common_wmma.hpp"
#include "ck/host_utility/device_prop.hpp"
// kernel data types
using InKernelDataType = FP16;
@@ -23,4 +24,14 @@ using OutElementOp = ck::tensor_operation::element_wise::AddReluAdd;
#include "run_grouped_conv_fwd_bias_relu_add_wmma_example.inc"
int main(int argc, char* argv[]) { return !run_grouped_conv_fwd_bias_relu_add_example(argc, argv); }
int main(int argc, char* argv[])
{
bool is_supported = ck::is_gfx11_supported();
if(!is_supported)
{
std::cout << "WARNING: wmma example not supported on the platform " << ck::get_device_name()
<< std::endl;
return 0;
}
return !run_grouped_conv_fwd_bias_relu_add_example(argc, argv);
}

View File

@@ -2,6 +2,7 @@
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "common_wmma.hpp"
#include "ck/host_utility/device_prop.hpp"
// kernel data types
using InKernelDataType = I8;
@@ -23,4 +24,14 @@ using OutElementOp = ck::tensor_operation::element_wise::AddReluAdd;
#include "run_grouped_conv_fwd_bias_relu_add_wmma_example.inc"
int main(int argc, char* argv[]) { return !run_grouped_conv_fwd_bias_relu_add_example(argc, argv); }
int main(int argc, char* argv[])
{
bool is_supported = ck::is_gfx11_supported();
if(!is_supported)
{
std::cout << "WARNING: wmma example not supported on the platform " << ck::get_device_name()
<< std::endl;
return 0;
}
return !run_grouped_conv_fwd_bias_relu_add_example(argc, argv);
}

View File

@@ -27,6 +27,7 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_n = Softmax(A_g_m_k * B0_g
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
#include "ck/host_utility/device_prop.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
@@ -163,4 +164,14 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
#include "run_batched_gemm_scale_softmax_gemm_permute_wmma.inc"
int main(int argc, char* argv[]) { return run(argc, argv); }
int main(int argc, char* argv[])
{
bool is_supported = ck::is_gfx11_supported();
if(!is_supported)
{
std::cout << "WARNING: wmma example not supported on the platform " << ck::get_device_name()
<< std::endl;
return 0;
}
return run(argc, argv);
}

View File

@@ -27,6 +27,7 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_n = Softmax(A_g_m_k * B0_g
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
#include "ck/host_utility/device_prop.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
@@ -285,4 +286,14 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
#include "run_batched_gemm_scale_softmax_gemm_permute_wmma.inc"
int main(int argc, char* argv[]) { return run(argc, argv); }
int main(int argc, char* argv[])
{
bool is_supported = ck::is_gfx11_supported();
if(!is_supported)
{
std::cout << "WARNING: wmma example not supported on the platform " << ck::get_device_name()
<< std::endl;
return 0;
}
return run(argc, argv);
}

View File

@@ -27,6 +27,7 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_n = Softmax(A_g_m_k * B0_g
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
#include "ck/host_utility/device_prop.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
@@ -351,4 +352,14 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
#include "run_cross_attention_wmma.inc"
int main(int argc, char* argv[]) { return run(argc, argv); }
int main(int argc, char* argv[])
{
bool is_supported = ck::is_gfx11_supported();
if(!is_supported)
{
std::cout << "WARNING: wmma example not supported on the platform " << ck::get_device_name()
<< std::endl;
return 0;
}
return run(argc, argv);
}

View File

@@ -28,6 +28,7 @@ Example is GQA-4
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
#include "ck/host_utility/device_prop.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
@@ -299,4 +300,14 @@ using ReferenceGemm1Instance =
#include "run_grouped_query_attention_forward_wmma.inc"
int main(int argc, char* argv[]) { return run(argc, argv); }
int main(int argc, char* argv[])
{
bool is_supported = ck::is_gfx11_supported();
if(!is_supported)
{
std::cout << "WARNING: wmma example not supported on the platform " << ck::get_device_name()
<< std::endl;
return 0;
}
return run(argc, argv);
}

View File

@@ -26,6 +26,7 @@ Shazeer, Noam. “Fast Transformer Decoding: One Write-Head Is All You Need.”
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
#include "ck/host_utility/device_prop.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
@@ -284,4 +285,14 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm_
#include "run_multi_query_attention_forward_wmma.inc"
int main(int argc, char* argv[]) { return run(argc, argv); }
int main(int argc, char* argv[])
{
bool is_supported = ck::is_gfx11_supported();
if(!is_supported)
{
std::cout << "WARNING: wmma example not supported on the platform " << ck::get_device_name()
<< std::endl;
return 0;
}
return run(argc, argv);
}

View File

@@ -27,6 +27,7 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_n = Softmax(A_g_m_k * B0_g
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
#include "ck/host_utility/device_prop.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
@@ -329,4 +330,14 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
#include "run_self_attention_wmma.inc"
int main(int argc, char* argv[]) { return run(argc, argv); }
int main(int argc, char* argv[])
{
bool is_supported = ck::is_gfx11_supported();
if(!is_supported)
{
std::cout << "WARNING: wmma example not supported on the platform " << ck::get_device_name()
<< std::endl;
return 0;
}
return run(argc, argv);
}

View File

@@ -3,6 +3,7 @@
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp"
#include "common.hpp"
#include "ck/host_utility/device_prop.hpp"
using OutDataType = FP16;
using WeiDataType = FP16;
@@ -31,4 +32,14 @@ using DeviceConvInstance = ck::tensor_operation::device::DeviceGroupedConvBwdDat
#include "run_grouped_conv_bwd_data_example.inc"
int main(int argc, char* argv[]) { return run_grouped_conv_bwd_data_example(argc, argv); }
int main(int argc, char* argv[])
{
bool is_supported = ck::is_gfx11_supported();
if(!is_supported)
{
std::cout << "WARNING: wmma example not supported on the platform " << ck::get_device_name()
<< std::endl;
return 0;
}
return run_grouped_conv_bwd_data_example(argc, argv);
}

View File

@@ -47,12 +47,12 @@ __global__ void
#endif
kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3(
typename GridwiseGemm::Argument karg,
const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1,
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
[[maybe_unused]] const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1,
[[maybe_unused]] const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1,
[[maybe_unused]] const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
const index_t num_k_per_block)
[[maybe_unused]] const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
[[maybe_unused]] const index_t num_k_per_block)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx94__))
@@ -103,12 +103,12 @@ __global__ void
#endif
kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3_2lds(
typename GridwiseGemm::Argument karg,
const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1,
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
[[maybe_unused]] const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1,
[[maybe_unused]] const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1,
[[maybe_unused]] const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
const index_t num_k_per_block)
[[maybe_unused]] const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
[[maybe_unused]] const index_t num_k_per_block)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))

View File

@@ -69,14 +69,15 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
#endif
kernel_grouped_conv_fwd_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg,
const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1,
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock,
const ComputePtrOffset compute_ptr_offset_of_groups,
const ComputePtrOffset compute_ptr_offset_of_n,
const index_t groups_count)
kernel_grouped_conv_fwd_xdl_cshuffle_v3(
typename GridwiseGemm::Argument karg,
[[maybe_unused]] const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1,
[[maybe_unused]] const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1,
[[maybe_unused]] const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock,
[[maybe_unused]] const ComputePtrOffset compute_ptr_offset_of_groups,
[[maybe_unused]] const ComputePtrOffset compute_ptr_offset_of_n,
[[maybe_unused]] const index_t groups_count)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
// offset base pointer for each work-group
@@ -132,13 +133,13 @@ __global__ void
#endif
kernel_grouped_conv_fwd_xdl_cshuffle_v3_2lds(
typename GridwiseGemm::Argument karg,
const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1,
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
[[maybe_unused]] const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1,
[[maybe_unused]] const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1,
[[maybe_unused]] const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock,
const ComputePtrOffset compute_ptr_offset_of_groups,
const ComputePtrOffset compute_ptr_offset_of_n,
const index_t groups_count)
[[maybe_unused]] const ComputePtrOffset compute_ptr_offset_of_groups,
[[maybe_unused]] const ComputePtrOffset compute_ptr_offset_of_n,
[[maybe_unused]] const index_t groups_count)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
// offset base pointer for each work-group

View File

@@ -16,8 +16,15 @@ struct intrin_smfmac_f32_16x16x32f16<16, 16>
__device__ static void
Run(const half4_t& reg_a, const half8_t& reg_b, const int32_t& reg_idx, FloatC& reg_c)
{
#if defined(__gfx94__)
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_smfmac_f32_16x16x32_f16(
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], reg_idx, 0, 0);
#else
ignore = reg_a;
ignore = reg_b;
ignore = reg_c;
ignore = reg_idx;
#endif
}
};
@@ -31,8 +38,15 @@ struct intrin_smfmac_f32_16x16x32bf16<16, 16>
__device__ static void
Run(const bhalf4_t& reg_a, const bhalf8_t& reg_b, const int32_t& reg_idx, FloatC& reg_c)
{
#if defined(__gfx94__)
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_smfmac_f32_16x16x32_bf16(
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], reg_idx, 0, 0);
#else
ignore = reg_a;
ignore = reg_b;
ignore = reg_c;
ignore = reg_idx;
#endif
}
};
@@ -46,8 +60,15 @@ struct intrin_smfmac_f32_32x32x16f16<32, 32>
__device__ static void
Run(const half4_t& reg_a, const half8_t& reg_b, const int32_t& reg_idx, FloatC& reg_c)
{
#if defined(__gfx94__)
reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_smfmac_f32_32x32x16_f16(
reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], reg_idx, 0, 0);
#else
ignore = reg_a;
ignore = reg_b;
ignore = reg_c;
ignore = reg_idx;
#endif
}
};
@@ -61,8 +82,15 @@ struct intrin_smfmac_f32_32x32x16bf16<32, 32>
__device__ static void
Run(const bhalf4_t& reg_a, const bhalf8_t& reg_b, const int32_t& reg_idx, FloatC& reg_c)
{
#if defined(__gfx94__)
reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_smfmac_f32_32x32x16_bf16(
reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], reg_idx, 0, 0);
#else
ignore = reg_a;
ignore = reg_b;
ignore = reg_c;
ignore = reg_idx;
#endif
}
};

View File

@@ -71,6 +71,8 @@ function(add_test_executable TEST_NAME)
list(REMOVE_ITEM TEST_TARGETS gfx1030 gfx1100 gfx1101 gfx1102 gfx1103)
elseif(ARGN MATCHES "_wmma")
list(REMOVE_ITEM TEST_TARGETS gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030)
elseif(ARGN MATCHES "_smfmac")
list(REMOVE_ITEM TEST_TARGETS gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx908 gfx90a)
endif()
set_source_files_properties(${ARGN} PROPERTIES LANGUAGE HIP)
add_executable(${TEST_NAME} ${ARGN})
@@ -150,6 +152,8 @@ function(add_gtest_executable TEST_NAME)
list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103)
elseif(ARGN MATCHES "_wmma")
list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030)
elseif(ARGN MATCHES "_smfmac")
list(REMOVE_ITEM TEST_TARGETS gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx908 gfx90a)
endif()
set_source_files_properties(${ARGN} PROPERTIES LANGUAGE HIP)
add_executable(${TEST_NAME} ${ARGN})
@@ -209,7 +213,7 @@ add_subdirectory(wrapper)
if(GPU_TARGETS MATCHES "gfx11")
add_subdirectory(wmma_op)
endif()
if(GPU_TARGETS MATCHES "gfx942")
if(GPU_TARGETS MATCHES "gfx942" AND CK_HIP_VERSION_MAJOR GREATER_EQUAL 6 AND CK_HIP_VERSION_MINOR GREATER_EQUAL 2) # smfmac needs ROCm6.2
add_subdirectory(smfmac_op)
endif()
add_subdirectory(position_embedding)

View File

@@ -2,11 +2,11 @@ add_gtest_executable(test_grouped_convnd_bwd_data test_grouped_convnd_bwd_data_x
if(result EQUAL 0)
target_link_libraries(test_grouped_convnd_bwd_data PRIVATE utility device_grouped_conv2d_bwd_data_instance device_grouped_conv3d_bwd_data_instance)
endif()
add_gtest_executable(test_grouped_convnd_bwd_data_interface test_grouped_convnd_bwd_data_interface_xdl.cpp)
add_gtest_executable(test_grouped_convnd_bwd_data_interface_xdl test_grouped_convnd_bwd_data_interface_xdl.cpp)
if(result EQUAL 0)
target_link_libraries(test_grouped_convnd_bwd_data_interface PRIVATE utility device_grouped_conv2d_bwd_data_instance)
target_link_libraries(test_grouped_convnd_bwd_data_interface_xdl PRIVATE utility device_grouped_conv2d_bwd_data_instance)
endif()
add_gtest_executable(test_grouped_convnd_bwd_data_interface test_grouped_convnd_bwd_data_interface_wmma.cpp)
add_gtest_executable(test_grouped_convnd_bwd_data_interface_wmma test_grouped_convnd_bwd_data_interface_wmma.cpp)
if(result EQUAL 0)
target_link_libraries(test_grouped_convnd_bwd_data_interface PRIVATE utility device_grouped_conv2d_bwd_data_instance)
target_link_libraries(test_grouped_convnd_bwd_data_interface_wmma PRIVATE utility device_grouped_conv2d_bwd_data_instance)
endif()

View File

@@ -52,6 +52,14 @@ class TestGroupedConvndBwdData : public ::testing::Test
ck::utils::conv::ConvParam conv_param;
void SetUp() override
{
if(!ck::is_gfx11_supported())
{
GTEST_SKIP();
}
}
template <ck::index_t NDimSpatial>
bool Run()
{

View File

@@ -5,13 +5,13 @@ if(GPU_TARGETS MATCHES "gfx9" OR DL_KERNELS)
add_gtest_executable(test_grouped_convnd_bwd_weight test_grouped_convnd_bwd_weight.cpp)
target_link_libraries(test_grouped_convnd_bwd_weight PRIVATE utility device_grouped_conv3d_bwd_weight_instance)
endif()
add_gtest_executable(test_grouped_convnd_bwd_weight_interface test_grouped_convnd_bwd_weight_interface_xdl.cpp)
add_gtest_executable(test_grouped_convnd_bwd_weight_interface_xdl test_grouped_convnd_bwd_weight_interface_xdl.cpp)
if(result EQUAL 0)
target_link_libraries(test_grouped_convnd_bwd_weight_interface PRIVATE utility)
target_link_libraries(test_grouped_convnd_bwd_weight_interface_xdl PRIVATE utility)
endif()
add_gtest_executable(test_grouped_convnd_bwd_weight_interface test_grouped_convnd_bwd_weight_interface_wmma.cpp)
add_gtest_executable(test_grouped_convnd_bwd_weight_interface_wmma test_grouped_convnd_bwd_weight_interface_wmma.cpp)
if(result EQUAL 0)
target_link_libraries(test_grouped_convnd_bwd_weight_interface PRIVATE utility)
target_link_libraries(test_grouped_convnd_bwd_weight_interface_wmma PRIVATE utility)
endif()
add_gtest_executable(test_grouped_conv_bwd_weight_xdl_bilinear test_grouped_conv_bwd_weight_xdl_bilinear.cpp)
if(result EQUAL 0)

View File

@@ -52,6 +52,14 @@ class TestGroupedConvndBwdWeight : public ::testing::Test
ck::utils::conv::ConvParam conv_param;
void SetUp() override
{
if(!ck::is_gfx11_supported())
{
GTEST_SKIP();
}
}
template <ck::index_t SplitK>
bool Run()
{

View File

@@ -1,6 +1,6 @@
if(GPU_TARGETS MATCHES "gfx9" OR GPU_TARGETS MATCHES "gfx11")
add_gtest_executable(test_grouped_convnd_fwd test_grouped_convnd_fwd.cpp)
if(GPU_TARGETS MATCHES "gfx11")
if((GPU_TARGETS MATCHES "gfx11") AND (NOT GPU_TARGETS MATCHES "gfx9"))
target_link_libraries(test_grouped_convnd_fwd PRIVATE utility device_grouped_conv2d_fwd_instance device_grouped_conv3d_fwd_instance)
else()
target_link_libraries(test_grouped_convnd_fwd PRIVATE utility device_grouped_conv1d_fwd_instance device_grouped_conv2d_fwd_instance device_grouped_conv3d_fwd_instance)

View File

@@ -11,6 +11,7 @@
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/utility/amd_wmma.hpp"
#include "ck/host_utility/device_prop.hpp"
namespace ck {
namespace wmma_op_util {
@@ -373,7 +374,8 @@ struct TestWmma
a, b, c_host, a_element_op, b_element_op, c_element_op);
// Act
bool is_supported = ck::wmma_op_util::RunDeviceGEMM(wmma_kernel, a, b, c_device);
bool is_supported = ck::is_gfx11_supported() &&
ck::wmma_op_util::RunDeviceGEMM(wmma_kernel, a, b, c_device);
if(is_supported)
{