Merge branch origin/wip-f4 into andriy/wip-f4

This commit is contained in:
Andriy Roshchenko
2025-05-23 22:14:30 +00:00
11 changed files with 501 additions and 475 deletions

View File

@@ -36,6 +36,8 @@ struct ExecutionConfig final
int init_method = 2; // (0=constant values, 1=integer values, 2=decimal values)
bool time_kernel = false; // (0=no, 1=yes)
int verbosity = 0; // (0=no info, 1=verbose info)
int warm_up = 10;
int repeat = 10;
};
struct ProblemSizeSplitK final
@@ -86,6 +88,8 @@ bool parse_cmd_args(int argc,
if(argc >= 12)
{
problem_size.KBatch = std::stoi(argv[11]);
config.warm_up = std::stoi(argv[12]);
config.repeat = std::stoi(argv[13]);
}
}
else
@@ -282,22 +286,13 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c
// ck::utils::FillConstant<ADataType>{a_data_element(1.0f)}(a_m_k);
// ck::utils::FillConstant<BDataType>{b_data_element(1.0f)}(b_k_n);
if constexpr(ck::is_same_v<XDataType, ck::e8m0_bexp_t>)
{
a_m_k_scale.GenerateTensorValue(
GeneratorTensor_2<XDataType>{120, 129}); // scales: {0.25, 0.5, 1, 2}
b_k_n_scale.GenerateTensorValue(
GeneratorTensor_2<XDataType>{125, 129}); // scales: {0.25, 0.5, 1, 2}
// ck::utils::FillConstant<XDataType>{ck::type_convert<XDataType>(1.0f)}(a_m_k_scale);
// ck::utils::FillConstant<XDataType>{ck::type_convert<XDataType>(1.0f)}(b_k_n_scale);
}
else
{
ck::utils::FillUniformDistributionIntegerValue<XDataType>{-1.0f, 1.0f}(a_m_k_scale);
ck::utils::FillUniformDistributionIntegerValue<XDataType>{-1.0f, 1.0f}(b_k_n_scale);
// ck::utils::FillConstant<XDataType>{ck::type_convert<XDataType>(1.0f)}(a_m_k_scale);
// ck::utils::FillConstant<XDataType>{ck::type_convert<XDataType>(0.5f)}(b_k_n_scale);
}
static_assert(ck::is_same_v<XDataType, ck::e8m0_bexp_t>);
a_m_k_scale.GenerateTensorValue(
GeneratorTensor_2<XDataType>{120, 129}); // scales: {0.25, 0.5, 1, 2}
b_k_n_scale.GenerateTensorValue(
GeneratorTensor_2<XDataType>{125, 129}); // scales: {0.25, 0.5, 1, 2}
// ck::utils::FillConstant<XDataType>{ck::type_convert<XDataType>(1.0f)}(a_m_k_scale);
// ck::utils::FillConstant<XDataType>{ck::type_convert<XDataType>(1.0f)}(b_k_n_scale);
break;
@@ -420,8 +415,9 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c
std::cout << "Computing GEMM on device..." << std::endl << std::endl;
}
float ave_time =
invoker.Run(argument, StreamConfig{nullptr, config.time_kernel, config.verbosity, 20, 50});
float ave_time = invoker.Run(
argument,
StreamConfig{nullptr, config.time_kernel, config.verbosity, config.warm_up, config.repeat});
bool res_verified = true;
if(config.do_verification > 0)
@@ -493,14 +489,14 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c
// partial sums(K/ScaleBlockSize)]
// FLOPS = 2 * M * N * K + 2 * M * N * K / ScaleBlockSize
std::size_t flop = std::size_t(2) * M * N * K + std::size_t(2) * M * N * K / ScaleBlockSize;
std::size_t num_btype = sizeof(ADataType) * M * K / ck::packed_size_v<ADataType> +
sizeof(BDataType) * K * N / ck::packed_size_v<BDataType> +
sizeof(CDataType) * M * N +
sizeof(XDataType) * (M * K + K * N) / ScaleBlockSize;
std::size_t num_btype =
sizeof(ADataType) * M * K / ck::packed_size_v<ADataType> +
sizeof(BDataType) * K * N / ck::packed_size_v<BDataType> + sizeof(CDataType) * M * N +
sizeof(XDataType) * M * K / ScaleBlockSize + sizeof(XDataType) * N * K / ScaleBlockSize;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
float gb_per_sec = static_cast<float>(num_btype) / 1e6f / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
<< " GB/s, " << device_op.GetTypeString() << std::endl;

View File

@@ -23,8 +23,8 @@ using AElementOp = PassThrough; // elementwise transformation for A matrix
using BElementOp = PassThrough; // elementwise transformation for B matrix
using CElementOp = PassThrough; // elementwise transformation for C matrix
constexpr ck::index_t DataPackedSize = 2; // Packed representation of data
constexpr ck::index_t ScaleBlockSize = 32; // scaling block size
constexpr ck::index_t DataPackedSize = 2; // Packed representation of data
constexpr ck::index_t ScaleBlockSize = 32; // scaling block size
constexpr ck::index_t KPerBlock = 256 / DataPackedSize; // 256 f4 = 128 fp4x2
constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
@@ -50,14 +50,14 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffle
GemmSpec, // GemmSpec
ScaleBlockSize, // ScaleBlockSize: Scaling block size
256, // BlockSize: Thread block size
192, // MPerBlock
256, // MPerBlock
256, // NPerBlock
KPerBlock, // KPerBlock
16, // AK1
16, // BK1
16, // MPerXDL
16, // NPerXDL
6, // MXdlPerWave
8, // MXdlPerWave
8, // NXdlPerWave
S<8, 32, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
@@ -65,14 +65,14 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffle
2, // ABlockTransferSrcVectorDim
16, // ABlockTransferSrcScalarPerVector
16, // ABlockTransferDstScalarPerVector_AK1
false, // ABlockLdsExtraM
true, // ABlockLdsExtraM
S<8, 32, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
16, // BBlockTransferSrcScalarPerVector
16, // BBlockTransferDstScalarPerVector_BK1
false, // BBlockLdsExtraN
true, // BBlockLdsExtraN
2, // CShuffleMXdlPerWavePerShuffle
2, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock

View File

