mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 19:28:33 +00:00
fixed reference and host_tensor
This commit is contained in:
@@ -158,6 +158,9 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
|
||||
}
|
||||
|
||||
b_k_n(0, 0) = 0xaa;
|
||||
b_k_n(1, 1) = 0xaa;
|
||||
|
||||
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
|
||||
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
|
||||
|
||||
@@ -207,31 +210,42 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
|
||||
bool pass = true;
|
||||
if(config.do_verification)
|
||||
{
|
||||
//auto ref_gemm = ReferenceGemmInstance{};
|
||||
//auto ref_invoker = ref_gemm.MakeInvoker();
|
||||
auto ref_gemm = ReferenceGemmInstance{};
|
||||
auto ref_invoker = ref_gemm.MakeInvoker();
|
||||
|
||||
//auto ref_argument = ref_gemm.MakeArgument(
|
||||
// a_m_k, b_k_n, c_m_n_host_result, PassThrough{}, PassThrough{}, PassThrough{});
|
||||
auto ref_argument = ref_gemm.MakeArgument(
|
||||
a_m_k, b_k_n, c_m_n_host_result, PassThrough{}, PassThrough{}, PassThrough{});
|
||||
|
||||
//ref_invoker.Run(ref_argument);
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
ave_time = invoker.Run(argument, StreamConfig{nullptr, false, 1});
|
||||
c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data());
|
||||
|
||||
//pass &= ck::utils::check_err(c_m_n_device_result,
|
||||
// c_m_n_host_result,
|
||||
// "Error: Incorrect results!",
|
||||
// get_rtol<CDataType>(),
|
||||
// get_atol<CDataType>());
|
||||
pass &= ck::utils::check_err(c_m_n_device_result,
|
||||
c_m_n_host_result,
|
||||
"Error: Incorrect results!",
|
||||
get_rtol<CDataType>(),
|
||||
get_atol<CDataType>());
|
||||
|
||||
//for(int i = 0; i < M; i++)
|
||||
//{
|
||||
// for(int j = 0; j < N; j++)
|
||||
// {
|
||||
// std::cout << ck::type_convert<float>(c_m_n_device_result(i, j)) << ",";
|
||||
// }
|
||||
// std::cout << std::endl;
|
||||
//}
|
||||
std::cout << "c_m_n_device_result: " << std::endl;
|
||||
for(int i = 0; i < M; i++)
|
||||
{
|
||||
for(int j = 0; j < N; j++)
|
||||
{
|
||||
std::cout << ck::type_convert<float>(c_m_n_device_result(i, j)) << ",";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
}
|
||||
|
||||
std::cout << "c_m_n_host_result: " << std::endl;
|
||||
for(int i = 0; i < M; i++)
|
||||
{
|
||||
for(int j = 0; j < N; j++)
|
||||
{
|
||||
std::cout << ck::type_convert<float>(c_m_n_host_result(i, j)) << ",";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
if(config.time_kernel)
|
||||
|
||||
@@ -157,8 +157,8 @@ struct intrin_mfma_f32_16x16x16f16<16, 16>
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
|
||||
{
|
||||
//reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x16f16(
|
||||
//reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0);
|
||||
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x16f16(
|
||||
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -84,6 +84,17 @@ struct ReferenceGemm : public device::BaseOperator
|
||||
{
|
||||
ck::tensor_operation::element_wise::PassThrough{}(v_b, arg.b_k_n_(k, n));
|
||||
}
|
||||
else if constexpr(is_same_v<BDataType, pk_i4_t>)
|
||||
{
|
||||
pk_i4_t i4x2 = arg.b_k_n_(k, n);
|
||||
int8_t i4 = 0;
|
||||
if(k % 2 == 1)
|
||||
i4 = (i4x2 >> 0) & 0xf;
|
||||
else
|
||||
i4 = (i4x2 >> 4) & 0xf;
|
||||
i4 = i4 - 8;
|
||||
arg.b_element_op_(v_b, i4);
|
||||
}
|
||||
else
|
||||
{
|
||||
arg.b_element_op_(v_b, arg.b_k_n_(k, n));
|
||||
|
||||
@@ -322,7 +322,12 @@ struct Tensor
|
||||
|
||||
std::size_t GetElementSize() const { return mDesc.GetElementSize(); }
|
||||
|
||||
std::size_t GetElementSpaceSize() const { return mDesc.GetElementSpaceSize(); }
|
||||
std::size_t GetElementSpaceSize() const {
|
||||
if constexpr(ck::is_same_v<T, ck::pk_i4_t>)
|
||||
return mDesc.GetElementSpaceSize() / 2;
|
||||
else
|
||||
return mDesc.GetElementSpaceSize();
|
||||
}
|
||||
|
||||
std::size_t GetElementSpaceSizeInBytes() const { return sizeof(T) * GetElementSpaceSize(); }
|
||||
|
||||
@@ -469,29 +474,64 @@ struct Tensor
|
||||
template <typename... Is>
|
||||
std::size_t GetOffsetFromMultiIndex(Is... is) const
|
||||
{
|
||||
return mDesc.GetOffsetFromMultiIndex(is...);
|
||||
if constexpr(ck::is_same_v<T, ck::pk_i4_t>)
|
||||
{
|
||||
return mDesc.GetOffsetFromMultiIndex(is...) / 2;
|
||||
}
|
||||
else
|
||||
{
|
||||
return mDesc.GetOffsetFromMultiIndex(is...);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename... Is>
|
||||
T& operator()(Is... is)
|
||||
{
|
||||
return mData[mDesc.GetOffsetFromMultiIndex(is...)];
|
||||
if constexpr(ck::is_same_v<T, ck::pk_i4_t>)
|
||||
{
|
||||
return mData[mDesc.GetOffsetFromMultiIndex(is...) / 2];
|
||||
}
|
||||
else
|
||||
{
|
||||
return mData[mDesc.GetOffsetFromMultiIndex(is...)];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename... Is>
|
||||
const T& operator()(Is... is) const
|
||||
{
|
||||
return mData[mDesc.GetOffsetFromMultiIndex(is...)];
|
||||
if constexpr(ck::is_same_v<T, ck::pk_i4_t>)
|
||||
{
|
||||
return mData[mDesc.GetOffsetFromMultiIndex(is...) / 2];
|
||||
}
|
||||
else
|
||||
{
|
||||
return mData[mDesc.GetOffsetFromMultiIndex(is...)];
|
||||
}
|
||||
}
|
||||
|
||||
T& operator()(std::vector<std::size_t> idx)
|
||||
{
|
||||
return mData[mDesc.GetOffsetFromMultiIndex(idx)];
|
||||
if constexpr(ck::is_same_v<T, ck::pk_i4_t>)
|
||||
{
|
||||
return mData[mDesc.GetOffsetFromMultiIndex(idx) / 2];
|
||||
}
|
||||
else
|
||||
{
|
||||
return mData[mDesc.GetOffsetFromMultiIndex(idx)];
|
||||
}
|
||||
}
|
||||
|
||||
const T& operator()(std::vector<std::size_t> idx) const
|
||||
{
|
||||
return mData[mDesc.GetOffsetFromMultiIndex(idx)];
|
||||
if constexpr(ck::is_same_v<T, ck::pk_i4_t>)
|
||||
{
|
||||
return mData[mDesc.GetOffsetFromMultiIndex(idx) / 2];
|
||||
}
|
||||
else
|
||||
{
|
||||
return mData[mDesc.GetOffsetFromMultiIndex(idx)];
|
||||
}
|
||||
}
|
||||
|
||||
typename Data::iterator begin() { return mData.begin(); }
|
||||
|
||||
Reference in New Issue
Block a user