mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 11:47:48 +00:00
optimize offset math in dma
This commit is contained in:
@@ -36,6 +36,8 @@ struct ExecutionConfig final
|
||||
int init_method = 2; // (0=constant values, 1=integer values, 2=decimal values)
|
||||
bool time_kernel = false; // (0=no, 1=yes)
|
||||
int verbosity = 0; // (0=no info, 1=verbose info)
|
||||
int warm_up = 10;
|
||||
int repeat = 10;
|
||||
};
|
||||
|
||||
struct ProblemSizeSplitK final
|
||||
@@ -86,6 +88,8 @@ bool parse_cmd_args(int argc,
|
||||
if(argc >= 12)
|
||||
{
|
||||
problem_size.KBatch = std::stoi(argv[11]);
|
||||
config.warm_up = std::stoi(argv[12]);
|
||||
config.repeat = std::stoi(argv[13]);
|
||||
}
|
||||
}
|
||||
else
|
||||
@@ -411,8 +415,9 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c
|
||||
std::cout << "Computing GEMM on device..." << std::endl << std::endl;
|
||||
}
|
||||
|
||||
float ave_time =
|
||||
invoker.Run(argument, StreamConfig{nullptr, config.time_kernel, config.verbosity, 20, 50});
|
||||
float ave_time = invoker.Run(
|
||||
argument,
|
||||
StreamConfig{nullptr, config.time_kernel, config.verbosity, config.warm_up, config.repeat});
|
||||
|
||||
bool res_verified = true;
|
||||
if(config.do_verification > 0)
|
||||
@@ -484,14 +489,15 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c
|
||||
// partial sums(K/ScaleBlockSize)]
|
||||
// FLOPS = 2 * M * N * K + 2 * M * N * K / ScaleBlockSize
|
||||
std::size_t flop = std::size_t(2) * M * N * K + std::size_t(2) * M * N * K / ScaleBlockSize;
|
||||
std::size_t num_btype = sizeof(ADataType) * M * K / ck::pack_size_v<ADataType> +
|
||||
sizeof(BDataType) * K * N / ck::pack_size_v<BDataType> +
|
||||
std::size_t num_btype = sizeof(ADataType) * M * K / ck::pack_size_v<ADataType> + //
|
||||
sizeof(BDataType) * K * N / ck::pack_size_v<BDataType> + //
|
||||
sizeof(CDataType) * M * N +
|
||||
sizeof(XDataType) * (M * K + K * N) / ScaleBlockSize;
|
||||
sizeof(XDataType) * M * K / ScaleBlockSize +
|
||||
sizeof(XDataType) * N * K / ScaleBlockSize;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
float gb_per_sec = static_cast<float>(num_btype) / 1e6f / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
|
||||
<< " GB/s, " << device_op.GetTypeString() << std::endl;
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -64,15 +64,24 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad
|
||||
using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{}));
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
|
||||
static constexpr auto block_slice_lengths = BlockSliceLengths{};
|
||||
static constexpr auto thread_cluster_lengths = ThreadClusterLengths{};
|
||||
static constexpr auto wave_thread_cluster_lengths =
|
||||
Sequence<ThreadClusterLengths{}.At(I0),
|
||||
ThreadClusterLengths{}.At(I1) * 64 / ThreadGroup::GetNumOfThread(),
|
||||
1>{};
|
||||
static constexpr auto wave_cluster_lengths =
|
||||
Sequence<1, ThreadGroup::GetNumOfThread() / 64, 1>{};
|
||||
|
||||
static constexpr auto thread_single_load_size = generate_sequence(
|
||||
detail::lambda_scalar_per_access<DstVectorDim, ScalarPerVector>{}, Number<nDim>{});
|
||||
// After a load, each thread moves by `thread_steps` instead of loading the next elements.
|
||||
// It makes the whole wavefront load contiguous memory, what is required for direct loads.
|
||||
static constexpr auto thread_steps = thread_cluster_lengths * thread_single_load_size;
|
||||
static constexpr auto thread_steps = thread_cluster_lengths * thread_single_load_size;
|
||||
static constexpr auto wave_single_load_size =
|
||||
wave_thread_cluster_lengths * thread_single_load_size;
|
||||
static constexpr auto thread_slice_lengths = block_slice_lengths / thread_steps;
|
||||
|
||||
static __device__ constexpr bool AreThreadClusterLengthsValid()
|
||||
@@ -171,10 +180,16 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad
|
||||
const auto thread_cluster_idx =
|
||||
thread_cluster_desc_.CalculateBottomIndex(make_multi_index(ThreadGroup::GetThreadId()));
|
||||
|
||||
const auto wave_cluster_idx = wave_cluster_desc_.CalculateBottomIndex(
|
||||
make_multi_index(ThreadGroup::GetThreadId() / 64));
|
||||
|
||||
const auto thread_data_idx_begin = thread_cluster_idx * thread_single_load_size;
|
||||
const auto wave_data_idx_begin = wave_cluster_idx * wave_single_load_size;
|
||||
|
||||
SetSrcSliceOrigin(src_desc, src_block_slice_origin + thread_data_idx_begin);
|
||||
SetDstSliceOrigin(dst_desc, dst_block_slice_origin + thread_data_idx_begin);
|
||||
// We don't need threadwise offset for lds since it was calculate by HW
|
||||
// We still need input the wavewise offset.
|
||||
SetDstSliceOrigin(dst_desc, dst_block_slice_origin + wave_data_idx_begin);
|
||||
}
|
||||
|
||||
__device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx)
|
||||
@@ -222,7 +237,7 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad
|
||||
// Loop over the destination block and copy data.
|
||||
static_ford<decltype(dst_access_lengths)>{}([&](auto ordered_dst_access_idx) {
|
||||
const auto src_offset = src_coord_.GetOffset();
|
||||
const auto dst_offset = dst_coord_.GetOffset();
|
||||
const auto dst_offset = __builtin_amdgcn_readfirstlane(dst_coord_.GetOffset());
|
||||
|
||||
// Check if src data is not in the logic padding area.
|
||||
const bool is_src_valid =
|
||||
@@ -312,6 +327,8 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad
|
||||
private:
|
||||
static constexpr auto thread_cluster_desc_ =
|
||||
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
|
||||
static constexpr auto wave_cluster_desc_ =
|
||||
make_cluster_descriptor(wave_cluster_lengths, ThreadClusterArrangeOrder{});
|
||||
|
||||
SrcCoord src_coord_;
|
||||
DstCoord dst_coord_;
|
||||
|
||||
@@ -1895,8 +1895,9 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
|
||||
const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_scale_grid, b_scale_grid_desc_bn_ak.GetElementSpaceSize());
|
||||
|
||||
const AElementwiseOperation a_element_op{};
|
||||
const BElementwiseOperation b_element_op{};
|
||||
static_assert(
|
||||
is_same_v<AElementwiseOperation, tensor_operation::element_wise::PassThrough> &&
|
||||
is_same_v<BElementwiseOperation, tensor_operation::element_wise::PassThrough>);
|
||||
const CElementwiseOperation c_element_op{};
|
||||
|
||||
// divide block work by [M, N]
|
||||
|
||||
Reference in New Issue
Block a user