mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 21:21:22 +00:00
Code refactor (#175)
* format * improving pipeline * fix typo * format * adding thread group * adding thread group * adding thread group * adding gemm pipeline * tweak * refactor * refactor * add missing type convert * refactor * refactor * refactor * clean * fix build * refactor * format * clean up * use remove_cvref_t * clean * clean up * clean up * clean up
This commit is contained in:
@@ -51,7 +51,7 @@ template <typename SrcData,
|
||||
typename DstData,
|
||||
typename SrcDesc,
|
||||
typename DstDesc,
|
||||
typename DstElementwiseOperation,
|
||||
typename ElementwiseOperation,
|
||||
typename SliceLengths,
|
||||
typename DimAccessOrder,
|
||||
index_t DstVectorDim,
|
||||
@@ -70,12 +70,11 @@ struct ThreadwiseTensorSliceTransfer_v1r3
|
||||
|
||||
using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{}));
|
||||
|
||||
__device__ constexpr ThreadwiseTensorSliceTransfer_v1r3(
|
||||
const DstDesc& dst_desc,
|
||||
const Index& dst_slice_origin_idx,
|
||||
const DstElementwiseOperation& dst_element_op)
|
||||
__device__ constexpr ThreadwiseTensorSliceTransfer_v1r3(const DstDesc& dst_desc,
|
||||
const Index& dst_slice_origin_idx,
|
||||
const ElementwiseOperation& element_op)
|
||||
: dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin_idx)),
|
||||
dst_element_op_{dst_element_op}
|
||||
element_op_{element_op}
|
||||
{
|
||||
static_assert(SrcDesc::IsKnownAtCompileTime(),
|
||||
"wrong! SrcDesc need to known at compile-time");
|
||||
@@ -136,13 +135,13 @@ struct ThreadwiseTensorSliceTransfer_v1r3
|
||||
constexpr index_t src_offset = src_desc.CalculateOffset(
|
||||
src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
|
||||
|
||||
SrcData dst_v;
|
||||
SrcData v;
|
||||
|
||||
// apply element-wise operation
|
||||
dst_element_op_(dst_v, src_buf[Number<src_offset>{}]);
|
||||
element_op_(v, src_buf[Number<src_offset>{}]);
|
||||
|
||||
// apply type convert
|
||||
dst_vector.template AsType<DstData>()(i) = type_convert<DstData>(dst_v);
|
||||
dst_vector.template AsType<DstData>()(i) = type_convert<DstData>(v);
|
||||
});
|
||||
|
||||
const bool is_dst_valid =
|
||||
@@ -213,7 +212,7 @@ struct ThreadwiseTensorSliceTransfer_v1r3
|
||||
|
||||
private:
|
||||
DstCoord dst_coord_;
|
||||
const DstElementwiseOperation dst_element_op_;
|
||||
const ElementwiseOperation element_op_;
|
||||
}; // namespace ThreadwiseTensorSliceTransfer_v1r3
|
||||
|
||||
// Assume:
|
||||
|
||||
@@ -102,8 +102,13 @@ struct ThreadwiseTensorSliceTransfer_v6r1
|
||||
|
||||
// apply pointwise operation
|
||||
static_for<0, ScalarPerVector, 1>{}([&](auto i) {
|
||||
element_op_(dst_vector_container.template AsType<DstData>()(i),
|
||||
src_vector_container.template AsType<SrcData>()[i]);
|
||||
SrcData v;
|
||||
|
||||
// apply element-wise operation
|
||||
element_op_(v, src_vector_container.template AsType<SrcData>()[i]);
|
||||
|
||||
// apply type convert
|
||||
dst_vector_container.template AsType<DstData>()(i) = type_convert<DstData>(v);
|
||||
});
|
||||
|
||||
const bool is_dst_valid =
|
||||
|
||||
Reference in New Issue
Block a user