@@ -203,8 +203,8 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx<BlockGemmPipelineScheduler::Intrawave,
? HotLoopInstList::B_LDS_Read_Inst_Num
: HotLoopInstList::B_LDS_Read_Inst_Num / 2;
constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num;
constexpr auto num_ds_write_inst_b = HotLoopInstList::B_LDS_Write_Inst_Num;
// constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num;
// constexpr auto num_ds_write_inst_b = HotLoopInstList::B_LDS_Write_Inst_Num;
constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num;
constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num;
@@ -243,29 +243,21 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx<BlockGemmPipelineScheduler::Intrawave,
constexpr auto mfma_stages_more =
num_mfma_stage1 - mfma_perstage_less * num_buffer_load_total;
constexpr auto num_dswrite_per_issue_a = num_ds_write_inst_a / num_buffer_load_inst_a;
constexpr auto num_dswrite_per_issue_b = num_ds_write_inst_b / num_buffer_load_inst_b;
// constexpr auto num_dswrite_per_issue_a = num_ds_write_inst_a / num_buffer_load_inst_a;
// constexpr auto num_dswrite_per_issue_b = num_ds_write_inst_b / num_buffer_load_inst_b;
static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) {
if constexpr(i < mfma_stages_more)
{
static_for<0, mfma_perstage_more, 1>{}([&](auto imfma) {
static_for<0, mfma_perstage_more, 1>{}([&](auto /*imfma*/) {
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
if constexpr(imfma < num_dswrite_per_issue_a)
{
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
}
});
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
}
else
{
static_for<0, mfma_perstage_less, 1>{}([&](auto imfma) {
static_for<0, mfma_perstage_less, 1>{}([&](auto /*imfma*/) {
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
if constexpr(imfma < num_dswrite_per_issue_a)
{
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
}
});
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
}
@@ -274,23 +266,15 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx<BlockGemmPipelineScheduler::Intrawave,
static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) {
if constexpr((i + num_buffer_load_inst_a) < mfma_stages_more)
{
static_for<0, mfma_perstage_more, 1>{}([&](auto imfma) {
static_for<0, mfma_perstage_more, 1>{}([&](auto /*imfma*/) {
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
if constexpr(imfma < num_dswrite_per_issue_a)
{
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
}
});
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
}
else
{
static_for<0, mfma_perstage_less, 1>{}([&](auto imfma) {
static_for<0, mfma_perstage_less, 1>{}([&](auto /*imfma*/) {
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
if constexpr(imfma < num_dswrite_per_issue_b)
{
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
}
});
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
}
@@ -392,14 +376,14 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx<BlockGemmPipelineScheduler::Intrawave,
const ABlockDesc& a_block_desc,
ABlockTransfer& a_blockwise_copy,
const AGridBuffer& a_grid_buf,
ABlockBuffer& a_block_buf,
ABlockBuffer& a_block_bufs,
const ABlockTransferStep& a_block_copy_step,
// BBlockCopy
const BGridDesc& b_grid_desc,
const BBlockDesc& b_block_desc,
BBlockTransfer& b_blockwise_copy,
const BGridBuffer& b_grid_buf,
BBlockBuffer& b_block_buf,
BBlockBuffer& b_block_bufs,
const BBlockTransferStep& b_block_copy_step,
// CThread
CThreadBuffer& c_thread_buf,
@@ -427,8 +411,8 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx<BlockGemmPipelineScheduler::Intrawave,
StaticallyIndexedArray<decltype(b_scale_thread_buf), Number<2>{}> b_scale_thread_bufs;
// Global prefetch 1
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_bufs(I0));
b_blockwise_copy.Run(b_grid_desc, b_grid_buf, b_block_desc, b_block_bufs(I0));
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
@@ -476,18 +460,8 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx<BlockGemmPipelineScheduler::Intrawave,
b_scale_grid_desc,
make_multi_index(-NWaves * NRepeat / NXdlPack, KRepeat / KXdlPack, 0));
// Local prefill 1
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
// Global prefetch 2
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// Local prefetch 1
// Local prefetch 1, sync the async load
__builtin_amdgcn_s_waitcnt(3952);
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k) {
constexpr auto k_step = k * xdlops_gemm.KPerXdlops * KPack / xdlops_gemm.K1PerXdlops;
@@ -502,7 +476,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx<BlockGemmPipelineScheduler::Intrawave,
Number<m0 % MXdlPack>{},
I0,
Number<a_k_step_chunk>{}),
a_block_buf,
a_block_bufs(I0),
a_thread_desc_,
make_tuple(Number<m0 / MXdlPack>{},
I0,
@@ -524,7 +498,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx<BlockGemmPipelineScheduler::Intrawave,
Number<n0 % NXdlPack>{},
I0,
Number<b_k_step_chunk>{}),
b_block_buf,
b_block_bufs(I0),
b_thread_desc_,
make_tuple(Number<n0 / NXdlPack>{},
I0,
@@ -536,6 +510,13 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx<BlockGemmPipelineScheduler::Intrawave,
});
});
// Global prefetch 2
a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_bufs(I1));
b_blockwise_copy.Run(b_grid_desc, b_grid_buf, b_block_desc, b_block_bufs(I1));
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// Initialize C
c_thread_buf.Clear();
__builtin_amdgcn_sched_barrier(0);
@@ -548,13 +529,13 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx<BlockGemmPipelineScheduler::Intrawave,
do
{
auto LoopFunc = [&](auto scale_comp_buf, auto scale_mem_buf) {
// __builtin_amdgcn_s_waitcnt(3952);
block_sync_lds();
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
a_blockwise_copy.Run(
a_grid_desc, a_grid_buf, a_block_desc, a_block_bufs(scale_comp_buf));
b_blockwise_copy.Run(
b_grid_desc, b_grid_buf, b_block_desc, b_block_bufs(scale_comp_buf));
// Prefetch a_scales
static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) {
@@ -699,7 +680,8 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx<BlockGemmPipelineScheduler::Intrawave,
// t32: |32 --> 47 96 --> 111| 160 --> 175 224 --> 239| etc.
// t48: |48 --> 63 112 --> 127| 176 --> 191 240 --> 255| etc.
// k = 0 k = 1
block_sync_lds();
// __builtin_amdgcn_s_waitcnt(3952);
// block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k) {
constexpr auto k_step =
k * xdlops_gemm.KPerXdlops * KPack / xdlops_gemm.K1PerXdlops;
@@ -716,7 +698,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx<BlockGemmPipelineScheduler::Intrawave,
Number<m0 % MXdlPack>{},
I0,
Number<a_k_step_chunk>{}),
a_block_buf,
a_block_bufs(scale_mem_buf),
a_thread_desc_,
make_tuple(Number<m0 / MXdlPack>{},
I0,
@@ -740,7 +722,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx<BlockGemmPipelineScheduler::Intrawave,
Number<n0 % NXdlPack>{},
I0,
Number<b_k_step_chunk>{}),
b_block_buf,
b_block_bufs(scale_mem_buf),
b_thread_desc_,
make_tuple(Number<n0 / NXdlPack>{},
I0,
@@ -798,10 +780,6 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx<BlockGemmPipelineScheduler::Intrawave,
b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0));
});
block_sync_lds();
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) {
static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) {
static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
@@ -880,6 +858,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx<BlockGemmPipelineScheduler::Intrawave,
});
});
__builtin_amdgcn_s_waitcnt(3952);
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k) {
@@ -897,7 +876,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx<BlockGemmPipelineScheduler::Intrawave,
Number<m0 % MXdlPack>{},
I0,
Number<a_k_step_chunk>{}),
a_block_buf,
a_block_bufs(I1),
a_thread_desc_,
make_tuple(Number<m0 / MXdlPack>{},
I0,
@@ -920,7 +899,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx<BlockGemmPipelineScheduler::Intrawave,
Number<n0 % NXdlPack>{},
I0,
Number<b_k_step_chunk>{}),
b_block_buf,
b_block_bufs(I1),
b_thread_desc_,
make_tuple(Number<n0 / NXdlPack>{},
I0,

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -42,13 +42,16 @@ namespace ck {
template <typename ThreadGroup,
typename BlockSliceLengths,
typename ThreadClusterLengths,
typename ThreadClusterArrangeOrder,
typename SrcData,
typename DstData,
typename SrcDesc,
typename DstDesc,
typename SrcDimAccessOrder,
index_t SrcVectorDim,
index_t DstVectorDim,
index_t ScalarPerVector>
index_t ScalarPerVector,
bool SrcXor = true>
struct ThreadGroupTensorSliceTransfer_DirectLoad
{
static constexpr index_t nDim = remove_reference_t<SrcDesc>::GetNumOfDimension();
@@ -61,15 +64,24 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad
using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{}));
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto block_slice_lengths = BlockSliceLengths{};
static constexpr auto thread_cluster_lengths = ThreadClusterLengths{};
static constexpr auto wave_thread_cluster_lengths =
Sequence<ThreadClusterLengths{}.At(I0),
ThreadClusterLengths{}.At(I1) * 64 / ThreadGroup::GetNumOfThread(),
1>{};
static constexpr auto wave_cluster_lengths =
Sequence<1, ThreadGroup::GetNumOfThread() / 64, 1>{};
static constexpr auto thread_single_load_size = generate_sequence(
detail::lambda_scalar_per_access<DstVectorDim, ScalarPerVector>{}, Number<nDim>{});
// After a load, each thread moves by `thread_steps` instead of loading the next elements.
// It makes the whole wavefront load contiguous memory, what is required for direct loads.
static constexpr auto thread_steps = thread_cluster_lengths * thread_single_load_size;
static constexpr auto thread_steps = thread_cluster_lengths * thread_single_load_size;
static constexpr auto wave_single_load_size =
wave_thread_cluster_lengths * thread_single_load_size;
static constexpr auto thread_slice_lengths = block_slice_lengths / thread_steps;
static __device__ constexpr bool AreThreadClusterLengthsValid()
@@ -96,8 +108,12 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad
// VALID: ThreadClusterLengths = [4, 16, 4] or [2, 32, 4] or [1, 64, 4] since in the
// first iteration, threads 0-63 write [0, 0, 0] - [0, 15, 7] -> 128 consecutive
// elements = 64 consecutive DWORDs.
#if defined(__gfx950__)
int num_contiguous_dwords = 4;
#else
int num_contiguous_dwords = 1;
bool is_contiguous = true;
#endif
bool is_contiguous = true;
static_for<0, nDim, 1>{}([&](auto i) {
if(is_contiguous)
{
@@ -141,11 +157,11 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad
"When loading more than one element per thread at once, the contiguous "
"dimension must be the same between source and destination.");
constexpr auto dword_bytes = 4;
constexpr auto bytes_per_thread_load = ScalarPerVector * sizeof(SrcData);
static_assert(bytes_per_thread_load == dword_bytes,
"Direct load transfer requires each thread to load exactly a single "
"DWORD of data.");
// constexpr auto dword_bytes = 4;
// constexpr auto bytes_per_thread_load = ScalarPerVector * sizeof(SrcData);
// static_assert(bytes_per_thread_load == dword_bytes,
// "Direct load transfer requires each thread to load exactly a single "
// "DWORD of data.");
static_assert(nDim == remove_cvref_t<SrcDesc>::GetNumOfDimension() &&
nDim == remove_cvref_t<DstDesc>::GetNumOfDimension() &&
@@ -156,18 +172,24 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad
"The number of threads cannot be less than the number of elements in "
"thread cluster lengths.");
static_assert(
AreThreadClusterLengthsValid(),
"Thread cluster lengths are incorrect. They must be set in a way that allows a single "
"wavefront to write contiguous DWORDs into LDS memory. ");
// static_assert(
// AreThreadClusterLengthsValid(),
// "Thread cluster lengths are incorrect. They must be set in a way that allows a single
// " "wavefront to write contiguous DWORDs into LDS memory. ");
const auto thread_cluster_idx =
thread_cluster_desc_.CalculateBottomIndex(make_multi_index(ThreadGroup::GetThreadId()));
const auto wave_cluster_idx = wave_cluster_desc_.CalculateBottomIndex(
make_multi_index(ThreadGroup::GetThreadId() / 64));
const auto thread_data_idx_begin = thread_cluster_idx * thread_single_load_size;
const auto wave_data_idx_begin = wave_cluster_idx * wave_single_load_size;
SetSrcSliceOrigin(src_desc, src_block_slice_origin + thread_data_idx_begin);
SetDstSliceOrigin(dst_desc, dst_block_slice_origin + thread_data_idx_begin);
// We don't need threadwise offset for lds since it was calculate by HW
// We still need input the wavewise offset.
SetDstSliceOrigin(dst_desc, dst_block_slice_origin + wave_data_idx_begin);
}
__device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx)
@@ -215,7 +237,7 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad
// Loop over the destination block and copy data.
static_ford<decltype(dst_access_lengths)>{}([&](auto ordered_dst_access_idx) {
const auto src_offset = src_coord_.GetOffset();
const auto dst_offset = dst_coord_.GetOffset();
const auto dst_offset = __builtin_amdgcn_readfirstlane(dst_coord_.GetOffset());
// Check if src data is not in the logic padding area.
const bool is_src_valid =
@@ -303,7 +325,10 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad
}
private:
static constexpr auto thread_cluster_desc_ = make_cluster_descriptor(ThreadClusterLengths{});
static constexpr auto thread_cluster_desc_ =
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
static constexpr auto wave_cluster_desc_ =
make_cluster_descriptor(wave_cluster_lengths, ThreadClusterArrangeOrder{});
SrcCoord src_coord_;
DstCoord dst_coord_;

View File

@@ -299,119 +299,43 @@ struct DeviceGemmMX_Xdl_CShuffleV3 : public DeviceGemmMX<ALayout,
constexpr index_t minimum_occupancy =
BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave
? (BlkGemmPipelineVer == BlockGemmPipelineVersion::v3 &&
MPerBlock * NPerBlock * KPerBlock * sizeof(ADataType) <=
128 * 128 * 64 * 2)
MPerBlock * NPerBlock * KPerBlock * sizeof(ADataType) <= 128 * 128 * 64 * 2)
? 2
: 1
: 2;
if(has_main_k_block_loop)
{
// Tail number always full
constexpr auto TailNumChoices = []() {
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
{
if(arg.KBatch > 1)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy>;
Run(kernel);
}
else
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy>;
Run(kernel);
}
}
// Tail number could be Odd or Even
return Tuple<constant<TailNumber::Full>>{};
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{
#if 1
if(arg.KBatch > 1)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
else
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
#endif
}
return Tuple<constant<TailNumber::Even>, constant<TailNumber::Odd>>{};
else
{
throw std::runtime_error("wrong! BlkGemmPipelineVer");
static_assert(false, "Unexpected BlkGemmPipelineVer!");
}
}
else
{
// Tail number always 1
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
{
if(arg.KBatch > 1)
}();
const TailNumber tail_num = GridwiseGemm::CalculateKBlockLoopTailNum(K_split);
using BoolChoices = Tuple<ck::true_type, ck::false_type>;
static_for_product<BoolChoices,
BoolChoices,
remove_cvref_t<decltype(TailNumChoices)>>{}(
[&](auto mainloop_choice, auto KBatch_cond_choice, auto tail_num_choice) {
constexpr auto CGlobalMemoryDataOperation =
KBatch_cond_choice.value ? InMemoryDataOperationEnum::AtomicAdd
: InMemoryDataOperationEnum::Set;
if(mainloop_choice.value == has_main_k_block_loop &&
KBatch_cond_choice.value == (arg.KBatch > 1) &&
tail_num_choice.value == tail_num)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
false,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy>;
const auto kernel = kernel_gemm_xdl_cshuffle_v3_2lds< //
GridwiseGemm,
mainloop_choice.value,
CGlobalMemoryDataOperation,
minimum_occupancy,
tail_num_choice.value>;
Run(kernel);
}
else
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
false,
InMemoryDataOperationEnum::Set,
minimum_occupancy>;
Run(kernel);
}
}
}
});
return ave_time;
}

