mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 13:11:25 +00:00
Minor fix for recent PR (#260)
* fix example * update IsSupportedArgument * fix * disable fp64 conv example as test
This commit is contained in:
@@ -1,5 +1,4 @@
|
||||
#ifndef DEVICE_CONVND_FWD_XDL_NHWC_KYXC_NHWK_HPP
|
||||
#define DEVICE_CONVND_FWD_XDL_NHWC_KYXC_NHWK_HPP
|
||||
#pragma once
|
||||
|
||||
#include <functional>
|
||||
#include <iostream>
|
||||
@@ -8,6 +7,7 @@
|
||||
#include <sstream>
|
||||
|
||||
#include "device.hpp"
|
||||
#include "device_prop.hpp"
|
||||
#include "device_base.hpp"
|
||||
#include "device_conv_fwd.hpp"
|
||||
#include "convolution_forward_specialization.hpp"
|
||||
@@ -858,6 +858,27 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(ck::get_device_name() == "gfx908")
|
||||
{
|
||||
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, float> ||
|
||||
is_same_v<AccDataType, int32_t>))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else if(ck::get_device_name() == "gfx90a")
|
||||
{
|
||||
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, float> ||
|
||||
is_same_v<AccDataType, int32_t> || is_same_v<AccDataType, double>))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// Input tensors can't be bigger than 2GB each.
|
||||
constexpr ck::long_index_t GB2 = (ck::long_index_t{1} << 31);
|
||||
|
||||
@@ -1021,4 +1042,3 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
#endif
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#include <sstream>
|
||||
|
||||
#include "device.hpp"
|
||||
#include "device_prop.hpp"
|
||||
#include "device_base.hpp"
|
||||
#include "device_gemm.hpp"
|
||||
#include "common_header.hpp"
|
||||
@@ -13,7 +14,6 @@
|
||||
#include "gemm_specialization.hpp"
|
||||
#include "element_wise_operation.hpp"
|
||||
#include "gridwise_gemm_dl_v1r3.hpp"
|
||||
#include "device_prop.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include "device.hpp"
|
||||
#include "device_prop.hpp"
|
||||
#include "device_base.hpp"
|
||||
#include "device_gemm.hpp"
|
||||
#include "common_header.hpp"
|
||||
@@ -11,7 +12,6 @@
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "gridwise_gemm_xdlops_v2r3.hpp"
|
||||
#include "gemm_specialization.hpp"
|
||||
#include "device_prop.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -408,7 +408,23 @@ struct DeviceGemmXdl
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a"))
|
||||
if(ck::get_device_name() == "gfx908")
|
||||
{
|
||||
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, float> ||
|
||||
is_same_v<AccDataType, int32_t>))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else if(ck::get_device_name() == "gfx90a")
|
||||
{
|
||||
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, float> ||
|
||||
is_same_v<AccDataType, int32_t> || is_same_v<AccDataType, double>))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user