mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 03:37:38 +00:00
Merge branch origin/wip-f4 into andriy/wip-f4
This commit is contained in:
@@ -36,6 +36,8 @@ struct ExecutionConfig final
|
||||
int init_method = 2; // (0=constant values, 1=integer values, 2=decimal values)
|
||||
bool time_kernel = false; // (0=no, 1=yes)
|
||||
int verbosity = 0; // (0=no info, 1=verbose info)
|
||||
int warm_up = 10;
|
||||
int repeat = 10;
|
||||
};
|
||||
|
||||
struct ProblemSizeSplitK final
|
||||
@@ -86,6 +88,8 @@ bool parse_cmd_args(int argc,
|
||||
if(argc >= 12)
|
||||
{
|
||||
problem_size.KBatch = std::stoi(argv[11]);
|
||||
config.warm_up = std::stoi(argv[12]);
|
||||
config.repeat = std::stoi(argv[13]);
|
||||
}
|
||||
}
|
||||
else
|
||||
@@ -282,22 +286,13 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c
|
||||
// ck::utils::FillConstant<ADataType>{a_data_element(1.0f)}(a_m_k);
|
||||
// ck::utils::FillConstant<BDataType>{b_data_element(1.0f)}(b_k_n);
|
||||
|
||||
if constexpr(ck::is_same_v<XDataType, ck::e8m0_bexp_t>)
|
||||
{
|
||||
a_m_k_scale.GenerateTensorValue(
|
||||
GeneratorTensor_2<XDataType>{120, 129}); // scales: {0.25, 0.5, 1, 2}
|
||||
b_k_n_scale.GenerateTensorValue(
|
||||
GeneratorTensor_2<XDataType>{125, 129}); // scales: {0.25, 0.5, 1, 2}
|
||||
// ck::utils::FillConstant<XDataType>{ck::type_convert<XDataType>(1.0f)}(a_m_k_scale);
|
||||
// ck::utils::FillConstant<XDataType>{ck::type_convert<XDataType>(1.0f)}(b_k_n_scale);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck::utils::FillUniformDistributionIntegerValue<XDataType>{-1.0f, 1.0f}(a_m_k_scale);
|
||||
ck::utils::FillUniformDistributionIntegerValue<XDataType>{-1.0f, 1.0f}(b_k_n_scale);
|
||||
// ck::utils::FillConstant<XDataType>{ck::type_convert<XDataType>(1.0f)}(a_m_k_scale);
|
||||
// ck::utils::FillConstant<XDataType>{ck::type_convert<XDataType>(0.5f)}(b_k_n_scale);
|
||||
}
|
||||
static_assert(ck::is_same_v<XDataType, ck::e8m0_bexp_t>);
|
||||
a_m_k_scale.GenerateTensorValue(
|
||||
GeneratorTensor_2<XDataType>{120, 129}); // scales: {0.25, 0.5, 1, 2}
|
||||
b_k_n_scale.GenerateTensorValue(
|
||||
GeneratorTensor_2<XDataType>{125, 129}); // scales: {0.25, 0.5, 1, 2}
|
||||
// ck::utils::FillConstant<XDataType>{ck::type_convert<XDataType>(1.0f)}(a_m_k_scale);
|
||||
// ck::utils::FillConstant<XDataType>{ck::type_convert<XDataType>(1.0f)}(b_k_n_scale);
|
||||
|
||||
break;
|
||||
|
||||
@@ -420,8 +415,9 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c
|
||||
std::cout << "Computing GEMM on device..." << std::endl << std::endl;
|
||||
}
|
||||
|
||||
float ave_time =
|
||||
invoker.Run(argument, StreamConfig{nullptr, config.time_kernel, config.verbosity, 20, 50});
|
||||
float ave_time = invoker.Run(
|
||||
argument,
|
||||
StreamConfig{nullptr, config.time_kernel, config.verbosity, config.warm_up, config.repeat});
|
||||
|
||||
bool res_verified = true;
|
||||
if(config.do_verification > 0)
|
||||
@@ -493,14 +489,14 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c
|
||||
// partial sums(K/ScaleBlockSize)]
|
||||
// FLOPS = 2 * M * N * K + 2 * M * N * K / ScaleBlockSize
|
||||
std::size_t flop = std::size_t(2) * M * N * K + std::size_t(2) * M * N * K / ScaleBlockSize;
|
||||
std::size_t num_btype = sizeof(ADataType) * M * K / ck::packed_size_v<ADataType> +
|
||||
sizeof(BDataType) * K * N / ck::packed_size_v<BDataType> +
|
||||
sizeof(CDataType) * M * N +
|
||||
sizeof(XDataType) * (M * K + K * N) / ScaleBlockSize;
|
||||
std::size_t num_btype =
|
||||
sizeof(ADataType) * M * K / ck::packed_size_v<ADataType> +
|
||||
sizeof(BDataType) * K * N / ck::packed_size_v<BDataType> + sizeof(CDataType) * M * N +
|
||||
sizeof(XDataType) * M * K / ScaleBlockSize + sizeof(XDataType) * N * K / ScaleBlockSize;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
float gb_per_sec = static_cast<float>(num_btype) / 1e6f / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
|
||||
<< " GB/s, " << device_op.GetTypeString() << std::endl;
|
||||
|
||||
@@ -23,8 +23,8 @@ using AElementOp = PassThrough; // elementwise transformation for A matrix
|
||||
using BElementOp = PassThrough; // elementwise transformation for B matrix
|
||||
using CElementOp = PassThrough; // elementwise transformation for C matrix
|
||||
|
||||
constexpr ck::index_t DataPackedSize = 2; // Packed representation of data
|
||||
constexpr ck::index_t ScaleBlockSize = 32; // scaling block size
|
||||
constexpr ck::index_t DataPackedSize = 2; // Packed representation of data
|
||||
constexpr ck::index_t ScaleBlockSize = 32; // scaling block size
|
||||
constexpr ck::index_t KPerBlock = 256 / DataPackedSize; // 256 f4 = 128 fp4x2
|
||||
|
||||
constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
@@ -50,14 +50,14 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffle
|
||||
GemmSpec, // GemmSpec
|
||||
ScaleBlockSize, // ScaleBlockSize: Scaling block size
|
||||
256, // BlockSize: Thread block size
|
||||
192, // MPerBlock
|
||||
256, // MPerBlock
|
||||
256, // NPerBlock
|
||||
KPerBlock, // KPerBlock
|
||||
16, // AK1
|
||||
16, // BK1
|
||||
16, // MPerXDL
|
||||
16, // NPerXDL
|
||||
6, // MXdlPerWave
|
||||
8, // MXdlPerWave
|
||||
8, // NXdlPerWave
|
||||
S<8, 32, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
|
||||
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
|
||||
@@ -65,14 +65,14 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffle
|
||||
2, // ABlockTransferSrcVectorDim
|
||||
16, // ABlockTransferSrcScalarPerVector
|
||||
16, // ABlockTransferDstScalarPerVector_AK1
|
||||
false, // ABlockLdsExtraM
|
||||
true, // ABlockLdsExtraM
|
||||
S<8, 32, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
|
||||
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
|
||||
2, // BBlockTransferSrcVectorDim
|
||||
16, // BBlockTransferSrcScalarPerVector
|
||||
16, // BBlockTransferDstScalarPerVector_BK1
|
||||
false, // BBlockLdsExtraN
|
||||
true, // BBlockLdsExtraN
|
||||
2, // CShuffleMXdlPerWavePerShuffle
|
||||
2, // CShuffleNXdlPerWavePerShuffle
|
||||
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
|
||||
@@ -203,8 +203,8 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx<BlockGemmPipelineScheduler::Intrawave,
|
||||
? HotLoopInstList::B_LDS_Read_Inst_Num
|
||||
: HotLoopInstList::B_LDS_Read_Inst_Num / 2;
|
||||
|
||||
constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num;
|
||||
constexpr auto num_ds_write_inst_b = HotLoopInstList::B_LDS_Write_Inst_Num;
|
||||
// constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num;
|
||||
// constexpr auto num_ds_write_inst_b = HotLoopInstList::B_LDS_Write_Inst_Num;
|
||||
|
||||
constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num;
|
||||
constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num;
|
||||
@@ -243,29 +243,21 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx<BlockGemmPipelineScheduler::Intrawave,
|
||||
constexpr auto mfma_stages_more =
|
||||
num_mfma_stage1 - mfma_perstage_less * num_buffer_load_total;
|
||||
|
||||
constexpr auto num_dswrite_per_issue_a = num_ds_write_inst_a / num_buffer_load_inst_a;
|
||||
constexpr auto num_dswrite_per_issue_b = num_ds_write_inst_b / num_buffer_load_inst_b;
|
||||
// constexpr auto num_dswrite_per_issue_a = num_ds_write_inst_a / num_buffer_load_inst_a;
|
||||
// constexpr auto num_dswrite_per_issue_b = num_ds_write_inst_b / num_buffer_load_inst_b;
|
||||
|
||||
static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) {
|
||||
if constexpr(i < mfma_stages_more)
|
||||
{
|
||||
static_for<0, mfma_perstage_more, 1>{}([&](auto imfma) {
|
||||
static_for<0, mfma_perstage_more, 1>{}([&](auto /*imfma*/) {
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
if constexpr(imfma < num_dswrite_per_issue_a)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
|
||||
}
|
||||
});
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
}
|
||||
else
|
||||
{
|
||||
static_for<0, mfma_perstage_less, 1>{}([&](auto imfma) {
|
||||
static_for<0, mfma_perstage_less, 1>{}([&](auto /*imfma*/) {
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
if constexpr(imfma < num_dswrite_per_issue_a)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
|
||||
}
|
||||
});
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
}
|
||||
@@ -274,23 +266,15 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx<BlockGemmPipelineScheduler::Intrawave,
|
||||
static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) {
|
||||
if constexpr((i + num_buffer_load_inst_a) < mfma_stages_more)
|
||||
{
|
||||
static_for<0, mfma_perstage_more, 1>{}([&](auto imfma) {
|
||||
static_for<0, mfma_perstage_more, 1>{}([&](auto /*imfma*/) {
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
if constexpr(imfma < num_dswrite_per_issue_a)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
|
||||
}
|
||||
});
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
}
|
||||
else
|
||||
{
|
||||
static_for<0, mfma_perstage_less, 1>{}([&](auto imfma) {
|
||||
static_for<0, mfma_perstage_less, 1>{}([&](auto /*imfma*/) {
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
if constexpr(imfma < num_dswrite_per_issue_b)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
|
||||
}
|
||||
});
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
}
|
||||
@@ -392,14 +376,14 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx<BlockGemmPipelineScheduler::Intrawave,
|
||||
const ABlockDesc& a_block_desc,
|
||||
ABlockTransfer& a_blockwise_copy,
|
||||
const AGridBuffer& a_grid_buf,
|
||||
ABlockBuffer& a_block_buf,
|
||||
ABlockBuffer& a_block_bufs,
|
||||
const ABlockTransferStep& a_block_copy_step,
|
||||
// BBlockCopy
|
||||
const BGridDesc& b_grid_desc,
|
||||
const BBlockDesc& b_block_desc,
|
||||
BBlockTransfer& b_blockwise_copy,
|
||||
const BGridBuffer& b_grid_buf,
|
||||
BBlockBuffer& b_block_buf,
|
||||
BBlockBuffer& b_block_bufs,
|
||||
const BBlockTransferStep& b_block_copy_step,
|
||||
// CThread
|
||||
CThreadBuffer& c_thread_buf,
|
||||
@@ -427,8 +411,8 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx<BlockGemmPipelineScheduler::Intrawave,
|
||||
StaticallyIndexedArray<decltype(b_scale_thread_buf), Number<2>{}> b_scale_thread_bufs;
|
||||
|
||||
// Global prefetch 1
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
|
||||
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
|
||||
a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_bufs(I0));
|
||||
b_blockwise_copy.Run(b_grid_desc, b_grid_buf, b_block_desc, b_block_bufs(I0));
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
@@ -476,18 +460,8 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx<BlockGemmPipelineScheduler::Intrawave,
|
||||
b_scale_grid_desc,
|
||||
make_multi_index(-NWaves * NRepeat / NXdlPack, KRepeat / KXdlPack, 0));
|
||||
|
||||
// Local prefill 1
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
|
||||
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
|
||||
|
||||
// Global prefetch 2
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
|
||||
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
// Local prefetch 1
|
||||
// Local prefetch 1, sync the async load
|
||||
__builtin_amdgcn_s_waitcnt(3952);
|
||||
block_sync_lds();
|
||||
static_for<0, KRepeat, 1>{}([&](auto k) {
|
||||
constexpr auto k_step = k * xdlops_gemm.KPerXdlops * KPack / xdlops_gemm.K1PerXdlops;
|
||||
@@ -502,7 +476,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx<BlockGemmPipelineScheduler::Intrawave,
|
||||
Number<m0 % MXdlPack>{},
|
||||
I0,
|
||||
Number<a_k_step_chunk>{}),
|
||||
a_block_buf,
|
||||
a_block_bufs(I0),
|
||||
a_thread_desc_,
|
||||
make_tuple(Number<m0 / MXdlPack>{},
|
||||
I0,
|
||||
@@ -524,7 +498,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx<BlockGemmPipelineScheduler::Intrawave,
|
||||
Number<n0 % NXdlPack>{},
|
||||
I0,
|
||||
Number<b_k_step_chunk>{}),
|
||||
b_block_buf,
|
||||
b_block_bufs(I0),
|
||||
b_thread_desc_,
|
||||
make_tuple(Number<n0 / NXdlPack>{},
|
||||
I0,
|
||||
@@ -536,6 +510,13 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx<BlockGemmPipelineScheduler::Intrawave,
|
||||
});
|
||||
});
|
||||
|
||||
// Global prefetch 2
|
||||
a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_bufs(I1));
|
||||
b_blockwise_copy.Run(b_grid_desc, b_grid_buf, b_block_desc, b_block_bufs(I1));
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
// Initialize C
|
||||
c_thread_buf.Clear();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
@@ -548,13 +529,13 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx<BlockGemmPipelineScheduler::Intrawave,
|
||||
do
|
||||
{
|
||||
auto LoopFunc = [&](auto scale_comp_buf, auto scale_mem_buf) {
|
||||
// __builtin_amdgcn_s_waitcnt(3952);
|
||||
block_sync_lds();
|
||||
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
|
||||
|
||||
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
|
||||
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
|
||||
a_blockwise_copy.Run(
|
||||
a_grid_desc, a_grid_buf, a_block_desc, a_block_bufs(scale_comp_buf));
|
||||
b_blockwise_copy.Run(
|
||||
b_grid_desc, b_grid_buf, b_block_desc, b_block_bufs(scale_comp_buf));
|
||||
|
||||
// Prefetch a_scales
|
||||
static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) {
|
||||
@@ -699,7 +680,8 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx<BlockGemmPipelineScheduler::Intrawave,
|
||||
// t32: |32 --> 47 96 --> 111| 160 --> 175 224 --> 239| etc.
|
||||
// t48: |48 --> 63 112 --> 127| 176 --> 191 240 --> 255| etc.
|
||||
// k = 0 k = 1
|
||||
block_sync_lds();
|
||||
// __builtin_amdgcn_s_waitcnt(3952);
|
||||
// block_sync_lds();
|
||||
static_for<0, KRepeat, 1>{}([&](auto k) {
|
||||
constexpr auto k_step =
|
||||
k * xdlops_gemm.KPerXdlops * KPack / xdlops_gemm.K1PerXdlops;
|
||||
@@ -716,7 +698,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx<BlockGemmPipelineScheduler::Intrawave,
|
||||
Number<m0 % MXdlPack>{},
|
||||
I0,
|
||||
Number<a_k_step_chunk>{}),
|
||||
a_block_buf,
|
||||
a_block_bufs(scale_mem_buf),
|
||||
a_thread_desc_,
|
||||
make_tuple(Number<m0 / MXdlPack>{},
|
||||
I0,
|
||||
@@ -740,7 +722,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx<BlockGemmPipelineScheduler::Intrawave,
|
||||
Number<n0 % NXdlPack>{},
|
||||
I0,
|
||||
Number<b_k_step_chunk>{}),
|
||||
b_block_buf,
|
||||
b_block_bufs(scale_mem_buf),
|
||||
b_thread_desc_,
|
||||
make_tuple(Number<n0 / NXdlPack>{},
|
||||
I0,
|
||||
@@ -798,10 +780,6 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx<BlockGemmPipelineScheduler::Intrawave,
|
||||
b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0));
|
||||
});
|
||||
|
||||
block_sync_lds();
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
|
||||
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
|
||||
|
||||
static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) {
|
||||
static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
|
||||
@@ -880,6 +858,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx<BlockGemmPipelineScheduler::Intrawave,
|
||||
});
|
||||
});
|
||||
|
||||
__builtin_amdgcn_s_waitcnt(3952);
|
||||
block_sync_lds();
|
||||
|
||||
static_for<0, KRepeat, 1>{}([&](auto k) {
|
||||
@@ -897,7 +876,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx<BlockGemmPipelineScheduler::Intrawave,
|
||||
Number<m0 % MXdlPack>{},
|
||||
I0,
|
||||
Number<a_k_step_chunk>{}),
|
||||
a_block_buf,
|
||||
a_block_bufs(I1),
|
||||
a_thread_desc_,
|
||||
make_tuple(Number<m0 / MXdlPack>{},
|
||||
I0,
|
||||
@@ -920,7 +899,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx<BlockGemmPipelineScheduler::Intrawave,
|
||||
Number<n0 % NXdlPack>{},
|
||||
I0,
|
||||
Number<b_k_step_chunk>{}),
|
||||
b_block_buf,
|
||||
b_block_bufs(I1),
|
||||
b_thread_desc_,
|
||||
make_tuple(Number<n0 / NXdlPack>{},
|
||||
I0,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -42,13 +42,16 @@ namespace ck {
|
||||
template <typename ThreadGroup,
|
||||
typename BlockSliceLengths,
|
||||
typename ThreadClusterLengths,
|
||||
typename ThreadClusterArrangeOrder,
|
||||
typename SrcData,
|
||||
typename DstData,
|
||||
typename SrcDesc,
|
||||
typename DstDesc,
|
||||
typename SrcDimAccessOrder,
|
||||
index_t SrcVectorDim,
|
||||
index_t DstVectorDim,
|
||||
index_t ScalarPerVector>
|
||||
index_t ScalarPerVector,
|
||||
bool SrcXor = true>
|
||||
struct ThreadGroupTensorSliceTransfer_DirectLoad
|
||||
{
|
||||
static constexpr index_t nDim = remove_reference_t<SrcDesc>::GetNumOfDimension();
|
||||
@@ -61,15 +64,24 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad
|
||||
using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{}));
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
|
||||
static constexpr auto block_slice_lengths = BlockSliceLengths{};
|
||||
static constexpr auto thread_cluster_lengths = ThreadClusterLengths{};
|
||||
static constexpr auto wave_thread_cluster_lengths =
|
||||
Sequence<ThreadClusterLengths{}.At(I0),
|
||||
ThreadClusterLengths{}.At(I1) * 64 / ThreadGroup::GetNumOfThread(),
|
||||
1>{};
|
||||
static constexpr auto wave_cluster_lengths =
|
||||
Sequence<1, ThreadGroup::GetNumOfThread() / 64, 1>{};
|
||||
|
||||
static constexpr auto thread_single_load_size = generate_sequence(
|
||||
detail::lambda_scalar_per_access<DstVectorDim, ScalarPerVector>{}, Number<nDim>{});
|
||||
// After a load, each thread moves by `thread_steps` instead of loading the next elements.
|
||||
// It makes the whole wavefront load contiguous memory, what is required for direct loads.
|
||||
static constexpr auto thread_steps = thread_cluster_lengths * thread_single_load_size;
|
||||
static constexpr auto thread_steps = thread_cluster_lengths * thread_single_load_size;
|
||||
static constexpr auto wave_single_load_size =
|
||||
wave_thread_cluster_lengths * thread_single_load_size;
|
||||
static constexpr auto thread_slice_lengths = block_slice_lengths / thread_steps;
|
||||
|
||||
static __device__ constexpr bool AreThreadClusterLengthsValid()
|
||||
@@ -96,8 +108,12 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad
|
||||
// VALID: ThreadClusterLengths = [4, 16, 4] or [2, 32, 4] or [1, 64, 4] since in the
|
||||
// first iteration, threads 0-63 write [0, 0, 0] - [0, 15, 7] -> 128 consecutive
|
||||
// elements = 64 consecutive DWORDs.
|
||||
#if defined(__gfx950__)
|
||||
int num_contiguous_dwords = 4;
|
||||
#else
|
||||
int num_contiguous_dwords = 1;
|
||||
bool is_contiguous = true;
|
||||
#endif
|
||||
bool is_contiguous = true;
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
if(is_contiguous)
|
||||
{
|
||||
@@ -141,11 +157,11 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad
|
||||
"When loading more than one element per thread at once, the contiguous "
|
||||
"dimension must be the same between source and destination.");
|
||||
|
||||
constexpr auto dword_bytes = 4;
|
||||
constexpr auto bytes_per_thread_load = ScalarPerVector * sizeof(SrcData);
|
||||
static_assert(bytes_per_thread_load == dword_bytes,
|
||||
"Direct load transfer requires each thread to load exactly a single "
|
||||
"DWORD of data.");
|
||||
// constexpr auto dword_bytes = 4;
|
||||
// constexpr auto bytes_per_thread_load = ScalarPerVector * sizeof(SrcData);
|
||||
// static_assert(bytes_per_thread_load == dword_bytes,
|
||||
// "Direct load transfer requires each thread to load exactly a single "
|
||||
// "DWORD of data.");
|
||||
|
||||
static_assert(nDim == remove_cvref_t<SrcDesc>::GetNumOfDimension() &&
|
||||
nDim == remove_cvref_t<DstDesc>::GetNumOfDimension() &&
|
||||
@@ -156,18 +172,24 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad
|
||||
"The number of threads cannot be less than the number of elements in "
|
||||
"thread cluster lengths.");
|
||||
|
||||
static_assert(
|
||||
AreThreadClusterLengthsValid(),
|
||||
"Thread cluster lengths are incorrect. They must be set in a way that allows a single "
|
||||
"wavefront to write contiguous DWORDs into LDS memory. ");
|
||||
// static_assert(
|
||||
// AreThreadClusterLengthsValid(),
|
||||
// "Thread cluster lengths are incorrect. They must be set in a way that allows a single
|
||||
// " "wavefront to write contiguous DWORDs into LDS memory. ");
|
||||
|
||||
const auto thread_cluster_idx =
|
||||
thread_cluster_desc_.CalculateBottomIndex(make_multi_index(ThreadGroup::GetThreadId()));
|
||||
|
||||
const auto wave_cluster_idx = wave_cluster_desc_.CalculateBottomIndex(
|
||||
make_multi_index(ThreadGroup::GetThreadId() / 64));
|
||||
|
||||
const auto thread_data_idx_begin = thread_cluster_idx * thread_single_load_size;
|
||||
const auto wave_data_idx_begin = wave_cluster_idx * wave_single_load_size;
|
||||
|
||||
SetSrcSliceOrigin(src_desc, src_block_slice_origin + thread_data_idx_begin);
|
||||
SetDstSliceOrigin(dst_desc, dst_block_slice_origin + thread_data_idx_begin);
|
||||
// We don't need threadwise offset for lds since it was calculate by HW
|
||||
// We still need input the wavewise offset.
|
||||
SetDstSliceOrigin(dst_desc, dst_block_slice_origin + wave_data_idx_begin);
|
||||
}
|
||||
|
||||
__device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx)
|
||||
@@ -215,7 +237,7 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad
|
||||
// Loop over the destination block and copy data.
|
||||
static_ford<decltype(dst_access_lengths)>{}([&](auto ordered_dst_access_idx) {
|
||||
const auto src_offset = src_coord_.GetOffset();
|
||||
const auto dst_offset = dst_coord_.GetOffset();
|
||||
const auto dst_offset = __builtin_amdgcn_readfirstlane(dst_coord_.GetOffset());
|
||||
|
||||
// Check if src data is not in the logic padding area.
|
||||
const bool is_src_valid =
|
||||
@@ -303,7 +325,10 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad
|
||||
}
|
||||
|
||||
private:
|
||||
static constexpr auto thread_cluster_desc_ = make_cluster_descriptor(ThreadClusterLengths{});
|
||||
static constexpr auto thread_cluster_desc_ =
|
||||
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
|
||||
static constexpr auto wave_cluster_desc_ =
|
||||
make_cluster_descriptor(wave_cluster_lengths, ThreadClusterArrangeOrder{});
|
||||
|
||||
SrcCoord src_coord_;
|
||||
DstCoord dst_coord_;
|
||||
|
||||
@@ -299,119 +299,43 @@ struct DeviceGemmMX_Xdl_CShuffleV3 : public DeviceGemmMX<ALayout,
|
||||
constexpr index_t minimum_occupancy =
|
||||
BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave
|
||||
? (BlkGemmPipelineVer == BlockGemmPipelineVersion::v3 &&
|
||||
MPerBlock * NPerBlock * KPerBlock * sizeof(ADataType) <=
|
||||
128 * 128 * 64 * 2)
|
||||
MPerBlock * NPerBlock * KPerBlock * sizeof(ADataType) <= 128 * 128 * 64 * 2)
|
||||
? 2
|
||||
: 1
|
||||
: 2;
|
||||
|
||||
if(has_main_k_block_loop)
|
||||
{
|
||||
// Tail number always full
|
||||
constexpr auto TailNumChoices = []() {
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
|
||||
{
|
||||
if(arg.KBatch > 1)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy>;
|
||||
Run(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
// Tail number could be Odd or Even
|
||||
return Tuple<constant<TailNumber::Full>>{};
|
||||
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
|
||||
{
|
||||
#if 1
|
||||
if(arg.KBatch > 1)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy,
|
||||
TailNumber::Odd>;
|
||||
Run(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy,
|
||||
TailNumber::Even>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Odd>;
|
||||
Run(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Even>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
return Tuple<constant<TailNumber::Even>, constant<TailNumber::Odd>>{};
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("wrong! BlkGemmPipelineVer");
|
||||
static_assert(false, "Unexpected BlkGemmPipelineVer!");
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// Tail number always 1
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
|
||||
{
|
||||
if(arg.KBatch > 1)
|
||||
}();
|
||||
const TailNumber tail_num = GridwiseGemm::CalculateKBlockLoopTailNum(K_split);
|
||||
using BoolChoices = Tuple<ck::true_type, ck::false_type>;
|
||||
static_for_product<BoolChoices,
|
||||
BoolChoices,
|
||||
remove_cvref_t<decltype(TailNumChoices)>>{}(
|
||||
[&](auto mainloop_choice, auto KBatch_cond_choice, auto tail_num_choice) {
|
||||
constexpr auto CGlobalMemoryDataOperation =
|
||||
KBatch_cond_choice.value ? InMemoryDataOperationEnum::AtomicAdd
|
||||
: InMemoryDataOperationEnum::Set;
|
||||
if(mainloop_choice.value == has_main_k_block_loop &&
|
||||
KBatch_cond_choice.value == (arg.KBatch > 1) &&
|
||||
tail_num_choice.value == tail_num)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
false,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy>;
|
||||
const auto kernel = kernel_gemm_xdl_cshuffle_v3_2lds< //
|
||||
GridwiseGemm,
|
||||
mainloop_choice.value,
|
||||
CGlobalMemoryDataOperation,
|
||||
minimum_occupancy,
|
||||
tail_num_choice.value>;
|
||||
Run(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
false,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
});
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/utility/env.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_direct_load.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
@@ -76,9 +77,10 @@ __global__ void
|
||||
|
||||
GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
|
||||
karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
|
||||
karg.p_a_scale_grid + splitk_batch_offset.a_scale_k_split_offset,
|
||||
karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
|
||||
karg.p_b_scale_grid + splitk_batch_offset.b_scale_k_split_offset,
|
||||
karg.p_c_grid + splitk_batch_offset.c_reduce_offset,
|
||||
karg.p_b_scale_grid + splitk_batch_offset.scale_k_split_offset,
|
||||
p_shared_0,
|
||||
p_shared_1,
|
||||
karg);
|
||||
@@ -198,7 +200,8 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
|
||||
NPerXdl,
|
||||
ComputeTypeB,
|
||||
is_single_rate_mfma,
|
||||
is_scale_mfma>::selected_mfma.k_per_blk/APackedSize);
|
||||
is_scale_mfma>::selected_mfma.k_per_blk /
|
||||
APackedSize);
|
||||
|
||||
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
|
||||
|
||||
@@ -265,10 +268,18 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
|
||||
__host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1&)
|
||||
{
|
||||
constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{});
|
||||
constexpr index_t MN = TileDesc_K0_MN_K1{}.GetLength(Number<1>{});
|
||||
constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{});
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
constexpr auto permuted_desc = transform_tensor_descriptor(
|
||||
TileDesc_K0_MN_K1{},
|
||||
make_tuple(make_xor_with_modulo_transform(make_tuple(Number<MN>{}, Number<K0>{})),
|
||||
make_pass_through_transform(Number<K1>{})),
|
||||
make_tuple(Sequence<1, 0>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<1, 0>{}, Sequence<2>{}));
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
permuted_desc,
|
||||
make_tuple(make_merge_transform_v3_division_mod(make_tuple(Number<K0>{}, Number<K1>{})),
|
||||
make_unmerge_transform(make_tuple(Number<MNXdlPerWave / MNXdlPack>{},
|
||||
Number<MNWaves>{},
|
||||
@@ -351,12 +362,29 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
|
||||
// not pad M or K
|
||||
const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
|
||||
a_grid_desc_mraw_kraw,
|
||||
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(K / KPerBlock, AK0Number, AK1Value)),
|
||||
make_pass_through_transform(M)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
|
||||
|
||||
return a_grid_desc_ak0_m_ak1;
|
||||
const auto a_grid_desc_permuted = transform_tensor_descriptor(
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
make_tuple(make_pass_through_transform(K / KPerBlock),
|
||||
make_xor_with_modulo_transform(make_tuple(M, AK0Number)),
|
||||
make_pass_through_transform(AK1Value)),
|
||||
make_tuple(Sequence<0>{}, Sequence<2, 1>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<2, 1>{}, Sequence<3>{}));
|
||||
|
||||
const auto a_grid_desc = transform_tensor_descriptor(
|
||||
a_grid_desc_permuted,
|
||||
make_tuple(
|
||||
make_merge_transform_v3_division_mod(make_tuple(K / KPerBlock, AK0Number)),
|
||||
make_pass_through_transform(M),
|
||||
make_pass_through_transform(AK1Value)),
|
||||
make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
|
||||
return a_grid_desc;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -442,12 +470,30 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
|
||||
// not pad N or K
|
||||
const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
|
||||
b_grid_desc_nraw_kraw,
|
||||
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)),
|
||||
make_pass_through_transform(N)),
|
||||
make_tuple(
|
||||
make_unmerge_transform(make_tuple(K / KPerBlock, BK0Number, BK1Value)),
|
||||
make_pass_through_transform(N)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
|
||||
|
||||
return b_grid_desc_bk0_n_bk1;
|
||||
const auto b_grid_desc_permuted = transform_tensor_descriptor(
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
make_tuple(make_pass_through_transform(K / KPerBlock),
|
||||
make_xor_with_modulo_transform(make_tuple(N, BK0Number)),
|
||||
make_pass_through_transform(BK1Value)),
|
||||
make_tuple(Sequence<0>{}, Sequence<2, 1>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<2, 1>{}, Sequence<3>{}));
|
||||
|
||||
const auto b_grid_desc = transform_tensor_descriptor(
|
||||
b_grid_desc_permuted,
|
||||
make_tuple(
|
||||
make_merge_transform_v3_division_mod(make_tuple(K / KPerBlock, BK0Number)),
|
||||
make_pass_through_transform(N),
|
||||
make_pass_through_transform(BK1Value)),
|
||||
make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
|
||||
return b_grid_desc;
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -648,10 +694,10 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
|
||||
bool is_reduce_ = false)
|
||||
: Problem{M_,
|
||||
N_,
|
||||
K_/APackedSize,
|
||||
StrideA_/APackedSize,
|
||||
K_ / APackedSize,
|
||||
StrideA_ / APackedSize,
|
||||
StrideScaleA_,
|
||||
StrideB_/BPackedSize,
|
||||
StrideB_ / BPackedSize,
|
||||
StrideScaleB_,
|
||||
StrideC_,
|
||||
k_batch_},
|
||||
@@ -723,21 +769,23 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
|
||||
// Calculate A scale offset
|
||||
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
|
||||
{
|
||||
a_scale_k_split_offset = k_id * karg.KRead / (ScaleBlockSize/APackedSize);
|
||||
a_scale_k_split_offset = k_id * karg.KRead / (ScaleBlockSize / APackedSize);
|
||||
}
|
||||
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
|
||||
{
|
||||
a_scale_k_split_offset = k_id * karg.KRead / (ScaleBlockSize/APackedSize) * karg.StrideScaleA;
|
||||
a_scale_k_split_offset =
|
||||
k_id * karg.KRead / (ScaleBlockSize / APackedSize) * karg.StrideScaleA;
|
||||
}
|
||||
|
||||
// Calculate B scale offset
|
||||
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
|
||||
{
|
||||
b_scale_k_split_offset = k_id * (karg.KRead / (ScaleBlockSize/BPackedSize)) * karg.StrideScaleB;
|
||||
b_scale_k_split_offset =
|
||||
k_id * (karg.KRead / (ScaleBlockSize / BPackedSize)) * karg.StrideScaleB;
|
||||
}
|
||||
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
|
||||
{
|
||||
b_scale_k_split_offset = k_id * karg.KRead / (ScaleBlockSize/BPackedSize);
|
||||
b_scale_k_split_offset = k_id * karg.KRead / (ScaleBlockSize / BPackedSize);
|
||||
}
|
||||
|
||||
if(k_id < (karg.KBatch - 1))
|
||||
@@ -771,9 +819,10 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
if constexpr(ABlockLdsExtraM || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
|
||||
{
|
||||
// contiguous in LDS
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(AK0Number, Number<MPerBlock>{}, AK1Number),
|
||||
make_tuple(AK1Number, Number<KPerBlock + ABlockLdsExtraM>{}, I1));
|
||||
make_tuple(AK1Number, Number<KPerBlock>{}, I1));
|
||||
}
|
||||
// xor tensor transformation request more unnecessary vgpr usage, would cause register spill
|
||||
// in some cases.
|
||||
@@ -888,9 +937,10 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
if constexpr(BBlockLdsExtraN || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
|
||||
{
|
||||
// contiguous in lds
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(BK0Number, Number<NPerBlock>{}, BK1Number),
|
||||
make_tuple(BK1Number, Number<KPerBlock + BBlockLdsExtraN>{}, I1));
|
||||
make_tuple(BK1Number, Number<KPerBlock>{}, I1));
|
||||
}
|
||||
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
|
||||
{
|
||||
@@ -1074,7 +1124,7 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
|
||||
(NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
|
||||
"Invalid tuning param!");
|
||||
|
||||
static_assert(KPerBlock % (ScaleBlockSize/BPackedSize) == 0,
|
||||
static_assert(KPerBlock % (ScaleBlockSize / BPackedSize) == 0,
|
||||
"KPerBlock should be multiple of ScaleBlockSize");
|
||||
|
||||
if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
|
||||
@@ -1381,67 +1431,42 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
|
||||
|
||||
// A matrix blockwise copy
|
||||
auto a_blockwise_copy =
|
||||
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
|
||||
AElementwiseOperation,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
Sequence<AK0Number, MPerBlock, AK1Number>,
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ADataType,
|
||||
ADataType,
|
||||
decltype(a_grid_desc_ak0_m_ak1),
|
||||
decltype(a_block_desc_ak0_m_ak1),
|
||||
ABlockTransferSrcAccessOrder,
|
||||
Sequence<0, 1, 2>,
|
||||
ABlockTransferSrcVectorDim,
|
||||
2,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_AK1,
|
||||
1,
|
||||
1,
|
||||
AThreadTransferSrcResetCoordinateAfterRun,
|
||||
true,
|
||||
BlockwiseGemmPipe::GlobalBufferNum>(
|
||||
ThreadGroupTensorSliceTransfer_DirectLoad<ThisThreadBlock,
|
||||
Sequence<AK0Number, MPerBlock, AK1Number>,
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ADataType,
|
||||
ADataType,
|
||||
decltype(a_grid_desc_ak0_m_ak1),
|
||||
decltype(a_block_desc_ak0_m_ak1),
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
2,
|
||||
ABlockTransferSrcScalarPerVector>(
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
make_multi_index(0, m_block_data_idx_on_grid, 0),
|
||||
a_element_op,
|
||||
a_block_desc_ak0_m_ak1,
|
||||
make_multi_index(0, 0, 0),
|
||||
ck::tensor_operation::element_wise::PassThrough{});
|
||||
make_multi_index(0, 0, 0));
|
||||
|
||||
// B matrix blockwise copy
|
||||
auto b_blockwise_copy =
|
||||
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
|
||||
BElementwiseOperation,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
Sequence<BK0Number, NPerBlock, BK1Number>,
|
||||
BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BDataType,
|
||||
BDataType,
|
||||
decltype(b_grid_desc_bk0_n_bk1),
|
||||
decltype(b_block_desc_bk0_n_bk1),
|
||||
BBlockTransferSrcAccessOrder,
|
||||
Sequence<0, 1, 2>,
|
||||
BBlockTransferSrcVectorDim,
|
||||
2,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_BK1,
|
||||
1,
|
||||
1,
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
true,
|
||||
BlockwiseGemmPipe::GlobalBufferNum>(
|
||||
ThreadGroupTensorSliceTransfer_DirectLoad<ThisThreadBlock,
|
||||
Sequence<BK0Number, NPerBlock, BK1Number>,
|
||||
BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BDataType,
|
||||
BDataType,
|
||||
decltype(b_grid_desc_bk0_n_bk1),
|
||||
decltype(b_block_desc_bk0_n_bk1),
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
2,
|
||||
BBlockTransferSrcScalarPerVector>(
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
make_multi_index(0, n_block_data_idx_on_grid, 0),
|
||||
b_element_op,
|
||||
b_block_desc_bk0_n_bk1,
|
||||
make_multi_index(0, 0, 0),
|
||||
ck::tensor_operation::element_wise::PassThrough{});
|
||||
make_multi_index(0, 0, 0));
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
|
||||
@@ -1449,12 +1474,11 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
|
||||
|
||||
// Cast after lds
|
||||
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<ADataType*>(p_shared),
|
||||
a_block_desc_ak0_m_ak1.GetElementSpaceSize());
|
||||
static_cast<ADataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
|
||||
|
||||
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
reinterpret_cast<BDataType*>(static_cast<char*>(p_shared) + a_block_space_size_aligned *
|
||||
sizeof(ADataType)),
|
||||
reinterpret_cast<BDataType*>(static_cast<char*>(p_shared) +
|
||||
a_block_space_size_aligned * sizeof(ADataType)),
|
||||
b_block_desc_bk0_n_bk1.GetElementSpaceSize());
|
||||
|
||||
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
|
||||
@@ -1556,7 +1580,6 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
|
||||
|
||||
// shuffle C and write out
|
||||
{
|
||||
// printf("c_thread_buf %f %f\n", c_thread_buf[I0], c_thread_buf[I1]);
|
||||
static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
|
||||
NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
|
||||
"wrong!");
|
||||
@@ -1801,15 +1824,17 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
|
||||
|
||||
// A/B shuffled scale for better 8-bit scale access pattern
|
||||
// MNRepeat -> KRepeat -> KThreadPerXdl -> MNThreadPerXdl -> KXdlPack -> MNXdlPack
|
||||
const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor_packed(make_tuple(
|
||||
problem.M / (MXdlPack * MPerXdl),
|
||||
math::integer_divide_ceil(problem.K, (ScaleBlockSize/APackedSize)) / (KXdlPack * 64 / MPerXdl),
|
||||
64 * KXdlPack * MXdlPack / scale_pack_size_a));
|
||||
const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(problem.M / (MXdlPack * MPerXdl),
|
||||
math::integer_divide_ceil(problem.K, (ScaleBlockSize / APackedSize)) /
|
||||
(KXdlPack * 64 / MPerXdl),
|
||||
64 * KXdlPack * MXdlPack / scale_pack_size_a));
|
||||
|
||||
const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor_packed(make_tuple(
|
||||
problem.N / (NXdlPack * NPerXdl),
|
||||
math::integer_divide_ceil(problem.K, (ScaleBlockSize/BPackedSize)) / (KXdlPack * 64 / NPerXdl),
|
||||
64 * KXdlPack * NXdlPack / scale_pack_size_b));
|
||||
const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(problem.N / (NXdlPack * NPerXdl),
|
||||
math::integer_divide_ceil(problem.K, (ScaleBlockSize / BPackedSize)) /
|
||||
(KXdlPack * 64 / NPerXdl),
|
||||
64 * KXdlPack * NXdlPack / scale_pack_size_b));
|
||||
|
||||
Run<decltype(a_grid_desc_ak0_m_ak1),
|
||||
decltype(a_scale_grid_desc_am_ak),
|
||||
@@ -1855,12 +1880,6 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
|
||||
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock)
|
||||
{
|
||||
ignore = p_a_scale_grid;
|
||||
ignore = a_scale_grid_desc_am_ak;
|
||||
|
||||
// TODO: Implement 2 LDS version
|
||||
static_assert(false, "Not implemented");
|
||||
|
||||
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
|
||||
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
@@ -1868,12 +1887,17 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
|
||||
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
|
||||
|
||||
// A Scale buffer
|
||||
const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
|
||||
|
||||
// B Scale buffer
|
||||
const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_scale_grid, b_scale_grid_desc_bn_ak.GetElementSpaceSize());
|
||||
|
||||
const AElementwiseOperation a_element_op{};
|
||||
const BElementwiseOperation b_element_op{};
|
||||
static_assert(
|
||||
is_same_v<AElementwiseOperation, tensor_operation::element_wise::PassThrough> &&
|
||||
is_same_v<BElementwiseOperation, tensor_operation::element_wise::PassThrough>);
|
||||
const CElementwiseOperation c_element_op{};
|
||||
|
||||
// divide block work by [M, N]
|
||||
@@ -1909,67 +1933,42 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
|
||||
|
||||
// A matrix blockwise copy
|
||||
auto a_blockwise_copy =
|
||||
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
|
||||
AElementwiseOperation,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
Sequence<AK0Number, MPerBlock, AK1Number>,
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ADataType,
|
||||
ADataType,
|
||||
decltype(a_grid_desc_ak0_m_ak1),
|
||||
decltype(a_block_desc_ak0_m_ak1),
|
||||
ABlockTransferSrcAccessOrder,
|
||||
Sequence<0, 1, 2>,
|
||||
ABlockTransferSrcVectorDim,
|
||||
2,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_AK1,
|
||||
1,
|
||||
1,
|
||||
AThreadTransferSrcResetCoordinateAfterRun,
|
||||
true,
|
||||
BlockwiseGemmPipe::GlobalBufferNum>(
|
||||
ThreadGroupTensorSliceTransfer_DirectLoad<ThisThreadBlock,
|
||||
Sequence<AK0Number, MPerBlock, AK1Number>,
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ADataType,
|
||||
ADataType,
|
||||
decltype(a_grid_desc_ak0_m_ak1),
|
||||
decltype(a_block_desc_ak0_m_ak1),
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
2,
|
||||
ABlockTransferSrcScalarPerVector>(
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
make_multi_index(0, m_block_data_idx_on_grid, 0),
|
||||
a_element_op,
|
||||
a_block_desc_ak0_m_ak1,
|
||||
make_multi_index(0, 0, 0),
|
||||
ck::tensor_operation::element_wise::PassThrough{});
|
||||
make_multi_index(0, 0, 0));
|
||||
|
||||
// B matrix blockwise copy
|
||||
auto b_blockwise_copy =
|
||||
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
|
||||
BElementwiseOperation,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
Sequence<BK0Number, NPerBlock, BK1Number>,
|
||||
BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BDataType,
|
||||
BDataType,
|
||||
decltype(b_grid_desc_bk0_n_bk1),
|
||||
decltype(b_block_desc_bk0_n_bk1),
|
||||
BBlockTransferSrcAccessOrder,
|
||||
Sequence<0, 1, 2>,
|
||||
BBlockTransferSrcVectorDim,
|
||||
2,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_BK1,
|
||||
1,
|
||||
1,
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
true,
|
||||
BlockwiseGemmPipe::GlobalBufferNum>(
|
||||
ThreadGroupTensorSliceTransfer_DirectLoad<ThisThreadBlock,
|
||||
Sequence<BK0Number, NPerBlock, BK1Number>,
|
||||
BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BDataType,
|
||||
BDataType,
|
||||
decltype(b_grid_desc_bk0_n_bk1),
|
||||
decltype(b_block_desc_bk0_n_bk1),
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
2,
|
||||
BBlockTransferSrcScalarPerVector>(
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
make_multi_index(0, n_block_data_idx_on_grid, 0),
|
||||
b_element_op,
|
||||
b_block_desc_bk0_n_bk1,
|
||||
make_multi_index(0, 0, 0),
|
||||
ck::tensor_operation::element_wise::PassThrough{});
|
||||
make_multi_index(0, 0, 0));
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
|
||||
@@ -2006,76 +2005,99 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
|
||||
(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
|
||||
KPerBlock);
|
||||
|
||||
// B scale
|
||||
static constexpr auto mfma =
|
||||
MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl, ComputeTypeA, is_single_rate_mfma>{};
|
||||
static constexpr auto KPerXdlops = mfma.GetKPerXdlops();
|
||||
static constexpr auto K1PerXdlops = mfma.GetK1PerXdlops();
|
||||
static constexpr auto K0PerXdlops = KPerXdlops / K1PerXdlops;
|
||||
static constexpr auto KPerThread = KPerBlock / K0PerXdlops;
|
||||
// Initial thread mapping for:
|
||||
// BlockSize = 256
|
||||
// MPerXdl=NPerXdl=32 and MPerBlock=NPerBlock=128 MRepeat=NRepeat=2 MWaves=NWaves=2
|
||||
// For each [m0, n0] tile, there are 4 waves:
|
||||
// tId in [ 0, 63] m x n = [ 0, 31] x [ 0, 31] waveId = [0, 0]
|
||||
// tId in [ 64, 127] m x n = [ 0, 31] x [32, 63] waveId = [0, 1]
|
||||
// tId in [128, 191] m x n = [32, 63] x [ 0, 31] waveId = [1, 0]
|
||||
// tId in [192, 255] m x n = [32, 63] x [32, 63] waveId = [1, 1]
|
||||
|
||||
const index_t ScaleSliceSizeN = NXdlPerWave;
|
||||
static constexpr auto ScaleSliceSizeK = (KPerThread + (ScaleBlockSize/BPackedSize) - 1) / (ScaleBlockSize/BPackedSize);
|
||||
static constexpr auto KBlockScaleSliceSizeK =
|
||||
(KPerBlock + (ScaleBlockSize/BPackedSize) - 1) / (ScaleBlockSize/BPackedSize);
|
||||
// BlockSize = 128
|
||||
// MPerXdl=NPerXdl=16 and MPerBlock=128 NPerBlock=16 MRepeat=4 NRepeat=1 MWaves=2 NWaves=1
|
||||
// For each [m0, n0] tile, there are 2 waves:
|
||||
// tId in [ 0, 63] m x n = [ 0, 15] x [0, 15] waveId = [0, 0]
|
||||
// tId in [ 64, 127] m x n = [16, 31] x [0, 15] waveId = [1, 0]
|
||||
|
||||
constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<ScaleSliceSizeN>{}, Number<ScaleSliceSizeK>{}));
|
||||
// TODO: Document initial thread mapping for more combinations of parameters
|
||||
|
||||
constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
|
||||
const auto wave_idx = BlockwiseGemmPipe::GetWaveIdx();
|
||||
const auto waveId_m = wave_idx[I0];
|
||||
const auto waveId_n = wave_idx[I1];
|
||||
|
||||
auto b_thread_offset_n =
|
||||
get_thread_local_1d_id() % NPerXdl +
|
||||
(get_thread_local_1d_id() / BlockwiseGemmPipe::WaveSize) % NWaves * NPerXdl;
|
||||
auto b_thread_offset_k =
|
||||
(get_thread_local_1d_id() % BlockwiseGemmPipe::WaveSize) / NPerXdl * KPerThread;
|
||||
// static constexpr auto mfma = BlockwiseGemmPipe::xdlops_gemm.mfma;
|
||||
|
||||
auto b_scale_thread_copy =
|
||||
ThreadwiseTensorSliceTransfer_v2<BScaleDataType,
|
||||
BScaleDataType,
|
||||
decltype(b_scale_grid_desc_bn_ak),
|
||||
decltype(b_scale_thread_desc),
|
||||
Sequence<1, ScaleSliceSizeK>,
|
||||
Sequence<0, 1>,
|
||||
1,
|
||||
ScaleSliceSizeK,
|
||||
1,
|
||||
false>(
|
||||
b_scale_grid_desc_bn_ak,
|
||||
make_multi_index(block_n_id * NPerBlock + b_thread_offset_n,
|
||||
b_thread_offset_k / (ScaleBlockSize/BPackedSize)));
|
||||
// auto thread_offset_k = (get_thread_local_1d_id() % BlockwiseGemmPipe::WaveSize) /
|
||||
// mfma.selected_mfma.num_threads_per_blk;
|
||||
|
||||
constexpr auto b_scale_thread_slice_copy_step =
|
||||
make_tuple(make_multi_index(NWaves * NPerXdl, 0),
|
||||
make_multi_index(-NPerBlock, 0),
|
||||
make_multi_index(-NPerBlock, KBlockScaleSliceSizeK));
|
||||
// A wave access continuous memory
|
||||
auto thread_offset_shuffled =
|
||||
get_thread_local_1d_id() % BlockwiseGemmPipe::WaveSize * KXdlPack * MXdlPack;
|
||||
|
||||
blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
a_block_desc_ak0_m_ak1,
|
||||
a_blockwise_copy,
|
||||
a_grid_buf,
|
||||
a_block_bufs,
|
||||
a_block_slice_copy_step,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
b_block_desc_bk0_n_bk1,
|
||||
b_blockwise_copy,
|
||||
b_grid_buf,
|
||||
b_block_bufs,
|
||||
b_block_slice_copy_step,
|
||||
c_thread_buf,
|
||||
b_scale_grid_desc_bn_ak,
|
||||
b_scale_thread_desc,
|
||||
b_scale_thread_copy,
|
||||
b_scale_grid_buf,
|
||||
b_scale_thread_slice_copy_step,
|
||||
num_k_block_main_loop);
|
||||
auto a_thread_offset_m = waveId_m;
|
||||
|
||||
auto a_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2<
|
||||
AScaleDataType,
|
||||
AScaleDataType,
|
||||
decltype(a_scale_grid_desc_am_ak),
|
||||
decltype(BlockwiseGemmPipe::a_scale_thread_desc),
|
||||
Sequence<1, 1, KXdlPack * MXdlPack / scale_pack_size_a>, // SliceLengths
|
||||
Sequence<0, 1, 2>, // DimAccessOrder
|
||||
2, // SrcVectorDim
|
||||
KXdlPack * MXdlPack / scale_pack_size_a, // SrcScalarPerVector
|
||||
1, // SrcScalarStrideInVector
|
||||
true>(a_scale_grid_desc_am_ak,
|
||||
make_multi_index(block_m_id * MPerBlock / MPerXdl / MXdlPack + a_thread_offset_m,
|
||||
0,
|
||||
thread_offset_shuffled / scale_pack_size_a));
|
||||
|
||||
auto b_thread_offset_n = waveId_n;
|
||||
|
||||
auto b_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2<
|
||||
BScaleDataType,
|
||||
BScaleDataType,
|
||||
decltype(b_scale_grid_desc_bn_ak),
|
||||
decltype(BlockwiseGemmPipe::b_scale_thread_desc),
|
||||
Sequence<1, 1, KXdlPack * NXdlPack / scale_pack_size_b>, // SliceLengths
|
||||
Sequence<0, 1, 2>, // DimAccessOrder
|
||||
2, // SrcVectorDim
|
||||
KXdlPack * MXdlPack / scale_pack_size_b, // SrcScalarPerVector
|
||||
1, // SrcScalarStrideInVector
|
||||
true>(b_scale_grid_desc_bn_ak,
|
||||
make_multi_index(block_n_id * NPerBlock / NPerXdl / NXdlPack + b_thread_offset_n,
|
||||
0,
|
||||
thread_offset_shuffled / scale_pack_size_b));
|
||||
|
||||
blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(a_grid_desc_ak0_m_ak1,
|
||||
a_block_desc_ak0_m_ak1,
|
||||
a_blockwise_copy,
|
||||
a_grid_buf,
|
||||
a_block_bufs,
|
||||
a_block_slice_copy_step,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
b_block_desc_bk0_n_bk1,
|
||||
b_blockwise_copy,
|
||||
b_grid_buf,
|
||||
b_block_bufs,
|
||||
b_block_slice_copy_step,
|
||||
c_thread_buf,
|
||||
a_scale_grid_desc_am_ak,
|
||||
a_scale_thread_copy,
|
||||
a_scale_grid_buf,
|
||||
b_scale_grid_desc_bn_ak,
|
||||
b_scale_thread_copy,
|
||||
b_scale_grid_buf,
|
||||
num_k_block_main_loop);
|
||||
|
||||
// shuffle C and write out
|
||||
{
|
||||
static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
|
||||
NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
|
||||
"wrong!");
|
||||
static_assert(CShuffleMXdlPerWavePerShuffle % MXdlPack == 0 &&
|
||||
CShuffleNXdlPerWavePerShuffle % NXdlPack == 0,
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
|
||||
constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
|
||||
@@ -2087,16 +2109,18 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
|
||||
// TODO: hacky, fix it!
|
||||
// c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
|
||||
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
|
||||
blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
|
||||
blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3();
|
||||
|
||||
constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
|
||||
constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
|
||||
constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
|
||||
constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
|
||||
constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
|
||||
constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
|
||||
constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
|
||||
constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
|
||||
constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
|
||||
constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
|
||||
constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
|
||||
constexpr auto M5 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I8);
|
||||
constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I9);
|
||||
|
||||
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
|
||||
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
|
||||
@@ -2110,19 +2134,25 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
|
||||
make_tuple(
|
||||
make_freeze_transform(I0),
|
||||
make_unmerge_transform(make_tuple(
|
||||
Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
|
||||
M1, // M1 = MWave
|
||||
M2, // M2 * M3 * M4 = MPerXdl
|
||||
M3,
|
||||
M4)),
|
||||
Number<CShuffleMXdlPerWavePerShuffle / MXdlPack>{}, // M0 (MXdlPerWave) per
|
||||
// shuffle
|
||||
M1, // M1 = MWave
|
||||
M2, // M2 = MXdlPack
|
||||
M3, // M3 * M4 * M5 = MPerXdl
|
||||
M4,
|
||||
M5)),
|
||||
make_freeze_transform(I0),
|
||||
make_unmerge_transform(make_tuple(
|
||||
Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
|
||||
N1, // N1 = NWave
|
||||
N2))), // N2 = NPerXdl
|
||||
Number<CShuffleNXdlPerWavePerShuffle / NXdlPack>{}, // N0 (NXdlPerWave) per
|
||||
// shuffle
|
||||
N1, // N1 = NWave
|
||||
N2, // N2 = NXdlPack
|
||||
N3))), // N3 = NPerXdl
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(
|
||||
Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{}));
|
||||
make_tuple(Sequence<>{},
|
||||
Sequence<0, 2, 4, 6, 7, 8>{},
|
||||
Sequence<>{},
|
||||
Sequence<1, 3, 5, 9>{}));
|
||||
|
||||
// calculate origin of thread output tensor on global memory
|
||||
// blockwise GEMM c matrix starting index
|
||||
@@ -2134,8 +2164,8 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
|
||||
|
||||
const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
|
||||
make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
|
||||
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
|
||||
make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4, M5))),
|
||||
make_tuple(Sequence<0, 1, 2, 3, 4, 5>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto m_thread_data_on_block_idx =
|
||||
@@ -2144,8 +2174,8 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
|
||||
|
||||
const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
|
||||
make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(make_tuple(N0, N1, N2))),
|
||||
make_tuple(Sequence<0, 1, 2>{}),
|
||||
make_tuple(make_merge_transform(make_tuple(N0, N1, N2, N3))),
|
||||
make_tuple(Sequence<0, 1, 2, 3>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto n_thread_data_on_block_idx =
|
||||
@@ -2153,36 +2183,39 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
|
||||
make_multi_index(n_thread_data_on_block));
|
||||
|
||||
// shuffle: threadwise copy C from VGPR to LDS
|
||||
auto c_thread_copy_vgpr_to_lds =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
|
||||
CShuffleDataType,
|
||||
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
|
||||
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
Sequence<CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
I1,
|
||||
I1,
|
||||
M2,
|
||||
I1,
|
||||
M4,
|
||||
I1>,
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
|
||||
7,
|
||||
1,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
true>{
|
||||
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
make_multi_index(0,
|
||||
0,
|
||||
m_thread_data_on_block_idx[I1],
|
||||
n_thread_data_on_block_idx[I1],
|
||||
m_thread_data_on_block_idx[I2],
|
||||
m_thread_data_on_block_idx[I3],
|
||||
m_thread_data_on_block_idx[I4],
|
||||
n_thread_data_on_block_idx[I2]),
|
||||
ck::tensor_operation::element_wise::PassThrough{}};
|
||||
auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3<
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
|
||||
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
Sequence<CShuffleMXdlPerWavePerShuffle / MXdlPack,
|
||||
CShuffleNXdlPerWavePerShuffle / NXdlPack,
|
||||
I1,
|
||||
I1,
|
||||
M2,
|
||||
N2,
|
||||
M3,
|
||||
I1,
|
||||
M5,
|
||||
I1>,
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>,
|
||||
9,
|
||||
1,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
make_multi_index(0,
|
||||
0,
|
||||
m_thread_data_on_block_idx[I1],
|
||||
n_thread_data_on_block_idx[I1],
|
||||
m_thread_data_on_block_idx[I2],
|
||||
n_thread_data_on_block_idx[I2],
|
||||
m_thread_data_on_block_idx[I3],
|
||||
m_thread_data_on_block_idx[I4],
|
||||
m_thread_data_on_block_idx[I5],
|
||||
n_thread_data_on_block_idx[I3]),
|
||||
ck::tensor_operation::element_wise::PassThrough{}};
|
||||
|
||||
// shuffle: blockwise copy C from LDS to global
|
||||
auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
|
||||
@@ -2212,12 +2245,23 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
|
||||
|
||||
// space filling curve for threadwise C in VGPR
|
||||
constexpr auto sfc_c_vgpr =
|
||||
SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>,
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
|
||||
Sequence<CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
SpaceFillingCurve<Sequence<MXdlPerWave / MXdlPack,
|
||||
NXdlPerWave / NXdlPack,
|
||||
1,
|
||||
1,
|
||||
MXdlPack,
|
||||
NXdlPack,
|
||||
M2,
|
||||
1,
|
||||
M4,
|
||||
1>,
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>,
|
||||
Sequence<CShuffleMXdlPerWavePerShuffle / MXdlPack,
|
||||
CShuffleNXdlPerWavePerShuffle / NXdlPack,
|
||||
1,
|
||||
1,
|
||||
MXdlPack,
|
||||
NXdlPack,
|
||||
M2,
|
||||
1,
|
||||
M4,
|
||||
@@ -2273,6 +2317,7 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
|
||||
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
|
||||
TailNumber TailNum = TailNumber::Odd>
|
||||
__device__ static void Run_2Lds(const ADataType* p_a_grid,
|
||||
const AScaleDataType* p_a_scale_grid,
|
||||
const BDataType* p_b_grid,
|
||||
const BScaleDataType* p_b_scale_grid,
|
||||
CDataType* p_c_grid,
|
||||
@@ -2286,22 +2331,33 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
|
||||
problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0);
|
||||
const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(
|
||||
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
|
||||
|
||||
const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
|
||||
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
c_grid_desc_m_n, problem.MBlock, problem.NBlock);
|
||||
|
||||
const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor(
|
||||
make_tuple(problem.N, math::integer_divide_ceil(problem.K, ScaleBlockSize/BPackedSize)),
|
||||
make_tuple(problem.StrideScaleB, 1));
|
||||
// A/B shuffled scale for better 8-bit scale access pattern
|
||||
// MNRepeat -> KRepeat -> KThreadPerXdl -> MNThreadPerXdl -> KXdlPack -> MNXdlPack
|
||||
const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(problem.M / (MXdlPack * MPerXdl),
|
||||
math::integer_divide_ceil(problem.K, (ScaleBlockSize / APackedSize)) /
|
||||
(KXdlPack * 64 / MPerXdl),
|
||||
64 * KXdlPack * MXdlPack / scale_pack_size_a));
|
||||
|
||||
const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(problem.N / (NXdlPack * NPerXdl),
|
||||
math::integer_divide_ceil(problem.K, (ScaleBlockSize / BPackedSize)) /
|
||||
(KXdlPack * 64 / NPerXdl),
|
||||
64 * KXdlPack * NXdlPack / scale_pack_size_b));
|
||||
|
||||
Run_2Lds<decltype(a_grid_desc_ak0_m_ak1),
|
||||
decltype(a_scale_grid_desc_am_ak),
|
||||
decltype(b_grid_desc_bk0_n_bk1),
|
||||
decltype(b_scale_grid_desc_bn_ak),
|
||||
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
HasMainKBlockLoop,
|
||||
CGlobalMemoryDataOperation,
|
||||
TailNum>(p_a_grid,
|
||||
p_a_scale_grid,
|
||||
p_b_grid,
|
||||
p_b_scale_grid,
|
||||
p_c_grid,
|
||||
@@ -2309,6 +2365,7 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
|
||||
p_shared_1,
|
||||
problem,
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
a_scale_grid_desc_am_ak,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
b_scale_grid_desc_bn_ak,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock);
|
||||
|
||||
@@ -1022,7 +1022,12 @@ __device__ void amd_direct_load_global_to_lds(const T* global_base_ptr,
|
||||
// Direct loads require that each thread reads and writes exactly a single DWORD.
|
||||
constexpr auto dword_bytes = 4;
|
||||
constexpr auto bytes_per_thread = sizeof(T) * NumElemsPerThread;
|
||||
#if defined(__gfx950__)
|
||||
static_assert(bytes_per_thread == dword_bytes || bytes_per_thread == dword_bytes * 3 ||
|
||||
bytes_per_thread == dword_bytes * 4);
|
||||
#else
|
||||
static_assert(bytes_per_thread == dword_bytes);
|
||||
#endif
|
||||
|
||||
#ifndef CK_CODE_GEN_RTC
|
||||
const uint32_t* global_ptr =
|
||||
@@ -1059,7 +1064,7 @@ __device__ void amd_direct_load_global_to_lds(const T* global_base_ptr,
|
||||
#endif
|
||||
|
||||
llvm_amdgcn_raw_buffer_load_lds(
|
||||
src_resource, lds_ptr, sizeof(uint32_t), global_offset_bytes, 0, 0, 0);
|
||||
src_resource, lds_ptr, bytes_per_thread, global_offset_bytes, 0, 0, 0);
|
||||
#endif
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/utility/functional.hpp"
|
||||
#include "ck/utility/sequence.hpp"
|
||||
#include "ck/utility/tuple.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
@@ -70,4 +71,35 @@ struct static_for<0, N, 1> : detail::make_applier<N>
|
||||
using detail::make_applier<N>::operator();
|
||||
};
|
||||
|
||||
template <typename... Is>
|
||||
struct static_for_range
|
||||
{
|
||||
template <typename F>
|
||||
__host__ __device__ constexpr void operator()(F f) const
|
||||
{
|
||||
// tweak -fbracket-depth if compilation fails. Clang default limit is 256
|
||||
(f(Is{}), ...);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename... Ts>
|
||||
struct static_for_product;
|
||||
template <typename... Is>
|
||||
struct static_for_product<Tuple<Is...>> : public static_for_range<Is...>
|
||||
{
|
||||
};
|
||||
template <typename... Is, typename... Rest>
|
||||
struct static_for_product<Tuple<Is...>, Rest...>
|
||||
{
|
||||
template <typename F>
|
||||
__host__ __device__ constexpr void operator()(F f) const
|
||||
{
|
||||
static_for_product<Tuple<Is...>>{}([&](auto i0) { //
|
||||
static_for_product<Rest...>{}([&](auto... is) { //
|
||||
f(i0, is...);
|
||||
});
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
|
||||
@@ -5,14 +5,22 @@
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <auto v>
|
||||
struct constant
|
||||
{
|
||||
using value_type = decltype(v);
|
||||
using type = constant; // using injected-class-name
|
||||
static constexpr value_type value = v;
|
||||
__host__ __device__ constexpr operator value_type() const noexcept { return value; }
|
||||
__host__ __device__ constexpr value_type operator()() const noexcept { return value; }
|
||||
};
|
||||
|
||||
template <class T, T v>
|
||||
struct integral_constant
|
||||
struct integral_constant : constant<v>
|
||||
{
|
||||
static constexpr T value = v;
|
||||
typedef T value_type;
|
||||
typedef integral_constant type;
|
||||
__host__ __device__ constexpr operator value_type() const noexcept { return value; }
|
||||
__host__ __device__ constexpr value_type operator()() const noexcept { return value; }
|
||||
};
|
||||
|
||||
template <typename TX, TX X, typename TY, TY Y>
|
||||
|
||||
@@ -44,17 +44,18 @@ using device_gemm_mx_xdl_f4_f4_f16_mk_nk_mn_instances = std::tuple<
|
||||
//#############################| | | | Type| Data| Type| Data| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
|
||||
//#############################| | | | | Type| | Type| | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
|
||||
//#############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 32, 128, 128, 16, 16, 16, 16, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, false, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 32, 256, 128, 16, 16, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, false, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 64, 128, 128, 16, 16, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, false, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 64, 256, 128, 16, 16, 16, 16, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, false, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 96, 128, 128, 16, 16, 16, 16, 6, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, false, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 96, 256, 128, 16, 16, 16, 16, 6, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, false, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
|
||||
// DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 32, 128, 128, 16, 16, 16, 16, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
|
||||
// DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 32, 256, 128, 16, 16, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
|
||||
// DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 64, 128, 128, 16, 16, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
|
||||
// DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 64, 256, 128, 16, 16, 16, 16, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
|
||||
// DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 96, 128, 128, 16, 16, 16, 16, 6, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
|
||||
// DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 96, 256, 128, 16, 16, 16, 16, 6, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
|
||||
|
||||
DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 256, 128, 16, 16, 16, 16, 4, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, false, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>,
|
||||
DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 256, 128, 128, 16, 16, 16, 16, 8, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, false, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>,
|
||||
DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 128, 128, 16, 16, 16, 16, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, false, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>,
|
||||
DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 64, 32, 32, 128, 16, 16, 16, 16, 2, 2, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, false, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, false, 2, 2, S<1, 16, 1, 4>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>,
|
||||
DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 256, 256, 128, 16, 16, 16, 16, 8, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>,
|
||||
DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 256, 128, 16, 16, 16, 16, 4, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>,
|
||||
DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 256, 128, 128, 16, 16, 16, 16, 8, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>,
|
||||
DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 128, 128, 16, 16, 16, 16, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>,
|
||||
DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 64, 32, 32, 128, 16, 16, 16, 16, 2, 2, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 16, 1, 4>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>,
|
||||
std::nullptr_t
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
@@ -213,8 +213,7 @@ bool profile_gemm_mx_impl(int do_verification,
|
||||
|
||||
default:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{-2.0, 2.0});
|
||||
a_m_k_scale.GenerateTensorValue(
|
||||
GeneratorTensor_3<XDataType>{powf(2.0f, -125.0f), 1.0f}); // R[2^-125, 1]
|
||||
a_m_k_scale.GenerateTensorValue(GeneratorTensor_3<XDataType>{powf(2.0f, -125.0f), 1.0f});
|
||||
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-2.0, 2.0});
|
||||
b_k_n_scale.GenerateTensorValue(GeneratorTensor_3<XDataType>{powf(2.0f, -125.0f), 1.0f});
|
||||
|
||||
Reference in New Issue
Block a user