mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-19 20:40:07 +00:00
Add FP64 XDL GEMM built-in function (#199)
* add intrin_mfma_f64_16x16x4f64
* add example
* gemm reference add double data type
* chang init data
* fix M N PerXdlops
* fix ifdef
* add comparsion config
* add conv fwd example
* format log out
* change rc matrix egister layout
* reorganize example
* reorganize example 2
* format,because merge develop
* fix call impl adding acc data type
* lost ;
* add compiler warning
* change example tunning parameters
* add test for fp64
* add instance
* add test/gemm/gemm_fp64.cpp
* fix get name issue
* remove some tunning parameter
* fix conflict
* format
* use integer value for GEMM test
* add acc data type
* remove typeid because fp16
* fix streamconfig etc bug from merging develop
* format
* remove test_gemm_xdl_fp64
* add AccDataType
* AccDataType problem
Co-authored-by: qinletao <letaoqin@amd.com>
Co-authored-by: Chao Liu <chao.liu2@amd.com>
[ROCm/composable_kernel commit: 3e6c2610ae]
This commit is contained in:
@@ -11,6 +11,7 @@ namespace host {
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
typename AccDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation>
|
||||
@@ -53,20 +54,20 @@ struct ReferenceGemm : public device::BaseOperator
|
||||
auto f_mk_kn_mn = [&](auto m, auto n) {
|
||||
const int K = arg.a_m_k_.mDesc.GetLengths()[1];
|
||||
|
||||
float v_acc = 0;
|
||||
AccDataType v_acc = 0;
|
||||
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
float v_a;
|
||||
float v_b;
|
||||
AccDataType v_a;
|
||||
AccDataType v_b;
|
||||
|
||||
arg.a_element_op_(v_a, static_cast<const float>(arg.a_m_k_(m, k)));
|
||||
arg.b_element_op_(v_b, static_cast<const float>(arg.b_k_n_(k, n)));
|
||||
arg.a_element_op_(v_a, static_cast<const AccDataType>(arg.a_m_k_(m, k)));
|
||||
arg.b_element_op_(v_b, static_cast<const AccDataType>(arg.b_k_n_(k, n)));
|
||||
|
||||
v_acc += v_a * v_b;
|
||||
}
|
||||
|
||||
float v_c;
|
||||
AccDataType v_c;
|
||||
|
||||
arg.c_element_op_(v_c, v_acc);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user