View File

@@ -14,6 +14,7 @@
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/utility/env.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_direct_load.hpp"
namespace ck {
@@ -76,9 +77,10 @@ __global__ void
GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
karg.p_a_scale_grid + splitk_batch_offset.a_scale_k_split_offset,
karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
karg.p_b_scale_grid + splitk_batch_offset.b_scale_k_split_offset,
karg.p_c_grid + splitk_batch_offset.c_reduce_offset,
karg.p_b_scale_grid + splitk_batch_offset.scale_k_split_offset,
p_shared_0,
p_shared_1,
karg);
@@ -198,7 +200,8 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
NPerXdl,
ComputeTypeB,
is_single_rate_mfma,
is_scale_mfma>::selected_mfma.k_per_blk/APackedSize);
is_scale_mfma>::selected_mfma.k_per_blk /
APackedSize);
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
@@ -265,10 +268,18 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
__host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1&)
{
constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{});
constexpr index_t MN = TileDesc_K0_MN_K1{}.GetLength(Number<1>{});
constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{});
return transform_tensor_descriptor(
constexpr auto permuted_desc = transform_tensor_descriptor(
TileDesc_K0_MN_K1{},
make_tuple(make_xor_with_modulo_transform(make_tuple(Number<MN>{}, Number<K0>{})),
make_pass_through_transform(Number<K1>{})),
make_tuple(Sequence<1, 0>{}, Sequence<2>{}),
make_tuple(Sequence<1, 0>{}, Sequence<2>{}));
return transform_tensor_descriptor(
permuted_desc,
make_tuple(make_merge_transform_v3_division_mod(make_tuple(Number<K0>{}, Number<K1>{})),
make_unmerge_transform(make_tuple(Number<MNXdlPerWave / MNXdlPack>{},
Number<MNWaves>{},
@@ -351,12 +362,29 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
// not pad M or K
const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
a_grid_desc_mraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)),
make_tuple(make_unmerge_transform(make_tuple(K / KPerBlock, AK0Number, AK1Value)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
return a_grid_desc_ak0_m_ak1;
const auto a_grid_desc_permuted = transform_tensor_descriptor(
a_grid_desc_ak0_m_ak1,
make_tuple(make_pass_through_transform(K / KPerBlock),
make_xor_with_modulo_transform(make_tuple(M, AK0Number)),
make_pass_through_transform(AK1Value)),
make_tuple(Sequence<0>{}, Sequence<2, 1>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<2, 1>{}, Sequence<3>{}));
const auto a_grid_desc = transform_tensor_descriptor(
a_grid_desc_permuted,
make_tuple(
make_merge_transform_v3_division_mod(make_tuple(K / KPerBlock, AK0Number)),
make_pass_through_transform(M),
make_pass_through_transform(AK1Value)),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
return a_grid_desc;
}
}
@@ -442,12 +470,30 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
// not pad N or K
const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
b_grid_desc_nraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)),
make_pass_through_transform(N)),
make_tuple(
make_unmerge_transform(make_tuple(K / KPerBlock, BK0Number, BK1Value)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
return b_grid_desc_bk0_n_bk1;
const auto b_grid_desc_permuted = transform_tensor_descriptor(
b_grid_desc_bk0_n_bk1,
make_tuple(make_pass_through_transform(K / KPerBlock),
make_xor_with_modulo_transform(make_tuple(N, BK0Number)),
make_pass_through_transform(BK1Value)),
make_tuple(Sequence<0>{}, Sequence<2, 1>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<2, 1>{}, Sequence<3>{}));
const auto b_grid_desc = transform_tensor_descriptor(
b_grid_desc_permuted,
make_tuple(
make_merge_transform_v3_division_mod(make_tuple(K / KPerBlock, BK0Number)),
make_pass_through_transform(N),
make_pass_through_transform(BK1Value)),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
return b_grid_desc;
}
else
{
@@ -648,10 +694,10 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
bool is_reduce_ = false)
: Problem{M_,
N_,
K_/APackedSize,
StrideA_/APackedSize,
K_ / APackedSize,
StrideA_ / APackedSize,
StrideScaleA_,
StrideB_/BPackedSize,
StrideB_ / BPackedSize,
StrideScaleB_,
StrideC_,
k_batch_},
@@ -723,21 +769,23 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
// Calculate A scale offset
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
{
a_scale_k_split_offset = k_id * karg.KRead / (ScaleBlockSize/APackedSize);
a_scale_k_split_offset = k_id * karg.KRead / (ScaleBlockSize / APackedSize);
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
{
a_scale_k_split_offset = k_id * karg.KRead / (ScaleBlockSize/APackedSize) * karg.StrideScaleA;
a_scale_k_split_offset =
k_id * karg.KRead / (ScaleBlockSize / APackedSize) * karg.StrideScaleA;
}
// Calculate B scale offset
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
{
b_scale_k_split_offset = k_id * (karg.KRead / (ScaleBlockSize/BPackedSize)) * karg.StrideScaleB;
b_scale_k_split_offset =
k_id * (karg.KRead / (ScaleBlockSize / BPackedSize)) * karg.StrideScaleB;
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
{
b_scale_k_split_offset = k_id * karg.KRead / (ScaleBlockSize/BPackedSize);
b_scale_k_split_offset = k_id * karg.KRead / (ScaleBlockSize / BPackedSize);
}
if(k_id < (karg.KBatch - 1))
@@ -771,9 +819,10 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
// A matrix in LDS memory, dst of blockwise copy
if constexpr(ABlockLdsExtraM || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
{
// contiguous in LDS
return make_naive_tensor_descriptor(
make_tuple(AK0Number, Number<MPerBlock>{}, AK1Number),
make_tuple(AK1Number, Number<KPerBlock + ABlockLdsExtraM>{}, I1));
make_tuple(AK1Number, Number<KPerBlock>{}, I1));
}
// xor tensor transformation request more unnecessary vgpr usage, would cause register spill
// in some cases.
@@ -888,9 +937,10 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
// B matrix in LDS memory, dst of blockwise copy
if constexpr(BBlockLdsExtraN || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
{
// contiguous in lds
return make_naive_tensor_descriptor(
make_tuple(BK0Number, Number<NPerBlock>{}, BK1Number),
make_tuple(BK1Number, Number<KPerBlock + BBlockLdsExtraN>{}, I1));
make_tuple(BK1Number, Number<KPerBlock>{}, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
{
@@ -1074,7 +1124,7 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
(NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
"Invalid tuning param!");
static_assert(KPerBlock % (ScaleBlockSize/BPackedSize) == 0,
static_assert(KPerBlock % (ScaleBlockSize / BPackedSize) == 0,
"KPerBlock should be multiple of ScaleBlockSize");
if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
@@ -1381,67 +1431,42 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
// B matrix in LDS memory, dst of blockwise copy
constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
// A matrix blockwise copy
auto a_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
AElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<AK0Number, MPerBlock, AK1Number>,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ADataType,
ADataType,
decltype(a_grid_desc_ak0_m_ak1),
decltype(a_block_desc_ak0_m_ak1),
ABlockTransferSrcAccessOrder,
Sequence<0, 1, 2>,
ABlockTransferSrcVectorDim,
2,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
1,
1,
AThreadTransferSrcResetCoordinateAfterRun,
true,
BlockwiseGemmPipe::GlobalBufferNum>(
ThreadGroupTensorSliceTransfer_DirectLoad<ThisThreadBlock,
Sequence<AK0Number, MPerBlock, AK1Number>,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ADataType,
ADataType,
decltype(a_grid_desc_ak0_m_ak1),
decltype(a_block_desc_ak0_m_ak1),
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
2,
ABlockTransferSrcScalarPerVector>(
a_grid_desc_ak0_m_ak1,
make_multi_index(0, m_block_data_idx_on_grid, 0),
a_element_op,
a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
make_multi_index(0, 0, 0));
// B matrix blockwise copy
auto b_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
BElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<BK0Number, NPerBlock, BK1Number>,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BDataType,
BDataType,
decltype(b_grid_desc_bk0_n_bk1),
decltype(b_block_desc_bk0_n_bk1),
BBlockTransferSrcAccessOrder,
Sequence<0, 1, 2>,
BBlockTransferSrcVectorDim,
2,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
1,
1,
BThreadTransferSrcResetCoordinateAfterRun,
true,
BlockwiseGemmPipe::GlobalBufferNum>(
ThreadGroupTensorSliceTransfer_DirectLoad<ThisThreadBlock,
Sequence<BK0Number, NPerBlock, BK1Number>,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BDataType,
BDataType,
decltype(b_grid_desc_bk0_n_bk1),
decltype(b_block_desc_bk0_n_bk1),
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
2,
BBlockTransferSrcScalarPerVector>(
b_grid_desc_bk0_n_bk1,
make_multi_index(0, n_block_data_idx_on_grid, 0),
b_element_op,
b_block_desc_bk0_n_bk1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
make_multi_index(0, 0, 0));
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
@@ -1449,12 +1474,11 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
// Cast after lds
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<ADataType*>(p_shared),
a_block_desc_ak0_m_ak1.GetElementSpaceSize());
static_cast<ADataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
reinterpret_cast<BDataType*>(static_cast<char*>(p_shared) + a_block_space_size_aligned *
sizeof(ADataType)),
reinterpret_cast<BDataType*>(static_cast<char*>(p_shared) +
a_block_space_size_aligned * sizeof(ADataType)),
b_block_desc_bk0_n_bk1.GetElementSpaceSize());
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
@@ -1556,7 +1580,6 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
// shuffle C and write out
{
// printf("c_thread_buf %f %f\n", c_thread_buf[I0], c_thread_buf[I1]);
static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
"wrong!");
@@ -1801,15 +1824,17 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
// A/B shuffled scale for better 8-bit scale access pattern
// MNRepeat -> KRepeat -> KThreadPerXdl -> MNThreadPerXdl -> KXdlPack -> MNXdlPack
const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor_packed(make_tuple(
problem.M / (MXdlPack * MPerXdl),
math::integer_divide_ceil(problem.K, (ScaleBlockSize/APackedSize)) / (KXdlPack * 64 / MPerXdl),
64 * KXdlPack * MXdlPack / scale_pack_size_a));
const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor_packed(
make_tuple(problem.M / (MXdlPack * MPerXdl),
math::integer_divide_ceil(problem.K, (ScaleBlockSize / APackedSize)) /
(KXdlPack * 64 / MPerXdl),
64 * KXdlPack * MXdlPack / scale_pack_size_a));
const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor_packed(make_tuple(
problem.N / (NXdlPack * NPerXdl),
math::integer_divide_ceil(problem.K, (ScaleBlockSize/BPackedSize)) / (KXdlPack * 64 / NPerXdl),
64 * KXdlPack * NXdlPack / scale_pack_size_b));
const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor_packed(
make_tuple(problem.N / (NXdlPack * NPerXdl),
math::integer_divide_ceil(problem.K, (ScaleBlockSize / BPackedSize)) /
(KXdlPack * 64 / NPerXdl),
64 * KXdlPack * NXdlPack / scale_pack_size_b));
Run<decltype(a_grid_desc_ak0_m_ak1),
decltype(a_scale_grid_desc_am_ak),
@@ -1855,12 +1880,6 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
c_grid_desc_mblock_mperblock_nblock_nperblock)
{
ignore = p_a_scale_grid;
ignore = a_scale_grid_desc_am_ak;
// TODO: Implement 2 LDS version
static_assert(false, "Not implemented");
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
@@ -1868,12 +1887,17 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
// A Scale buffer
const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
// B Scale buffer
const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_scale_grid, b_scale_grid_desc_bn_ak.GetElementSpaceSize());
const AElementwiseOperation a_element_op{};
const BElementwiseOperation b_element_op{};
static_assert(
is_same_v<AElementwiseOperation, tensor_operation::element_wise::PassThrough> &&
is_same_v<BElementwiseOperation, tensor_operation::element_wise::PassThrough>);
const CElementwiseOperation c_element_op{};
// divide block work by [M, N]
@@ -1909,67 +1933,42 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
// B matrix in LDS memory, dst of blockwise copy
constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
// A matrix blockwise copy
auto a_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
AElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<AK0Number, MPerBlock, AK1Number>,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ADataType,
ADataType,
decltype(a_grid_desc_ak0_m_ak1),
decltype(a_block_desc_ak0_m_ak1),
ABlockTransferSrcAccessOrder,
Sequence<0, 1, 2>,
ABlockTransferSrcVectorDim,
2,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
1,
1,
AThreadTransferSrcResetCoordinateAfterRun,
true,
BlockwiseGemmPipe::GlobalBufferNum>(
ThreadGroupTensorSliceTransfer_DirectLoad<ThisThreadBlock,
Sequence<AK0Number, MPerBlock, AK1Number>,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ADataType,
ADataType,
decltype(a_grid_desc_ak0_m_ak1),
decltype(a_block_desc_ak0_m_ak1),
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
2,
ABlockTransferSrcScalarPerVector>(
a_grid_desc_ak0_m_ak1,
make_multi_index(0, m_block_data_idx_on_grid, 0),
a_element_op,
a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
make_multi_index(0, 0, 0));
// B matrix blockwise copy
auto b_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
BElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<BK0Number, NPerBlock, BK1Number>,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BDataType,
BDataType,
decltype(b_grid_desc_bk0_n_bk1),
decltype(b_block_desc_bk0_n_bk1),
BBlockTransferSrcAccessOrder,
Sequence<0, 1, 2>,
BBlockTransferSrcVectorDim,
2,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
1,
1,
BThreadTransferSrcResetCoordinateAfterRun,
true,
BlockwiseGemmPipe::GlobalBufferNum>(
ThreadGroupTensorSliceTransfer_DirectLoad<ThisThreadBlock,
Sequence<BK0Number, NPerBlock, BK1Number>,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BDataType,
BDataType,
decltype(b_grid_desc_bk0_n_bk1),
decltype(b_block_desc_bk0_n_bk1),
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
2,
BBlockTransferSrcScalarPerVector>(
b_grid_desc_bk0_n_bk1,
make_multi_index(0, n_block_data_idx_on_grid, 0),
b_element_op,
b_block_desc_bk0_n_bk1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
make_multi_index(0, 0, 0));
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
@@ -2006,76 +2005,99 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
KPerBlock);
// B scale
static constexpr auto mfma =
MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl, ComputeTypeA, is_single_rate_mfma>{};
static constexpr auto KPerXdlops = mfma.GetKPerXdlops();
static constexpr auto K1PerXdlops = mfma.GetK1PerXdlops();
static constexpr auto K0PerXdlops = KPerXdlops / K1PerXdlops;
static constexpr auto KPerThread = KPerBlock / K0PerXdlops;
// Initial thread mapping for:
// BlockSize = 256
// MPerXdl=NPerXdl=32 and MPerBlock=NPerBlock=128 MRepeat=NRepeat=2 MWaves=NWaves=2
// For each [m0, n0] tile, there are 4 waves:
// tId in [ 0, 63] m x n = [ 0, 31] x [ 0, 31] waveId = [0, 0]
// tId in [ 64, 127] m x n = [ 0, 31] x [32, 63] waveId = [0, 1]
// tId in [128, 191] m x n = [32, 63] x [ 0, 31] waveId = [1, 0]
// tId in [192, 255] m x n = [32, 63] x [32, 63] waveId = [1, 1]
const index_t ScaleSliceSizeN = NXdlPerWave;
static constexpr auto ScaleSliceSizeK = (KPerThread + (ScaleBlockSize/BPackedSize) - 1) / (ScaleBlockSize/BPackedSize);
static constexpr auto KBlockScaleSliceSizeK =
(KPerBlock + (ScaleBlockSize/BPackedSize) - 1) / (ScaleBlockSize/BPackedSize);
// BlockSize = 128
// MPerXdl=NPerXdl=16 and MPerBlock=128 NPerBlock=16 MRepeat=4 NRepeat=1 MWaves=2 NWaves=1
// For each [m0, n0] tile, there are 2 waves:
// tId in [ 0, 63] m x n = [ 0, 15] x [0, 15] waveId = [0, 0]
// tId in [ 64, 127] m x n = [16, 31] x [0, 15] waveId = [1, 0]
constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<ScaleSliceSizeN>{}, Number<ScaleSliceSizeK>{}));
// TODO: Document initial thread mapping for more combinations of parameters
constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
const auto wave_idx = BlockwiseGemmPipe::GetWaveIdx();
const auto waveId_m = wave_idx[I0];
const auto waveId_n = wave_idx[I1];
auto b_thread_offset_n =
get_thread_local_1d_id() % NPerXdl +
(get_thread_local_1d_id() / BlockwiseGemmPipe::WaveSize) % NWaves * NPerXdl;
auto b_thread_offset_k =
(get_thread_local_1d_id() % BlockwiseGemmPipe::WaveSize) / NPerXdl * KPerThread;
// static constexpr auto mfma = BlockwiseGemmPipe::xdlops_gemm.mfma;
auto b_scale_thread_copy =
ThreadwiseTensorSliceTransfer_v2<BScaleDataType,
BScaleDataType,
decltype(b_scale_grid_desc_bn_ak),
decltype(b_scale_thread_desc),
Sequence<1, ScaleSliceSizeK>,
Sequence<0, 1>,
1,
ScaleSliceSizeK,
1,
false>(
b_scale_grid_desc_bn_ak,
make_multi_index(block_n_id * NPerBlock + b_thread_offset_n,
b_thread_offset_k / (ScaleBlockSize/BPackedSize)));
// auto thread_offset_k = (get_thread_local_1d_id() % BlockwiseGemmPipe::WaveSize) /
// mfma.selected_mfma.num_threads_per_blk;
constexpr auto b_scale_thread_slice_copy_step =
make_tuple(make_multi_index(NWaves * NPerXdl, 0),
make_multi_index(-NPerBlock, 0),
make_multi_index(-NPerBlock, KBlockScaleSliceSizeK));
// A wave access continuous memory
auto thread_offset_shuffled =
get_thread_local_1d_id() % BlockwiseGemmPipe::WaveSize * KXdlPack * MXdlPack;
blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
a_grid_desc_ak0_m_ak1,
a_block_desc_ak0_m_ak1,
a_blockwise_copy,
a_grid_buf,
a_block_bufs,
a_block_slice_copy_step,
b_grid_desc_bk0_n_bk1,
b_block_desc_bk0_n_bk1,
b_blockwise_copy,
b_grid_buf,
b_block_bufs,
b_block_slice_copy_step,
c_thread_buf,
b_scale_grid_desc_bn_ak,
b_scale_thread_desc,
b_scale_thread_copy,
b_scale_grid_buf,
b_scale_thread_slice_copy_step,
num_k_block_main_loop);
auto a_thread_offset_m = waveId_m;
auto a_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2<
AScaleDataType,
AScaleDataType,
decltype(a_scale_grid_desc_am_ak),
decltype(BlockwiseGemmPipe::a_scale_thread_desc),
Sequence<1, 1, KXdlPack * MXdlPack / scale_pack_size_a>, // SliceLengths
Sequence<0, 1, 2>, // DimAccessOrder
2, // SrcVectorDim
KXdlPack * MXdlPack / scale_pack_size_a, // SrcScalarPerVector
1, // SrcScalarStrideInVector
true>(a_scale_grid_desc_am_ak,
make_multi_index(block_m_id * MPerBlock / MPerXdl / MXdlPack + a_thread_offset_m,
0,
thread_offset_shuffled / scale_pack_size_a));
auto b_thread_offset_n = waveId_n;
auto b_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2<
BScaleDataType,
BScaleDataType,
decltype(b_scale_grid_desc_bn_ak),
decltype(BlockwiseGemmPipe::b_scale_thread_desc),
Sequence<1, 1, KXdlPack * NXdlPack / scale_pack_size_b>, // SliceLengths
Sequence<0, 1, 2>, // DimAccessOrder
2, // SrcVectorDim
KXdlPack * MXdlPack / scale_pack_size_b, // SrcScalarPerVector
1, // SrcScalarStrideInVector
true>(b_scale_grid_desc_bn_ak,
make_multi_index(block_n_id * NPerBlock / NPerXdl / NXdlPack + b_thread_offset_n,
0,
thread_offset_shuffled / scale_pack_size_b));
blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(a_grid_desc_ak0_m_ak1,
a_block_desc_ak0_m_ak1,
a_blockwise_copy,
a_grid_buf,
a_block_bufs,
a_block_slice_copy_step,
b_grid_desc_bk0_n_bk1,
b_block_desc_bk0_n_bk1,
b_blockwise_copy,
b_grid_buf,
b_block_bufs,
b_block_slice_copy_step,
c_thread_buf,
a_scale_grid_desc_am_ak,
a_scale_thread_copy,
a_scale_grid_buf,
b_scale_grid_desc_bn_ak,
b_scale_thread_copy,
b_scale_grid_buf,
num_k_block_main_loop);
// shuffle C and write out
{
static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
"wrong!");
static_assert(CShuffleMXdlPerWavePerShuffle % MXdlPack == 0 &&
CShuffleNXdlPerWavePerShuffle % NXdlPack == 0,
"wrong!");
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
@@ -2087,16 +2109,18 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
// TODO: hacky, fix it!
// c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3();
constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
constexpr auto M5 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I8);
constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I9);
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
@@ -2110,19 +2134,25 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
make_tuple(
make_freeze_transform(I0),
make_unmerge_transform(make_tuple(
Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
M1, // M1 = MWave
M2, // M2 * M3 * M4 = MPerXdl
M3,
M4)),
Number<CShuffleMXdlPerWavePerShuffle / MXdlPack>{}, // M0 (MXdlPerWave) per
// shuffle
M1, // M1 = MWave
M2, // M2 = MXdlPack
M3, // M3 * M4 * M5 = MPerXdl
M4,
M5)),
make_freeze_transform(I0),
make_unmerge_transform(make_tuple(
Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
N1, // N1 = NWave
N2))), // N2 = NPerXdl
Number<CShuffleNXdlPerWavePerShuffle / NXdlPack>{}, // N0 (NXdlPerWave) per
// shuffle
N1, // N1 = NWave
N2, // N2 = NXdlPack
N3))), // N3 = NPerXdl
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(
Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{}));
make_tuple(Sequence<>{},
Sequence<0, 2, 4, 6, 7, 8>{},
Sequence<>{},
Sequence<1, 3, 5, 9>{}));
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
@@ -2134,8 +2164,8 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4, M5))),
make_tuple(Sequence<0, 1, 2, 3, 4, 5>{}),
make_tuple(Sequence<0>{}));
const auto m_thread_data_on_block_idx =
@@ -2144,8 +2174,8 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(N0, N1, N2))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(make_merge_transform(make_tuple(N0, N1, N2, N3))),
make_tuple(Sequence<0, 1, 2, 3>{}),
make_tuple(Sequence<0>{}));
const auto n_thread_data_on_block_idx =
@@ -2153,36 +2183,39 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
make_multi_index(n_thread_data_on_block));
// shuffle: threadwise copy C from VGPR to LDS
auto c_thread_copy_vgpr_to_lds =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
CShuffleDataType,
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
ck::tensor_operation::element_wise::PassThrough,
Sequence<CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
I1,
I1,
M2,
I1,
M4,
I1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
7,
1,
InMemoryDataOperationEnum::Set,
1,
true>{
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
make_multi_index(0,
0,
m_thread_data_on_block_idx[I1],
n_thread_data_on_block_idx[I1],
m_thread_data_on_block_idx[I2],
m_thread_data_on_block_idx[I3],
m_thread_data_on_block_idx[I4],
n_thread_data_on_block_idx[I2]),
ck::tensor_operation::element_wise::PassThrough{}};
auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3<
AccDataType,
CShuffleDataType,
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
ck::tensor_operation::element_wise::PassThrough,
Sequence<CShuffleMXdlPerWavePerShuffle / MXdlPack,
CShuffleNXdlPerWavePerShuffle / NXdlPack,
I1,
I1,
M2,
N2,
M3,
I1,
M5,
I1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>,
9,
1,
InMemoryDataOperationEnum::Set,
1,
true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
make_multi_index(0,
0,
m_thread_data_on_block_idx[I1],
n_thread_data_on_block_idx[I1],
m_thread_data_on_block_idx[I2],
n_thread_data_on_block_idx[I2],
m_thread_data_on_block_idx[I3],
m_thread_data_on_block_idx[I4],
m_thread_data_on_block_idx[I5],
n_thread_data_on_block_idx[I3]),
ck::tensor_operation::element_wise::PassThrough{}};
// shuffle: blockwise copy C from LDS to global
auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
@@ -2212,12 +2245,23 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
// space filling curve for threadwise C in VGPR
constexpr auto sfc_c_vgpr =
SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
Sequence<CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
SpaceFillingCurve<Sequence<MXdlPerWave / MXdlPack,
NXdlPerWave / NXdlPack,
1,
1,
MXdlPack,
NXdlPack,
M2,
1,
M4,
1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>,
Sequence<CShuffleMXdlPerWavePerShuffle / MXdlPack,
CShuffleNXdlPerWavePerShuffle / NXdlPack,
1,
1,
MXdlPack,
NXdlPack,
M2,
1,
M4,
@@ -2273,6 +2317,7 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
TailNumber TailNum = TailNumber::Odd>
__device__ static void Run_2Lds(const ADataType* p_a_grid,
const AScaleDataType* p_a_scale_grid,
const BDataType* p_b_grid,
const BScaleDataType* p_b_scale_grid,
CDataType* p_c_grid,
@@ -2286,22 +2331,33 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0);
const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n, problem.MBlock, problem.NBlock);
const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor(
make_tuple(problem.N, math::integer_divide_ceil(problem.K, ScaleBlockSize/BPackedSize)),
make_tuple(problem.StrideScaleB, 1));
// A/B shuffled scale for better 8-bit scale access pattern
// MNRepeat -> KRepeat -> KThreadPerXdl -> MNThreadPerXdl -> KXdlPack -> MNXdlPack
const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor_packed(
make_tuple(problem.M / (MXdlPack * MPerXdl),
math::integer_divide_ceil(problem.K, (ScaleBlockSize / APackedSize)) /
(KXdlPack * 64 / MPerXdl),
64 * KXdlPack * MXdlPack / scale_pack_size_a));
const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor_packed(
make_tuple(problem.N / (NXdlPack * NPerXdl),
math::integer_divide_ceil(problem.K, (ScaleBlockSize / BPackedSize)) /
(KXdlPack * 64 / NPerXdl),
64 * KXdlPack * NXdlPack / scale_pack_size_b));
Run_2Lds<decltype(a_grid_desc_ak0_m_ak1),
decltype(a_scale_grid_desc_am_ak),
decltype(b_grid_desc_bk0_n_bk1),
decltype(b_scale_grid_desc_bn_ak),
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
HasMainKBlockLoop,
CGlobalMemoryDataOperation,
TailNum>(p_a_grid,
p_a_scale_grid,
p_b_grid,
p_b_scale_grid,
p_c_grid,
@@ -2309,6 +2365,7 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
p_shared_1,
problem,
a_grid_desc_ak0_m_ak1,
a_scale_grid_desc_am_ak,
b_grid_desc_bk0_n_bk1,
b_scale_grid_desc_bn_ak,
c_grid_desc_mblock_mperblock_nblock_nperblock);

