mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-13 09:45:56 +00:00
Revert "Add support for mixed precision in contraction scale and bilinear" (#967)
* Revert "Add support for mixed precision in contraction scale and bilinear (#936)"
This reverts commit f07485060e.
* revert commits #957 and #960
This commit is contained in:
@@ -31,14 +31,10 @@ namespace profiler {
|
||||
using Bilinear = ck::tensor_operation::element_wise::Bilinear;
|
||||
using Scale = ck::tensor_operation::element_wise::Scale;
|
||||
|
||||
using F32 = float;
|
||||
using F64 = double;
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename CDELayout,
|
||||
typename DataType,
|
||||
typename ComputeDataType,
|
||||
typename DTupleDataType,
|
||||
typename CDElementOp>
|
||||
int profile_contraction_impl(ck::index_t do_verification,
|
||||
@@ -49,10 +45,10 @@ int profile_contraction_impl(ck::index_t do_verification,
|
||||
const std::vector<ck::index_t>& M,
|
||||
const std::vector<ck::index_t>& N,
|
||||
const std::vector<ck::index_t>& K,
|
||||
const std::vector<ck::index_t>& StridesA, // [M0, M1, K0, K1]
|
||||
const std::vector<ck::index_t>& StridesB, // [N0, N1, K0, K1]
|
||||
const std::vector<ck::index_t>& StridesE, // [M0, M1, N0, N1]
|
||||
const std::vector<ck::index_t>& StridesD) // [M0, M1, N0, N1]
|
||||
const std::vector<ck::index_t>& StridesA,
|
||||
const std::vector<ck::index_t>& StridesB,
|
||||
const std::vector<ck::index_t>& StridesE,
|
||||
const std::vector<ck::index_t>& StridesD)
|
||||
{
|
||||
bool pass = true;
|
||||
|
||||
@@ -67,13 +63,13 @@ int profile_contraction_impl(ck::index_t do_verification,
|
||||
};
|
||||
|
||||
Tensor<DataType> a_m_k(f_host_tensor_descriptor(M, K, StridesA));
|
||||
Tensor<DataType> b_n_k(f_host_tensor_descriptor(N, K, StridesB));
|
||||
Tensor<DataType> b_k_n(f_host_tensor_descriptor(K, N, StridesB));
|
||||
Tensor<DataType> e_m_n_host_result(f_host_tensor_descriptor(M, N, StridesE));
|
||||
Tensor<DataType> e_m_n_device_result(f_host_tensor_descriptor(M, N, StridesE));
|
||||
Tensor<DataType> d_m_n(f_host_tensor_descriptor(M, N, StridesD));
|
||||
|
||||
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
|
||||
std::cout << "b_n_k: " << b_n_k.mDesc << std::endl;
|
||||
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
|
||||
std::cout << "d_m_n: " << d_m_n.mDesc << std::endl;
|
||||
std::cout << "e_m_n: " << e_m_n_device_result.mDesc << std::endl;
|
||||
|
||||
@@ -82,12 +78,12 @@ int profile_contraction_impl(ck::index_t do_verification,
|
||||
case 0: break;
|
||||
case 1:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_2<DataType>{-5, 5});
|
||||
b_n_k.GenerateTensorValue(GeneratorTensor_2<DataType>{-5, 5});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_2<DataType>{-5, 5});
|
||||
d_m_n.GenerateTensorValue(GeneratorTensor_2<DataType>{-5, 5});
|
||||
break;
|
||||
default:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_3<DataType>{0.0, 1.0});
|
||||
b_n_k.GenerateTensorValue(GeneratorTensor_3<DataType>{-0.5, 0.5});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_3<DataType>{-0.5, 0.5});
|
||||
d_m_n.GenerateTensorValue(GeneratorTensor_3<DataType>{-0.5, 0.5});
|
||||
}
|
||||
|
||||
@@ -95,12 +91,12 @@ int profile_contraction_impl(ck::index_t do_verification,
|
||||
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
DeviceMem a_device_buf(sizeof(DataType) * a_m_k.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b_device_buf(sizeof(DataType) * b_n_k.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b_device_buf(sizeof(DataType) * b_k_n.mDesc.GetElementSpaceSize());
|
||||
DeviceMem e_device_buf(sizeof(DataType) * e_m_n_device_result.mDesc.GetElementSpaceSize());
|
||||
DeviceMem d_device_buf(sizeof(DataType) * d_m_n.mDesc.GetElementSpaceSize());
|
||||
|
||||
a_device_buf.ToDevice(a_m_k.mData.data());
|
||||
b_device_buf.ToDevice(b_n_k.mData.data());
|
||||
b_device_buf.ToDevice(b_k_n.mData.data());
|
||||
e_device_buf.SetZero();
|
||||
d_device_buf.ToDevice(d_m_n.mData.data());
|
||||
|
||||
@@ -122,8 +118,7 @@ int profile_contraction_impl(ck::index_t do_verification,
|
||||
DataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CDElementOp,
|
||||
ComputeDataType>;
|
||||
CDElementOp>;
|
||||
|
||||
// get device op instances
|
||||
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
|
||||
@@ -131,9 +126,6 @@ int profile_contraction_impl(ck::index_t do_verification,
|
||||
|
||||
std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
|
||||
|
||||
using AccDataType =
|
||||
typename std::conditional<std::is_same<ComputeDataType, F64>::value, F64, F32>::type;
|
||||
|
||||
// Run reference op
|
||||
if(do_verification)
|
||||
{
|
||||
@@ -144,8 +136,7 @@ int profile_contraction_impl(ck::index_t do_verification,
|
||||
DataType,
|
||||
DataType,
|
||||
DataType,
|
||||
AccDataType,
|
||||
ComputeDataType,
|
||||
DataType,
|
||||
AElementOp,
|
||||
BElementOp>;
|
||||
|
||||
@@ -155,7 +146,7 @@ int profile_contraction_impl(ck::index_t do_verification,
|
||||
Tensor<DataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StridesE));
|
||||
|
||||
auto ref_argument =
|
||||
ref_op.MakeArgument(a_m_k, b_n_k, c_m_n_host_result, a_element_op, b_element_op);
|
||||
ref_op.MakeArgument(a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op);
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
@@ -281,29 +272,8 @@ int profile_contraction_impl(ck::index_t do_verification,
|
||||
{
|
||||
e_device_buf.FromDevice(e_m_n_device_result.mData.data());
|
||||
|
||||
// Both the kernel and the reference use `AccDataType`, so an absolute error of both
|
||||
// of them is bounded by `nelems_k * std::numeric_limits<AccDataType>::epsilon()`.
|
||||
// Comparing one to another can result in an absolute error as high as twice that
|
||||
// value.
|
||||
double threshold = 2 * nelems_k * std::numeric_limits<AccDataType>::epsilon();
|
||||
// Handle the possible casting error of either AccDataType -> DataType or
|
||||
// DataType -> ComputeDataType.
|
||||
// TODO: Add a generic solution for calculating thresholds in CK.
|
||||
if constexpr(ck::is_same_v<DataType, ck::bhalf_t> ||
|
||||
ck::is_same_v<ComputeDataType, ck::bhalf_t>)
|
||||
{
|
||||
const double epsilon = std::pow(2, -7);
|
||||
// Maximum relative casting error when rounding to zero.
|
||||
threshold += epsilon * 2;
|
||||
}
|
||||
else if constexpr(ck::is_same_v<DataType, ck::half_t> ||
|
||||
ck::is_same_v<ComputeDataType, ck::half_t>)
|
||||
{
|
||||
const double epsilon = std::pow(2, -10);
|
||||
// Maximum relative casting error when rounding to zero.
|
||||
threshold += epsilon * 2;
|
||||
}
|
||||
|
||||
float threshold =
|
||||
static_cast<DataType>(nelems_k) * std::numeric_limits<DataType>::epsilon();
|
||||
pass = pass & ck::utils::check_err(e_m_n_device_result,
|
||||
e_m_n_host_result,
|
||||
"Error: incorrect results!",
|
||||
@@ -313,7 +283,7 @@ int profile_contraction_impl(ck::index_t do_verification,
|
||||
if(do_log)
|
||||
{
|
||||
LogRangeAsType<float>(std::cout << "a : ", a_m_k.mData, ",") << std::endl;
|
||||
LogRangeAsType<float>(std::cout << "b: ", b_n_k.mData, ",") << std::endl;
|
||||
LogRangeAsType<float>(std::cout << "b: ", b_k_n.mData, ",") << std::endl;
|
||||
LogRangeAsType<float>(std::cout << "c_host : ", e_m_n_host_result.mData, ",")
|
||||
<< std::endl;
|
||||
LogRangeAsType<float>(std::cout << "c_device: ", e_m_n_device_result.mData, ",")
|
||||
|
||||
@@ -23,18 +23,8 @@ enum struct ContractionMatrixLayout
|
||||
|
||||
enum struct ContractionDataType
|
||||
{
|
||||
F32_F32_F32_F32, // 0
|
||||
F64_F64_F64_F64, // 1
|
||||
F16_F16_F16_F16, // 2
|
||||
BF16_BF16_BF16_BF16, // 3
|
||||
};
|
||||
|
||||
enum struct ContractionComputeDataType
|
||||
{
|
||||
F32 = 0,
|
||||
F64,
|
||||
F16,
|
||||
BF16,
|
||||
F32_F32_F32_F32, // 0
|
||||
F64_F64_F64_F64, // 1
|
||||
};
|
||||
|
||||
inline void collect_index_params(char* argv[],
|
||||
|
||||
Reference in New Issue
Block a user