mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[rocm-libraries] ROCm/rocm-libraries#4340 (commit 70a312f)
Implement device_grouped_gemm_fixed_nk_bias for RDNA4 ## Proposed changes Summary: - Modified implementation for grouped_gemm_fixed_nk_bias - FP16 WMMA examples - WMMA instances - Profiler for grouped_gemm_fixed_nk_bias - Add WMMA instances to existing tests **This PR depends on PR https://github.com/ROCm/rocm-libraries/pull/4299 and should be merged after it. Only the last 6 commits are in the scope of this PR.** ## Checklist Please put an `x` into the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask. - [x] I have added tests relevant to the introduced functionality, and the unit tests are passing locally - [x] I have added the test to REGRESSION_TESTS list defined at the top of CMakeLists.txt in tests/CMakeLists.txt, **IF** the test takes more than 30 seconds to run. - [x] I have added inline documentation which enables the maintainers with understanding the motivation - [x] I have removed the stale documentation which is no longer relevant after this pull request - [ ] (If this change is user-facing) I have added release notes which provide the end users with a brief summary of the improvement from this pull request - [x] I have run `clang-format` on all changed files - [ ] Any dependent changes have been merged ## Discussion If this is a relatively large or complex change, feel free to start a discussion by explaining why you chose the solution you did and what alternatives you considered ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
9a32f0ea19
commit
75aea70c2c
@@ -7,22 +7,55 @@
|
||||
#include <sstream>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
|
||||
#include "ck/utility/env.hpp"
|
||||
#include "ck/host_utility/hip_check_error.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/utility/env.hpp"
|
||||
#include "ck/utility/scheduler_enum.hpp"
|
||||
#include "ck/utility/tuple.hpp"
|
||||
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_fixed_nk.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace element_wise {
|
||||
struct SplitKAdd
|
||||
{
|
||||
static constexpr const char* name = "SplitKAdd";
|
||||
|
||||
__host__ __device__ void set_kbatch(const index_t& id, const index_t& total)
|
||||
{
|
||||
kbatch_id = id;
|
||||
KBatch = total;
|
||||
}
|
||||
|
||||
template <typename Y, typename X0, typename X1>
|
||||
__host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const
|
||||
{
|
||||
if(kbatch_id == KBatch - 1)
|
||||
{
|
||||
add_op(y, x0, x1);
|
||||
}
|
||||
else
|
||||
{
|
||||
passthrough_op(y, x0);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
index_t kbatch_id = 0;
|
||||
index_t KBatch = 1;
|
||||
static constexpr auto add_op = Add{};
|
||||
static constexpr auto passthrough_op = PassThrough{};
|
||||
};
|
||||
} // namespace element_wise
|
||||
|
||||
namespace device {
|
||||
|
||||
template <typename GridwiseGemm,
|
||||
@@ -110,6 +143,21 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
|
||||
while(id_local < local_grid_size)
|
||||
{
|
||||
const auto block_2_etile_map =
|
||||
GroupedGemmBlock2ETileMap(local_b2c_tile_map, group_start, id_off);
|
||||
|
||||
const auto tile_index =
|
||||
block_2_etile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
|
||||
|
||||
const index_t kbatch_id = __builtin_amdgcn_readfirstlane(tile_index[Number<0>{}]);
|
||||
|
||||
auto c_element_op_copy(c_element_op);
|
||||
if constexpr(std::is_same_v<decltype(c_element_op_copy),
|
||||
ck::tensor_operation::element_wise::SplitKAdd>)
|
||||
{
|
||||
c_element_op_copy.set_kbatch(kbatch_id, k_batch_);
|
||||
}
|
||||
|
||||
KernelArgument kernel_arg{std::array<const void*, 1>{gemmTransKernelArg.p_a_grid},
|
||||
std::array<const void*, 1>{gemmTransKernelArg.p_b_grid},
|
||||
gemmTransKernelArg.p_ds_grid,
|
||||
@@ -124,15 +172,9 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
k_batch_,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
c_element_op_copy,
|
||||
false};
|
||||
|
||||
const auto block_2_etile_map =
|
||||
GroupedGemmBlock2ETileMap(local_b2c_tile_map, group_start, id_off);
|
||||
|
||||
const auto tile_index =
|
||||
block_2_etile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
|
||||
|
||||
const auto splitk_batch_offset =
|
||||
typename GridwiseGemm::SplitKBatchOffset(kernel_arg, tile_index[Number<0>{}]);
|
||||
|
||||
@@ -813,7 +855,9 @@ struct DeviceGroupedGemm_Wmma_Fixed_Nk : public DeviceGroupedGemmFixedNK<ALayout
|
||||
}
|
||||
|
||||
if constexpr(!std::is_same_v<CDEElementwiseOperation,
|
||||
ck::tensor_operation::element_wise::PassThrough>)
|
||||
ck::tensor_operation::element_wise::PassThrough> &&
|
||||
!std::is_same_v<CDEElementwiseOperation,
|
||||
ck::tensor_operation::element_wise::SplitKAdd>)
|
||||
{
|
||||
if(arg.k_batch_ > 1)
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user