View File

@@ -1022,7 +1022,12 @@ __device__ void amd_direct_load_global_to_lds(const T* global_base_ptr,
// Direct loads require that each thread reads and writes exactly a single DWORD.
constexpr auto dword_bytes = 4;
constexpr auto bytes_per_thread = sizeof(T) * NumElemsPerThread;
#if defined(__gfx950__)
static_assert(bytes_per_thread == dword_bytes || bytes_per_thread == dword_bytes * 3 ||
bytes_per_thread == dword_bytes * 4);
#else
static_assert(bytes_per_thread == dword_bytes);
#endif
#ifndef CK_CODE_GEN_RTC
const uint32_t* global_ptr =
@@ -1059,7 +1064,7 @@ __device__ void amd_direct_load_global_to_lds(const T* global_base_ptr,
#endif
llvm_amdgcn_raw_buffer_load_lds(
src_resource, lds_ptr, sizeof(uint32_t), global_offset_bytes, 0, 0, 0);
src_resource, lds_ptr, bytes_per_thread, global_offset_bytes, 0, 0, 0);
#endif
}
#endif

View File

@@ -1,10 +1,11 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/functional.hpp"
#include "ck/utility/sequence.hpp"
#include "ck/utility/tuple.hpp"
namespace ck {
@@ -70,4 +71,35 @@ struct static_for<0, N, 1> : detail::make_applier<N>
using detail::make_applier<N>::operator();
};
template <typename... Is>
struct static_for_range
{
template <typename F>
__host__ __device__ constexpr void operator()(F f) const
{
// tweak -fbracket-depth if compilation fails. Clang default limit is 256
(f(Is{}), ...);
}
};
template <typename... Ts>
struct static_for_product;
template <typename... Is>
struct static_for_product<Tuple<Is...>> : public static_for_range<Is...>
{
};
template <typename... Is, typename... Rest>
struct static_for_product<Tuple<Is...>, Rest...>
{
template <typename F>
__host__ __device__ constexpr void operator()(F f) const
{
static_for_product<Tuple<Is...>>{}([&](auto i0) { //
static_for_product<Rest...>{}([&](auto... is) { //
f(i0, is...);
});
});
}
};
} // namespace ck

