mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 17:00:18 +00:00
Add bfp16/int8 support into XDL GEMM operator (#50)
* init StaticBufferV2 * clean * adopt old output stage for staticBufferV2 * clean * remove hack * clean * clean * add parameters * clean code * move c_buffer alloc into blockwise gemm * add adaptors for m/n_thread_data_on_grid * tweak gemm * adjust blockwise_gemm_xdlops * tweak * update conv * update script * adding bwd 1x1 * update script * adding 1x1 bwd * debugging bwd 1x1 failure * update script * update script * test * test v100 * add bf16_1k * clang-format * clean * add bfp16 for gfx908 * add verification * clean up * clean code * restore bfl16 * clean * add bfp16 support into gemm_driver * apply new generator to other drivers * add int8 support * cleanb * clean * clean * clean Co-authored-by: Chao Liu <chao.liu2@amd.com> Co-authored-by: Chao Liu <lc.roy86@gmail.com> Co-authored-by: root <root@hayabusa6111.amd.com>
This commit is contained in:
@@ -321,4 +321,41 @@ void check_error(const Tensor<T>& ref, const Tensor<T>& result)
|
||||
std::cout << "max_diff: " << max_diff << ", " << ref_value << ", " << result_value << std::endl;
|
||||
}
|
||||
|
||||
float bf16_to_f32(ushort src_val)
|
||||
{
|
||||
typedef union
|
||||
{
|
||||
ushort x, y;
|
||||
float f32;
|
||||
} bf16_f32_t;
|
||||
|
||||
bf16_f32_t v;
|
||||
v.x = 0;
|
||||
v.y = src_val;
|
||||
return v.f32;
|
||||
}
|
||||
|
||||
template <>
|
||||
void check_error<ushort>(const Tensor<ushort>& ref, const Tensor<ushort>& result)
|
||||
{
|
||||
float error = 0;
|
||||
float max_diff = -1;
|
||||
float ref_value = 0, result_value = 0;
|
||||
for(int i = 0; i < ref.mData.size(); ++i)
|
||||
{
|
||||
error += std::abs(bf16_to_f32(ref.mData[i]) - bf16_to_f32(result.mData[i]));
|
||||
float diff = std::abs(bf16_to_f32(ref.mData[i]) - bf16_to_f32(result.mData[i]));
|
||||
if(max_diff < diff)
|
||||
{
|
||||
max_diff = diff;
|
||||
ref_value = bf16_to_f32(ref.mData[i]);
|
||||
result_value = bf16_to_f32(result.mData[i]);
|
||||
}
|
||||
}
|
||||
|
||||
std::cout << "error: " << error << std::endl;
|
||||
std::cout << "max_diff: " << max_diff << ", ref: " << ref_value << ", res: " << result_value
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
Reference in New Issue
Block a user