diff --git a/library/include/ck/library/tensor_operation_instance/gpu/max_pool_bwd.hpp b/library/include/ck/library/tensor_operation_instance/gpu/max_pool_bwd.hpp index 1dc3544ecb..05c2182689 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/max_pool_bwd.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/max_pool_bwd.hpp @@ -23,6 +23,10 @@ void add_device_maxpool_bwd_bf16_instances( void add_device_maxpool_bwd_f32_instances( std::vector>>&); #endif +#ifdef CK_ENABLE_FP8 +void add_device_maxpool_bwd_f8_instances( + std::vector>>&); +#endif #ifdef CK_ENABLE_INT8 void add_device_maxpool_bwd_int8_instances( std::vector>>&); @@ -53,6 +57,11 @@ struct DeviceOperationInstanceFactory< is_same_v) add_device_maxpool_bwd_f32_instances(op_ptrs); #endif +#ifdef CK_ENABLE_FP8 + else if constexpr(is_same_v && is_same_v && + is_same_v) + add_device_maxpool_bwd_f8_instances(op_ptrs); +#endif #ifdef CK_ENABLE_INT8 else if constexpr(is_same_v && is_same_v && is_same_v) diff --git a/library/src/tensor_operation_instance/gpu/max_pool_bwd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/max_pool_bwd/CMakeLists.txt index a2315175d8..6925e800b2 100644 --- a/library/src/tensor_operation_instance/gpu/max_pool_bwd/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/max_pool_bwd/CMakeLists.txt @@ -2,5 +2,6 @@ set(DEVICE_MAXPOOL_BWD_INSTANCES) list(APPEND DEVICE_MAXPOOL_BWD_INSTANCES device_max_pool_bwd_f16_instance.cpp device_max_pool_bwd_bf16_instance.cpp device_max_pool_bwd_f32_instance.cpp + device_max_pool_bwd_f8_instance.cpp device_max_pool_bwd_int8_instance.cpp) add_instance_library(device_max_pool_bwd_instance ${DEVICE_MAXPOOL_BWD_INSTANCES}) diff --git a/library/src/tensor_operation_instance/gpu/max_pool_bwd/device_max_pool_bwd_f8_instance.cpp b/library/src/tensor_operation_instance/gpu/max_pool_bwd/device_max_pool_bwd_f8_instance.cpp new file mode 100644 index 0000000000..7c25015bb4 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/max_pool_bwd/device_max_pool_bwd_f8_instance.cpp @@ -0,0 +1,20 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "max_pool_bwd_instance_common.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_maxpool_bwd_f8_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, device_maxpool_bwd_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/test/pool/test_max_pool2d_bwd.cpp b/test/pool/test_max_pool2d_bwd.cpp index eae8f2c4d5..65a897dd5b 100644 --- a/test/pool/test_max_pool2d_bwd.cpp +++ b/test/pool/test_max_pool2d_bwd.cpp @@ -55,6 +55,7 @@ using Max_Pool_2D_f32_types = ::testing::Types>; using Max_Pool_2D_int8_types = ::testing::Types>; using Max_Pool_2D_f16_types = ::testing::Types>; using Max_Pool_2D_bf16_types = ::testing::Types>; +using Max_Pool_2D_f8_types = ::testing::Types>; template class MaxPool2D_f32 : public MaxPool2dBWDTest @@ -108,10 +109,24 @@ class MaxPool2D_bf16 : public MaxPool2dBWDTest } }; +template +class MaxPool2D_f8 : public MaxPool2dBWDTest +{ + protected: + void SetUp() override + { + if(!CK_ENABLE_FP8) + { + GTEST_SKIP() << "Skipping MaxPool2D_f8 tests because CK_ENABLE_FP8 is not enabled"; + } + } +}; + TYPED_TEST_SUITE(MaxPool2D_f32, Max_Pool_2D_f32_types); TYPED_TEST_SUITE(MaxPool2D_int8, Max_Pool_2D_int8_types); TYPED_TEST_SUITE(MaxPool2D_f16, Max_Pool_2D_f16_types); TYPED_TEST_SUITE(MaxPool2D_bf16, Max_Pool_2D_bf16_types); +TYPED_TEST_SUITE(MaxPool2D_f8, Max_Pool_2D_f8_types); TYPED_TEST(MaxPool2D_f32, MaxPool2DTest_f32) { this->Run(); } @@ -120,3 +135,5 @@ TYPED_TEST(MaxPool2D_int8, MaxPool2DTest_int8) { this->Run(); } TYPED_TEST(MaxPool2D_f16, MaxPool2DTest_f16) { this->Run(); } TYPED_TEST(MaxPool2D_bf16, MaxPool2DTest_bf16) { this->Run(); } + +TYPED_TEST(MaxPool2D_f8, MaxPool2DTest_f8) { this->Run(); }