mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-20 21:09:08 +00:00
Add dynamic elementwise op (#1426)
* Add dynamic elementwise op
Co-authored-by: ThruptiRajLakshmanaGowda <thruptiraj.lakshmanagowda@amd.com>
* CI issues fix
* Custom parameter value for dynamic functions - Comments addressed
---------
Co-authored-by: ThruptiRajLakshmanaGowda <thruptiraj.lakshmanagowda@amd.com>
Co-authored-by: ThruptiRajLakshmanaGowda <tlakshma@amd.com>
[ROCm/composable_kernel commit: 31bf253aeb]
This commit is contained in:
@@ -85,9 +85,9 @@ __global__ void
|
||||
BsPointer p_bs_grid,
|
||||
DsPointer p_ds_grid,
|
||||
EDataType* __restrict__ p_e_grid,
|
||||
const AElementwiseOperation a_element_op,
|
||||
const BElementwiseOperation b_element_op,
|
||||
const CDEElementwiseOperation cde_element_op,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op,
|
||||
const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1,
|
||||
const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1,
|
||||
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
@@ -121,6 +121,19 @@ __global__ void
|
||||
static_for<0, NumDTensor, 1>{}(
|
||||
[&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_group_offset[i]; });
|
||||
|
||||
if constexpr(is_same_v<AElementwiseOperation, element_wise::DynamicUnaryOp>)
|
||||
{
|
||||
a_element_op.InitUnaryOpPtrOnDevice();
|
||||
}
|
||||
if constexpr(is_same_v<BElementwiseOperation, element_wise::DynamicUnaryOp>)
|
||||
{
|
||||
b_element_op.InitUnaryOpPtrOnDevice();
|
||||
}
|
||||
if constexpr(is_same_v<CDEElementwiseOperation, element_wise::DynamicUnaryOp>)
|
||||
{
|
||||
cde_element_op.InitUnaryOpPtrOnDevice();
|
||||
}
|
||||
|
||||
if constexpr(isMultiA || isMultiB)
|
||||
{
|
||||
AsPointer p_as_grid_grp;
|
||||
|
||||
@@ -405,7 +405,7 @@ struct ScaleAddScaleAddRelu
|
||||
const float& d1) const
|
||||
{
|
||||
const float x = c * alpha1_ + alpha2_ * d0 + d1;
|
||||
Relu{}.template operator()<float>(e, x);
|
||||
e = x > 0 ? x : 0;
|
||||
}
|
||||
|
||||
template <>
|
||||
@@ -416,7 +416,7 @@ struct ScaleAddScaleAddRelu
|
||||
type_convert<float>(d1);
|
||||
|
||||
float result = 0;
|
||||
Relu{}.template operator()<float>(result, x);
|
||||
result = x > 0 ? x : 0;
|
||||
|
||||
e = type_convert<half_t>(result);
|
||||
}
|
||||
@@ -429,7 +429,7 @@ struct ScaleAddScaleAddRelu
|
||||
type_convert<float>(d1);
|
||||
|
||||
float result = 0;
|
||||
Relu{}.template operator()<float>(result, x);
|
||||
result = x > 0 ? x : 0;
|
||||
|
||||
e = type_convert<bhalf_t>(result);
|
||||
}
|
||||
@@ -441,7 +441,7 @@ struct ScaleAddScaleAddRelu
|
||||
const float x = type_convert<float>(c) * alpha1_ + alpha2_ * d0 + d1;
|
||||
|
||||
float result = 0;
|
||||
Relu{}.template operator()<float>(result, x);
|
||||
result = x > 0 ? x : 0;
|
||||
|
||||
e = type_convert<int8_t>(result);
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user