Fix multi-abd tests bug (#3099)

This commit is contained in:
Enrico Degregori
2025-10-27 16:09:02 +01:00
committed by GitHub
parent a1ce64374f
commit 06973b1cf4

View File

@@ -188,66 +188,42 @@ bool profile_gemm_multi_abd_impl(int do_verification,
EDataType,
remove_cvref_t<tuple_element_t<0, AsDataType>>>::type;
auto get_a_matrix = [&]() -> auto {
// in case of pass through we avoid allocating a new
// tensor and copying values
if constexpr(is_same_v<AElementOp, PassThrough>)
Tensor<AComputeType> a_m_k({M, K});
for(int m = 0; m < M; ++m)
{
for(int k = 0; k < K; ++k)
{
return as_m_k(Number<0>{});
// result
auto data_refs1 = ck::tie(a_m_k(m, k));
// inputs
auto data_refs2 =
generate_tie([&](auto i) -> auto& { return as_m_k(Number<i>{})(m, k); },
Number<NumATensor>{});
auto data_refs = concat_tuple_of_refs(data_refs1, data_refs2);
unpack(a_element_op, data_refs);
}
else
{
Tensor<AComputeType> a_m_k({M, K});
for(int m = 0; m < M; ++m)
{
for(int k = 0; k < K; ++k)
{
// result
auto data_refs1 = ck::tie(a_m_k(m, k));
// inputs
auto data_refs2 =
generate_tie([&](auto i) -> auto& { return as_m_k(Number<i>{})(m, k); },
Number<NumATensor>{});
auto data_refs = concat_tuple_of_refs(data_refs1, data_refs2);
unpack(a_element_op, data_refs);
}
}
return a_m_k;
}
};
}
using BComputeType =
typename std::conditional<(NumBTensor > 1),
EDataType,
remove_cvref_t<tuple_element_t<0, BsDataType>>>::type;
auto get_b_matrix = [&]() -> auto {
// in case of pass through we avoid allocating a new
// tensor and copying values
if constexpr(is_same_v<BElementOp, PassThrough>)
Tensor<BComputeType> b_k_n({K, N});
for(int k = 0; k < K; ++k)
{
for(int n = 0; n < N; ++n)
{
return bs_k_n(Number<0>{});
// result
auto data_refs1 = ck::tie(b_k_n(k, n));
// inputs
auto data_refs2 =
generate_tie([&](auto i) -> auto& { return bs_k_n(Number<i>{})(k, n); },
Number<NumBTensor>{});
auto data_refs = concat_tuple_of_refs(data_refs1, data_refs2);
unpack(b_element_op, data_refs);
}
else
{
Tensor<BComputeType> b_k_n({K, N});
for(int k = 0; k < K; ++k)
{
for(int n = 0; n < N; ++n)
{
// result
auto data_refs1 = ck::tie(b_k_n(k, n));
// inputs
auto data_refs2 =
generate_tie([&](auto i) -> auto& { return bs_k_n(Number<i>{})(k, n); },
Number<NumBTensor>{});
auto data_refs = concat_tuple_of_refs(data_refs1, data_refs2);
unpack(b_element_op, data_refs);
}
}
return b_k_n;
}
};
}
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<AComputeType,
BComputeType,
@@ -259,8 +235,8 @@ bool profile_gemm_multi_abd_impl(int do_verification,
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(
get_a_matrix(), get_b_matrix(), c_m_n, PassThrough{}, PassThrough{}, PassThrough{});
auto ref_argument =
ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, PassThrough{}, PassThrough{}, PassThrough{});
ref_invoker.Run(ref_argument);