change example to turn bf16, increase error threshold

This commit is contained in:
Astha Rai
2025-04-01 21:29:09 +00:00
parent 8f571c0bd8
commit eb8f04cd72
2 changed files with 7 additions and 7 deletions

View File

@@ -5,11 +5,11 @@
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_streamk.hpp"
using ADataType = ck::half_t;
using BDataType = ck::half_t;
using ADataType = ck::bhalf_t;
using BDataType = ck::bhalf_t;
using AccDataType = float;
using CShuffleDataType = float;
using CDataType = ck::half_t;
using CDataType = ck::bhalf_t;
using F16 = ck::half_t;

View File

@@ -18,11 +18,11 @@ inline __host__ __device__ constexpr double get_rtol()
}
else if constexpr(std::is_same_v<DataType, ck::half_t>)
{
return 1e-3;
return 1e-1;
}
else if constexpr(std::is_same_v<DataType, ck::bhalf_t>)
{
return 5e-2;
return 5e-1;
}
else if constexpr(std::is_same_v<DataType, int32_t>)
{
@@ -59,11 +59,11 @@ inline __host__ __device__ constexpr double get_atol()
}
else if constexpr(std::is_same_v<DataType, ck::half_t>)
{
return 1e-3;
return 1e-1;
}
else if constexpr(std::is_same_v<DataType, ck::bhalf_t>)
{
return 5e-2;
return 5e-1;
}
else if constexpr(std::is_same_v<DataType, int32_t>)
{