Code refactor (#175)

* format

* improving pipeline

* fix typo

* format

* adding thread group

* adding thread group

* adding thread group

* adding gemm pipeline

* tweak

* refactor

* refactor

* add missing type convert

* refactor

* refactor

* refactor

* clean

* fix build

* refactor

* format

* clean up

* use remove_cvref_t

* clean

* clean up

* clean up

* clean up
This commit is contained in:
Chao Liu
2022-05-09 14:57:59 -05:00
committed by GitHub
parent a3c910ac6c
commit ec7c2e912e
52 changed files with 1167 additions and 1912 deletions

View File

@@ -51,7 +51,7 @@ template <typename SrcData,
typename DstData,
typename SrcDesc,
typename DstDesc,
typename DstElementwiseOperation,
typename ElementwiseOperation,
typename SliceLengths,
typename DimAccessOrder,
index_t DstVectorDim,
@@ -70,12 +70,11 @@ struct ThreadwiseTensorSliceTransfer_v1r3
using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{}));
__device__ constexpr ThreadwiseTensorSliceTransfer_v1r3(
const DstDesc& dst_desc,
const Index& dst_slice_origin_idx,
const DstElementwiseOperation& dst_element_op)
__device__ constexpr ThreadwiseTensorSliceTransfer_v1r3(const DstDesc& dst_desc,
const Index& dst_slice_origin_idx,
const ElementwiseOperation& element_op)
: dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin_idx)),
dst_element_op_{dst_element_op}
element_op_{element_op}
{
static_assert(SrcDesc::IsKnownAtCompileTime(),
"wrong! SrcDesc need to known at compile-time");
@@ -136,13 +135,13 @@ struct ThreadwiseTensorSliceTransfer_v1r3
constexpr index_t src_offset = src_desc.CalculateOffset(
src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
SrcData dst_v;
SrcData v;
// apply element-wise operation
dst_element_op_(dst_v, src_buf[Number<src_offset>{}]);
element_op_(v, src_buf[Number<src_offset>{}]);
// apply type convert
dst_vector.template AsType<DstData>()(i) = type_convert<DstData>(dst_v);
dst_vector.template AsType<DstData>()(i) = type_convert<DstData>(v);
});
const bool is_dst_valid =
@@ -213,7 +212,7 @@ struct ThreadwiseTensorSliceTransfer_v1r3
private:
DstCoord dst_coord_;
const DstElementwiseOperation dst_element_op_;
const ElementwiseOperation element_op_;
}; // namespace ThreadwiseTensorSliceTransfer_v1r3
// Assume:

View File

@@ -102,8 +102,13 @@ struct ThreadwiseTensorSliceTransfer_v6r1
// apply pointwise operation
static_for<0, ScalarPerVector, 1>{}([&](auto i) {
element_op_(dst_vector_container.template AsType<DstData>()(i),
src_vector_container.template AsType<SrcData>()[i]);
SrcData v;
// apply element-wise operation
element_op_(v, src_vector_container.template AsType<SrcData>()[i]);
// apply type convert
dst_vector_container.template AsType<DstData>()(i) = type_convert<DstData>(v);
});
const bool is_dst_valid =