optimize offset math in dma

This commit is contained in:
aska-0096
2025-05-21 10:28:00 +00:00
committed by mtgu0705
parent 513f92f5b9
commit 5ea3fe488d
4 changed files with 31 additions and 9 deletions

View File

@@ -36,6 +36,8 @@ struct ExecutionConfig final
int init_method = 2; // (0=constant values, 1=integer values, 2=decimal values)
bool time_kernel = false; // (0=no, 1=yes)
int verbosity = 0; // (0=no info, 1=verbose info)
int warm_up = 10;
int repeat = 10;
};
struct ProblemSizeSplitK final
@@ -86,6 +88,8 @@ bool parse_cmd_args(int argc,
if(argc >= 12)
{
problem_size.KBatch = std::stoi(argv[11]);
config.warm_up = std::stoi(argv[12]);
config.repeat = std::stoi(argv[13]);
}
}
else
@@ -411,7 +415,7 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c
}
float ave_time =
invoker.Run(argument, StreamConfig{nullptr, config.time_kernel, config.verbosity, 20, 50});
invoker.Run(argument, StreamConfig{nullptr, config.time_kernel, config.verbosity, config.warm_up, config.repeat});
bool res_verified = true;
if(config.do_verification > 0)
@@ -482,14 +486,20 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c
// Output size(M*N) * [dot product(2K) + product of scales(K/ScaleBlockSize) + scaling of
// partial sums(K/ScaleBlockSize)]
// FLOPS = 2 * M * N * K + 2 * M * N * K / ScaleBlockSize
auto APackedSize =
ck::is_same_v<ck::remove_cvref_t<ADataType>, ck::f4x2_pk_t> ? 2 : 1;
auto BPackedSize =
ck::is_same_v<ck::remove_cvref_t<BDataType>, ck::f4x2_pk_t> ? 2 : 1;
std::size_t flop = std::size_t(2) * M * N * K + std::size_t(2) * M * N * K / ScaleBlockSize;
std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
std::size_t num_btype = sizeof(ADataType) * M * K/APackedSize + sizeof(BDataType) * K* N/BPackedSize +
sizeof(CDataType) * M * N +
sizeof(XDataType) * (M * K + K * N) / ScaleBlockSize;
sizeof(XDataType) * M * K / ScaleBlockSize +
sizeof(XDataType) * N * K / ScaleBlockSize;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
float gb_per_sec = static_cast<float>(num_btype) / static_cast<float>(1.E6) / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
<< " GB/s, " << device_op.GetTypeString() << std::endl;

View File

@@ -64,15 +64,19 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad
using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{}));
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto block_slice_lengths = BlockSliceLengths{};
static constexpr auto thread_cluster_lengths = ThreadClusterLengths{};
static constexpr auto wave_thread_cluster_lengths = Sequence<ThreadClusterLengths{}.At(I0), ThreadClusterLengths{}.At(I1)*64/ThreadGroup::GetNumOfThread(),1>{};
static constexpr auto wave_cluster_lengths = Sequence<1, ThreadGroup::GetNumOfThread()/64, 1>{};
static constexpr auto thread_single_load_size = generate_sequence(
detail::lambda_scalar_per_access<DstVectorDim, ScalarPerVector>{}, Number<nDim>{});
// After a load, each thread moves by `thread_steps` instead of loading the next elements.
// It makes the whole wavefront load contiguous memory, what is required for direct loads.
static constexpr auto thread_steps = thread_cluster_lengths * thread_single_load_size;
static constexpr auto wave_single_load_size= wave_thread_cluster_lengths*thread_single_load_size;
static constexpr auto thread_slice_lengths = block_slice_lengths / thread_steps;
static __device__ constexpr bool AreThreadClusterLengthsValid()
@@ -167,11 +171,17 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad
const auto thread_cluster_idx =
thread_cluster_desc_.CalculateBottomIndex(make_multi_index(ThreadGroup::GetThreadId()));
const auto wave_cluster_idx =
wave_cluster_desc_.CalculateBottomIndex(make_multi_index(ThreadGroup::GetThreadId()/64));
const auto thread_data_idx_begin = thread_cluster_idx * thread_single_load_size;
const auto wave_data_idx_begin = wave_cluster_idx * wave_single_load_size;
SetSrcSliceOrigin(src_desc, src_block_slice_origin + thread_data_idx_begin);
SetDstSliceOrigin(dst_desc, dst_block_slice_origin + thread_data_idx_begin);
// We don't need threadwise offset for lds since it was calculate by HW
// We still need input the wavewise offset.
SetDstSliceOrigin(dst_desc, dst_block_slice_origin + wave_data_idx_begin);
}
__device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx)
@@ -220,7 +230,7 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad
static_ford<decltype(dst_access_lengths)>{}([&](auto ordered_dst_access_idx) {
// CK_PRINT<decltype(dst_access_lengths), decltype(ordered_dst_access_idx)>();
const auto src_offset = src_coord_.GetOffset();
const auto dst_offset = dst_coord_.GetOffset();
const auto dst_offset = __builtin_amdgcn_readfirstlane(dst_coord_.GetOffset());
// printf("Tid: %03d, src_offset: %d, dst_offset: %d\n", get_thread_local_1d_id(),
// src_coord_.GetOffset(), dst_coord_.GetOffset());
// Check if src data is not in the logic padding area.
@@ -311,6 +321,8 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad
private:
static constexpr auto thread_cluster_desc_ =
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
static constexpr auto wave_cluster_desc_ =
make_cluster_descriptor(wave_cluster_lengths, ThreadClusterArrangeOrder{});
SrcCoord src_coord_;
DstCoord dst_coord_;

View File

@@ -1895,8 +1895,8 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_scale_grid, b_scale_grid_desc_bn_ak.GetElementSpaceSize());
const AElementwiseOperation a_element_op{};
const BElementwiseOperation b_element_op{};
// const AElementwiseOperation a_element_op{};
// const BElementwiseOperation b_element_op{};
const CElementwiseOperation c_element_op{};
// divide block work by [M, N]

View File

@@ -1020,7 +1020,7 @@ __device__ void amd_direct_load_global_to_lds(const T* global_base_ptr,
const index_t src_element_space_size)
{
// Direct loads require that each thread reads and writes exactly a single DWORD.
constexpr auto dword_bytes = 4;
// constexpr auto dword_bytes = 4;
constexpr auto bytes_per_thread = sizeof(T) * NumElemsPerThread;
// static_assert(bytes_per_thread == dword_bytes);