View File

@@ -5,14 +5,22 @@
namespace ck {
template <auto v>
struct constant
{
using value_type = decltype(v);
using type = constant; // using injected-class-name
static constexpr value_type value = v;
__host__ __device__ constexpr operator value_type() const noexcept { return value; }
__host__ __device__ constexpr value_type operator()() const noexcept { return value; }
};
template <class T, T v>
struct integral_constant
struct integral_constant : constant<v>
{
static constexpr T value = v;
typedef T value_type;
typedef integral_constant type;
__host__ __device__ constexpr operator value_type() const noexcept { return value; }
__host__ __device__ constexpr value_type operator()() const noexcept { return value; }
};
template <typename TX, TX X, typename TY, TY Y>

View File

@@ -44,17 +44,18 @@ using device_gemm_mx_xdl_f4_f4_f16_mk_nk_mn_instances = std::tuple<
//#############################| | | | Type| Data| Type| Data| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
//#############################| | | | | Type| | Type| | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
//#############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 32, 128, 128, 16, 16, 16, 16, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, false, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 32, 256, 128, 16, 16, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, false, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 64, 128, 128, 16, 16, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, false, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 64, 256, 128, 16, 16, 16, 16, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, false, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 96, 128, 128, 16, 16, 16, 16, 6, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, false, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 96, 256, 128, 16, 16, 16, 16, 6, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, false, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
// DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 32, 128, 128, 16, 16, 16, 16, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
// DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 32, 256, 128, 16, 16, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
// DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 64, 128, 128, 16, 16, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
// DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 64, 256, 128, 16, 16, 16, 16, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
// DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 96, 128, 128, 16, 16, 16, 16, 6, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
// DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 96, 256, 128, 16, 16, 16, 16, 6, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 256, 128, 16, 16, 16, 16, 4, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, false, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>,
DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 256, 128, 128, 16, 16, 16, 16, 8, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, false, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>,
DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 128, 128, 16, 16, 16, 16, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, false, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>,
DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 64, 32, 32, 128, 16, 16, 16, 16, 2, 2, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, false, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, false, 2, 2, S<1, 16, 1, 4>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>,
DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 256, 256, 128, 16, 16, 16, 16, 8, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>,
DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 256, 128, 16, 16, 16, 16, 4, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>,
DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 256, 128, 128, 16, 16, 16, 16, 8, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>,
DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 128, 128, 16, 16, 16, 16, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>,
DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 64, 32, 32, 128, 16, 16, 16, 16, 2, 2, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 16, 1, 4>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>,
std::nullptr_t
// clang-format on
>;

View File

@@ -213,8 +213,7 @@ bool profile_gemm_mx_impl(int do_verification,
default:
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{-2.0, 2.0});
a_m_k_scale.GenerateTensorValue(
GeneratorTensor_3<XDataType>{powf(2.0f, -125.0f), 1.0f}); // R[2^-125, 1]
a_m_k_scale.GenerateTensorValue(GeneratorTensor_3<XDataType>{powf(2.0f, -125.0f), 1.0f});
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-2.0, 2.0});
b_k_n_scale.GenerateTensorValue(GeneratorTensor_3<XDataType>{powf(2.0f, -125.0f), 1.0f});