Change relu to clamp for grouped conv fwd instances (#2249)

[ROCm/composable_kernel commit: e7906dd644]
This commit is contained in:
Bartłomiej Kocot
2025-05-29 00:51:25 +02:00
committed by GitHub
parent fe30e881d6
commit f1ec9e0be5
34 changed files with 291 additions and 195 deletions

View File

@@ -251,7 +251,7 @@ add_subdirectory(reduce)
add_subdirectory(convnd_fwd)
add_subdirectory(convnd_bwd_data)
add_subdirectory(grouped_convnd_fwd)
add_subdirectory(grouped_convnd_fwd_bias_relu)
add_subdirectory(grouped_convnd_fwd_bias_clamp)
add_subdirectory(grouped_convnd_bwd_weight)
add_subdirectory(block_to_ctile_map)
add_subdirectory(softmax)

View File

@@ -0,0 +1,4 @@
if(GPU_TARGETS MATCHES "gfx9")
add_gtest_executable(test_grouped_convnd_fwd_bias_clamp test_grouped_convnd_fwd_bias_clamp.cpp)
target_link_libraries(test_grouped_convnd_fwd_bias_clamp PRIVATE utility device_grouped_conv2d_fwd_bias_clamp_instance device_grouped_conv3d_fwd_bias_clamp_instance)
endif()

View File

@@ -7,11 +7,11 @@
#include <vector>
#include <gtest/gtest.h>
#include "profiler/profile_grouped_conv_fwd_bias_relu_impl.hpp"
#include "profiler/profile_grouped_conv_fwd_bias_clamp_impl.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
using AddRelu = ck::tensor_operation::element_wise::AddRelu;
using AddClamp = ck::tensor_operation::element_wise::AddClamp;
template <typename Tuple>
class TestGroupedConvndFwd : public ::testing::Test
@@ -32,16 +32,16 @@ class TestGroupedConvndFwd : public ::testing::Test
bool pass = true;
for(auto& param : conv_params)
{
pass = pass && ck::profiler::profile_grouped_conv_fwd_bias_relu_impl<NDimSpatial,
InLayout,
WeiLayout,
OutLayout,
DataType,
DataType,
DataType,
DataType,
DataType,
IndexType>(
pass = pass && ck::profiler::profile_grouped_conv_fwd_bias_clamp_impl<NDimSpatial,
InLayout,
WeiLayout,
OutLayout,
DataType,
DataType,
DataType,
DataType,
DataType,
IndexType>(
true, // do_verification
1, // init_method: integer value
false, // do_log

View File

@@ -1,4 +0,0 @@
if(GPU_TARGETS MATCHES "gfx9")
add_gtest_executable(test_grouped_convnd_fwd_bias_relu test_grouped_convnd_fwd_bias_relu.cpp)
target_link_libraries(test_grouped_convnd_fwd_bias_relu PRIVATE utility device_grouped_conv2d_fwd_bias_relu_instance device_grouped_conv3d_fwd_bias_relu_instance)
endif()