mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 21:51:28 +00:00
Added Multi_ABD support into Gemm and GroupedGemmFixedNK (#978)
* added an example grouped_gemm_multi_abd * fixed ci * add setElementwiseOp * changed API * clean code: add multiA into example * fixed v7r2 copy * add transpose * clean * fixed vector_load check * Update example/15_grouped_gemm/grouped_gemm_multi_abd_xdl_fixed_nk_bias_fp16.cpp Co-authored-by: Bartłomiej Kocot <barkocot@amd.com> * Update example/15_grouped_gemm/grouped_gemm_multi_abd_xdl_fixed_nk_bias_fp16.cpp Co-authored-by: Bartłomiej Kocot <barkocot@amd.com> * Update example/15_grouped_gemm/grouped_gemm_multi_abd_xdl_fixed_nk_bias_fp16.cpp Co-authored-by: Bartłomiej Kocot <barkocot@amd.com> * Update include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp Co-authored-by: Bartłomiej Kocot <barkocot@amd.com> * Update include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp Co-authored-by: Bartłomiej Kocot <barkocot@amd.com> * Update include/ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp Co-authored-by: Bartłomiej Kocot <barkocot@amd.com> * Update include/ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp Co-authored-by: Bartłomiej Kocot <barkocot@amd.com> * Update include/ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp Co-authored-by: Bartłomiej Kocot <barkocot@amd.com> * Update include/ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp Co-authored-by: Bartłomiej Kocot <barkocot@amd.com> * Update include/ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp Co-authored-by: Bartłomiej Kocot <barkocot@amd.com> * Update include/ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp Co-authored-by: Bartłomiej Kocot <barkocot@amd.com> * Update include/ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd_fixed_nk.hpp Co-authored-by: Bartłomiej Kocot <barkocot@amd.com> * Update include/ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd_fixed_nk.hpp Co-authored-by: Bartłomiej Kocot <barkocot@amd.com> * add reduce * testing * add example_b16_i8 * refactor example * clean * add mpading * disable reduce for kbatch = 1 * seperate reduce device op * add reduce op * add guard for workspace_size * add instances * format * fixed * add client example * add a colmajor * add instances * Update cmake-ck-dev.sh * Update profile_gemm_splitk.cpp * Update gridwise_gemm_xdlops_v2r4r2.hpp * format * Update profile_gemm_splitk.cpp * fixed * fixed * adjust test * adjust precision loss * adjust test * fixed * add bf16_i8 scale bias * fixed scale * fixed scale elementwise_op * revert contraction deviceop changes * fixed * Add AddFastGelu * Revert "Merge branch 'jizhan/gemm_splitk_reduce' into grouped_gemm_multi_abd_fixed_nk_example" This reverts commit3b5d001efd, reversing changes made to943199a991. * add Scales into elementwise * add gemm_multi_abd client example * add client examples * add rcr and crr * add grouped gemm client example * add grouped gemm client example * add instance for rcr crr * format * fixed * fixed cmake * fixed * fixed client_example * format * fixed contraction isSupport * Update include/ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd_fixed_nk.hpp Co-authored-by: Bartłomiej Kocot <barkocot@amd.com> * Update device_reduce_threadwise.hpp * clean * Fixes * Fix example --------- Co-authored-by: Jing Zhang <jizha@amd.com> Co-authored-by: Bartłomiej Kocot <barkocot@amd.com>
This commit is contained in:
@@ -439,7 +439,7 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
|
||||
|
||||
template <typename BLayout, GemmSpecialization GemmSpec>
|
||||
__host__ __device__ static auto
|
||||
MakeBGridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideB)
|
||||
MakeBGridDescriptor_N_K(const index_t NRaw, const index_t KRaw, const index_t StrideB)
|
||||
{
|
||||
constexpr auto matrix_padder =
|
||||
ck::tensor_operation::device::MatrixPadder<GemmSpec, index_t, index_t, index_t>{
|
||||
@@ -463,15 +463,15 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
|
||||
|
||||
template <typename BsLayout, GemmSpecialization GemmSpec>
|
||||
__host__ __device__ static auto
|
||||
MakeBsGridDescriptor_N_K(const std::array<index_t, NumBTensor>& KRaws,
|
||||
const std::array<index_t, NumBTensor>& NRaws,
|
||||
MakeBsGridDescriptor_N_K(const std::array<index_t, NumBTensor>& NRaws,
|
||||
const std::array<index_t, NumBTensor>& KRaws,
|
||||
const std::array<index_t, NumBTensor>& BsStride)
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
using BLayout = remove_cvref_t<tuple_element_t<i.value, BsLayout>>;
|
||||
|
||||
return MakeBGridDescriptor_N_K<BLayout, GemmSpec>(KRaws[i], NRaws[i], BsStride[i]);
|
||||
return MakeBGridDescriptor_N_K<BLayout, GemmSpec>(NRaws[i], KRaws[i], BsStride[i]);
|
||||
},
|
||||
Number<NumBTensor>{});
|
||||
}
|
||||
@@ -574,7 +574,6 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
// HACK: this force m/n_block_data_idx_on_grid into SGPR
|
||||
const index_t m_block_data_idx_on_grid =
|
||||
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
|
||||
@@ -595,8 +594,10 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
|
||||
generate_tuple([&](auto) { return make_multi_index(0, m_block_data_idx_on_grid, 0); },
|
||||
Number<NumATensor>{});
|
||||
|
||||
#if 0
|
||||
static_assert(ABlockTransferSrcScalarPerVector == ABlockTransferDstScalarPerVector_AK1,
|
||||
"Src and Dst ScalarPerVector must be the same");
|
||||
#endif
|
||||
|
||||
auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v7r2<
|
||||
ThisThreadBlock,
|
||||
@@ -626,8 +627,10 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
|
||||
generate_tuple([&](auto) { return make_multi_index(0, n_block_data_idx_on_grid, 0); },
|
||||
Number<NumBTensor>{});
|
||||
|
||||
#if 0
|
||||
static_assert(BBlockTransferSrcScalarPerVector == BBlockTransferDstScalarPerVector_BK1,
|
||||
"Src and Dst ScalarPerVector must be the same");
|
||||
#endif
|
||||
|
||||
auto b_blockwise_copy = ThreadGroupTensorSliceTransfer_v7r2<
|
||||
ThisThreadBlock,
|
||||
|
||||
Reference in New Issue
Block a user