Added Multi_ABD support into Gemm and GroupedGemmFixedNK (#978)

* added an example grouped_gemm_multi_abd

* fixed ci

* add setElementwiseOp

* changed API

* clean code: add multiA into example

* fixed v7r2 copy

* add transpose

* clean

* fixed vector_load check

* Update example/15_grouped_gemm/grouped_gemm_multi_abd_xdl_fixed_nk_bias_fp16.cpp

Co-authored-by: Bartłomiej Kocot <barkocot@amd.com>

* Update example/15_grouped_gemm/grouped_gemm_multi_abd_xdl_fixed_nk_bias_fp16.cpp

Co-authored-by: Bartłomiej Kocot <barkocot@amd.com>

* Update example/15_grouped_gemm/grouped_gemm_multi_abd_xdl_fixed_nk_bias_fp16.cpp

Co-authored-by: Bartłomiej Kocot <barkocot@amd.com>

* Update include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp

Co-authored-by: Bartłomiej Kocot <barkocot@amd.com>

* Update include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp

Co-authored-by: Bartłomiej Kocot <barkocot@amd.com>

* Update include/ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp

Co-authored-by: Bartłomiej Kocot <barkocot@amd.com>

* Update include/ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp

Co-authored-by: Bartłomiej Kocot <barkocot@amd.com>

* Update include/ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp

Co-authored-by: Bartłomiej Kocot <barkocot@amd.com>

* Update include/ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp

Co-authored-by: Bartłomiej Kocot <barkocot@amd.com>

* Update include/ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp

Co-authored-by: Bartłomiej Kocot <barkocot@amd.com>

* Update include/ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp

Co-authored-by: Bartłomiej Kocot <barkocot@amd.com>

* Update include/ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd_fixed_nk.hpp

Co-authored-by: Bartłomiej Kocot <barkocot@amd.com>

* Update include/ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd_fixed_nk.hpp

Co-authored-by: Bartłomiej Kocot <barkocot@amd.com>

* add reduce

* testing

* add example_b16_i8

* refactor example

* clean

* add mpading

* disable reduce for kbatch = 1

* seperate reduce device op

* add reduce op

* add guard for workspace_size

* add instances

* format

* fixed

* add client example

* add a colmajor

* add instances

* Update cmake-ck-dev.sh

* Update profile_gemm_splitk.cpp

* Update gridwise_gemm_xdlops_v2r4r2.hpp

* format

* Update profile_gemm_splitk.cpp

* fixed

* fixed

* adjust test

* adjust precision loss

* adjust test

* fixed

* add bf16_i8 scale bias

* fixed scale

* fixed scale elementwise_op

* revert contraction deviceop changes

* fixed

* Add AddFastGelu

* Revert "Merge branch 'jizhan/gemm_splitk_reduce' into grouped_gemm_multi_abd_fixed_nk_example"

This reverts commit 3b5d001efd, reversing
changes made to 943199a991.

* add Scales into elementwise

* add gemm_multi_abd client example

* add client examples

* add rcr and crr

* add grouped gemm client example

* add grouped gemm client example

* add instance for rcr crr

* format

* fixed

* fixed cmake

* fixed

* fixed client_example

* format

* fixed contraction isSupport

* Update include/ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd_fixed_nk.hpp

Co-authored-by: Bartłomiej Kocot <barkocot@amd.com>

* Update device_reduce_threadwise.hpp

* clean

* Fixes

* Fix example

---------

Co-authored-by: Jing Zhang <jizha@amd.com>
Co-authored-by: Bartłomiej Kocot <barkocot@amd.com>
This commit is contained in:
zjing14
2024-04-15 21:09:45 -05:00
committed by GitHub
parent db376dd8a4
commit 12865fbf28
45 changed files with 6345 additions and 199 deletions

View File

@@ -10,38 +10,9 @@
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_util.hpp"
namespace ck {
// Do following things to avoid "alloca" in LLVM-IR, which would cause scratch memory
// and sometimes useless instructions:
// 1. Don't save a reference to tensor descriptor in class, pass in tensor descriptor as argument
// instead
// 2. Don't construct a new tensor coordinate everytime when using it, update and reuse the same
// tensor coordinate instead
// 3. Don't use a pointer to VGPR buffer, use vector instead
namespace detail {
// TODO: How to fix this? It uses an struct instead of lambda because lambda
// doesn't have constructor
template <index_t VectorDim, index_t ScalarPerVector>
struct lambda_scalar_per_access
{
__host__ __device__ constexpr auto operator()(index_t i) const
{
return (i == VectorDim) ? ScalarPerVector : 1;
}
};
template <index_t VectorDim>
struct lambda_scalar_step_in_vector
{
__host__ __device__ constexpr auto operator()(index_t i) const
{
return (i == VectorDim) ? 1 : 0;
}
};
} // namespace detail
// Assume:
// 1. src:
// 1. SrcDesc is known at compile-time

View File

@@ -0,0 +1,67 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace ck {
// Do following things to avoid "alloca" in LLVM-IR, which would cause scratch memory
// and sometimes useless instructions:
// 1. Don't save a reference to tensor descriptor in class, pass in tensor descriptor as argument
// instead
// 2. Don't construct a new tensor coordinate everytime when using it, update and reuse the same
// tensor coordinate instead
// 3. Don't use a pointer to VGPR buffer, use vector instead
namespace detail {
// TODO: How to fix this? It uses an struct instead of lambda because lambda
// doesn't have constructor
template <index_t VectorDim, index_t ScalarPerVector>
struct lambda_scalar_per_access
{
__host__ __device__ constexpr auto operator()(index_t i) const
{
return (i == VectorDim) ? ScalarPerVector : 1;
}
};
template <index_t VectorDim>
struct lambda_scalar_step_in_vector
{
__host__ __device__ constexpr auto operator()(index_t i) const
{
return (i == VectorDim) ? 1 : 0;
}
};
// TODO: How to fix this? It uses an struct instead of lambda because lambda
// doesn't have constructor
template <index_t SrcVectorDim,
index_t SrcScalarPerVector,
index_t DstVectorDim,
index_t DstScalarPerVector>
struct lambda_scalar_per_access_for_src_and_dst
{
__host__ __device__ constexpr auto operator()(index_t i) const
{
if(i == SrcVectorDim && i == DstVectorDim)
{
return math::lcm(SrcScalarPerVector, DstScalarPerVector);
}
else if(i == SrcVectorDim)
{
return SrcScalarPerVector;
}
else if(i == DstVectorDim)
{
return DstScalarPerVector;
}
else
{
return 1;
}
}
};
} // namespace detail
} // namespace ck

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -7,44 +7,13 @@
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor/static_tensor.hpp"
#include "ck/utility/is_detected.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_util.hpp"
namespace ck {
namespace detail {
// TODO: How to fix this? It uses an struct instead of lambda because lambda
// doesn't have constructor
template <index_t SrcVectorDim,
index_t SrcScalarPerVector,
index_t DstVectorDim,
index_t DstScalarPerVector>
struct lambda_scalar_per_access_for_src_and_dst
{
__host__ __device__ constexpr auto operator()(index_t i) const
{
if(i == SrcVectorDim && i == DstVectorDim)
{
return math::lcm(SrcScalarPerVector, DstScalarPerVector);
}
else if(i == SrcVectorDim)
{
return SrcScalarPerVector;
}
else if(i == DstVectorDim)
{
return DstScalarPerVector;
}
else
{
return 1;
}
}
};
} // namespace detail
// Assume:
// 1. src_desc and dst_desc are not known at compile-time
// 2. SrcBuffer and DstBuffer are DynamicBuffer

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -8,9 +8,11 @@
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_space_filling_curve.hpp"
#include "ck/utility/is_detected.hpp"
#include "ck/tensor/static_tensor.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_util.hpp"
namespace ck {
// Thread-level multi-source, multi-destination tensor slice data movement
// Assume:
// 1. All sources and destinations are DynamicBuffer
@@ -70,16 +72,18 @@ struct ThreadwiseTensorSliceTransfer_v7r2
static constexpr auto src_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
using SrcSpaceFillingCurve = SpaceFillingCurve<SliceLengths,
SrcDimAccessOrder,
remove_cv_t<decltype(src_scalar_per_access)>>;
static constexpr auto dst_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
using SrcSpaceFillingCurve = SpaceFillingCurve<SliceLengths,
SrcDimAccessOrder,
remove_cv_t<decltype(src_scalar_per_access)>,
false>;
using DstSpaceFillingCurve = SpaceFillingCurve<SliceLengths,
DstDimAccessOrder,
remove_cv_t<decltype(dst_scalar_per_access)>>;
remove_cv_t<decltype(dst_scalar_per_access)>,
false>;
__device__ constexpr ThreadwiseTensorSliceTransfer_v7r2(
const SrcDescs& src_descs,
@@ -139,9 +143,9 @@ struct ThreadwiseTensorSliceTransfer_v7r2
__device__ void RunRead(const SrcDescs& src_descs, const SrcBuffers& src_bufs)
{
// loop over space-filling curve
static_for<0, num_access, 1>{}([&](auto iAccess) {
static_for<0, src_num_access, 1>{}([&](auto iAccess) {
auto src_vectors = generate_vectors<SrcDatas, SrcScalarPerVector>();
auto dst_vectors = generate_vectors<DstDatas, DstScalarPerVector>();
auto elm_vectors = generate_vectors<DstDatas, SrcScalarPerVector>();
// copy data from src_bufs into src_vectors
static_for<0, nSrc, 1>{}([&](auto i) {
@@ -199,7 +203,7 @@ struct ThreadwiseTensorSliceTransfer_v7r2
using elem_op_vec_t = typename vector_type<DstData, elem_op_vec_len>::type;
return dst_vectors(iDst).template AsType<elem_op_vec_t>()(i);
return elm_vectors(iDst).template AsType<elem_op_vec_t>()(i);
},
Number<nDst>{});
@@ -214,10 +218,10 @@ struct ThreadwiseTensorSliceTransfer_v7r2
unpack2(element_op_, dst_data_refs, src_data_refs);
});
dst_vectors_tuple_(iAccess) = dst_vectors;
elm_vectors_tuple_(iAccess) = elm_vectors;
// move coordinate
if constexpr(iAccess.value != num_access - 1)
if constexpr(iAccess.value != src_num_access - 1)
{
constexpr auto forward_step = SrcSpaceFillingCurve::GetForwardStep(iAccess);
@@ -241,15 +245,113 @@ struct ThreadwiseTensorSliceTransfer_v7r2
});
}
__device__ void TransposeFromElmToDst()
{
using DstData = remove_cvref_t<decltype(DstDatas{}[I0])>;
using SrcThreadScratch =
StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
DstData,
SrcScalarPerVector,
decltype(GetSrcThreadScratchDescriptor()),
true>;
using DstThreadScratch =
StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
DstData,
DstScalarPerVector,
decltype(GetDstThreadScratchDescriptor()),
true>;
SrcThreadScratch elm_thread_scratch_;
DstThreadScratch dst_thread_scratch_;
elm_thread_scratch_.data_ =
bit_cast<decltype(elm_thread_scratch_.data_)>(elm_vectors_tuple_);
if constexpr(SrcVectorDim != DstVectorDim &&
((is_same<half_t, remove_cvref_t<DstData>>::value &&
SrcScalarPerVector % 2 == 0 && DstScalarPerVector % 2 == 0) ||
(is_same<int8_t, remove_cvref_t<DstData>>::value &&
SrcScalarPerVector % 4 == 0 && DstScalarPerVector % 4 == 0)))
{
// each transpose does
// DstScalarPerVector # of src vectors in src_thread_scratch_
// SrcScalarPerVector # of dst vectors in dst_thread_scratch_
constexpr index_t num_src_vector = Number<DstScalarPerVector>{};
constexpr index_t num_dst_vector = Number<SrcScalarPerVector>{};
// Assume SrcVectorDim is not the same as DstVectorDim, so we do transpose
// TODO: make this logic generic for all scenario
constexpr auto src_scalar_step_in_vector = generate_sequence(
detail::lambda_scalar_step_in_vector<SrcVectorDim>{}, Number<nDim>{});
constexpr auto dst_scalar_step_in_vector = generate_sequence(
detail::lambda_scalar_step_in_vector<DstVectorDim>{}, Number<nDim>{});
constexpr auto scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access_for_src_and_dst<SrcVectorDim,
SrcScalarPerVector,
DstVectorDim,
DstScalarPerVector>{},
Number<nDim>{});
constexpr auto access_lengths = SliceLengths{} / scalar_per_access;
static_ford<decltype(access_lengths)>{}([&](auto access_idx) {
constexpr auto data_idx = access_idx * scalar_per_access;
constexpr auto data_idx_seq = generate_sequence_v2(
[&](auto i) { return Number<data_idx[i]>{}; }, Number<nDim>{});
using src_vector_t = vector_type_maker_t<DstData, SrcScalarPerVector>;
using dst_vector_t = vector_type_maker_t<DstData, DstScalarPerVector>;
// get DstScalarPerVector # of read-only references to src vectors from
// src_thread_scratch_
const auto src_vector_refs = generate_tie(
[&](auto i) -> const src_vector_t& {
// i increment corresponds to movement in DstVectorDim
return elm_thread_scratch_.GetVectorTypeReference(
data_idx_seq + i * dst_scalar_step_in_vector);
},
Number<num_src_vector>{});
// get SrcScalarPerVector # of references to dst vectors from dst_thread_scratch_
auto dst_vector_refs = generate_tie(
[&](auto i) -> dst_vector_t& {
// i increment corresponds to movement in SrcVectorDim
return dst_thread_scratch_.GetVectorTypeReference(
data_idx_seq + i * src_scalar_step_in_vector);
},
Number<num_dst_vector>{});
// do data transpose
transpose_vectors<DstData, DstScalarPerVector, SrcScalarPerVector>{}(
src_vector_refs, dst_vector_refs);
});
}
else
{
static_ford<SliceLengths>{}(
[&](auto idx) { dst_thread_scratch_(idx) = elm_thread_scratch_[idx]; });
}
dst_vectors_tuple_ = bit_cast<decltype(dst_vectors_tuple_)>(dst_thread_scratch_.data_);
}
// DstDescs: Tuple<const DstDesc0&, const DstDesc1&, ...>
// DstBuffers: Tuple<const DstBuffer0&, const DstBuffer1&, ...>
template <typename DstBuffers,
enable_if_t<DstDescs::Size() == DstBuffers::Size(), bool> = false>
enable_if_t<DstDescs::Size() == 1 && DstBuffers::Size() == 1, bool> = false>
__device__ void RunWrite(const DstDescs& dst_descs, DstBuffers dst_bufs)
{
TransposeFromElmToDst();
// loop over space-filling curve
static_for<0, num_access, 1>{}([&](auto iAccess) {
auto dst_vectors = dst_vectors_tuple_[iAccess];
static_for<0, dst_num_access, 1>{}([&](auto iAccess) {
auto dst_vectors = dst_vectors_tuple_[Number<iAccess>{}];
// copy data from buf_vectors into dst_bufs
static_for<0, nDst, 1>{}([&](auto i) {
@@ -269,7 +371,7 @@ struct ThreadwiseTensorSliceTransfer_v7r2
});
// move coordinate
if constexpr(iAccess.value != num_access - 1)
if constexpr(iAccess.value != dst_num_access - 1)
{
constexpr auto forward_step = DstSpaceFillingCurve::GetForwardStep(iAccess);
@@ -312,28 +414,126 @@ struct ThreadwiseTensorSliceTransfer_v7r2
__device__ static constexpr auto GetSrcCoordinateResetStep()
{
if constexpr(num_access == 0)
if constexpr(src_num_access == 0)
{
return typename SrcSpaceFillingCurve::Index{};
}
else
{
return SrcSpaceFillingCurve::GetStepBetween(Number<num_access - 1>{}, Number<0>{});
return SrcSpaceFillingCurve::GetStepBetween(Number<src_num_access - 1>{}, Number<0>{});
}
}
__device__ static constexpr auto GetDstCoordinateResetStep()
{
if constexpr(num_access == 0)
if constexpr(dst_num_access == 0)
{
return typename DstSpaceFillingCurve::Index{};
}
else
{
return DstSpaceFillingCurve::GetStepBetween(Number<num_access - 1>{}, Number<0>{});
return DstSpaceFillingCurve::GetStepBetween(Number<dst_num_access - 1>{}, Number<0>{});
}
}
__device__ static constexpr auto GetSrcThreadScratchDescriptor()
{
// constexpr auto src_scalar_per_access = generate_sequence(
// detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
constexpr auto src_access_lengths_and_vector_length = container_push_back(
sequence_to_tuple_of_number(src_access_lengths), Number<SrcScalarPerVector>{});
// 1st stage of transforms
constexpr auto desc0 =
make_naive_tensor_descriptor_packed(src_access_lengths_and_vector_length);
// 2nd stage of transforms
constexpr auto transforms = generate_tuple(
[&](auto i) {
if constexpr(i == SrcVectorDim)
{
return make_merge_transform_v3_division_mod(
make_tuple(src_access_lengths_and_vector_length[i],
src_access_lengths_and_vector_length[Number<nDim>{}]));
}
else
{
return make_pass_through_transform(src_access_lengths_and_vector_length[i]);
}
},
Number<nDim>{});
constexpr auto low_dim_idss = generate_tuple(
[&](auto i) {
if constexpr(i == SrcVectorDim)
{
return Sequence<i.value, nDim>{};
}
else
{
return Sequence<i.value>{};
}
},
Number<nDim>{});
constexpr auto up_dim_idss =
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
}
__device__ static constexpr auto GetDstThreadScratchDescriptor()
{
// 1st stage of transforms
// constexpr auto dst_scalar_per_access = generate_sequence(
// detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access;
constexpr auto dst_access_lengths_and_vector_length = container_push_back(
sequence_to_tuple_of_number(dst_access_lengths), Number<DstScalarPerVector>{});
constexpr auto desc0 =
make_naive_tensor_descriptor_packed(dst_access_lengths_and_vector_length);
// 2nd stage of transforms
constexpr auto transforms = generate_tuple(
[&](auto i) {
if constexpr(i == DstVectorDim)
{
return make_merge_transform_v3_division_mod(
make_tuple(dst_access_lengths_and_vector_length[i],
dst_access_lengths_and_vector_length[Number<nDim>{}]));
}
else
{
return make_pass_through_transform(dst_access_lengths_and_vector_length[i]);
}
},
Number<nDim>{});
constexpr auto low_dim_idss = generate_tuple(
[&](auto i) {
if constexpr(i == DstVectorDim)
{
return Sequence<i.value, nDim>{};
}
else
{
return Sequence<i.value>{};
}
},
Number<nDim>{});
constexpr auto up_dim_idss =
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
}
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
template <index_t ISrc>
__device__ void MoveSrcSliceWindow(const SrcDescs& src_descs,
@@ -372,11 +572,14 @@ struct ThreadwiseTensorSliceTransfer_v7r2
private:
using SrcVectorsType = decltype(generate_vectors<SrcDatas, SrcScalarPerVector>());
using ElmVectorsType = decltype(generate_vectors<DstDatas, SrcScalarPerVector>());
using DstVectorsType = decltype(generate_vectors<DstDatas, DstScalarPerVector>());
static constexpr auto num_access = SrcSpaceFillingCurve::GetNumOfAccess();
static constexpr auto src_num_access = SrcSpaceFillingCurve::GetNumOfAccess();
static constexpr auto dst_num_access = DstSpaceFillingCurve::GetNumOfAccess();
StaticallyIndexedArray<DstVectorsType, num_access> dst_vectors_tuple_;
StaticallyIndexedArray<ElmVectorsType, src_num_access> elm_vectors_tuple_;
StaticallyIndexedArray<DstVectorsType, dst_num_access> dst_vectors_tuple_;
SrcCoords src_coords_;
DstCoords dst_coords_;