diff --git a/profiler/include/profiler/profile_gemm_multi_abd_impl.hpp b/profiler/include/profiler/profile_gemm_multi_abd_impl.hpp index 46745fd02b..51922fde33 100644 --- a/profiler/include/profiler/profile_gemm_multi_abd_impl.hpp +++ b/profiler/include/profiler/profile_gemm_multi_abd_impl.hpp @@ -188,66 +188,42 @@ bool profile_gemm_multi_abd_impl(int do_verification, EDataType, remove_cvref_t>>::type; - auto get_a_matrix = [&]() -> auto { - // in case of pass through we avoid allocating a new - // tensor and copying values - if constexpr(is_same_v) + Tensor a_m_k({M, K}); + for(int m = 0; m < M; ++m) + { + for(int k = 0; k < K; ++k) { - return as_m_k(Number<0>{}); + // result + auto data_refs1 = ck::tie(a_m_k(m, k)); + // inputs + auto data_refs2 = + generate_tie([&](auto i) -> auto& { return as_m_k(Number{})(m, k); }, + Number{}); + auto data_refs = concat_tuple_of_refs(data_refs1, data_refs2); + unpack(a_element_op, data_refs); } - else - { - Tensor a_m_k({M, K}); - for(int m = 0; m < M; ++m) - { - for(int k = 0; k < K; ++k) - { - // result - auto data_refs1 = ck::tie(a_m_k(m, k)); - // inputs - auto data_refs2 = - generate_tie([&](auto i) -> auto& { return as_m_k(Number{})(m, k); }, - Number{}); - auto data_refs = concat_tuple_of_refs(data_refs1, data_refs2); - unpack(a_element_op, data_refs); - } - } - return a_m_k; - } - }; + } using BComputeType = typename std::conditional<(NumBTensor > 1), EDataType, remove_cvref_t>>::type; - auto get_b_matrix = [&]() -> auto { - // in case of pass through we avoid allocating a new - // tensor and copying values - if constexpr(is_same_v) + Tensor b_k_n({K, N}); + for(int k = 0; k < K; ++k) + { + for(int n = 0; n < N; ++n) { - return bs_k_n(Number<0>{}); + // result + auto data_refs1 = ck::tie(b_k_n(k, n)); + // inputs + auto data_refs2 = + generate_tie([&](auto i) -> auto& { return bs_k_n(Number{})(k, n); }, + Number{}); + auto data_refs = concat_tuple_of_refs(data_refs1, data_refs2); + unpack(b_element_op, data_refs); } - else - { - Tensor b_k_n({K, N}); - for(int k = 0; k < K; ++k) - { - for(int n = 0; n < N; ++n) - { - // result - auto data_refs1 = ck::tie(b_k_n(k, n)); - // inputs - auto data_refs2 = - generate_tie([&](auto i) -> auto& { return bs_k_n(Number{})(k, n); }, - Number{}); - auto data_refs = concat_tuple_of_refs(data_refs1, data_refs2); - unpack(b_element_op, data_refs); - } - } - return b_k_n; - } - }; + } using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm