mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 00:40:09 +00:00
Support for dtypes (fp8, bf8, bf16 and fp16) for the ck_tile/03_gemm example. (#1845)
* Support bf16/fb8/bf8 datatypes for ck_tile/gemm * remove commented out code. * Addressing code review comments and enabling universal_gemm for all the supported data types. * Merge conflict resolution. * Solve the memory pipeline compilation error. Merge with the new change of CShuffle * finish the feature, pass the tests * Fix the pipeline and add the benchmark script for other data types --------- Co-authored-by: ThomasNing <thomas.ning@amd.com>
This commit is contained in:
@@ -22,13 +22,14 @@ template <typename ComputeDataType, typename OutDataType, typename AccDataType =
|
||||
double get_relative_threshold(const int number_of_accumulations = 1)
|
||||
{
|
||||
using F8 = ck_tile::fp8_t;
|
||||
using BF8 = ck_tile::bf8_t;
|
||||
using F16 = ck_tile::half_t;
|
||||
using BF16 = ck_tile::bf16_t;
|
||||
using F32 = float;
|
||||
using I8 = int8_t;
|
||||
using I32 = int32_t;
|
||||
|
||||
static_assert(is_any_of<ComputeDataType, F8, F16, BF16, F32, I8, I32, int>::value,
|
||||
static_assert(is_any_of<ComputeDataType, F8, BF8, F16, BF16, F32, I8, I32, int>::value,
|
||||
"Warning: Unhandled ComputeDataType for setting up the relative threshold!");
|
||||
|
||||
double compute_error = 0;
|
||||
@@ -41,7 +42,7 @@ double get_relative_threshold(const int number_of_accumulations = 1)
|
||||
compute_error = std::pow(2, -numeric_traits<ComputeDataType>::mant) * 0.5;
|
||||
}
|
||||
|
||||
static_assert(is_any_of<OutDataType, F8, F16, BF16, F32, I8, I32, int>::value,
|
||||
static_assert(is_any_of<OutDataType, F8, BF8, F16, BF16, F32, I8, I32, int>::value,
|
||||
"Warning: Unhandled OutDataType for setting up the relative threshold!");
|
||||
|
||||
double output_error = 0;
|
||||
@@ -55,7 +56,7 @@ double get_relative_threshold(const int number_of_accumulations = 1)
|
||||
}
|
||||
double midway_error = std::max(compute_error, output_error);
|
||||
|
||||
static_assert(is_any_of<AccDataType, F8, F16, BF16, F32, I8, I32, int>::value,
|
||||
static_assert(is_any_of<AccDataType, F8, BF8, F16, BF16, F32, I8, I32, int>::value,
|
||||
"Warning: Unhandled AccDataType for setting up the relative threshold!");
|
||||
|
||||
double acc_error = 0;
|
||||
@@ -74,13 +75,14 @@ template <typename ComputeDataType, typename OutDataType, typename AccDataType =
|
||||
double get_absolute_threshold(const double max_possible_num, const int number_of_accumulations = 1)
|
||||
{
|
||||
using F8 = ck_tile::fp8_t;
|
||||
using BF8 = ck_tile::bf8_t;
|
||||
using F16 = ck_tile::half_t;
|
||||
using BF16 = ck_tile::bf16_t;
|
||||
using F32 = float;
|
||||
using I8 = int8_t;
|
||||
using I32 = int32_t;
|
||||
|
||||
static_assert(is_any_of<ComputeDataType, F8, F16, BF16, F32, I8, I32, int>::value,
|
||||
static_assert(is_any_of<ComputeDataType, F8, BF8, F16, BF16, F32, I8, I32, int>::value,
|
||||
"Warning: Unhandled ComputeDataType for setting up the absolute threshold!");
|
||||
|
||||
auto expo = std::log2(std::abs(max_possible_num));
|
||||
@@ -94,7 +96,7 @@ double get_absolute_threshold(const double max_possible_num, const int number_of
|
||||
compute_error = std::pow(2, expo - numeric_traits<ComputeDataType>::mant) * 0.5;
|
||||
}
|
||||
|
||||
static_assert(is_any_of<OutDataType, F8, F16, BF16, F32, I8, I32, int>::value,
|
||||
static_assert(is_any_of<OutDataType, F8, BF8, F16, BF16, F32, I8, I32, int>::value,
|
||||
"Warning: Unhandled OutDataType for setting up the absolute threshold!");
|
||||
|
||||
double output_error = 0;
|
||||
@@ -108,7 +110,7 @@ double get_absolute_threshold(const double max_possible_num, const int number_of
|
||||
}
|
||||
double midway_error = std::max(compute_error, output_error);
|
||||
|
||||
static_assert(is_any_of<AccDataType, F8, F16, BF16, F32, I8, I32, int>::value,
|
||||
static_assert(is_any_of<AccDataType, F8, BF8, F16, BF16, F32, I8, I32, int>::value,
|
||||
"Warning: Unhandled AccDataType for setting up the absolute threshold!");
|
||||
|
||||
double acc_error = 0;
|
||||
@@ -501,7 +503,11 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val
|
||||
}
|
||||
if(!res)
|
||||
{
|
||||
std::cerr << std::setw(12) << std::setprecision(7) << "max err: " << max_err << std::endl;
|
||||
const float error_percent =
|
||||
static_cast<float>(err_count) / static_cast<float>(out.size()) * 100.f;
|
||||
std::cerr << "max err: " << max_err;
|
||||
std::cerr << ", number of errors: " << err_count;
|
||||
std::cerr << ", " << error_percent << "% wrong values" << std::endl;
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user