mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
Rangify FillUniformDistributionIntegerValue<> (#443)
Allow passing forward range to its call operator
[ROCm/composable_kernel commit: 6f0564f013]
This commit is contained in:
@@ -32,14 +32,12 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
ck::utils::FillUniformDistributionIntegerValue<ADataType>{-5.f, 5.f}(a_m_k.begin(),
|
||||
a_m_k.end());
|
||||
ck::utils::FillUniformDistributionIntegerValue<BDataType>{-5.f, 5.f}(b_k_n.begin(),
|
||||
b_k_n.end());
|
||||
ck::utils::FillUniformDistributionIntegerValue<ADataType>{-5.f, 5.f}(a_m_k);
|
||||
ck::utils::FillUniformDistributionIntegerValue<BDataType>{-5.f, 5.f}(b_k_n);
|
||||
break;
|
||||
default:
|
||||
ck::utils::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k.begin(), a_m_k.end());
|
||||
ck::utils::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n.begin(), b_k_n.end());
|
||||
ck::utils::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k);
|
||||
ck::utils::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n);
|
||||
}
|
||||
|
||||
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
|
||||
|
||||
@@ -77,15 +77,12 @@ bool run_convnd_fwd_max(const ck::utils::conv::ConvParam& problem_size,
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
ck::utils::FillUniformDistributionIntegerValue<ADataType>{-8, 7}(conv_input.begin(),
|
||||
conv_input.end());
|
||||
ck::utils::FillUniformDistributionIntegerValue<BDataType>{-8, 7}(conv_weight.begin(),
|
||||
conv_weight.end());
|
||||
ck::utils::FillUniformDistributionIntegerValue<ADataType>{-8, 7}(conv_input);
|
||||
ck::utils::FillUniformDistributionIntegerValue<BDataType>{-8, 7}(conv_weight);
|
||||
break;
|
||||
default:
|
||||
ck::utils::FillUniformDistribution<ADataType>{-5, 5}(conv_input.begin(), conv_input.end());
|
||||
ck::utils::FillUniformDistribution<BDataType>{-5, 5}(conv_weight.begin(),
|
||||
conv_weight.end());
|
||||
ck::utils::FillUniformDistribution<ADataType>{-5, 5}(conv_input);
|
||||
ck::utils::FillUniformDistribution<BDataType>{-5, 5}(conv_weight);
|
||||
}
|
||||
|
||||
DeviceMem conv_input_device_buf(sizeof(ADataType) * conv_input.mDesc.GetElementSpaceSize());
|
||||
|
||||
@@ -160,14 +160,12 @@ bool run_gemm_reduce_add_addsquare_xdl(ck::index_t M,
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
ck::utils::FillUniformDistributionIntegerValue<ADataType>{-5.f, 5.f}(a_m_k.begin(),
|
||||
a_m_k.end());
|
||||
ck::utils::FillUniformDistributionIntegerValue<BDataType>{-5.f, 5.f}(b_k_n.begin(),
|
||||
b_k_n.end());
|
||||
ck::utils::FillUniformDistributionIntegerValue<ADataType>{-5.f, 5.f}(a_m_k);
|
||||
ck::utils::FillUniformDistributionIntegerValue<BDataType>{-5.f, 5.f}(b_k_n);
|
||||
break;
|
||||
default:
|
||||
ck::utils::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k.begin(), a_m_k.end());
|
||||
ck::utils::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n.begin(), b_k_n.end());
|
||||
ck::utils::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k);
|
||||
ck::utils::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n);
|
||||
break;
|
||||
}
|
||||
|
||||
|
||||
@@ -134,14 +134,12 @@ auto run_gemm_reduce_max_xdl(ck::index_t M,
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
ck::utils::FillUniformDistributionIntegerValue<ADataType>{-5.f, 5.f}(a_m_k.begin(),
|
||||
a_m_k.end());
|
||||
ck::utils::FillUniformDistributionIntegerValue<BDataType>{-5.f, 5.f}(b_k_n.begin(),
|
||||
b_k_n.end());
|
||||
ck::utils::FillUniformDistributionIntegerValue<ADataType>{-5.f, 5.f}(a_m_k);
|
||||
ck::utils::FillUniformDistributionIntegerValue<BDataType>{-5.f, 5.f}(b_k_n);
|
||||
break;
|
||||
default:
|
||||
ck::utils::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k.begin(), a_m_k.end());
|
||||
ck::utils::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n.begin(), b_k_n.end());
|
||||
ck::utils::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k);
|
||||
ck::utils::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n);
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -339,14 +337,12 @@ bool run_gemm_reduce_mean_meansquare_xdl(ck::index_t M,
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
ck::utils::FillUniformDistributionIntegerValue<ADataType>{-5.f, 5.f}(a_m_k.begin(),
|
||||
a_m_k.end());
|
||||
ck::utils::FillUniformDistributionIntegerValue<BDataType>{-5.f, 5.f}(b_k_n.begin(),
|
||||
b_k_n.end());
|
||||
ck::utils::FillUniformDistributionIntegerValue<ADataType>{-5.f, 5.f}(a_m_k);
|
||||
ck::utils::FillUniformDistributionIntegerValue<BDataType>{-5.f, 5.f}(b_k_n);
|
||||
break;
|
||||
default:
|
||||
ck::utils::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k.begin(), a_m_k.end());
|
||||
ck::utils::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n.begin(), b_k_n.end());
|
||||
ck::utils::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k);
|
||||
ck::utils::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n);
|
||||
break;
|
||||
}
|
||||
|
||||
|
||||
@@ -100,9 +100,9 @@ int main(int argc, char* argv[])
|
||||
Tensor<GammaDataType> gamma({G, C});
|
||||
Tensor<BetaDataType> beta({G, C});
|
||||
|
||||
ck::utils::FillUniformDistribution<XDataType>{0.f, 1.f}(x.begin(), x.end());
|
||||
ck::utils::FillUniformDistribution<GammaDataType>{0.f, 1.f}(gamma.begin(), gamma.end());
|
||||
ck::utils::FillUniformDistribution<BetaDataType>{0.f, 1.f}(beta.begin(), beta.end());
|
||||
ck::utils::FillUniformDistribution<XDataType>{0.f, 1.f}(x);
|
||||
ck::utils::FillUniformDistribution<GammaDataType>{0.f, 1.f}(gamma);
|
||||
ck::utils::FillUniformDistribution<BetaDataType>{0.f, 1.f}(beta);
|
||||
|
||||
DeviceMem x_dev(sizeof(XDataType) * x.mDesc.GetElementSpaceSize());
|
||||
DeviceMem gamma_dev(sizeof(GammaDataType) * gamma.mDesc.GetElementSpaceSize());
|
||||
|
||||
@@ -30,9 +30,10 @@ struct FillUniformDistribution
|
||||
}
|
||||
|
||||
template <typename ForwardRange>
|
||||
auto operator()(ForwardRange&& range) -> std::void_t<decltype(
|
||||
std::declval<FillUniformDistribution>()(std::begin(std::forward<ForwardRange>(range)),
|
||||
std::end(std::forward<ForwardRange>(range))))>
|
||||
auto operator()(ForwardRange&& range) const
|
||||
-> std::void_t<decltype(std::declval<const FillUniformDistribution&>()(
|
||||
std::begin(std::forward<ForwardRange>(range)),
|
||||
std::end(std::forward<ForwardRange>(range))))>
|
||||
{
|
||||
(*this)(std::begin(std::forward<ForwardRange>(range)),
|
||||
std::end(std::forward<ForwardRange>(range)));
|
||||
@@ -72,6 +73,16 @@ struct FillUniformDistributionIntegerValue
|
||||
std::generate(
|
||||
first, last, [&dis, &gen]() { return ck::type_convert<T>(std::round(dis(gen))); });
|
||||
}
|
||||
|
||||
template <typename ForwardRange>
|
||||
auto operator()(ForwardRange&& range) const
|
||||
-> std::void_t<decltype(std::declval<const FillUniformDistributionIntegerValue&>()(
|
||||
std::begin(std::forward<ForwardRange>(range)),
|
||||
std::end(std::forward<ForwardRange>(range))))>
|
||||
{
|
||||
(*this)(std::begin(std::forward<ForwardRange>(range)),
|
||||
std::end(std::forward<ForwardRange>(range)));
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
|
||||
Reference in New Issue
Block a user