diff --git a/example/01_gemm/run_gemm_example_v2.inc b/example/01_gemm/run_gemm_example_v2.inc index c7472ba84c..5a76454648 100644 --- a/example/01_gemm/run_gemm_example_v2.inc +++ b/example/01_gemm/run_gemm_example_v2.inc @@ -158,6 +158,9 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); } + b_k_n(0, 0) = 0xaa; + b_k_n(1, 1) = 0xaa; + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); @@ -207,31 +210,42 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) bool pass = true; if(config.do_verification) { - //auto ref_gemm = ReferenceGemmInstance{}; - //auto ref_invoker = ref_gemm.MakeInvoker(); + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); - //auto ref_argument = ref_gemm.MakeArgument( - // a_m_k, b_k_n, c_m_n_host_result, PassThrough{}, PassThrough{}, PassThrough{}); + auto ref_argument = ref_gemm.MakeArgument( + a_m_k, b_k_n, c_m_n_host_result, PassThrough{}, PassThrough{}, PassThrough{}); - //ref_invoker.Run(ref_argument); + ref_invoker.Run(ref_argument); ave_time = invoker.Run(argument, StreamConfig{nullptr, false, 1}); c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); - //pass &= ck::utils::check_err(c_m_n_device_result, - // c_m_n_host_result, - // "Error: Incorrect results!", - // get_rtol(), - // get_atol()); + pass &= ck::utils::check_err(c_m_n_device_result, + c_m_n_host_result, + "Error: Incorrect results!", + get_rtol(), + get_atol()); - //for(int i = 0; i < M; i++) - //{ - // for(int j = 0; j < N; j++) - // { - // std::cout << ck::type_convert(c_m_n_device_result(i, j)) << ","; - // } - // std::cout << std::endl; - //} + std::cout << "c_m_n_device_result: " << std::endl; + for(int i = 0; i < M; i++) + { + for(int j = 0; j < N; j++) + { + std::cout << ck::type_convert(c_m_n_device_result(i, j)) << ","; + } + std::cout << std::endl; + } + + std::cout << "c_m_n_host_result: " << std::endl; + for(int i = 0; i < M; i++) + { + for(int j = 0; j < N; j++) + { + std::cout << ck::type_convert(c_m_n_host_result(i, j)) << ","; + } + std::cout << std::endl; + } } if(config.time_kernel) diff --git a/include/ck/utility/amd_xdlops.hpp b/include/ck/utility/amd_xdlops.hpp index a20e3d3556..d8ccb2ea76 100644 --- a/include/ck/utility/amd_xdlops.hpp +++ b/include/ck/utility/amd_xdlops.hpp @@ -157,8 +157,8 @@ struct intrin_mfma_f32_16x16x16f16<16, 16> template __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c) { - //reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x16f16( - //reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x16f16( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); } }; diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp index e1edc4fae0..363db2c85b 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp @@ -84,6 +84,17 @@ struct ReferenceGemm : public device::BaseOperator { ck::tensor_operation::element_wise::PassThrough{}(v_b, arg.b_k_n_(k, n)); } + else if constexpr(is_same_v) + { + pk_i4_t i4x2 = arg.b_k_n_(k, n); + int8_t i4 = 0; + if(k % 2 == 1) + i4 = (i4x2 >> 0) & 0xf; + else + i4 = (i4x2 >> 4) & 0xf; + i4 = i4 - 8; + arg.b_element_op_(v_b, i4); + } else { arg.b_element_op_(v_b, arg.b_k_n_(k, n)); diff --git a/library/include/ck/library/utility/host_tensor.hpp b/library/include/ck/library/utility/host_tensor.hpp index a58acaf116..bc35db92e3 100644 --- a/library/include/ck/library/utility/host_tensor.hpp +++ b/library/include/ck/library/utility/host_tensor.hpp @@ -322,7 +322,12 @@ struct Tensor std::size_t GetElementSize() const { return mDesc.GetElementSize(); } - std::size_t GetElementSpaceSize() const { return mDesc.GetElementSpaceSize(); } + std::size_t GetElementSpaceSize() const { + if constexpr(ck::is_same_v) + return mDesc.GetElementSpaceSize() / 2; + else + return mDesc.GetElementSpaceSize(); + } std::size_t GetElementSpaceSizeInBytes() const { return sizeof(T) * GetElementSpaceSize(); } @@ -469,29 +474,64 @@ struct Tensor template std::size_t GetOffsetFromMultiIndex(Is... is) const { - return mDesc.GetOffsetFromMultiIndex(is...); + if constexpr(ck::is_same_v) + { + return mDesc.GetOffsetFromMultiIndex(is...) / 2; + } + else + { + return mDesc.GetOffsetFromMultiIndex(is...); + } } template T& operator()(Is... is) { - return mData[mDesc.GetOffsetFromMultiIndex(is...)]; + if constexpr(ck::is_same_v) + { + return mData[mDesc.GetOffsetFromMultiIndex(is...) / 2]; + } + else + { + return mData[mDesc.GetOffsetFromMultiIndex(is...)]; + } } template const T& operator()(Is... is) const { - return mData[mDesc.GetOffsetFromMultiIndex(is...)]; + if constexpr(ck::is_same_v) + { + return mData[mDesc.GetOffsetFromMultiIndex(is...) / 2]; + } + else + { + return mData[mDesc.GetOffsetFromMultiIndex(is...)]; + } } T& operator()(std::vector idx) { - return mData[mDesc.GetOffsetFromMultiIndex(idx)]; + if constexpr(ck::is_same_v) + { + return mData[mDesc.GetOffsetFromMultiIndex(idx) / 2]; + } + else + { + return mData[mDesc.GetOffsetFromMultiIndex(idx)]; + } } const T& operator()(std::vector idx) const { - return mData[mDesc.GetOffsetFromMultiIndex(idx)]; + if constexpr(ck::is_same_v) + { + return mData[mDesc.GetOffsetFromMultiIndex(idx) / 2]; + } + else + { + return mData[mDesc.GetOffsetFromMultiIndex(idx)]; + } } typename Data::iterator begin() { return mData.begin(); }