mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
Fix multi-abd tests bug (#3099)
This commit is contained in:
@@ -188,66 +188,42 @@ bool profile_gemm_multi_abd_impl(int do_verification,
|
||||
EDataType,
|
||||
remove_cvref_t<tuple_element_t<0, AsDataType>>>::type;
|
||||
|
||||
auto get_a_matrix = [&]() -> auto {
|
||||
// in case of pass through we avoid allocating a new
|
||||
// tensor and copying values
|
||||
if constexpr(is_same_v<AElementOp, PassThrough>)
|
||||
Tensor<AComputeType> a_m_k({M, K});
|
||||
for(int m = 0; m < M; ++m)
|
||||
{
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
return as_m_k(Number<0>{});
|
||||
// result
|
||||
auto data_refs1 = ck::tie(a_m_k(m, k));
|
||||
// inputs
|
||||
auto data_refs2 =
|
||||
generate_tie([&](auto i) -> auto& { return as_m_k(Number<i>{})(m, k); },
|
||||
Number<NumATensor>{});
|
||||
auto data_refs = concat_tuple_of_refs(data_refs1, data_refs2);
|
||||
unpack(a_element_op, data_refs);
|
||||
}
|
||||
else
|
||||
{
|
||||
Tensor<AComputeType> a_m_k({M, K});
|
||||
for(int m = 0; m < M; ++m)
|
||||
{
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
// result
|
||||
auto data_refs1 = ck::tie(a_m_k(m, k));
|
||||
// inputs
|
||||
auto data_refs2 =
|
||||
generate_tie([&](auto i) -> auto& { return as_m_k(Number<i>{})(m, k); },
|
||||
Number<NumATensor>{});
|
||||
auto data_refs = concat_tuple_of_refs(data_refs1, data_refs2);
|
||||
unpack(a_element_op, data_refs);
|
||||
}
|
||||
}
|
||||
return a_m_k;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
using BComputeType =
|
||||
typename std::conditional<(NumBTensor > 1),
|
||||
EDataType,
|
||||
remove_cvref_t<tuple_element_t<0, BsDataType>>>::type;
|
||||
|
||||
auto get_b_matrix = [&]() -> auto {
|
||||
// in case of pass through we avoid allocating a new
|
||||
// tensor and copying values
|
||||
if constexpr(is_same_v<BElementOp, PassThrough>)
|
||||
Tensor<BComputeType> b_k_n({K, N});
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
return bs_k_n(Number<0>{});
|
||||
// result
|
||||
auto data_refs1 = ck::tie(b_k_n(k, n));
|
||||
// inputs
|
||||
auto data_refs2 =
|
||||
generate_tie([&](auto i) -> auto& { return bs_k_n(Number<i>{})(k, n); },
|
||||
Number<NumBTensor>{});
|
||||
auto data_refs = concat_tuple_of_refs(data_refs1, data_refs2);
|
||||
unpack(b_element_op, data_refs);
|
||||
}
|
||||
else
|
||||
{
|
||||
Tensor<BComputeType> b_k_n({K, N});
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
// result
|
||||
auto data_refs1 = ck::tie(b_k_n(k, n));
|
||||
// inputs
|
||||
auto data_refs2 =
|
||||
generate_tie([&](auto i) -> auto& { return bs_k_n(Number<i>{})(k, n); },
|
||||
Number<NumBTensor>{});
|
||||
auto data_refs = concat_tuple_of_refs(data_refs1, data_refs2);
|
||||
unpack(b_element_op, data_refs);
|
||||
}
|
||||
}
|
||||
return b_k_n;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<AComputeType,
|
||||
BComputeType,
|
||||
@@ -259,8 +235,8 @@ bool profile_gemm_multi_abd_impl(int do_verification,
|
||||
auto ref_gemm = ReferenceGemmInstance{};
|
||||
auto ref_invoker = ref_gemm.MakeInvoker();
|
||||
|
||||
auto ref_argument = ref_gemm.MakeArgument(
|
||||
get_a_matrix(), get_b_matrix(), c_m_n, PassThrough{}, PassThrough{}, PassThrough{});
|
||||
auto ref_argument =
|
||||
ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, PassThrough{}, PassThrough{}, PassThrough{});
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user