mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 14:29:05 +00:00
Misc fixes (#994)
* reinterpret_cast to const char* in dumpBufferToFile to be compatible with both const and non-const input pointers * Add seed input to GeneratorTensor_4 for normal_distribution generator * Add GetTypeString() for DeviceElementwiseImpl * Add HIP_CHECK_ERROR macro
This commit is contained in:
@@ -3,8 +3,10 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <sstream>
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
// To be removed, which really does not tell the location of failed HIP functional call
|
||||
inline void hip_check_error(hipError_t x)
|
||||
{
|
||||
if(x != hipSuccess)
|
||||
@@ -15,3 +17,16 @@ inline void hip_check_error(hipError_t x)
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
}
|
||||
|
||||
#define HIP_CHECK_ERROR(retval_or_funcall) \
|
||||
do \
|
||||
{ \
|
||||
hipError_t _tmpVal = retval_or_funcall; \
|
||||
if(_tmpVal != hipSuccess) \
|
||||
{ \
|
||||
std::ostringstream ostr; \
|
||||
ostr << "HIP Function Failed (" << __FILE__ << "," << __LINE__ << ") " \
|
||||
<< hipGetErrorString(_tmpVal); \
|
||||
throw std::runtime_error(ostr.str()); \
|
||||
} \
|
||||
} while(0)
|
||||
|
||||
@@ -296,6 +296,28 @@ struct DeviceElementwiseImpl
|
||||
{
|
||||
return std::make_unique<Invoker>();
|
||||
};
|
||||
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceElementwiseImpl<" ;
|
||||
str << "NumDim_" << NumDim << ",";
|
||||
str << "MPerThread_" << MPerThread << ",";
|
||||
|
||||
str << "InScalarPerVector";
|
||||
static_for<0, InScalarPerVectorSeq::Size(), 1>{}([&](auto i) { str << "_" << InScalarPerVectorSeq::At(i).value; });
|
||||
str << ",";
|
||||
str << "OutScalarPerVector";
|
||||
static_for<0, OutScalarPerVectorSeq::Size(), 1>{}([&](auto i) { str << "_" << OutScalarPerVectorSeq::At(i).value; });
|
||||
|
||||
str << ">";
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
|
||||
}; // namespace device
|
||||
|
||||
} // namespace device
|
||||
|
||||
@@ -22,7 +22,7 @@ static inline void dumpBufferToFile(const char* fileName, T* data, size_t dataNu
|
||||
std::ofstream outFile(fileName, std::ios::binary);
|
||||
if(outFile)
|
||||
{
|
||||
outFile.write(reinterpret_cast<char*>(data), dataNumItems * sizeof(T));
|
||||
outFile.write(reinterpret_cast<const char*>(data), dataNumItems * sizeof(T));
|
||||
outFile.close();
|
||||
std::cout << "Write output to file " << fileName << std::endl;
|
||||
}
|
||||
|
||||
@@ -200,10 +200,11 @@ struct GeneratorTensor_3<ck::bf8_t>
|
||||
template <typename T>
|
||||
struct GeneratorTensor_4
|
||||
{
|
||||
std::default_random_engine generator;
|
||||
std::mt19937 generator;
|
||||
std::normal_distribution<float> distribution;
|
||||
|
||||
GeneratorTensor_4(float mean, float stddev) : generator(1), distribution(mean, stddev){};
|
||||
GeneratorTensor_4(float mean, float stddev, unsigned int seed = 1)
|
||||
: generator(seed), distribution(mean, stddev){};
|
||||
|
||||
template <typename... Is>
|
||||
T operator()(Is...)
|
||||
|
||||
Reference in New Issue
Block a user