mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
Gemm+Reduce Fusion (#128)
* add gridwise gemm v4r1 * rename * adding gemm+reduce * adding gemm+reduce * adding gemm+reduce * adding gemm+reduce * use sfc in shuffling * remove hardcode * remove hardcode * refactor * fix build * adding gemm+reduce * adding gemm+reduce * adding gemm+reduce * adding gemm+reduce * adding gemm+reduce * format * clean * adding gemm+reduce * adding profiler for gemm+reduce * adding gemm+reduce profiler * fix build * clean up * gemm+reduce * fix build * update DeviceGemm_Xdl_CShuffle; update enum to enum class * clean up * add test for gemm+reduce * clean up * refactor * fix build * fix build
This commit is contained in:
@@ -13,8 +13,10 @@ struct DeviceMem
|
||||
DeviceMem() = delete;
|
||||
DeviceMem(std::size_t mem_size);
|
||||
void* GetDeviceBuffer();
|
||||
std::size_t GetBufferSize();
|
||||
void ToDevice(const void* p);
|
||||
void FromDevice(void* p);
|
||||
void SetZero();
|
||||
~DeviceMem();
|
||||
|
||||
void* mpDeviceBuf;
|
||||
@@ -48,7 +50,6 @@ template <typename... Args, typename F>
|
||||
float launch_and_time_kernel(
|
||||
F kernel, int nrepeat, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
|
||||
{
|
||||
#if 1
|
||||
KernelTimer timer;
|
||||
|
||||
printf("%s: grid_dim {%d, %d, %d}, block_dim {%d, %d, %d} \n",
|
||||
@@ -78,13 +79,6 @@ float launch_and_time_kernel(
|
||||
|
||||
timer.End();
|
||||
|
||||
// std::this_thread::sleep_for (std::chrono::microseconds(10));
|
||||
|
||||
return timer.GetElapsedTime() / nrepeat;
|
||||
#else
|
||||
launch_kernel(kernel, grid_dim, block_dim, lds_byte, args...);
|
||||
|
||||
return 0;
|
||||
#endif
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -40,20 +40,6 @@ std::ostream& LogRangeAsType(std::ostream& os, Range&& range, std::string delim)
|
||||
return os;
|
||||
}
|
||||
|
||||
typedef enum
|
||||
{
|
||||
Half = 0,
|
||||
Float = 1,
|
||||
} DataType_t;
|
||||
|
||||
template <typename T>
|
||||
struct DataType;
|
||||
|
||||
template <>
|
||||
struct DataType<float> : std::integral_constant<DataType_t, DataType_t::Float>
|
||||
{
|
||||
};
|
||||
|
||||
template <typename F, typename T, std::size_t... Is>
|
||||
auto call_f_unpack_args_impl(F f, T args, std::index_sequence<Is...>)
|
||||
{
|
||||
@@ -312,49 +298,58 @@ HostTensorDescriptor::HostTensorDescriptor(std::vector<X> lens, std::vector<Y> s
|
||||
|
||||
void ostream_HostTensorDescriptor(const HostTensorDescriptor& desc, std::ostream& os = std::cout);
|
||||
|
||||
#if 1
|
||||
// FIXME: remove
|
||||
float bf16_to_f32_(ck::bhalf_t src_val);
|
||||
|
||||
// FIXME: remove
|
||||
void bf16_to_f32_(const Tensor<ck::bhalf_t>& src, Tensor<float>& dst);
|
||||
#endif
|
||||
|
||||
template <typename T>
|
||||
float check_error(const Tensor<T>& ref, const Tensor<T>& result)
|
||||
{
|
||||
float error = 0;
|
||||
float max_diff = -1;
|
||||
float ref_value = 0, result_value = 0;
|
||||
float l1_error = 0;
|
||||
float linf_error = -1;
|
||||
float linf_rel_error = -1;
|
||||
|
||||
if constexpr(std::is_same<ck::bhalf_t, T>::value)
|
||||
float linf_ref_value = 0, linf_result_value = 0;
|
||||
float linf_rel_ref_value = 0, linf_rel_result_value = 0;
|
||||
|
||||
constexpr float eps = 1e-10;
|
||||
|
||||
for(int i = 0; i < ref.mData.size(); ++i)
|
||||
{
|
||||
for(int i = 0; i < ref.mData.size(); ++i)
|
||||
float ref_v = ck::type_convert<float>(ref.mData[i]);
|
||||
float result_v = ck::type_convert<float>(result.mData[i]);
|
||||
|
||||
float diff = std::abs(ref_v - result_v);
|
||||
float rel_diff = diff / std::max(std::abs(ref_v), eps);
|
||||
|
||||
l1_error += diff;
|
||||
|
||||
if(linf_error < diff)
|
||||
{
|
||||
error += std::abs(bf16_to_f32_(ref.mData[i]) - bf16_to_f32_(result.mData[i]));
|
||||
float diff = std::abs(bf16_to_f32_(ref.mData[i]) - bf16_to_f32_(result.mData[i]));
|
||||
if(max_diff < diff)
|
||||
{
|
||||
max_diff = diff;
|
||||
ref_value = bf16_to_f32_(ref.mData[i]);
|
||||
result_value = bf16_to_f32_(result.mData[i]);
|
||||
}
|
||||
linf_error = diff;
|
||||
linf_ref_value = ref_v;
|
||||
linf_result_value = result_v;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
for(int i = 0; i < ref.mData.size(); ++i)
|
||||
|
||||
if(linf_rel_error < rel_diff)
|
||||
{
|
||||
error += std::abs(double(ref.mData[i]) - double(result.mData[i]));
|
||||
float diff = std::abs(double(ref.mData[i]) - double(result.mData[i]));
|
||||
if(max_diff < diff)
|
||||
{
|
||||
max_diff = diff;
|
||||
ref_value = ref.mData[i];
|
||||
result_value = result.mData[i];
|
||||
}
|
||||
linf_rel_error = rel_diff;
|
||||
linf_rel_ref_value = ref_v;
|
||||
linf_rel_result_value = result_v;
|
||||
}
|
||||
}
|
||||
|
||||
std::cout << "error: " << error << std::endl;
|
||||
std::cout << "max_diff: " << max_diff << ", " << ref_value << ", " << result_value << std::endl;
|
||||
return max_diff;
|
||||
std::cout << "Absolute Error L1 Norm (sum of abs diff): " << l1_error << std::endl;
|
||||
std::cout << "Absolute Error L-inf Norm (max abs diff): " << linf_error << ", ref "
|
||||
<< linf_ref_value << ", result " << linf_result_value << std::endl;
|
||||
std::cout << "Relative Error L-inf Norm (max relative abs diff): " << linf_rel_error << ", ref "
|
||||
<< linf_rel_ref_value << ", result " << linf_rel_result_value << std::endl;
|
||||
|
||||
return linf_error;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
||||
Reference in New Issue
Block a user