mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-16 02:54:21 +00:00
Merge commit '8f1274d9b655c2584b3643acac07ef813f31238e' into develop
This commit is contained in:
22
Jenkinsfile
vendored
22
Jenkinsfile
vendored
@@ -1642,14 +1642,9 @@ pipeline {
|
||||
ninja -j64 benchmark_gemm_preshuffle_all && \
|
||||
python3 ../tile_engine/ops/gemm_preshuffle/gemm_preshuffle_benchmark.py . --problem-sizes "1024,1024,1024" \
|
||||
--warmup 5 --repeat 5 --verbose --json results.json && \
|
||||
ninja -j64 benchmark_gemm_multi_d_fp16_rrrr && \
|
||||
./bin/benchmark_gemm_multi_d_fp16_rrrr && \
|
||||
ninja -j64 benchmark_gemm_multi_d_fp16_ccrr && \
|
||||
./bin/benchmark_gemm_multi_d_fp16_ccrr && \
|
||||
ninja -j64 benchmark_gemm_multi_d_fp16_crrr && \
|
||||
./bin/benchmark_gemm_multi_d_fp16_crrr && \
|
||||
ninja -j64 benchmark_gemm_multi_d_fp16_rcrr && \
|
||||
./bin/benchmark_gemm_multi_d_fp16_rcrr """
|
||||
ninja -j64 benchmark_gemm_multi_d_all && \
|
||||
python3 ../tile_engine/ops/gemm_multi_d/gemm_multi_d_benchmark.py . --problem-sizes "1024,1024,1024" \
|
||||
--warmup 5 --repeat 5 --verbose --json results.json """
|
||||
}
|
||||
steps{
|
||||
buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args)
|
||||
@@ -1682,14 +1677,9 @@ pipeline {
|
||||
ninja -j64 benchmark_gemm_preshuffle_all && \
|
||||
python3 ../tile_engine/ops/gemm_preshuffle/gemm_preshuffle_benchmark.py . --problem-sizes "1024,1024,1024" \
|
||||
--warmup 5 --repeat 5 --verbose --json results.json && \
|
||||
ninja -j64 benchmark_gemm_multi_d_fp16_rrrr && \
|
||||
./bin/benchmark_gemm_multi_d_fp16_rrrr && \
|
||||
ninja -j64 benchmark_gemm_multi_d_fp16_ccrr && \
|
||||
./bin/benchmark_gemm_multi_d_fp16_ccrr && \
|
||||
ninja -j64 benchmark_gemm_multi_d_fp16_crrr && \
|
||||
./bin/benchmark_gemm_multi_d_fp16_crrr && \
|
||||
ninja -j64 benchmark_gemm_multi_d_fp16_rcrr && \
|
||||
./bin/benchmark_gemm_multi_d_fp16_rcrr """
|
||||
ninja -j64 benchmark_gemm_multi_d_all && \
|
||||
python3 ../tile_engine/ops/gemm_multi_d/gemm_multi_d_benchmark.py . --problem-sizes "1024,1024,1024" \
|
||||
--warmup 5 --repeat 5 --verbose --json results.json """
|
||||
}
|
||||
steps{
|
||||
buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args)
|
||||
|
||||
@@ -49,7 +49,7 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
|
||||
GemmConfig::kPadN,
|
||||
GemmConfig::kPadK,
|
||||
false, // PreshuffleQuant
|
||||
false, // PreshuffleB
|
||||
GemmConfig::PreshuffleB, // PreshuffleB
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
@@ -58,7 +58,7 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
|
||||
BQLayout,
|
||||
GemmConfig::TransposeC,
|
||||
GemmConfig::DoubleSmemBuffer,
|
||||
true>;
|
||||
true>; // Persistence
|
||||
|
||||
float ave_time{0};
|
||||
|
||||
@@ -86,10 +86,14 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
|
||||
BDataType,
|
||||
scheduler>>::type;
|
||||
|
||||
using GemmPipeline =
|
||||
typename std::conditional<QuantMode == ck_tile::QuantType::BQuantGrouped,
|
||||
ck_tile::BQuantGemmPipelineAgBgCrCompV3<QuantGemmProblem>,
|
||||
ck_tile::GemmPipelineAgBgCrCompV3<QuantGemmProblem>>::type;
|
||||
using GemmPipeline = std::conditional_t<
|
||||
QuantMode == ck_tile::QuantType::RowColQuant ||
|
||||
QuantMode == ck_tile::QuantType::TensorQuant,
|
||||
ck_tile::GemmPipelineAgBgCrCompV3<QuantGemmProblem>,
|
||||
std::conditional_t<GemmConfig::PreshuffleB == true,
|
||||
ck_tile::WPQuantBPipelineAgBgCrV2<QuantGemmProblem>,
|
||||
ck_tile::BQuantGemmPipelineAgBgCrCompV3<QuantGemmProblem>>>;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
@@ -141,5 +145,6 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
return !run_grouped_gemm_example<GemmConfigComputeV3_2>(argc, argv);
|
||||
int result1 = !run_grouped_gemm_example<GemmConfigPreshuffleB_Bquant_prefill>(argc, argv);
|
||||
return result1;
|
||||
}
|
||||
|
||||
@@ -10,9 +10,6 @@
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
|
||||
|
||||
#define CK_TILE_PIPELINE_COMPUTE_V3 1
|
||||
#define CK_TILE_PIPELINE_BQUANT_COMPUTE_V3 2
|
||||
|
||||
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
|
||||
constexpr ck_tile::index_t get_k_warp_tile()
|
||||
{
|
||||
@@ -31,6 +28,22 @@ constexpr ck_tile::index_t get_k_warp_tile()
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
|
||||
constexpr ck_tile::index_t get_k_from_preshuffled_warp_tile()
|
||||
{
|
||||
#if defined(CK_GFX950_SUPPORT)
|
||||
if constexpr(M_Warp_Tile == 32)
|
||||
return sizeof(PrecType) == 2 ? 16 : 64;
|
||||
else
|
||||
return sizeof(PrecType) == 2 ? 32 : 128;
|
||||
#else
|
||||
if constexpr(M_Warp_Tile == 32)
|
||||
return sizeof(PrecType) == 2 ? 16 : 32;
|
||||
else
|
||||
return sizeof(PrecType) == 2 ? 32 : 64;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename DataType>
|
||||
struct GemmTypeConfig;
|
||||
|
||||
@@ -67,8 +80,9 @@ struct GemmConfigBase
|
||||
static constexpr ck_tile::index_t TileParitionerGroupNum = 8;
|
||||
static constexpr ck_tile::index_t TileParitionerM01 = 4;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
|
||||
static constexpr ck_tile::index_t NumWaveGroups = 1;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr bool PreshuffleB = false;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
@@ -85,10 +99,26 @@ struct GemmConfigComputeV3_2 : public GemmConfigBase
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
};
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
template <typename PrecType>
|
||||
struct GemmConfigPreshuffleB_Bquant_prefill : public GemmConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
|
||||
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 4;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
get_k_from_preshuffled_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr bool PreshuffleB = true;
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
};
|
||||
|
||||
using grouped_gemm_kargs = ck_tile::QuantGroupedGemmHostArgs;
|
||||
@@ -118,7 +148,8 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("repeat", "100", "number of iterations to benchmark the kernel.")
|
||||
.insert("group_count", "8", "group count.")
|
||||
.insert("kbatch", "1", "kbatch for SplitK")
|
||||
.insert("quant_mode", "bquant", "Choose bquant (default), tensor, or rowcol");
|
||||
.insert("quant_mode", "bquant", "Choose bquant (default), tensor, or rowcol")
|
||||
.insert("init", "0", "0. Random, 2. One(s) (Constant)");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
|
||||
@@ -163,6 +163,7 @@ int run_grouped_gemm_example_with_layouts(int argc,
|
||||
const int repeat = arg_parser.get_int("repeat");
|
||||
const int warmup = arg_parser.get_int("warmup");
|
||||
const int kbatch = arg_parser.get_int("kbatch");
|
||||
const int init_method = arg_parser.get_int("init");
|
||||
bool validate = arg_parser.get_bool("validate");
|
||||
const ck_tile::index_t QuantGroupSize = 128;
|
||||
|
||||
@@ -203,6 +204,7 @@ int run_grouped_gemm_example_with_layouts(int argc,
|
||||
|
||||
for(int i = 0; i < group_count; i++)
|
||||
{
|
||||
|
||||
Ms.push_back(256 + 256 * i);
|
||||
Ns.push_back(256 + 512 * i);
|
||||
Ks.push_back(512 + 128 * i);
|
||||
@@ -280,6 +282,12 @@ int run_grouped_gemm_example_with_layouts(int argc,
|
||||
stride_AQs[i] = 1; // Tensor quantization: tensor shape [1]
|
||||
stride_BQs[i] = 1; // Tensor quantization: tensor shape [1]
|
||||
}
|
||||
else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped)
|
||||
{
|
||||
stride_AQs[i] = 0; // No A quantization
|
||||
stride_BQs[i] =
|
||||
ck_tile::get_default_stride(BQK, N, stride_BQs[i], is_row_major(bq_layout));
|
||||
}
|
||||
|
||||
a_m_k_tensors.push_back(ck_tile::HostTensor<ADataType>(
|
||||
ck_tile::host_tensor_descriptor(M, K, stride_As[i], is_row_major(a_layout))));
|
||||
@@ -313,10 +321,20 @@ int run_grouped_gemm_example_with_layouts(int argc,
|
||||
<< " b_k_n: " << b_k_n_tensors[i].mDesc << " c_m_n: " << c_m_n_tensors[i].mDesc
|
||||
<< " aq: " << aq_tensors[i].mDesc << " bq: " << bq_tensors[i].mDesc << std::endl;
|
||||
|
||||
ck_tile::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k_tensors[i]);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n_tensors[i]);
|
||||
ck_tile::FillUniformDistribution<AQDataType>{-1.f, 1.f}(aq_tensors[i]);
|
||||
ck_tile::FillUniformDistribution<BQDataType>{-1.f, 1.f}(bq_tensors[i]);
|
||||
if(init_method == 2)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{1.f, 1.f}(a_m_k_tensors[i]);
|
||||
ck_tile::FillUniformDistribution<BDataType>{1.f, 1.f}(b_k_n_tensors[i]);
|
||||
ck_tile::FillUniformDistribution<AQDataType>{1.f, 1.f}(aq_tensors[i]);
|
||||
ck_tile::FillUniformDistribution<BQDataType>{1.f, 1.f}(bq_tensors[i]);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k_tensors[i]);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n_tensors[i]);
|
||||
ck_tile::FillUniformDistribution<AQDataType>{-1.f, 1.f}(aq_tensors[i]);
|
||||
ck_tile::FillUniformDistribution<BQDataType>{-1.f, 1.f}(bq_tensors[i]);
|
||||
}
|
||||
|
||||
a_m_k_dev_buf.push_back(std::make_unique<ck_tile::DeviceMem>(
|
||||
a_m_k_tensors[i].get_element_space_size_in_bytes()));
|
||||
@@ -329,8 +347,18 @@ int run_grouped_gemm_example_with_layouts(int argc,
|
||||
bq_dev_buf.push_back(
|
||||
std::make_unique<ck_tile::DeviceMem>(bq_tensors[i].get_element_space_size_in_bytes()));
|
||||
|
||||
if constexpr(GemmConfig::PreshuffleB && QuantMode == ck_tile::QuantType::BQuantGrouped)
|
||||
{
|
||||
ck_tile::HostTensor<BDataType> b_shuffle_host =
|
||||
ck_tile::shuffle_b<GemmConfig>(b_k_n_tensors[i]);
|
||||
b_k_n_dev_buf[i]->ToDevice(b_shuffle_host.data());
|
||||
}
|
||||
else
|
||||
{
|
||||
b_k_n_dev_buf[i]->ToDevice(b_k_n_tensors[i].data());
|
||||
}
|
||||
|
||||
a_m_k_dev_buf[i]->ToDevice(a_m_k_tensors[i].data());
|
||||
b_k_n_dev_buf[i]->ToDevice(b_k_n_tensors[i].data());
|
||||
aq_dev_buf[i]->ToDevice(aq_tensors[i].data());
|
||||
bq_dev_buf[i]->ToDevice(bq_tensors[i].data());
|
||||
c_m_n_dev_buf[i]->SetZero();
|
||||
|
||||
@@ -292,13 +292,15 @@ struct BlockwiseGemmWmmaops_pipeline_base
|
||||
make_tuple(Number<MRepeat>{}, I1, I1, Number<NRepeat>{}, I1, I1, NAccVgprs));
|
||||
}
|
||||
|
||||
static constexpr auto MAccVgprs =
|
||||
wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths()[I2];
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs()
|
||||
{
|
||||
constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens =
|
||||
wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths();
|
||||
|
||||
constexpr auto MAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2];
|
||||
constexpr auto AccStride = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I3];
|
||||
return make_naive_tensor_descriptor(
|
||||
// |MRepeat |MWave |MSubGroup |NRepeat |NWave
|
||||
|
||||
@@ -42,7 +42,8 @@ template <typename ThreadGroup,
|
||||
index_t DstScalarPerVector,
|
||||
typename ThreadTransferSrcResetCoordinateAfterRunFlags,
|
||||
typename ThreadTransferDstResetCoordinateAfterRunFlags,
|
||||
index_t NumThreadScratch = 1>
|
||||
index_t NumThreadScratch = 1,
|
||||
typename InterDatas = DstDatas>
|
||||
struct ThreadGroupTensorSliceTransfer_v7r3
|
||||
{
|
||||
static constexpr index_t nDim =
|
||||
@@ -97,7 +98,7 @@ struct ThreadGroupTensorSliceTransfer_v7r3
|
||||
static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(),
|
||||
"wrong! ThreadGroup::GetNumOfThread() too small");
|
||||
|
||||
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
|
||||
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() ||
|
||||
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
|
||||
@@ -123,7 +124,7 @@ struct ThreadGroupTensorSliceTransfer_v7r3
|
||||
const SrcBuffers& src_bufs,
|
||||
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
|
||||
{
|
||||
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
|
||||
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() ||
|
||||
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.RunRead(src_descs, src_bufs, thread_scratch_id);
|
||||
@@ -138,7 +139,7 @@ struct ThreadGroupTensorSliceTransfer_v7r3
|
||||
DstBuffers dst_bufs,
|
||||
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
|
||||
{
|
||||
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
|
||||
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() ||
|
||||
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
if constexpr(is_detected<is_tuple, decltype(dst_bufs)>::value)
|
||||
@@ -148,6 +149,36 @@ struct ThreadGroupTensorSliceTransfer_v7r3
|
||||
}
|
||||
}
|
||||
|
||||
template <typename DstBuffers,
|
||||
typename DstVgprDescs,
|
||||
typename DstVgprBuffers,
|
||||
index_t ThreadScratchId = 0>
|
||||
__device__ void
|
||||
RunWriteAndStoreVgpr(const DstDescs& dst_descs,
|
||||
DstBuffers dst_bufs,
|
||||
const DstVgprDescs& dst_vgpr_desc,
|
||||
DstVgprBuffers dst_vgpr_buf,
|
||||
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
|
||||
{
|
||||
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() ||
|
||||
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
if constexpr(is_detected<is_tuple, decltype(dst_bufs)>::value &&
|
||||
is_detected<is_tuple, decltype(dst_vgpr_buf)>::value)
|
||||
threadwise_transfer_.RunWriteAndStoreVgpr(
|
||||
dst_descs, dst_bufs, dst_vgpr_desc, dst_vgpr_buf, thread_scratch_id);
|
||||
else if constexpr(is_detected<is_tuple, decltype(dst_bufs)>::value)
|
||||
threadwise_transfer_.RunWriteAndStoreVgpr(
|
||||
dst_descs, dst_bufs, dst_vgpr_desc, tie(dst_vgpr_buf), thread_scratch_id);
|
||||
else if constexpr(is_detected<is_tuple, decltype(dst_vgpr_buf)>::value)
|
||||
threadwise_transfer_.RunWriteAndStoreVgpr(
|
||||
dst_descs, tie(dst_bufs), dst_vgpr_desc, dst_vgpr_buf, thread_scratch_id);
|
||||
else
|
||||
threadwise_transfer_.RunWriteAndStoreVgpr(
|
||||
dst_descs, tie(dst_bufs), dst_vgpr_desc, tie(dst_vgpr_buf), thread_scratch_id);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SrcBuffers, typename DstBuffers>
|
||||
__device__ void Run(const SrcDescs& src_descs,
|
||||
const SrcBuffers& src_bufs,
|
||||
@@ -162,7 +193,7 @@ struct ThreadGroupTensorSliceTransfer_v7r3
|
||||
__device__ void
|
||||
MoveSrcSliceWindow(const SrcDescs& src_descs, Number<ISrc> iSrc, const Index& step)
|
||||
{
|
||||
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
|
||||
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() ||
|
||||
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.MoveSrcSliceWindow(src_descs, iSrc, step);
|
||||
@@ -179,7 +210,7 @@ struct ThreadGroupTensorSliceTransfer_v7r3
|
||||
__device__ void
|
||||
MoveDstSliceWindow(const DstDescs& dst_descs, Number<IDst> iDst, const Index& step)
|
||||
{
|
||||
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
|
||||
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() ||
|
||||
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.MoveDstSliceWindow(dst_descs, iDst, step);
|
||||
@@ -212,7 +243,8 @@ struct ThreadGroupTensorSliceTransfer_v7r3
|
||||
DstScalarPerVector,
|
||||
ThreadTransferSrcResetCoordinateAfterRunFlags,
|
||||
ThreadTransferDstResetCoordinateAfterRunFlags,
|
||||
NumThreadScratch>;
|
||||
NumThreadScratch,
|
||||
InterDatas>;
|
||||
|
||||
ThreadwiseTransfer threadwise_transfer_;
|
||||
};
|
||||
|
||||
@@ -60,7 +60,9 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
const long_index_t c_batch_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx));
|
||||
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte<
|
||||
typename GridwiseGemm::EpilogueCShuffle>();
|
||||
__shared__ char p_shared[LDS_size];
|
||||
|
||||
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
|
||||
|
||||
@@ -82,6 +84,8 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
splitk_batch_offset.b_k_split_offset[i] + b_batch_offset;
|
||||
});
|
||||
|
||||
auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{};
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
|
||||
p_as_grid_shift,
|
||||
p_bs_grid_shift,
|
||||
@@ -91,7 +95,8 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
karg,
|
||||
karg.a_element_op,
|
||||
karg.b_element_op,
|
||||
karg.cde_element_op);
|
||||
karg.cde_element_op,
|
||||
epilogue_args);
|
||||
#if defined(__gfx11__)
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -46,12 +46,14 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
std::is_same_v<c_data_type, ck::bhalf_t>)))
|
||||
{
|
||||
#endif
|
||||
constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte<
|
||||
typename GridwiseGemm::EpilogueCShuffle>();
|
||||
// The normal approach to batching would be to increase the grid size by just stretching out
|
||||
// the grid Z dimension (which is the outermost dimension), but this depends on lower level
|
||||
// functions not directly using the Z dimension for other calculations. As it turns out, k
|
||||
// batching does rely directly on blockIdx.Z through SplitKBatchOffset. Therefore, for now
|
||||
// we will use the grid Y dimension for batching. This may be a bit fragile.
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
__shared__ char p_shared[LDS_size];
|
||||
|
||||
const index_t g_idx = amd_wave_read_first_lane(blockIdx.y);
|
||||
|
||||
@@ -84,6 +86,8 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
splitk_batch_offset.b_k_split_offset[i] + b_batch_offset;
|
||||
});
|
||||
|
||||
auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{};
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
|
||||
p_as_grid_shift,
|
||||
p_bs_grid_shift,
|
||||
@@ -94,7 +98,8 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
karg,
|
||||
karg.a_element_op,
|
||||
karg.b_element_op,
|
||||
karg.cde_element_op);
|
||||
karg.cde_element_op,
|
||||
epilogue_args);
|
||||
#if defined(__gfx11__)
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -0,0 +1,896 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d_layernorm.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_welford_second_half_layernorm2d.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename GridwiseGemm,
|
||||
typename EMeanVarDataType,
|
||||
bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
|
||||
index_t MinimumOccupancy = 1,
|
||||
TailNumber TailNum = TailNumber::Full>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
#endif
|
||||
kernel_gemm_multiple_d_welford_first_half_wmma_cshuffle_v3(
|
||||
typename GridwiseGemm::Argument karg,
|
||||
EMeanVarDataType* __restrict__ p_welford_mean_grid,
|
||||
EMeanVarDataType* __restrict__ p_welford_var_grid,
|
||||
int32_t* __restrict__ p_welford_count_grid)
|
||||
{
|
||||
#if(defined(__gfx11__) || defined(__gfx12__))
|
||||
#if defined(__gfx11__)
|
||||
// gfx11 does not support *_atomic_pk_add_f16/bf16 instructions
|
||||
using e_data_type = remove_cvref_t<remove_pointer_t<decltype(karg.p_e_grid)>>;
|
||||
if constexpr(!(EGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd &&
|
||||
(std::is_same_v<e_data_type, ck::half_t> ||
|
||||
std::is_same_v<e_data_type, ck::bhalf_t>)))
|
||||
{
|
||||
#endif
|
||||
constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte<
|
||||
typename GridwiseGemm::EpilogueWelfordCShuffle>();
|
||||
|
||||
__shared__ char p_shared[LDS_size];
|
||||
|
||||
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
|
||||
|
||||
auto epilogue_args = typename GridwiseGemm::EpilogueWelfordCShuffle(
|
||||
p_welford_mean_grid, p_welford_var_grid, p_welford_count_grid, karg.M, karg.N);
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop, EGlobalMemoryDataOperation, TailNum>(
|
||||
p_shared, splitk_batch_offset, karg, epilogue_args);
|
||||
|
||||
#if defined(__gfx11__)
|
||||
}
|
||||
#endif
|
||||
#else
|
||||
ignore = karg;
|
||||
ignore = p_welford_mean_grid;
|
||||
ignore = p_welford_var_grid;
|
||||
ignore = p_welford_count_grid;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename GridwiseWelfordLayernorm,
|
||||
typename EMeanVarDataType,
|
||||
typename HDataType,
|
||||
typename GammaDataType,
|
||||
typename BetaDataType,
|
||||
typename ComputeDataType,
|
||||
typename EHGridDesc_M_N,
|
||||
typename LayernormMeanVarGridDesc_M_NBlock,
|
||||
typename LayernormCountGridDesc_M_NBlock,
|
||||
typename GammaBetaGridDesc_N,
|
||||
typename HElementwiseOperation>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_welford_layernorm2d_second_half(
|
||||
const EMeanVarDataType* __restrict__ p_e_grid,
|
||||
const EMeanVarDataType* __restrict__ p_in_welford_mean_grid,
|
||||
const EMeanVarDataType* __restrict__ p_in_welford_var_grid,
|
||||
const int32_t* __restrict__ p_in_welford_count_grid,
|
||||
const GammaDataType* __restrict__ p_gamma_grid,
|
||||
const BetaDataType* __restrict__ p_beta_grid,
|
||||
HDataType* __restrict__ p_h_grid,
|
||||
const EHGridDesc_M_N e_grid_desc_m_n,
|
||||
const EHGridDesc_M_N h_grid_desc_m_n,
|
||||
const LayernormMeanVarGridDesc_M_NBlock mean_var_grid_desc_m_nblock,
|
||||
const LayernormCountGridDesc_M_NBlock count_grid_desc_m_nblock,
|
||||
const GammaBetaGridDesc_N gamma_grid_desc_n,
|
||||
const GammaBetaGridDesc_N beta_grid_desc_n,
|
||||
index_t numMeanVarCountBlockTileIteration_N,
|
||||
index_t NBlockClusterLength,
|
||||
ComputeDataType epsilon,
|
||||
HElementwiseOperation h_element_op)
|
||||
{
|
||||
GridwiseWelfordLayernorm::Run(p_e_grid,
|
||||
p_in_welford_mean_grid,
|
||||
p_in_welford_var_grid,
|
||||
p_in_welford_count_grid,
|
||||
p_gamma_grid,
|
||||
p_beta_grid,
|
||||
p_h_grid,
|
||||
e_grid_desc_m_n,
|
||||
h_grid_desc_m_n,
|
||||
mean_var_grid_desc_m_nblock,
|
||||
count_grid_desc_m_nblock,
|
||||
gamma_grid_desc_n,
|
||||
beta_grid_desc_n,
|
||||
numMeanVarCountBlockTileIteration_N,
|
||||
NBlockClusterLength,
|
||||
epsilon,
|
||||
h_element_op);
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename HLayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename HDataType,
|
||||
typename AccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename EMeanVarDataType, // LayerNorm
|
||||
typename GammaDataType, // LayerNorm
|
||||
typename BetaDataType, // LayerNorm
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation,
|
||||
typename HElementwiseOperation,
|
||||
GemmSpecialization GemmSpec,
|
||||
index_t BlockSize,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t AK1,
|
||||
index_t BK1,
|
||||
index_t MPerWmma,
|
||||
index_t NPerWmma,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
index_t ABlockTransferSrcVectorDim,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t ABlockTransferDstScalarPerVector_AK1,
|
||||
bool ABlockLdsExtraM,
|
||||
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
index_t BBlockTransferSrcVectorDim,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferDstScalarPerVector_BK1,
|
||||
bool BBlockLdsExtraN,
|
||||
index_t CShuffleMRepeatPerShuffle,
|
||||
index_t CShuffleNRepeatPerShuffle,
|
||||
typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CDEShuffleBlockTransferScalarPerVector,
|
||||
typename LayernormThreadClusterSize_M_N,
|
||||
index_t LayernormThreadSliceSize_M,
|
||||
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1,
|
||||
typename ComputeTypeA = HDataType,
|
||||
typename ComputeTypeB = ComputeTypeA,
|
||||
bool PermuteA = false,
|
||||
bool PermuteB = false>
|
||||
struct DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3
|
||||
: public DeviceGemmMultipleDLayernorm<ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
HLayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
GammaDataType,
|
||||
BetaDataType,
|
||||
HDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
HElementwiseOperation>
|
||||
{
|
||||
// EDataType, MeanDataType and VarDataType must be the same.
|
||||
using DeviceOp = DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3;
|
||||
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
static constexpr index_t LayernormHDstVectorSize = CDEShuffleBlockTransferScalarPerVector;
|
||||
static constexpr index_t LayernormGammaSrcVectorSize = CDEShuffleBlockTransferScalarPerVector;
|
||||
static constexpr index_t LayernormBetaSrcVectorSize = CDEShuffleBlockTransferScalarPerVector;
|
||||
static constexpr index_t LayernormESrcVectorSize = CDEShuffleBlockTransferScalarPerVector;
|
||||
static constexpr index_t LayernormThreadSliceSize_N = CDEShuffleBlockTransferScalarPerVector;
|
||||
|
||||
using LayernormBlockTileSize_M_N =
|
||||
Sequence<LayernormThreadClusterSize_M_N::At(0) * LayernormThreadSliceSize_M,
|
||||
LayernormThreadClusterSize_M_N::At(1) * LayernormThreadSliceSize_N>;
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
|
||||
using CDEShuffleBlockTransferScalarPerVectors =
|
||||
Sequence<CDEShuffleBlockTransferScalarPerVector,
|
||||
CDEShuffleBlockTransferScalarPerVector,
|
||||
CDEShuffleBlockTransferScalarPerVector>;
|
||||
|
||||
// GEMM + Welford 1st part kernel
|
||||
using GridwiseGemmWelford = GridwiseGemm_wmma_cshuffle_v3<
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
HLayout,
|
||||
Tuple<ADataType>,
|
||||
Tuple<BDataType>,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
DsDataType,
|
||||
EMeanVarDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
GemmSpec,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
AK1,
|
||||
BK1,
|
||||
MPerWmma,
|
||||
NPerWmma,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_AK1,
|
||||
false,
|
||||
ABlockLdsExtraM,
|
||||
BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_BK1,
|
||||
false,
|
||||
BBlockLdsExtraN,
|
||||
CShuffleMRepeatPerShuffle,
|
||||
CShuffleNRepeatPerShuffle,
|
||||
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CDEShuffleBlockTransferScalarPerVectors,
|
||||
BlkGemmPipeSched,
|
||||
BlkGemmPipelineVer,
|
||||
ComputeTypeA,
|
||||
ComputeTypeB,
|
||||
PermuteA,
|
||||
PermuteB>;
|
||||
|
||||
// Welford 2nd part kernel
|
||||
template <typename DoPads, index_t MPerTile, index_t NPerTile>
|
||||
static auto MakeEHGridDescriptor_M_N(index_t M, index_t N, index_t Stride)
|
||||
{
|
||||
// Only support row major for E and H
|
||||
const auto grid_desc_m_n =
|
||||
make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(Stride, I1));
|
||||
return PadTensorDescriptor(grid_desc_m_n, make_tuple(MPerTile, NPerTile), DoPads{});
|
||||
}
|
||||
|
||||
template <index_t XPerTile>
|
||||
static auto MakeDescriptor_X(index_t X)
|
||||
{
|
||||
const auto grid_desc_x = make_naive_tensor_descriptor_packed(make_tuple(X));
|
||||
return PadTensorDescriptor(grid_desc_x, make_tuple(XPerTile), Sequence<true>{});
|
||||
}
|
||||
|
||||
using LayernormMeanVarGridDesc_M_NBlock =
|
||||
decltype(GridwiseGemmWelford::EpilogueWelfordCShuffle::template MakeMeanVarDescriptor_M_N<
|
||||
Sequence<true, true>,
|
||||
LayernormBlockTileSize_M_N::At(0),
|
||||
LayernormBlockTileSize_M_N::At(1)>(1, 1));
|
||||
|
||||
using LayernormCountGridDesc_M_NBlock =
|
||||
decltype(GridwiseGemmWelford::EpilogueWelfordCShuffle::template MakeCountDescriptor_M_N<
|
||||
Sequence<true, true>,
|
||||
LayernormBlockTileSize_M_N::At(0),
|
||||
LayernormBlockTileSize_M_N::At(1)>(1, 1));
|
||||
|
||||
using GammaBetaGridDesc_N = decltype(MakeDescriptor_X<LayernormBlockTileSize_M_N::At(1)>(1));
|
||||
using EHGridDesc_M_N = decltype(MakeEHGridDescriptor_M_N<Sequence<true, true>, 1, 1>(1, 1, 1));
|
||||
|
||||
using GridwiseWelfordLayernorm =
|
||||
GridwiseWelfordSecondHalfLayernorm2d<EMeanVarDataType,
|
||||
HDataType,
|
||||
GammaDataType,
|
||||
BetaDataType,
|
||||
AccDataType,
|
||||
EHGridDesc_M_N,
|
||||
LayernormMeanVarGridDesc_M_NBlock,
|
||||
LayernormCountGridDesc_M_NBlock,
|
||||
GammaBetaGridDesc_N,
|
||||
HElementwiseOperation,
|
||||
BlockSize,
|
||||
LayernormThreadClusterSize_M_N::At(I0),
|
||||
LayernormThreadClusterSize_M_N::At(I1),
|
||||
LayernormThreadSliceSize_M,
|
||||
LayernormThreadSliceSize_N,
|
||||
LayernormESrcVectorSize,
|
||||
LayernormHDstVectorSize,
|
||||
LayernormGammaSrcVectorSize,
|
||||
LayernormBetaSrcVectorSize>;
|
||||
|
||||
// Argument
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
Argument(const void* p_a_grid,
|
||||
const void* p_b_grid,
|
||||
std::array<const void*, NumDTensor> p_ds_grid,
|
||||
const void* p_gamma_grid,
|
||||
const void* p_beta_grid,
|
||||
void* p_h_grid,
|
||||
index_t MRaw,
|
||||
index_t NRaw,
|
||||
index_t KRaw,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
std::array<index_t, NumDTensor> StrideDs,
|
||||
index_t StrideH,
|
||||
double epsilon,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op,
|
||||
HElementwiseOperation h_element_op)
|
||||
: p_a_grid_{static_cast<const ADataType*>(p_a_grid)},
|
||||
p_b_grid_{static_cast<const BDataType*>(p_b_grid)},
|
||||
p_ds_grid_{},
|
||||
p_workspace_e_grid_{nullptr},
|
||||
p_workspace_mean_{nullptr},
|
||||
p_workspace_var_{nullptr},
|
||||
p_workspace_count_{nullptr},
|
||||
p_gamma_grid_{static_cast<const GammaDataType*>(p_gamma_grid)},
|
||||
p_beta_grid_{static_cast<const BetaDataType*>(p_beta_grid)},
|
||||
p_h_grid_{static_cast<HDataType*>(p_h_grid)},
|
||||
layernorm_e_grid_desc_m_n_{
|
||||
DeviceOp::MakeEHGridDescriptor_M_N<Sequence<true, true>,
|
||||
LayernormBlockTileSize_M_N::At(0),
|
||||
LayernormBlockTileSize_M_N::At(1)>(
|
||||
MRaw, NRaw, StrideH)},
|
||||
layernorm_mean_var_grid_desc_m_nblock_{},
|
||||
layernorm_count_grid_desc_m_nblock_{},
|
||||
gamma_grid_desc_n_{
|
||||
DeviceOp::MakeDescriptor_X<LayernormBlockTileSize_M_N::At(1)>(NRaw)},
|
||||
beta_grid_desc_n_{
|
||||
DeviceOp::MakeDescriptor_X<LayernormBlockTileSize_M_N::At(1)>(NRaw)},
|
||||
h_grid_desc_m_n_{
|
||||
DeviceOp::MakeEHGridDescriptor_M_N<Sequence<true, true>,
|
||||
LayernormBlockTileSize_M_N::At(0),
|
||||
LayernormBlockTileSize_M_N::At(1)>(
|
||||
MRaw, NRaw, StrideH)},
|
||||
a_element_op_{a_element_op},
|
||||
b_element_op_{b_element_op},
|
||||
cde_element_op_{cde_element_op},
|
||||
h_element_op_{h_element_op},
|
||||
MRaw_{MRaw},
|
||||
NRaw_{NRaw},
|
||||
KRaw_{KRaw},
|
||||
StrideA_{StrideA},
|
||||
StrideB_{StrideB},
|
||||
StrideDs_{StrideDs},
|
||||
StrideH_{StrideH},
|
||||
gemm_nblock_{math::integer_divide_ceil(NRaw, NPerBlock)},
|
||||
epsilon_{static_cast<AccDataType>(epsilon)}
|
||||
{
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) { p_ds_grid_[i] = p_ds_grid[i]; });
|
||||
|
||||
layernorm_mean_var_grid_desc_m_nblock_ =
|
||||
GridwiseGemmWelford::EpilogueWelfordCShuffle::template MakeMeanVarDescriptor_M_N<
|
||||
Sequence<true, true>,
|
||||
LayernormBlockTileSize_M_N::At(0),
|
||||
LayernormBlockTileSize_M_N::At(1)>(MRaw, gemm_nblock_);
|
||||
|
||||
layernorm_count_grid_desc_m_nblock_ =
|
||||
GridwiseGemmWelford::EpilogueWelfordCShuffle::template MakeCountDescriptor_M_N<
|
||||
Sequence<true, true>,
|
||||
LayernormBlockTileSize_M_N::At(0),
|
||||
LayernormBlockTileSize_M_N::At(1)>(MRaw, gemm_nblock_);
|
||||
}
|
||||
|
||||
// pointers
|
||||
const ADataType* p_a_grid_;
|
||||
const BDataType* p_b_grid_;
|
||||
std::array<const void*, NumDTensor> p_ds_grid_;
|
||||
void* p_workspace_e_grid_;
|
||||
void* p_workspace_mean_;
|
||||
void* p_workspace_var_;
|
||||
void* p_workspace_count_;
|
||||
const GammaDataType* p_gamma_grid_;
|
||||
const BetaDataType* p_beta_grid_;
|
||||
HDataType* p_h_grid_;
|
||||
|
||||
// tensor descriptors (Welford second half)
|
||||
EHGridDesc_M_N layernorm_e_grid_desc_m_n_;
|
||||
LayernormMeanVarGridDesc_M_NBlock layernorm_mean_var_grid_desc_m_nblock_;
|
||||
LayernormCountGridDesc_M_NBlock layernorm_count_grid_desc_m_nblock_;
|
||||
GammaBetaGridDesc_N gamma_grid_desc_n_;
|
||||
GammaBetaGridDesc_N beta_grid_desc_n_;
|
||||
EHGridDesc_M_N h_grid_desc_m_n_;
|
||||
|
||||
// element-wise op
|
||||
AElementwiseOperation a_element_op_;
|
||||
BElementwiseOperation b_element_op_;
|
||||
CDEElementwiseOperation cde_element_op_;
|
||||
HElementwiseOperation h_element_op_;
|
||||
|
||||
index_t MRaw_;
|
||||
index_t NRaw_;
|
||||
index_t KRaw_;
|
||||
index_t StrideA_;
|
||||
index_t StrideB_;
|
||||
std::array<index_t, NumDTensor> StrideDs_;
|
||||
index_t StrideH_;
|
||||
index_t gemm_nblock_;
|
||||
AccDataType epsilon_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
typename GridwiseGemmWelford::Argument gemm_arg{
|
||||
std::array<const void*, 1>{arg.p_a_grid_},
|
||||
std::array<const void*, 1>{arg.p_b_grid_},
|
||||
arg.p_ds_grid_,
|
||||
static_cast<EMeanVarDataType*>(arg.p_workspace_e_grid_),
|
||||
arg.MRaw_,
|
||||
arg.NRaw_,
|
||||
arg.KRaw_,
|
||||
std::array<index_t, 1>{arg.StrideA_}, // StrideAs
|
||||
std::array<index_t, 1>{arg.StrideB_}, // StrideBs
|
||||
arg.StrideDs_, // StrideDs
|
||||
arg.StrideH_, // StrideE
|
||||
I1, // kbatch
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.cde_element_op_};
|
||||
|
||||
if(stream_config.log_level_ > 0)
|
||||
{
|
||||
gemm_arg.Print();
|
||||
GridwiseGemmWelford::BlockwiseGemmPipe::HotLoopInstList::Print();
|
||||
}
|
||||
|
||||
if(!GridwiseGemmWelford::CheckValidity(gemm_arg))
|
||||
{
|
||||
throw std::runtime_error("wrong! GridwiseGemmWelford has invalid setting");
|
||||
}
|
||||
|
||||
if(arg.p_workspace_e_grid_ == nullptr || arg.p_workspace_mean_ == nullptr ||
|
||||
arg.p_workspace_var_ == nullptr || arg.p_workspace_count_ == nullptr)
|
||||
throw std::runtime_error("wrong! WorkSpace pointer has not been set");
|
||||
|
||||
index_t gdx, gdy, gdz;
|
||||
std::tie(gdx, gdy, gdz) =
|
||||
GridwiseGemmWelford::CalculateGridSize(arg.MRaw_, arg.NRaw_, 1);
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
index_t K_split = (arg.KRaw_ + KPerBlock - 1) / KPerBlock * KPerBlock;
|
||||
|
||||
const bool has_main_k_block_loop =
|
||||
GridwiseGemmWelford::CalculateHasMainKBlockLoop(K_split);
|
||||
|
||||
const auto Run = [&](const auto& kernel_gemm_welford_first_half) {
|
||||
// Note: cache flushing not supported
|
||||
|
||||
const auto kernel_welford_second_half =
|
||||
kernel_welford_layernorm2d_second_half<GridwiseWelfordLayernorm,
|
||||
EMeanVarDataType,
|
||||
HDataType,
|
||||
GammaDataType,
|
||||
BetaDataType,
|
||||
AccDataType,
|
||||
EHGridDesc_M_N,
|
||||
LayernormMeanVarGridDesc_M_NBlock,
|
||||
LayernormCountGridDesc_M_NBlock,
|
||||
GammaBetaGridDesc_N,
|
||||
HElementwiseOperation>;
|
||||
|
||||
// First kernel launch: GEMM + Welford first part
|
||||
ave_time +=
|
||||
launch_and_time_kernel(stream_config,
|
||||
kernel_gemm_welford_first_half,
|
||||
dim3(gdx, gdy, gdz),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
gemm_arg,
|
||||
static_cast<EMeanVarDataType*>(arg.p_workspace_mean_),
|
||||
static_cast<EMeanVarDataType*>(arg.p_workspace_var_),
|
||||
static_cast<int32_t*>(arg.p_workspace_count_));
|
||||
|
||||
// Second kernel launch: Welford second part
|
||||
const auto M = arg.h_grid_desc_m_n_.GetLength(I0);
|
||||
const auto N = arg.h_grid_desc_m_n_.GetLength(I1);
|
||||
|
||||
index_t MBlockClusterLength =
|
||||
math::integer_divide_ceil(M, LayernormBlockTileSize_M_N::At(0));
|
||||
index_t NBlockClusterLength =
|
||||
math::integer_divide_ceil(N, LayernormBlockTileSize_M_N::At(1));
|
||||
|
||||
auto grid_size = MBlockClusterLength * NBlockClusterLength;
|
||||
|
||||
index_t numMeanVarCountBlockTileIteration_N = math::integer_divide_ceil(
|
||||
arg.gemm_nblock_, LayernormThreadClusterSize_M_N::At(I1));
|
||||
|
||||
ave_time += launch_and_time_kernel(
|
||||
stream_config,
|
||||
kernel_welford_second_half,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
static_cast<EMeanVarDataType*>(arg.p_workspace_e_grid_),
|
||||
static_cast<const EMeanVarDataType*>(arg.p_workspace_mean_),
|
||||
static_cast<const EMeanVarDataType*>(arg.p_workspace_var_),
|
||||
static_cast<const int32_t*>(arg.p_workspace_count_),
|
||||
arg.p_gamma_grid_,
|
||||
arg.p_beta_grid_,
|
||||
arg.p_h_grid_,
|
||||
arg.layernorm_e_grid_desc_m_n_,
|
||||
arg.h_grid_desc_m_n_,
|
||||
arg.layernorm_mean_var_grid_desc_m_nblock_,
|
||||
arg.layernorm_count_grid_desc_m_nblock_,
|
||||
arg.gamma_grid_desc_n_,
|
||||
arg.beta_grid_desc_n_,
|
||||
numMeanVarCountBlockTileIteration_N,
|
||||
NBlockClusterLength,
|
||||
arg.epsilon_,
|
||||
arg.h_element_op_);
|
||||
};
|
||||
|
||||
constexpr index_t minimum_occupancy = []() {
|
||||
if constexpr(BlkGemmPipeSched == BlockGemmPipelineScheduler::Interwave)
|
||||
{
|
||||
return 2;
|
||||
}
|
||||
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
|
||||
{
|
||||
return (MPerBlock * NPerBlock / BlockSize <= 128) ? 2 : 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
return 1;
|
||||
}
|
||||
}();
|
||||
|
||||
if(has_main_k_block_loop)
|
||||
{
|
||||
// Tail number always full
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 ||
|
||||
BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
|
||||
{
|
||||
const auto kernel = kernel_gemm_multiple_d_welford_first_half_wmma_cshuffle_v3<
|
||||
GridwiseGemmWelford,
|
||||
EMeanVarDataType,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// Tail number always 1
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
|
||||
{
|
||||
const auto kernel = kernel_gemm_multiple_d_welford_first_half_wmma_cshuffle_v3<
|
||||
GridwiseGemmWelford,
|
||||
EMeanVarDataType,
|
||||
false,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
size_t GetWorkSpaceSize(const BaseArgument* pArg) const override
|
||||
{
|
||||
const Argument* pArg_ = dynamic_cast<const Argument*>(pArg);
|
||||
|
||||
size_t workspace_size = 0;
|
||||
|
||||
int gemm_welford_size = pArg_->MRaw_ * pArg_->gemm_nblock_;
|
||||
|
||||
// workspace for welford intermediate mean
|
||||
workspace_size += gemm_welford_size * sizeof(EMeanVarDataType) + 128;
|
||||
|
||||
// workspace for welford intermediate variance
|
||||
workspace_size += gemm_welford_size * sizeof(EMeanVarDataType) + 128;
|
||||
|
||||
// workspace for welford intermediate count
|
||||
workspace_size += pArg_->gemm_nblock_ * sizeof(int32_t) + 128;
|
||||
|
||||
if constexpr(!is_same_v<EMeanVarDataType, HDataType>)
|
||||
workspace_size += pArg_->MRaw_ * pArg_->NRaw_ * sizeof(EMeanVarDataType);
|
||||
|
||||
return (workspace_size);
|
||||
};
|
||||
|
||||
void SetWorkSpacePointer(BaseArgument* pArg,
|
||||
void* p_workspace,
|
||||
const StreamConfig& = StreamConfig{}) const override
|
||||
{
|
||||
Argument* pArg_ = dynamic_cast<Argument*>(pArg);
|
||||
|
||||
pArg_->p_workspace_ = p_workspace;
|
||||
|
||||
int gemm_welford_size = pArg_->MRaw_ * pArg_->gemm_nblock_;
|
||||
|
||||
// setup buffer used for intermediate welford mean
|
||||
pArg_->p_workspace_mean_ = static_cast<char*>(pArg_->p_workspace_);
|
||||
|
||||
index_t mean_space_sz = gemm_welford_size * sizeof(EMeanVarDataType);
|
||||
mean_space_sz = math::integer_least_multiple(mean_space_sz, 128);
|
||||
|
||||
// setup buffer used for intermediate welford variance
|
||||
pArg_->p_workspace_var_ = reinterpret_cast<char*>(pArg_->p_workspace_mean_) + mean_space_sz;
|
||||
|
||||
index_t variance_space_sz = gemm_welford_size * sizeof(EMeanVarDataType);
|
||||
variance_space_sz = math::integer_least_multiple(variance_space_sz, 128);
|
||||
|
||||
// setup buffer used for intermediate welford count
|
||||
pArg_->p_workspace_count_ =
|
||||
reinterpret_cast<char*>(pArg_->p_workspace_var_) + variance_space_sz;
|
||||
|
||||
index_t count_space_sz = gemm_welford_size * sizeof(int32_t);
|
||||
count_space_sz = math::integer_least_multiple(count_space_sz, 128);
|
||||
|
||||
if constexpr(!is_same_v<EMeanVarDataType, HDataType>)
|
||||
pArg_->p_workspace_e_grid_ =
|
||||
reinterpret_cast<char*>(pArg_->p_workspace_count_) + count_space_sz;
|
||||
else
|
||||
pArg_->p_workspace_e_grid_ = static_cast<void*>(pArg_->p_h_grid_);
|
||||
};
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(!ck::is_gfx11_supported() && !ck::is_gfx12_supported())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// No need to check for splitK because we force KBatch = 1 (no support)
|
||||
|
||||
if constexpr(std::is_same_v<ComputeTypeA, f8_t> || std::is_same_v<ComputeTypeA, bf8_t> ||
|
||||
std::is_same_v<ComputeTypeB, f8_t> || std::is_same_v<ComputeTypeB, bf8_t>)
|
||||
{
|
||||
if(ck::is_gfx11_supported())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if((arg.KRaw_ % AK1 != 0 || arg.KRaw_ % BK1 != 0) &&
|
||||
!(GemmSpec == GemmSpecialization::MKPadding ||
|
||||
GemmSpec == GemmSpecialization::NKPadding ||
|
||||
GemmSpec == GemmSpecialization::MNKPadding ||
|
||||
GemmSpec == GemmSpecialization::KPadding))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
typename GridwiseGemmWelford::Argument gemm_arg{
|
||||
std::array<const void*, 1>{arg.p_a_grid_},
|
||||
std::array<const void*, 1>{arg.p_b_grid_},
|
||||
arg.p_ds_grid_,
|
||||
static_cast<EMeanVarDataType*>(arg.p_workspace_e_grid_),
|
||||
arg.MRaw_,
|
||||
arg.NRaw_,
|
||||
arg.KRaw_,
|
||||
std::array<index_t, 1>{arg.StrideA_}, // StrideAs
|
||||
std::array<index_t, 1>{arg.StrideB_}, // StrideBs
|
||||
arg.StrideDs_, // StrideDs
|
||||
arg.StrideH_, // StrideE
|
||||
I1, // kbatch
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.cde_element_op_};
|
||||
|
||||
const auto a_grid_desc_ak0_m_ak1 =
|
||||
GridwiseGemmWelford::MakeAsGridDescriptor_AK0_M_AK1(gemm_arg.M,
|
||||
gemm_arg.MPadded,
|
||||
gemm_arg.K,
|
||||
gemm_arg.KPadded,
|
||||
gemm_arg.StrideAs,
|
||||
gemm_arg.AK0);
|
||||
const auto b_grid_desc_bk0_n_bk1 =
|
||||
GridwiseGemmWelford::MakeBsGridDescriptor_BK0_N_BK1(gemm_arg.K,
|
||||
gemm_arg.KPadded,
|
||||
gemm_arg.N,
|
||||
gemm_arg.NPadded,
|
||||
gemm_arg.StrideBs,
|
||||
gemm_arg.BK0);
|
||||
|
||||
const auto M = a_grid_desc_ak0_m_ak1[I0].GetLength(I1);
|
||||
const auto N = b_grid_desc_bk0_n_bk1[I0].GetLength(I1);
|
||||
const auto K =
|
||||
a_grid_desc_ak0_m_ak1[I0].GetLength(I0) * a_grid_desc_ak0_m_ak1[I0].GetLength(I2);
|
||||
|
||||
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
return GridwiseGemmWelford::CheckValidity(gemm_arg);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
|
||||
static auto MakeArgument(const void* p_a,
|
||||
const void* p_b,
|
||||
std::array<const void*, NumDTensor> p_ds,
|
||||
const void* p_gamma,
|
||||
const void* p_beta,
|
||||
void* p_h,
|
||||
index_t MRaw,
|
||||
index_t NRaw,
|
||||
index_t KRaw,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
std::array<index_t, NumDTensor> StrideDs,
|
||||
index_t StrideH,
|
||||
double epsilon,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op,
|
||||
HElementwiseOperation h_element_op)
|
||||
{
|
||||
return Argument{p_a,
|
||||
p_b,
|
||||
p_ds,
|
||||
p_gamma,
|
||||
p_beta,
|
||||
p_h,
|
||||
MRaw,
|
||||
NRaw,
|
||||
KRaw,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideDs,
|
||||
StrideH,
|
||||
epsilon,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op,
|
||||
h_element_op};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
std::array<const void*, NumDTensor> p_ds,
|
||||
const void* p_gamma,
|
||||
const void* p_beta,
|
||||
void* p_h,
|
||||
index_t MRaw,
|
||||
index_t NRaw,
|
||||
index_t KRaw,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
std::array<index_t, NumDTensor> StrideDs,
|
||||
index_t StrideH,
|
||||
double epsilon,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op,
|
||||
HElementwiseOperation h_element_op) override
|
||||
{
|
||||
return std::make_unique<Argument>(p_a,
|
||||
p_b,
|
||||
p_ds,
|
||||
p_gamma,
|
||||
p_beta,
|
||||
p_h,
|
||||
MRaw,
|
||||
NRaw,
|
||||
KRaw,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideDs,
|
||||
StrideH,
|
||||
epsilon,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op,
|
||||
h_element_op);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
|
||||
{BlockGemmPipelineScheduler::Intrawave, "Intrawave"},
|
||||
{BlockGemmPipelineScheduler::Interwave, "Interwave"}};
|
||||
|
||||
std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
|
||||
{BlockGemmPipelineVersion::v1, "v1"},
|
||||
{BlockGemmPipelineVersion::v2, "v2"},
|
||||
{BlockGemmPipelineVersion::v3, "v3"},
|
||||
{BlockGemmPipelineVersion::v4, "v4"},
|
||||
{BlockGemmPipelineVersion::v5, "v5"}};
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3"
|
||||
<< ">"
|
||||
<< "BlkSize: "
|
||||
<< BlockSize << ", "
|
||||
<< "BlkTile: "
|
||||
<< MPerBlock << ", "
|
||||
<< NPerBlock << ", "
|
||||
<< KPerBlock << ", "
|
||||
<< "WaveTile: "
|
||||
<< MPerWmma << "x"<<NPerWmma << ", "
|
||||
<< "WaveMap: "
|
||||
<< MRepeat << "x" << NRepeat << ", "
|
||||
<< "VmemReadVec: "
|
||||
<< ABlockTransferSrcScalarPerVector << "x" << BBlockTransferSrcScalarPerVector << ", "
|
||||
<< "GemmSpec: "
|
||||
<< getGemmSpecializationString(GemmSpec) << ", "
|
||||
<< "VmemWriteThreadCluster: "
|
||||
<< CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(I1) << ", "
|
||||
<< CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(I3) << ", "
|
||||
<< "LayerNormThreadCluster: "
|
||||
<< LayernormThreadClusterSize_M_N::At(I0) << ", "
|
||||
<< LayernormThreadClusterSize_M_N::At(I1) << ", "
|
||||
<< "LayerNormThreadSliceSize: "
|
||||
<< LayernormThreadSliceSize_M << ", "
|
||||
<< "BlkGemmPipelineScheduler: "
|
||||
<< BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
|
||||
<< "BlkGemmPipelineVersion: "
|
||||
<< BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
|
||||
<< "BlkGemmPipelinePrefetchStages: "
|
||||
<< GridwiseGemmWelford::BlockwiseGemmPipe::PrefetchStages << ", "
|
||||
<< "KPack: "
|
||||
<< GridwiseGemmWelford::KPack;
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -60,8 +60,8 @@ struct AddReluAdd
|
||||
}
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void operator()<half_t, float, half_t, half_t>(
|
||||
half_t& y, const float& x0, const half_t& x1, const half_t& x2) const
|
||||
__host__ __device__ constexpr void operator()<float, float, half_t, half_t>(
|
||||
float& y, const float& x0, const half_t& x1, const half_t& x2) const
|
||||
{
|
||||
float a = x0 + x1;
|
||||
float b = a > 0 ? a : 0;
|
||||
@@ -69,6 +69,15 @@ struct AddReluAdd
|
||||
y = c;
|
||||
}
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void operator()<half_t, float, half_t, half_t>(
|
||||
half_t& y, const float& x0, const half_t& x1, const half_t& x2) const
|
||||
{
|
||||
float y_float;
|
||||
(*this)(y_float, x0, x1, x2);
|
||||
y = y_float;
|
||||
}
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void operator()<bhalf_t, float, bhalf_t, bhalf_t>(
|
||||
bhalf_t& y, const float& x0, const bhalf_t& x1, const bhalf_t& x2) const
|
||||
|
||||
@@ -0,0 +1,510 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_wmma_base.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_welford.hpp"
|
||||
#include "ck/tensor_operation/gpu/thread/threadwise_welford.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename DsDataType,
|
||||
typename EDataType,
|
||||
typename AccDataType,
|
||||
typename CShuffleDataType,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t MPerWmma,
|
||||
index_t NPerWmma,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
index_t CShuffleMRepeatPerShuffle,
|
||||
index_t CShuffleNRepeatPerShuffle,
|
||||
typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename CDEShuffleBlockTransferScalarPerVectors,
|
||||
typename CDEElementwiseOperation,
|
||||
typename ThisThreadBlock,
|
||||
typename BlockwiseGemmPipe,
|
||||
index_t BlockSize>
|
||||
struct EpilogueWelfordCShuffle
|
||||
: EpilogueCShuffleBase<DsDataType,
|
||||
EDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
MPerWmma,
|
||||
NPerWmma,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
CShuffleMRepeatPerShuffle,
|
||||
CShuffleNRepeatPerShuffle,
|
||||
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CDEShuffleBlockTransferScalarPerVectors,
|
||||
CDEElementwiseOperation,
|
||||
ThisThreadBlock,
|
||||
BlockwiseGemmPipe>
|
||||
{
|
||||
using Base = EpilogueCShuffleBase<
|
||||
DsDataType,
|
||||
EDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
MPerWmma,
|
||||
NPerWmma,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
CShuffleMRepeatPerShuffle,
|
||||
CShuffleNRepeatPerShuffle,
|
||||
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CDEShuffleBlockTransferScalarPerVectors,
|
||||
CDEElementwiseOperation,
|
||||
ThisThreadBlock,
|
||||
BlockwiseGemmPipe>;
|
||||
|
||||
using Base::GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat;
|
||||
using Base::GetCShuffleLDSDescriptor;
|
||||
using Base::GetVgprToLDSEpilogueDescriptor;
|
||||
using Base::I0;
|
||||
using Base::I1;
|
||||
using Base::I2;
|
||||
using Base::I3;
|
||||
using Base::NumDTensor;
|
||||
|
||||
template <typename DoPads, index_t MPerTile, index_t NPerTile>
|
||||
__host__ __device__ static auto MakeMeanVarDescriptor_M_N(index_t M, index_t N)
|
||||
{
|
||||
const auto grid_desc_m_n =
|
||||
make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(N, I1));
|
||||
return tensor_operation::device::PadTensorDescriptor(
|
||||
grid_desc_m_n, make_tuple(MPerTile, NPerTile), DoPads{});
|
||||
}
|
||||
|
||||
template <typename DoPads, index_t MPerTile, index_t NPerTile>
|
||||
__host__ __device__ static auto MakeCountDescriptor_M_N(index_t M, index_t N)
|
||||
{
|
||||
// We will broadcast [N] to [M, N] in this descriptor
|
||||
// Hence, 1st stride is 0
|
||||
const auto grid_desc_m_n =
|
||||
make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I0, I1));
|
||||
return tensor_operation::device::PadTensorDescriptor(
|
||||
grid_desc_m_n, make_tuple(MPerTile, NPerTile), DoPads{});
|
||||
}
|
||||
|
||||
template <typename GridDescriptor_M_N>
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock(const GridDescriptor_M_N& grid_desc_m_n)
|
||||
{
|
||||
const auto M = grid_desc_m_n.GetLength(I0);
|
||||
const auto NBlock = grid_desc_m_n.GetLength(I1);
|
||||
const auto MBlock = M / MPerBlock;
|
||||
|
||||
const auto grid_desc_mblock_mperblock_nblock = transform_tensor_descriptor(
|
||||
grid_desc_m_n,
|
||||
make_tuple(make_unmerge_transform(make_tuple(MBlock, Number<MPerBlock>{})),
|
||||
make_pass_through_transform(NBlock)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1>{}, Sequence<2>{}));
|
||||
|
||||
return grid_desc_mblock_mperblock_nblock;
|
||||
}
|
||||
|
||||
using GemmMeanVarGridDesc_M_N =
|
||||
decltype(MakeMeanVarDescriptor_M_N<Sequence<true, false>, MPerBlock, 1>(1, 1));
|
||||
|
||||
using GemmCountGridDesc_M_N =
|
||||
decltype(MakeCountDescriptor_M_N<Sequence<true, false>, MPerBlock, 1>(1, 1));
|
||||
|
||||
__device__ EpilogueWelfordCShuffle(EDataType* p_welford_mean_grid_,
|
||||
EDataType* p_welford_var_grid_,
|
||||
int32_t* p_welford_count_grid_,
|
||||
index_t MRaw_,
|
||||
index_t NRaw_)
|
||||
: p_welford_mean_grid(p_welford_mean_grid_),
|
||||
p_welford_var_grid(p_welford_var_grid_),
|
||||
p_welford_count_grid(p_welford_count_grid_),
|
||||
NRaw(NRaw_)
|
||||
{
|
||||
index_t gemm_nblock = math::integer_divide_ceil(NRaw_, NPerBlock);
|
||||
|
||||
gemm_mean_var_grid_desc_m_nblock =
|
||||
MakeMeanVarDescriptor_M_N<Sequence<true, false>, MPerBlock, 1>(MRaw_, gemm_nblock);
|
||||
|
||||
gemm_count_grid_desc_m_nblock =
|
||||
MakeCountDescriptor_M_N<Sequence<true, false>, MPerBlock, 1>(MRaw_, gemm_nblock);
|
||||
}
|
||||
|
||||
template <InMemoryDataOperationEnum EGlobalMemoryDataOperation,
|
||||
typename CThreadBuf,
|
||||
typename DsGridPointer,
|
||||
typename DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>
|
||||
__device__ void Run(CThreadBuf& c_thread_buf,
|
||||
DsGridPointer p_ds_grid,
|
||||
EDataType* p_e_grid,
|
||||
void* p_shared,
|
||||
const DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
CDEElementwiseOperation& cde_element_op,
|
||||
const index_t& block_m_id,
|
||||
const index_t& block_n_id)
|
||||
{
|
||||
// Vmem buffers
|
||||
const auto ds_grid_buf = generate_tuple(
|
||||
[&](auto i) {
|
||||
return make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_ds_grid[i],
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize());
|
||||
},
|
||||
Number<NumDTensor>{});
|
||||
|
||||
auto e_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
|
||||
|
||||
auto mean_var_grid_desc_mblock_mperblock_nblock =
|
||||
MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock(
|
||||
gemm_mean_var_grid_desc_m_nblock);
|
||||
|
||||
auto mean_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_welford_mean_grid, mean_var_grid_desc_mblock_mperblock_nblock.GetElementSpaceSize());
|
||||
|
||||
auto var_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_welford_var_grid, mean_var_grid_desc_mblock_mperblock_nblock.GetElementSpaceSize());
|
||||
|
||||
auto count_grid_desc_mblock_mperblock_nblock =
|
||||
MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock(gemm_count_grid_desc_m_nblock);
|
||||
auto welford_count_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_welford_count_grid, count_grid_desc_mblock_mperblock_nblock.GetElementSpaceSize());
|
||||
|
||||
// LDS buffer
|
||||
constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat =
|
||||
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat();
|
||||
|
||||
auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<CShuffleDataType*>(p_shared),
|
||||
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat
|
||||
.GetElementSpaceSize());
|
||||
|
||||
// tuple of reference to C/Ds tensor buffers (mix LDS and Vmem)
|
||||
const auto c_ds_buf_refs = concat_tuple_of_reference(
|
||||
tie(c_shuffle_block_buf),
|
||||
generate_tie([&](auto i) -> const auto& // return type should be reference
|
||||
{ return ds_grid_buf[i]; },
|
||||
Number<NumDTensor>{}));
|
||||
|
||||
// Thread transfer Vgpr to LDS
|
||||
auto c_thread_copy_vgpr_to_lds = GetVgprToLDSEpilogueDescriptor();
|
||||
|
||||
// Space Filling Curve Vgpr
|
||||
constexpr auto sfc_c_vgpr = typename Base::SpaceFillingCurveVgpr{};
|
||||
|
||||
// Space Filling Curve Vmem
|
||||
constexpr auto sfc_cde_global = typename Base::SpaceFillingCurveVmem{};
|
||||
|
||||
// C thread descriptor
|
||||
constexpr auto c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs =
|
||||
BlockwiseGemmPipe::
|
||||
GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs();
|
||||
|
||||
// tuple of reference to C/Ds tensor descriptors
|
||||
const auto c_ds_desc_refs = concat_tuple_of_reference(
|
||||
tie(c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat),
|
||||
generate_tie([&](auto i) -> const auto& // return type should be reference
|
||||
{ return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
|
||||
Number<NumDTensor>{}));
|
||||
|
||||
// Thread transfer LDS to Vmem
|
||||
auto cde_shuffle_block_copy_lds_and_global =
|
||||
Base::template GetLDSToVmemEpilogueDescriptor<EGlobalMemoryDataOperation, AccDataType>(
|
||||
c_ds_desc_refs,
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
cde_element_op,
|
||||
block_m_id,
|
||||
block_n_id);
|
||||
|
||||
// Block descriptor
|
||||
constexpr auto
|
||||
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs =
|
||||
GetCShuffleLDSDescriptor();
|
||||
|
||||
// E Vgpr buffer
|
||||
constexpr index_t PostShuffleThreadSliceSize_M =
|
||||
(CShuffleMRepeatPerShuffle * BlockwiseGemmPipe::MWaves * MPerWmma) /
|
||||
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(I1);
|
||||
|
||||
constexpr index_t PostShuffleThreadSliceSize_N =
|
||||
(CShuffleNRepeatPerShuffle * BlockwiseGemmPipe::NWaves * NPerWmma) /
|
||||
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(I3);
|
||||
|
||||
constexpr auto PostShuffleThreadSliceSize_M_N =
|
||||
Sequence<PostShuffleThreadSliceSize_M, PostShuffleThreadSliceSize_N>{};
|
||||
|
||||
// Welford
|
||||
constexpr auto post_shuffle_thread_desc_m_n =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<1>{},
|
||||
Number<PostShuffleThreadSliceSize_M>{},
|
||||
Number<1>{},
|
||||
Number<PostShuffleThreadSliceSize_N>{}));
|
||||
|
||||
auto e_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
|
||||
post_shuffle_thread_desc_m_n.GetElementSpaceSize());
|
||||
|
||||
using PostShuffleThreadClusterSize_M_N = Sequence<
|
||||
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(I1),
|
||||
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(I3)>;
|
||||
|
||||
constexpr auto post_shuffle_thread_cluster_desc =
|
||||
make_cluster_descriptor(PostShuffleThreadClusterSize_M_N{}, Sequence<0, 1>{});
|
||||
|
||||
const auto post_shuffle_thread_cluster_idx =
|
||||
post_shuffle_thread_cluster_desc.CalculateBottomIndex(
|
||||
make_multi_index(get_thread_local_1d_id()));
|
||||
|
||||
const auto post_shuffle_thread_data_idx_begin =
|
||||
post_shuffle_thread_cluster_idx * PostShuffleThreadSliceSize_M_N;
|
||||
|
||||
constexpr auto thread_welford_src_desc_m_k = make_naive_tensor_descriptor_packed(make_tuple(
|
||||
Number<PostShuffleThreadSliceSize_M>{}, Number<PostShuffleThreadSliceSize_N>{}));
|
||||
|
||||
constexpr auto thread_welford_dst_desc_m =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<PostShuffleThreadSliceSize_M>{}));
|
||||
|
||||
using ThreadwiseWelford = ThreadwiseWelford<AccDataType,
|
||||
decltype(thread_welford_src_desc_m_k),
|
||||
decltype(thread_welford_dst_desc_m)>;
|
||||
|
||||
using BlockwiseWelford = BlockwiseWelford<AccDataType,
|
||||
BlockSize,
|
||||
PostShuffleThreadClusterSize_M_N,
|
||||
Sequence<0, 1>,
|
||||
false>;
|
||||
|
||||
constexpr int num_shuffleM =
|
||||
MPerBlock / (CShuffleMRepeatPerShuffle * BlockwiseGemmPipe::MWaves * MPerWmma);
|
||||
|
||||
constexpr int num_shuffleN =
|
||||
NPerBlock / (CShuffleNRepeatPerShuffle * BlockwiseGemmPipe::NWaves * NPerWmma);
|
||||
|
||||
using mean_var_vgpr_type = decltype(make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
|
||||
thread_welford_dst_desc_m.GetElementSpaceSize()));
|
||||
|
||||
using welford_count_vgpr_type =
|
||||
decltype(make_static_buffer<AddressSpaceEnum::Vgpr, int32_t>(
|
||||
thread_welford_dst_desc_m.GetElementSpaceSize()));
|
||||
|
||||
Array<ThreadwiseWelford, num_shuffleM> threadwise_welfords;
|
||||
Array<mean_var_vgpr_type, num_shuffleM> mean_thread_bufs;
|
||||
Array<mean_var_vgpr_type, num_shuffleM> var_thread_bufs;
|
||||
Array<welford_count_vgpr_type, num_shuffleM> welford_count_thread_bufs;
|
||||
|
||||
int max_count = PostShuffleThreadSliceSize_N * num_shuffleN;
|
||||
const auto nblock = mean_var_grid_desc_mblock_mperblock_nblock.GetLength(I2);
|
||||
|
||||
// tail block
|
||||
if(block_n_id % nblock == nblock - 1)
|
||||
{
|
||||
constexpr index_t NPerShuffleBlock =
|
||||
CShuffleNRepeatPerShuffle * BlockwiseGemmPipe::NWaves * NPerWmma;
|
||||
|
||||
int NPerBlockTail = NRaw - NPerBlock * (nblock - 1);
|
||||
int thread_max_len =
|
||||
PostShuffleThreadSliceSize_N * (post_shuffle_thread_cluster_idx[I1] + 1);
|
||||
int shuffle_step = 0;
|
||||
while(thread_max_len <= NPerBlockTail && shuffle_step < num_shuffleN)
|
||||
{
|
||||
++shuffle_step;
|
||||
thread_max_len += NPerShuffleBlock;
|
||||
}
|
||||
|
||||
int delta = 0;
|
||||
if(thread_max_len - NPerBlockTail > PostShuffleThreadSliceSize_N)
|
||||
delta = 0;
|
||||
else if(NPerBlockTail > thread_max_len)
|
||||
delta = PostShuffleThreadSliceSize_N;
|
||||
else
|
||||
delta = PostShuffleThreadSliceSize_N - thread_max_len + NPerBlockTail;
|
||||
|
||||
max_count = shuffle_step * PostShuffleThreadSliceSize_N + delta;
|
||||
}
|
||||
|
||||
// Initialize Welford
|
||||
static_for<0, num_shuffleM, 1>{}([&](auto i) {
|
||||
threadwise_welfords(i).max_count_ = max_count;
|
||||
mean_thread_bufs(i) = make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
|
||||
thread_welford_dst_desc_m.GetElementSpaceSize());
|
||||
|
||||
var_thread_bufs(i) = make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
|
||||
thread_welford_dst_desc_m.GetElementSpaceSize());
|
||||
|
||||
welford_count_thread_bufs(i) = make_static_buffer<AddressSpaceEnum::Vgpr, int32_t>(
|
||||
thread_welford_dst_desc_m.GetElementSpaceSize());
|
||||
|
||||
static_for<0, PostShuffleThreadSliceSize_M, 1>{}([&](auto j) {
|
||||
mean_thread_bufs(i)(j) = type_convert<AccDataType>(0.0f);
|
||||
var_thread_bufs(i)(j) = type_convert<AccDataType>(0.0f);
|
||||
welford_count_thread_bufs(i)(j) = 0;
|
||||
});
|
||||
});
|
||||
|
||||
constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
|
||||
|
||||
static_assert(num_access == sfc_cde_global.GetNumOfAccess(), "wrong!");
|
||||
|
||||
// Run CShuffle + Store E + Welford threadwise
|
||||
int shuffleM_index = __builtin_amdgcn_readfirstlane(0);
|
||||
static_for<0, num_access, 1>{}([&](auto access_id) {
|
||||
// make sure it's safe to read from LDS
|
||||
block_sync_lds();
|
||||
|
||||
// each thread shuffle data from VGPR to LDS
|
||||
c_thread_copy_vgpr_to_lds.Run(
|
||||
c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs,
|
||||
sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
|
||||
c_thread_buf,
|
||||
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs,
|
||||
c_shuffle_block_buf);
|
||||
|
||||
// make sure it's safe to write to LDS
|
||||
block_sync_lds();
|
||||
|
||||
// Read LDS / Vmem + CDE elementwise operation
|
||||
cde_shuffle_block_copy_lds_and_global.RunRead(c_ds_desc_refs, c_ds_buf_refs);
|
||||
|
||||
// Store to Vmem, but keep data in Vgpr for Welford
|
||||
cde_shuffle_block_copy_lds_and_global.RunWriteAndStoreVgpr(
|
||||
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
tie(e_grid_buf),
|
||||
tie(post_shuffle_thread_desc_m_n),
|
||||
tie(e_thread_buf));
|
||||
|
||||
if constexpr(access_id < num_access - 1)
|
||||
{
|
||||
constexpr auto cde_global_step = sfc_cde_global.GetForwardStep(access_id);
|
||||
// move on Ds
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
cde_shuffle_block_copy_lds_and_global.MoveSrcSliceWindow(
|
||||
c_ds_desc_refs, i + I1, cde_global_step);
|
||||
});
|
||||
|
||||
// move on E
|
||||
cde_shuffle_block_copy_lds_and_global.MoveDstSliceWindow(
|
||||
tie(e_grid_desc_mblock_mperblock_nblock_nperblock), cde_global_step);
|
||||
}
|
||||
|
||||
// Threadwise welford
|
||||
auto& threadwise_welford = threadwise_welfords(shuffleM_index);
|
||||
auto& mean_thread_buf = mean_thread_bufs(shuffleM_index);
|
||||
auto& var_thread_buf = var_thread_bufs(shuffleM_index);
|
||||
|
||||
threadwise_welford.Run(e_thread_buf, mean_thread_buf, var_thread_buf);
|
||||
|
||||
if constexpr(access_id < num_access - 1)
|
||||
{
|
||||
constexpr auto de_global_step = sfc_cde_global.GetForwardStep(access_id);
|
||||
constexpr int shuffleMInc =
|
||||
de_global_step[I1] /
|
||||
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat.GetLength(
|
||||
I1);
|
||||
shuffleM_index = __builtin_amdgcn_readfirstlane(shuffleM_index + shuffleMInc);
|
||||
}
|
||||
});
|
||||
|
||||
// Blockwise welford and write out
|
||||
static_for<0, num_shuffleM, 1>{}([&](auto i) {
|
||||
auto& mean_thread_buf = mean_thread_bufs(i);
|
||||
auto& var_thread_buf = var_thread_bufs(i);
|
||||
auto& count_thread_buf = welford_count_thread_bufs(i);
|
||||
|
||||
static_for<0, PostShuffleThreadSliceSize_M, 1>{}([&](auto j) {
|
||||
block_sync_lds();
|
||||
count_thread_buf(j) = threadwise_welfords(i).cur_count_;
|
||||
BlockwiseWelford::Run(mean_thread_buf(j), var_thread_buf(j), count_thread_buf(j));
|
||||
});
|
||||
|
||||
if(post_shuffle_thread_cluster_idx[I1] == 0)
|
||||
{
|
||||
constexpr auto thread_welford_desc_I_m_I = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(I1, Number<PostShuffleThreadSliceSize_M>{}, I1));
|
||||
|
||||
constexpr int shuffleMPerBlock =
|
||||
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat.GetLength(
|
||||
I1);
|
||||
|
||||
auto mean_var_count_thread_copy_index = make_multi_index(
|
||||
block_m_id, // mblock
|
||||
shuffleMPerBlock * i + post_shuffle_thread_data_idx_begin[I0], // mperblock
|
||||
block_n_id); // nblock
|
||||
|
||||
auto mean_var_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
|
||||
AccDataType,
|
||||
EDataType,
|
||||
decltype(thread_welford_desc_I_m_I),
|
||||
decltype(mean_var_grid_desc_mblock_mperblock_nblock),
|
||||
tensor_operation::element_wise::PassThrough,
|
||||
Sequence<1, PostShuffleThreadSliceSize_M, 1>,
|
||||
Sequence<0, 1, 2>,
|
||||
1,
|
||||
1,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
true>{mean_var_grid_desc_mblock_mperblock_nblock,
|
||||
mean_var_count_thread_copy_index,
|
||||
tensor_operation::element_wise::PassThrough{}};
|
||||
|
||||
mean_var_thread_copy_vgpr_to_global.Run(thread_welford_desc_I_m_I,
|
||||
make_tuple(I0, I0, I0),
|
||||
mean_thread_buf,
|
||||
mean_var_grid_desc_mblock_mperblock_nblock,
|
||||
mean_grid_buf); // write mean
|
||||
|
||||
mean_var_thread_copy_vgpr_to_global.Run(thread_welford_desc_I_m_I,
|
||||
make_tuple(I0, I0, I0),
|
||||
var_thread_buf,
|
||||
mean_var_grid_desc_mblock_mperblock_nblock,
|
||||
var_grid_buf); // write variance
|
||||
|
||||
// Stride of count is [0, 1]. Only the first row in count[0, 0:nblock] need
|
||||
// to be written.
|
||||
if(i == 0 && block_m_id == 0 && post_shuffle_thread_cluster_idx[I0] == 0)
|
||||
{
|
||||
auto count_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
|
||||
int32_t,
|
||||
int32_t,
|
||||
decltype(thread_welford_desc_I_m_I),
|
||||
decltype(count_grid_desc_mblock_mperblock_nblock),
|
||||
tensor_operation::element_wise::PassThrough,
|
||||
Sequence<1, PostShuffleThreadSliceSize_M, 1>,
|
||||
Sequence<0, 1, 2>,
|
||||
1,
|
||||
1,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
false>{count_grid_desc_mblock_mperblock_nblock,
|
||||
mean_var_count_thread_copy_index,
|
||||
tensor_operation::element_wise::PassThrough{}};
|
||||
|
||||
count_thread_copy_vgpr_to_global.Run(thread_welford_desc_I_m_I,
|
||||
make_tuple(I0, I0, I0),
|
||||
count_thread_buf,
|
||||
count_grid_desc_mblock_mperblock_nblock,
|
||||
welford_count_grid_buf); // write count
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
EDataType* p_welford_mean_grid;
|
||||
EDataType* p_welford_var_grid;
|
||||
int32_t* p_welford_count_grid;
|
||||
index_t NRaw;
|
||||
GemmMeanVarGridDesc_M_N gemm_mean_var_grid_desc_m_nblock;
|
||||
GemmCountGridDesc_M_N gemm_count_grid_desc_m_nblock;
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,195 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_wmma_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename DsDataType,
|
||||
typename EDataType,
|
||||
typename AccDataType,
|
||||
typename CShuffleDataType,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t MPerWmma,
|
||||
index_t NPerWmma,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
index_t CShuffleMRepeatPerShuffle,
|
||||
index_t CShuffleNRepeatPerShuffle,
|
||||
typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename CDEShuffleBlockTransferScalarPerVectors,
|
||||
typename CDEElementwiseOperation,
|
||||
typename ThisThreadBlock,
|
||||
typename BlockwiseGemmPipe>
|
||||
struct EpilogueCShuffle
|
||||
: EpilogueCShuffleBase<DsDataType,
|
||||
EDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
MPerWmma,
|
||||
NPerWmma,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
CShuffleMRepeatPerShuffle,
|
||||
CShuffleNRepeatPerShuffle,
|
||||
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CDEShuffleBlockTransferScalarPerVectors,
|
||||
CDEElementwiseOperation,
|
||||
ThisThreadBlock,
|
||||
BlockwiseGemmPipe>
|
||||
{
|
||||
using Base = EpilogueCShuffleBase<
|
||||
DsDataType,
|
||||
EDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
MPerWmma,
|
||||
NPerWmma,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
CShuffleMRepeatPerShuffle,
|
||||
CShuffleNRepeatPerShuffle,
|
||||
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CDEShuffleBlockTransferScalarPerVectors,
|
||||
CDEElementwiseOperation,
|
||||
ThisThreadBlock,
|
||||
BlockwiseGemmPipe>;
|
||||
|
||||
using Base::GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat;
|
||||
using Base::GetCShuffleLDSDescriptor;
|
||||
using Base::GetVgprToLDSEpilogueDescriptor;
|
||||
using Base::I1;
|
||||
using Base::NumDTensor;
|
||||
|
||||
template <InMemoryDataOperationEnum EGlobalMemoryDataOperation,
|
||||
typename CThreadBuf,
|
||||
typename DsGridPointer,
|
||||
typename DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>
|
||||
__device__ static void Run(CThreadBuf& c_thread_buf,
|
||||
DsGridPointer p_ds_grid,
|
||||
EDataType* p_e_grid,
|
||||
void* p_shared,
|
||||
const DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
CDEElementwiseOperation& cde_element_op,
|
||||
const index_t& block_m_id,
|
||||
const index_t& block_n_id)
|
||||
{
|
||||
const auto ds_grid_buf = generate_tuple(
|
||||
[&](auto i) {
|
||||
return make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_ds_grid[i],
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize());
|
||||
},
|
||||
Number<NumDTensor>{});
|
||||
|
||||
auto e_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
|
||||
|
||||
// C mapping in single thread.
|
||||
constexpr auto c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs =
|
||||
BlockwiseGemmPipe::
|
||||
GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs();
|
||||
|
||||
// LDS buffer
|
||||
constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat =
|
||||
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat();
|
||||
|
||||
auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<CShuffleDataType*>(p_shared),
|
||||
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat
|
||||
.GetElementSpaceSize());
|
||||
|
||||
// Thread transfer Vgpr to LDS
|
||||
auto c_thread_copy_vgpr_to_lds = GetVgprToLDSEpilogueDescriptor();
|
||||
|
||||
// Space Filling Curve Vgpr
|
||||
constexpr auto sfc_c_vgpr = typename Base::SpaceFillingCurveVgpr{};
|
||||
|
||||
// Space Filling Curve Vmem
|
||||
constexpr auto sfc_cde_global = typename Base::SpaceFillingCurveVmem{};
|
||||
|
||||
// Block descriptor
|
||||
constexpr auto
|
||||
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs =
|
||||
GetCShuffleLDSDescriptor();
|
||||
|
||||
// tuple of reference to C/Ds tensor descriptors
|
||||
const auto c_ds_desc_refs = concat_tuple_of_reference(
|
||||
tie(c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat),
|
||||
generate_tie([&](auto i) -> const auto& // return type should be reference
|
||||
{ return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
|
||||
Number<NumDTensor>{}));
|
||||
|
||||
// Thread transfer LDS to Vmem
|
||||
auto cde_shuffle_block_copy_lds_and_global =
|
||||
Base::template GetLDSToVmemEpilogueDescriptor<EGlobalMemoryDataOperation, EDataType>(
|
||||
c_ds_desc_refs,
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
cde_element_op,
|
||||
block_m_id,
|
||||
block_n_id);
|
||||
|
||||
// tuple of reference to C/Ds tensor buffers
|
||||
const auto c_ds_buf_refs = concat_tuple_of_reference(
|
||||
tie(c_shuffle_block_buf),
|
||||
generate_tie([&](auto i) -> const auto& // return type should be reference
|
||||
{ return ds_grid_buf[i]; },
|
||||
Number<NumDTensor>{}));
|
||||
|
||||
constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
|
||||
|
||||
static_assert(num_access == sfc_cde_global.GetNumOfAccess(), "wrong!");
|
||||
|
||||
// CShuffle and Store
|
||||
static_for<0, num_access, 1>{}([&](auto access_id) {
|
||||
// make sure it's safe to write to LDS
|
||||
block_sync_lds();
|
||||
|
||||
// each thread write its data from VGPR to LDS
|
||||
c_thread_copy_vgpr_to_lds.Run(
|
||||
c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs,
|
||||
sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
|
||||
c_thread_buf,
|
||||
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs,
|
||||
c_shuffle_block_buf);
|
||||
|
||||
// make sure it's safe to read from LDS
|
||||
block_sync_lds();
|
||||
|
||||
// each block loads its C data from LDS, D from global, applies elementwise
|
||||
// operation and stores result E to global
|
||||
cde_shuffle_block_copy_lds_and_global.Run(
|
||||
c_ds_desc_refs,
|
||||
c_ds_buf_refs,
|
||||
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
tie(e_grid_buf));
|
||||
|
||||
if constexpr(access_id < num_access - 1)
|
||||
{
|
||||
constexpr auto cde_global_step = sfc_cde_global.GetForwardStep(access_id);
|
||||
// move on Ds
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
cde_shuffle_block_copy_lds_and_global.MoveSrcSliceWindow(
|
||||
c_ds_desc_refs, i + I1, cde_global_step);
|
||||
});
|
||||
|
||||
// move on E
|
||||
cde_shuffle_block_copy_lds_and_global.MoveDstSliceWindow(
|
||||
tie(e_grid_desc_mblock_mperblock_nblock_nperblock), cde_global_step);
|
||||
}
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,253 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename DsDataType,
|
||||
typename EDataType,
|
||||
typename AccDataType,
|
||||
typename CShuffleDataType,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t MPerWmma,
|
||||
index_t NPerWmma,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
index_t CShuffleMRepeatPerShuffle,
|
||||
index_t CShuffleNRepeatPerShuffle,
|
||||
typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename CDEShuffleBlockTransferScalarPerVectors,
|
||||
typename CDEElementwiseOperation,
|
||||
typename ThisThreadBlock,
|
||||
typename BlockwiseGemmPipe>
|
||||
struct EpilogueCShuffleBase
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
static constexpr auto I4 = Number<4>{};
|
||||
static constexpr auto I5 = Number<5>{};
|
||||
static constexpr auto I6 = Number<6>{};
|
||||
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
static constexpr auto EShuffleBlockTransferScalarPerVector =
|
||||
CDEShuffleBlockTransferScalarPerVectors{}[I0];
|
||||
|
||||
using SpaceFillingCurveVgpr =
|
||||
SpaceFillingCurve<Sequence<MRepeat, 1, 1, NRepeat, 1, 1, BlockwiseGemmPipe::MAccVgprs>,
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6>,
|
||||
Sequence<CShuffleMRepeatPerShuffle,
|
||||
1,
|
||||
1,
|
||||
CShuffleNRepeatPerShuffle,
|
||||
1,
|
||||
1,
|
||||
BlockwiseGemmPipe::MAccVgprs>>;
|
||||
|
||||
using SpaceFillingCurveVmem = SpaceFillingCurve<
|
||||
Sequence<1, MPerBlock, 1, NPerBlock>,
|
||||
Sequence<0, 2, 1, 3>,
|
||||
Sequence<1,
|
||||
CShuffleMRepeatPerShuffle * BlockwiseGemmPipe::MWaves * MPerWmma,
|
||||
1,
|
||||
CShuffleNRepeatPerShuffle * BlockwiseGemmPipe::NWaves * NPerWmma>>;
|
||||
|
||||
// *Caution Here repeat is shuffle repeat
|
||||
__device__ static constexpr auto
|
||||
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat()
|
||||
{
|
||||
constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWmma);
|
||||
constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWmma);
|
||||
|
||||
constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat =
|
||||
make_naive_tensor_descriptor_packed(
|
||||
make_tuple(I1,
|
||||
Number<CShuffleMRepeatPerShuffle * MWaves * MPerWmma>{},
|
||||
I1,
|
||||
Number<CShuffleNRepeatPerShuffle * NWaves * NPerWmma>{}));
|
||||
|
||||
return c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat;
|
||||
}
|
||||
|
||||
__device__ static constexpr auto GetCShuffleLDSDescriptor()
|
||||
{
|
||||
// C mapping in single block
|
||||
constexpr auto c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp =
|
||||
BlockwiseGemmPipe::
|
||||
GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs();
|
||||
|
||||
constexpr auto MWave =
|
||||
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
|
||||
.GetLength(I1);
|
||||
constexpr auto MSubGroup =
|
||||
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
|
||||
.GetLength(I2);
|
||||
constexpr auto NWave =
|
||||
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
|
||||
.GetLength(I4);
|
||||
constexpr auto NThreadPerSubGroup =
|
||||
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
|
||||
.GetLength(I5);
|
||||
constexpr auto MAccVgprs =
|
||||
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
|
||||
.GetLength(I6);
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat(),
|
||||
make_tuple(make_freeze_transform(I0),
|
||||
make_unmerge_transform(make_tuple(
|
||||
Number<CShuffleMRepeatPerShuffle>{}, // MRepeat per shuffle repeat
|
||||
MWave, // MWave
|
||||
MSubGroup, // MSubGroup * MAccVgprs = MPerWmma
|
||||
MAccVgprs)),
|
||||
make_freeze_transform(I0),
|
||||
make_unmerge_transform(make_tuple(
|
||||
Number<CShuffleNRepeatPerShuffle>{}, // NRepeat per shuffle repeat
|
||||
NWave, // NWave
|
||||
NThreadPerSubGroup))), // NThreadPerSubGroup = NPerWmma
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<>{}, Sequence<0, 1, 2, 6>{}, Sequence<>{}, Sequence<3, 4, 5>{}));
|
||||
}
|
||||
|
||||
__device__ static auto GetVgprToLDSEpilogueDescriptor()
|
||||
{
|
||||
// C mapping in single block
|
||||
constexpr auto c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp =
|
||||
BlockwiseGemmPipe::
|
||||
GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs();
|
||||
|
||||
constexpr auto MWave =
|
||||
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
|
||||
.GetLength(I1);
|
||||
constexpr auto MSubGroup =
|
||||
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
|
||||
.GetLength(I2);
|
||||
constexpr auto NWave =
|
||||
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
|
||||
.GetLength(I4);
|
||||
constexpr auto NThreadPerSubGroup =
|
||||
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
|
||||
.GetLength(I5);
|
||||
constexpr auto MAccVgprs =
|
||||
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
|
||||
.GetLength(I6);
|
||||
|
||||
// calculate origin of thread output tensor on global memory
|
||||
// blockwise GEMM c matrix starting index
|
||||
const auto c_thread_mtx_on_block =
|
||||
BlockwiseGemmPipe::CalculateCThreadOriginDataIndex(I0, I0);
|
||||
|
||||
const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
|
||||
const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
|
||||
|
||||
const auto m_thread_data_on_block_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor =
|
||||
make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(make_tuple(MRepeat, MWave, MSubGroup, MAccVgprs))),
|
||||
make_tuple(Sequence<0, 1, 2, 3>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto m_thread_data_on_block_idx =
|
||||
m_thread_data_on_block_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor
|
||||
.CalculateBottomIndex(make_multi_index(m_thread_data_on_block));
|
||||
|
||||
const auto n_thread_data_on_block_to_nrepeat_nwave_nthreadpersubgroup_adaptor =
|
||||
make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(make_tuple(NRepeat, NWave, NThreadPerSubGroup))),
|
||||
make_tuple(Sequence<0, 1, 2>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto n_thread_data_on_block_idx =
|
||||
n_thread_data_on_block_to_nrepeat_nwave_nthreadpersubgroup_adaptor.CalculateBottomIndex(
|
||||
make_multi_index(n_thread_data_on_block));
|
||||
|
||||
return ThreadwiseTensorSliceTransfer_v1r3<
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
decltype(BlockwiseGemmPipe::
|
||||
GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs()),
|
||||
decltype(GetCShuffleLDSDescriptor()),
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
Sequence<CShuffleMRepeatPerShuffle,
|
||||
I1,
|
||||
I1,
|
||||
CShuffleNRepeatPerShuffle,
|
||||
I1,
|
||||
I1,
|
||||
MAccVgprs>,
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6>,
|
||||
6,
|
||||
1,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
true>{GetCShuffleLDSDescriptor(),
|
||||
make_multi_index(0,
|
||||
m_thread_data_on_block_idx[I1],
|
||||
m_thread_data_on_block_idx[I2],
|
||||
0,
|
||||
n_thread_data_on_block_idx[I1],
|
||||
n_thread_data_on_block_idx[I2],
|
||||
m_thread_data_on_block_idx[I3]),
|
||||
ck::tensor_operation::element_wise::PassThrough{}};
|
||||
}
|
||||
|
||||
template <InMemoryDataOperationEnum EGlobalMemoryDataOperation,
|
||||
typename InterDataType,
|
||||
typename CDsDescRefs,
|
||||
typename EGridDesc>
|
||||
__device__ static auto
|
||||
GetLDSToVmemEpilogueDescriptor(CDsDescRefs& c_ds_desc_refs,
|
||||
EGridDesc& e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
CDEElementwiseOperation& cde_element_op,
|
||||
const index_t& block_m_id,
|
||||
const index_t& block_n_id)
|
||||
{
|
||||
// tuple of starting index of C/Ds blockwise copy
|
||||
const auto idx_c_ds_block_begin = container_concat(
|
||||
make_tuple(make_multi_index(0, 0, 0, 0)),
|
||||
generate_tuple([&](auto) { return make_multi_index(block_m_id, 0, block_n_id, 0); },
|
||||
Number<NumDTensor>{}));
|
||||
|
||||
// blockwise copy which loads C from LDS, D from global, applies elementwise
|
||||
// operation and stores result E to global
|
||||
return ThreadGroupTensorSliceTransfer_v7r3<
|
||||
ThisThreadBlock, // ThreadGroup
|
||||
decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
|
||||
Tuple<EDataType>,
|
||||
CDsDescRefs,
|
||||
decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
|
||||
CDEElementwiseOperation, // ElementwiseOperation,
|
||||
Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // DstInMemOps,
|
||||
Sequence<1,
|
||||
CShuffleMRepeatPerShuffle * BlockwiseGemmPipe::MWaves * MPerWmma,
|
||||
1,
|
||||
CShuffleNRepeatPerShuffle * BlockwiseGemmPipe::NWaves *
|
||||
NPerWmma>, // BlockSliceLengths,
|
||||
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
Sequence<0, 1, 2, 3>, // ThreadClusterArrangeOrder,
|
||||
Sequence<0, 1, 2, 3>, // SrcDimAccessOrder,
|
||||
Sequence<0, 1, 2, 3>, // DstDimAccessOrder,
|
||||
3, // SrcVectorDim,
|
||||
3, // DstVectorDim,
|
||||
CDEShuffleBlockTransferScalarPerVectors, // SrcScalarPerVectors
|
||||
EShuffleBlockTransferScalarPerVector, // DstScalarPerVector
|
||||
sequence_merge_t<
|
||||
Sequence<true>,
|
||||
uniform_sequence_gen_t<NumDTensor,
|
||||
false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
|
||||
Sequence<false>, // ThreadTransferDstResetCoordinateAfterRunFlags
|
||||
1,
|
||||
Tuple<InterDataType>>{c_ds_desc_refs,
|
||||
idx_c_ds_block_begin,
|
||||
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
make_tuple(make_multi_index(block_m_id, 0, block_n_id, 0)),
|
||||
cde_element_op};
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
@@ -315,8 +315,6 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
using Base::MakeDsGridDescriptor_M_N;
|
||||
using Base::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock;
|
||||
|
||||
using Base::GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat;
|
||||
|
||||
using Base::MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock;
|
||||
|
||||
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
|
||||
@@ -556,7 +554,8 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
|
||||
template <bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
|
||||
TailNumber TailNum>
|
||||
TailNumber TailNum,
|
||||
typename EpilogueArgument>
|
||||
__device__ static void Run(AsGridPointer& p_as_grid,
|
||||
BsGridPointer& p_bs_grid,
|
||||
DsGridPointer& p_ds_grid,
|
||||
@@ -565,7 +564,8 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
const Problem& problem,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op)
|
||||
CDEElementwiseOperation cde_element_op,
|
||||
EpilogueArgument& epilogue_args)
|
||||
{
|
||||
const auto as_grid_desc_ak0_m_ak1 = MakeAsGridDescriptor_AK0_M_AK1(
|
||||
problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideAs, problem.AK0);
|
||||
@@ -610,6 +610,7 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
decltype(e_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
decltype(b_scale_struct),
|
||||
decltype(epilogue_args),
|
||||
HasMainKBlockLoop,
|
||||
EGlobalMemoryDataOperation,
|
||||
TailNum>(p_as_grid,
|
||||
@@ -627,16 +628,20 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
block_m_id,
|
||||
block_n_id,
|
||||
num_k_block_per_scale,
|
||||
b_scale_struct);
|
||||
b_scale_struct,
|
||||
epilogue_args);
|
||||
}
|
||||
|
||||
// Wrapper function to have __global__ function in common
|
||||
// between gemm_universal, b_scale, ab_scale, etc.
|
||||
template <bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
|
||||
TailNumber TailNum>
|
||||
__device__ static void
|
||||
Run(void* p_shared, const SplitKBatchOffset& splitk_batch_offset, Argument& karg)
|
||||
TailNumber TailNum,
|
||||
typename EpilogueArgument>
|
||||
__device__ static void Run(void* p_shared,
|
||||
const SplitKBatchOffset& splitk_batch_offset,
|
||||
Argument& karg,
|
||||
EpilogueArgument& epilogue_args)
|
||||
{
|
||||
// shift A matrices pointer for splitk
|
||||
AsGridPointer p_as_grid_splitk;
|
||||
@@ -663,7 +668,8 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
karg,
|
||||
karg.a_element_op,
|
||||
karg.b_element_op,
|
||||
karg.cde_element_op);
|
||||
karg.cde_element_op,
|
||||
epilogue_args);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -209,8 +209,6 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
|
||||
using Base::MakeDsGridDescriptor_M_N;
|
||||
using Base::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock;
|
||||
|
||||
using Base::GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat;
|
||||
|
||||
using Base::MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock;
|
||||
|
||||
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
|
||||
@@ -533,7 +531,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
|
||||
|
||||
template <bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
|
||||
TailNumber TailNum>
|
||||
TailNumber TailNum,
|
||||
typename EpilogueArgument>
|
||||
__device__ static void Run(AsGridPointer& p_as_grid,
|
||||
BsGridPointer& p_bs_grid,
|
||||
DsGridPointer& p_ds_grid,
|
||||
@@ -543,7 +542,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
|
||||
const Problem& problem,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op)
|
||||
CDEElementwiseOperation cde_element_op,
|
||||
EpilogueArgument& epilogue_args)
|
||||
{
|
||||
const auto as_grid_desc_ak0_m_ak1 = MakeAsGridDescriptor_AK0_M_AK1(
|
||||
problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideAs, problem.AK0);
|
||||
@@ -593,6 +593,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
|
||||
decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
decltype(e_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
decltype(b_scale_struct),
|
||||
decltype(epilogue_args),
|
||||
HasMainKBlockLoop,
|
||||
EGlobalMemoryDataOperation,
|
||||
TailNum>(p_as_grid,
|
||||
@@ -610,16 +611,20 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
|
||||
block_m_id,
|
||||
block_n_id,
|
||||
num_k_block_per_scale,
|
||||
b_scale_struct);
|
||||
b_scale_struct,
|
||||
epilogue_args);
|
||||
}
|
||||
|
||||
// NOTE: Wrapper function to have __global__ function in common
|
||||
// between gemm_universal, b_scale, ab_scale, etc.
|
||||
template <bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
|
||||
TailNumber TailNum>
|
||||
__device__ static void
|
||||
Run(void* p_shared, const SplitKBatchOffset& splitk_batch_offset, Argument& karg)
|
||||
TailNumber TailNum,
|
||||
typename EpilogueArgument>
|
||||
__device__ static void Run(void* p_shared,
|
||||
const SplitKBatchOffset& splitk_batch_offset,
|
||||
Argument& karg,
|
||||
EpilogueArgument& epilogue_args)
|
||||
{
|
||||
// shift A matrices pointer for splitk
|
||||
AsGridPointer p_as_grid_splitk;
|
||||
@@ -647,7 +652,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
|
||||
karg,
|
||||
karg.a_element_op,
|
||||
karg.b_element_op,
|
||||
karg.cde_element_op);
|
||||
karg.cde_element_op,
|
||||
epilogue_args);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -23,6 +23,8 @@
|
||||
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_global.hpp"
|
||||
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_wmma.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_welford_wmma.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
@@ -46,12 +48,16 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
std::is_same_v<e_data_type, ck::bhalf_t>)))
|
||||
{
|
||||
#endif
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte<
|
||||
typename GridwiseGemm::EpilogueCShuffle>();
|
||||
__shared__ char p_shared[LDS_size];
|
||||
|
||||
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
|
||||
|
||||
auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{};
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop, EGlobalMemoryDataOperation, TailNum>(
|
||||
p_shared, splitk_batch_offset, karg);
|
||||
p_shared, splitk_batch_offset, karg, epilogue_args);
|
||||
|
||||
#if defined(__gfx11__)
|
||||
}
|
||||
@@ -262,9 +268,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
static_assert(!PermuteA, "PermuteA is not supported");
|
||||
|
||||
// return block_id to C matrix tile idx (m0, n0) mapping
|
||||
// if arch = gfx942
|
||||
using Block2CTileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>;
|
||||
// using Block2CTileMap = BlockToCTileMap_3DGrid_KSplit<MPerBlock, NPerBlock>;
|
||||
|
||||
__host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch)
|
||||
{
|
||||
@@ -539,23 +543,6 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
Number<NumDTensor>{});
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
// *Caution Here repeat is shuffle repeat
|
||||
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat()
|
||||
{
|
||||
constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWmma);
|
||||
constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWmma);
|
||||
|
||||
constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat =
|
||||
make_naive_tensor_descriptor_packed(
|
||||
make_tuple(I1,
|
||||
Number<CShuffleMRepeatPerShuffle * MWaves * MPerWmma>{},
|
||||
I1,
|
||||
Number<CShuffleNRepeatPerShuffle * NWaves * NPerWmma>{}));
|
||||
|
||||
return c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat;
|
||||
}
|
||||
|
||||
using BlockwiseGemmPipe =
|
||||
remove_cvref_t<decltype(BlockGemmPipeline_Selector<BlkGemmPipelineVer,
|
||||
BlkGemmPipeSched,
|
||||
@@ -578,6 +565,46 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
NRepeat,
|
||||
KPack>())>;
|
||||
|
||||
// Used to create obj in global function and pass it to Run method
|
||||
using EpilogueCShuffle =
|
||||
EpilogueCShuffle<DsDataType,
|
||||
EDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
MPerWmma,
|
||||
NPerWmma,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
CShuffleMRepeatPerShuffle,
|
||||
CShuffleNRepeatPerShuffle,
|
||||
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CDEShuffleBlockTransferScalarPerVectors,
|
||||
CDEElementwiseOperation,
|
||||
ThisThreadBlock,
|
||||
BlockwiseGemmPipe>;
|
||||
|
||||
using EpilogueWelfordCShuffle = EpilogueWelfordCShuffle<
|
||||
DsDataType,
|
||||
EDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
MPerWmma,
|
||||
NPerWmma,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
CShuffleMRepeatPerShuffle,
|
||||
CShuffleNRepeatPerShuffle,
|
||||
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CDEShuffleBlockTransferScalarPerVectors,
|
||||
CDEElementwiseOperation,
|
||||
ThisThreadBlock,
|
||||
BlockwiseGemmPipe,
|
||||
BlockSize>;
|
||||
|
||||
template <typename DEGridDesc>
|
||||
__device__ static constexpr auto MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
const DEGridDesc& de_grid_desc_m_n, index_t MBlock, index_t NBlock)
|
||||
@@ -821,6 +848,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
return BlockwiseGemmPipe::BlockLoopTailNum(num_loop);
|
||||
}
|
||||
|
||||
template <typename EpilogueType>
|
||||
__device__ static constexpr index_t GetSharedMemoryNumberOfByte()
|
||||
{
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
@@ -838,7 +866,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
|
||||
// LDS allocation for C shuffle in LDS
|
||||
constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat =
|
||||
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat();
|
||||
EpilogueType::
|
||||
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat();
|
||||
|
||||
constexpr auto c_block_size =
|
||||
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat
|
||||
@@ -867,6 +896,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
typename DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename BScaleStruct,
|
||||
typename EpilogueArgument,
|
||||
bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
|
||||
TailNumber TailNum = TailNumber::Odd>
|
||||
@@ -887,7 +917,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
const index_t& block_m_id,
|
||||
const index_t& block_n_id,
|
||||
const index_t& num_k_block_per_scale,
|
||||
BScaleStruct& b_scale_struct)
|
||||
BScaleStruct& b_scale_struct,
|
||||
EpilogueArgument& epilogue_args)
|
||||
{
|
||||
const auto as_grid_buf = generate_tuple(
|
||||
[&](auto i) {
|
||||
@@ -903,16 +934,6 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
},
|
||||
Number<NumBTensor>{});
|
||||
|
||||
const auto ds_grid_buf = generate_tuple(
|
||||
[&](auto i) {
|
||||
return make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_ds_grid[i],
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize());
|
||||
},
|
||||
Number<NumDTensor>{});
|
||||
auto e_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
|
||||
|
||||
// lds max alignment
|
||||
constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
|
||||
|
||||
@@ -984,240 +1005,16 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
num_k_block_per_scale);
|
||||
|
||||
// shuffle C and write out
|
||||
{
|
||||
// C mapping in single thread.
|
||||
constexpr auto c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs =
|
||||
blockwise_gemm_pipeline
|
||||
.GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs();
|
||||
|
||||
// C mapping in single block
|
||||
constexpr auto
|
||||
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp =
|
||||
blockwise_gemm_pipeline
|
||||
.GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs();
|
||||
|
||||
constexpr auto MWave =
|
||||
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
|
||||
.GetLength(I1);
|
||||
constexpr auto MSubGroup =
|
||||
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
|
||||
.GetLength(I2);
|
||||
constexpr auto NWave =
|
||||
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
|
||||
.GetLength(I4);
|
||||
constexpr auto NThreadPerSubGroup =
|
||||
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
|
||||
.GetLength(I5);
|
||||
constexpr auto MAccVgprs =
|
||||
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
|
||||
.GetLength(I6);
|
||||
|
||||
// LDS descriptor, shuffle and write out in MRepeat x NRepeat times
|
||||
constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat =
|
||||
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat();
|
||||
|
||||
auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<CShuffleDataType*>(p_shared),
|
||||
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat
|
||||
.GetElementSpaceSize());
|
||||
|
||||
constexpr auto
|
||||
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs =
|
||||
transform_tensor_descriptor(
|
||||
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat,
|
||||
make_tuple(
|
||||
make_freeze_transform(I0),
|
||||
make_unmerge_transform(make_tuple(
|
||||
Number<CShuffleMRepeatPerShuffle>{}, // MRepeat per shuffle repeat
|
||||
MWave, // MWave
|
||||
MSubGroup, // MSubGroup * MAccVgprs = MPerWmma
|
||||
MAccVgprs)),
|
||||
make_freeze_transform(I0),
|
||||
make_unmerge_transform(make_tuple(
|
||||
Number<CShuffleNRepeatPerShuffle>{}, // NRepeat per shuffle repeat
|
||||
NWave, // NWave
|
||||
NThreadPerSubGroup))), // NThreadPerSubGroup = NPerWmma
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<>{},
|
||||
Sequence<0, 1, 2, 6>{},
|
||||
Sequence<>{},
|
||||
Sequence<3, 4, 5>{}));
|
||||
|
||||
// calculate origin of thread output tensor on global memory
|
||||
// blockwise GEMM c matrix starting index
|
||||
const auto c_thread_mtx_on_block =
|
||||
blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0);
|
||||
|
||||
const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
|
||||
const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
|
||||
|
||||
const auto m_thread_data_on_block_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor =
|
||||
make_single_stage_tensor_adaptor(make_tuple(make_merge_transform(make_tuple(
|
||||
MRepeat, MWave, MSubGroup, MAccVgprs))),
|
||||
make_tuple(Sequence<0, 1, 2, 3>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto m_thread_data_on_block_idx =
|
||||
m_thread_data_on_block_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor
|
||||
.CalculateBottomIndex(make_multi_index(m_thread_data_on_block));
|
||||
|
||||
const auto n_thread_data_on_block_to_nrepeat_nwave_nthreadpersubgroup_adaptor =
|
||||
make_single_stage_tensor_adaptor(make_tuple(make_merge_transform(make_tuple(
|
||||
NRepeat, NWave, NThreadPerSubGroup))),
|
||||
make_tuple(Sequence<0, 1, 2>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto n_thread_data_on_block_idx =
|
||||
n_thread_data_on_block_to_nrepeat_nwave_nthreadpersubgroup_adaptor
|
||||
.CalculateBottomIndex(make_multi_index(n_thread_data_on_block));
|
||||
|
||||
// shuffle: threadwise copy C from VGPR to LDS
|
||||
auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3<
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
decltype(c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs),
|
||||
decltype(c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs),
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
Sequence<CShuffleMRepeatPerShuffle,
|
||||
I1,
|
||||
I1,
|
||||
CShuffleNRepeatPerShuffle,
|
||||
I1,
|
||||
I1,
|
||||
MAccVgprs>,
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6>,
|
||||
6,
|
||||
1, // vector write pixel
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
true>{
|
||||
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs,
|
||||
make_multi_index(0,
|
||||
m_thread_data_on_block_idx[I1],
|
||||
m_thread_data_on_block_idx[I2],
|
||||
0,
|
||||
n_thread_data_on_block_idx[I1],
|
||||
n_thread_data_on_block_idx[I2],
|
||||
m_thread_data_on_block_idx[I3]),
|
||||
ck::tensor_operation::element_wise::PassThrough{}};
|
||||
|
||||
// tuple of reference to C/Ds tensor descriptors
|
||||
const auto c_ds_desc_refs = concat_tuple_of_reference(
|
||||
tie(c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat),
|
||||
generate_tie([&](auto i) -> const auto& // return type should be reference
|
||||
{ return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
|
||||
Number<NumDTensor>{}));
|
||||
|
||||
// tuple of reference to C/Ds tensor buffers
|
||||
const auto c_ds_buf_refs = concat_tuple_of_reference(
|
||||
tie(c_shuffle_block_buf),
|
||||
generate_tie([&](auto i) -> const auto& // return type should be reference
|
||||
{ return ds_grid_buf[i]; },
|
||||
Number<NumDTensor>{}));
|
||||
|
||||
// tuple of starting index of C/Ds blockwise copy
|
||||
const auto idx_c_ds_block_begin = container_concat(
|
||||
make_tuple(make_multi_index(0, 0, 0, 0)),
|
||||
generate_tuple([&](auto) { return make_multi_index(block_m_id, 0, block_n_id, 0); },
|
||||
Number<NumDTensor>{}));
|
||||
|
||||
// blockwise copy which loads C from LDS, D from global, applies elementwise
|
||||
// operation and stores result E to global
|
||||
auto cde_shuffle_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3<
|
||||
ThisThreadBlock, // ThreadGroup
|
||||
decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
|
||||
Tuple<EDataType>,
|
||||
decltype(c_ds_desc_refs),
|
||||
decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
|
||||
CDEElementwiseOperation, // ElementwiseOperation,
|
||||
Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // DstInMemOps,
|
||||
Sequence<1,
|
||||
CShuffleMRepeatPerShuffle * MWave * MPerWmma,
|
||||
1,
|
||||
CShuffleNRepeatPerShuffle * NWave * NPerWmma>, // BlockSliceLengths,
|
||||
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
Sequence<0, 1, 2, 3>, // ThreadClusterArrangeOrder,
|
||||
Sequence<0, 1, 2, 3>, // SrcDimAccessOrder,
|
||||
Sequence<0, 1, 2, 3>, // DstDimAccessOrder,
|
||||
3, // SrcVectorDim,
|
||||
3, // DstVectorDim,
|
||||
CDEShuffleBlockTransferScalarPerVectors, // SrcScalarPerVectors
|
||||
EShuffleBlockTransferScalarPerVector, // DstScalarPerVector
|
||||
sequence_merge_t<
|
||||
Sequence<true>,
|
||||
uniform_sequence_gen_t<NumDTensor,
|
||||
false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
|
||||
Sequence<false>> // ThreadTransferDstResetCoordinateAfterRunFlags
|
||||
{c_ds_desc_refs,
|
||||
idx_c_ds_block_begin,
|
||||
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
make_tuple(make_multi_index(block_m_id, 0, block_n_id, 0)),
|
||||
cde_element_op};
|
||||
|
||||
// space filling curve for local reg & global memory
|
||||
// space filling curve for threadwise C in VGPR
|
||||
constexpr auto sfc_c_vgpr =
|
||||
SpaceFillingCurve<Sequence<MRepeat, 1, 1, NRepeat, 1, 1, MAccVgprs>,
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6>,
|
||||
Sequence<CShuffleMRepeatPerShuffle,
|
||||
1,
|
||||
1,
|
||||
CShuffleNRepeatPerShuffle,
|
||||
1,
|
||||
1,
|
||||
MAccVgprs>>{};
|
||||
|
||||
// space filling curve for shuffled blockwise C in global mem
|
||||
constexpr auto sfc_cde_global =
|
||||
SpaceFillingCurve<Sequence<1, MPerBlock, 1, NPerBlock>,
|
||||
Sequence<0, 2, 1, 3>,
|
||||
Sequence<1,
|
||||
CShuffleMRepeatPerShuffle * MWave * MPerWmma,
|
||||
1,
|
||||
CShuffleNRepeatPerShuffle * NWave * NPerWmma>>{};
|
||||
|
||||
constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
|
||||
|
||||
static_assert(num_access == sfc_cde_global.GetNumOfAccess(), "wrong!");
|
||||
|
||||
static_for<0, num_access, 1>{}([&](auto access_id) {
|
||||
// make sure it's safe to write to LDS
|
||||
block_sync_lds();
|
||||
|
||||
// each thread write its data from VGPR to LDS
|
||||
c_thread_copy_vgpr_to_lds.Run(
|
||||
c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs,
|
||||
sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
|
||||
c_thread_buf,
|
||||
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs,
|
||||
c_shuffle_block_buf);
|
||||
|
||||
// make sure it's safe to read from LDS
|
||||
block_sync_lds();
|
||||
|
||||
// each block loads its C data from LDS, D from global, applies elementwise
|
||||
// operation and stores result E to global
|
||||
cde_shuffle_block_copy_lds_and_global.Run(
|
||||
c_ds_desc_refs,
|
||||
c_ds_buf_refs,
|
||||
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
tie(e_grid_buf));
|
||||
|
||||
if constexpr(access_id < num_access - 1)
|
||||
{
|
||||
constexpr auto cde_global_step = sfc_cde_global.GetForwardStep(access_id);
|
||||
// move on Ds
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
cde_shuffle_block_copy_lds_and_global.MoveSrcSliceWindow(
|
||||
c_ds_desc_refs, i + I1, cde_global_step);
|
||||
});
|
||||
|
||||
// move on E
|
||||
cde_shuffle_block_copy_lds_and_global.MoveDstSliceWindow(
|
||||
tie(e_grid_desc_mblock_mperblock_nblock_nperblock), cde_global_step);
|
||||
}
|
||||
});
|
||||
}
|
||||
epilogue_args.template Run<EGlobalMemoryDataOperation>(
|
||||
c_thread_buf,
|
||||
p_ds_grid,
|
||||
p_e_grid,
|
||||
p_shared,
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
cde_element_op,
|
||||
block_m_id,
|
||||
block_n_id);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -43,7 +43,8 @@ template <typename SrcDatas,
|
||||
index_t DstScalarPerVector,
|
||||
typename SrcResetCoordinateAfterRunFlags, // Sequence<bool ...>
|
||||
typename DstResetCoordinateAfterRunFlags, // Sequence<bool ...>
|
||||
index_t NumThreadScratch = 1>
|
||||
index_t NumThreadScratch = 1,
|
||||
typename InterDatas = DstDatas>
|
||||
struct ThreadwiseTensorSliceTransfer_v7r3
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
@@ -153,7 +154,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3
|
||||
// loop over space-filling curve
|
||||
static_for<0, src_num_access, 1>{}([&](auto iAccess) {
|
||||
auto src_vectors = generate_vectors<SrcDatas, SrcScalarPerVector>();
|
||||
auto elm_vectors = generate_vectors<DstDatas, SrcScalarPerVector>();
|
||||
auto elm_vectors = generate_vectors<InterDatas, SrcScalarPerVector>();
|
||||
|
||||
bool oob_val = true;
|
||||
|
||||
@@ -226,9 +227,10 @@ struct ThreadwiseTensorSliceTransfer_v7r3
|
||||
auto dst_data_refs = generate_tie(
|
||||
// return type should be lvalue
|
||||
[&](auto iDst) -> auto& {
|
||||
using DstData = remove_cvref_t<tuple_element_t<iDst.value, DstDatas>>;
|
||||
using InterData = remove_cvref_t<tuple_element_t<iDst.value, InterDatas>>;
|
||||
|
||||
using elem_op_vec_t = typename vector_type<DstData, elem_op_vec_len>::type;
|
||||
using elem_op_vec_t =
|
||||
typename vector_type<InterData, elem_op_vec_len>::type;
|
||||
|
||||
return elm_vectors(iDst).template AsType<elem_op_vec_t>()(i);
|
||||
},
|
||||
@@ -297,17 +299,17 @@ struct ThreadwiseTensorSliceTransfer_v7r3
|
||||
__device__ void
|
||||
TransposeFromElmToDst(Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
|
||||
{
|
||||
using DstData = remove_cvref_t<decltype(DstDatas{}[I0])>;
|
||||
using InterData = remove_cvref_t<decltype(InterDatas{}[I0])>;
|
||||
|
||||
using ElmThreadScratch =
|
||||
StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
|
||||
DstData,
|
||||
InterData,
|
||||
SrcScalarPerVector,
|
||||
decltype(GetSrcThreadScratchDescriptor()),
|
||||
true>;
|
||||
using DstThreadScratch =
|
||||
StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
|
||||
DstData,
|
||||
InterData,
|
||||
DstScalarPerVector,
|
||||
decltype(GetDstThreadScratchDescriptor()),
|
||||
true>;
|
||||
@@ -319,11 +321,11 @@ struct ThreadwiseTensorSliceTransfer_v7r3
|
||||
bit_cast<decltype(elm_thread_scratch_.data_)>(elm_vectors_tuple_[thread_scratch_id]);
|
||||
|
||||
if constexpr(SrcVectorDim != DstVectorDim &&
|
||||
((is_same<half_t, remove_cvref_t<DstData>>::value &&
|
||||
((is_same<half_t, remove_cvref_t<InterData>>::value &&
|
||||
SrcScalarPerVector % 2 == 0 && DstScalarPerVector % 2 == 0) ||
|
||||
(is_same<f8_t, remove_cvref_t<DstData>>::value &&
|
||||
(is_same<f8_t, remove_cvref_t<InterData>>::value &&
|
||||
SrcScalarPerVector % 4 == 0 && DstScalarPerVector % 4 == 0) ||
|
||||
(is_same<int8_t, remove_cvref_t<DstData>>::value &&
|
||||
(is_same<int8_t, remove_cvref_t<InterData>>::value &&
|
||||
SrcScalarPerVector % 4 == 0 && DstScalarPerVector % 4 == 0)))
|
||||
{
|
||||
// each transpose does
|
||||
@@ -356,8 +358,8 @@ struct ThreadwiseTensorSliceTransfer_v7r3
|
||||
constexpr auto data_idx_seq = generate_sequence_v2(
|
||||
[&](auto i) { return Number<data_idx[i]>{}; }, Number<nDim>{});
|
||||
|
||||
using src_vector_t = vector_type_maker_t<DstData, SrcScalarPerVector>;
|
||||
using dst_vector_t = vector_type_maker_t<DstData, DstScalarPerVector>;
|
||||
using src_vector_t = vector_type_maker_t<InterData, SrcScalarPerVector>;
|
||||
using dst_vector_t = vector_type_maker_t<InterData, DstScalarPerVector>;
|
||||
|
||||
// get DstScalarPerVector # of read-only references to src vectors from
|
||||
// src_thread_scratch_
|
||||
@@ -380,7 +382,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3
|
||||
Number<num_dst_vector>{});
|
||||
|
||||
// do data transpose
|
||||
transpose_vectors<DstData, DstScalarPerVector, SrcScalarPerVector>{}(
|
||||
transpose_vectors<InterData, DstScalarPerVector, SrcScalarPerVector>{}(
|
||||
src_vector_refs, dst_vector_refs);
|
||||
});
|
||||
}
|
||||
@@ -393,6 +395,104 @@ struct ThreadwiseTensorSliceTransfer_v7r3
|
||||
dst_vectors_tuple_(thread_scratch_id) = bit_cast<DstVectorTuple>(dst_thread_scratch_.data_);
|
||||
}
|
||||
|
||||
// DstDescs: Tuple<const DstDesc0&, const DstDesc1&, ...>
|
||||
// DstBuffers: Tuple<const DstBuffer0&, const DstBuffer1&, ...>
|
||||
// DstVgprDescs: Tuple<const DstVgprDesc0&, const DstVgprDesc1&, ...>
|
||||
// DstVgprBuffers: Tuple<DstVgprBuffer0&, DstVgprBuffer1&, ...>
|
||||
template <typename DstBuffers,
|
||||
typename DstVgprDescs,
|
||||
typename DstVgprBuffers,
|
||||
index_t ThreadScratchId = 0,
|
||||
enable_if_t<DstDescs::Size() == 1 && DstBuffers::Size() == 1, bool> = false>
|
||||
__device__ void
|
||||
RunWriteAndStoreVgpr(const DstDescs& dst_descs,
|
||||
DstBuffers dst_bufs,
|
||||
const DstVgprDescs&,
|
||||
DstVgprBuffers dst_vgpr_buf,
|
||||
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
|
||||
{
|
||||
// Same functionality of RunWrite but additionally store internal Vgpr in dst_vgpr_buf
|
||||
OOBCheck(thread_scratch_id);
|
||||
TransposeFromElmToDst(thread_scratch_id);
|
||||
|
||||
// Vgpr buffer origin is set internally to 0
|
||||
constexpr auto dst_slice_origin_idx =
|
||||
generate_tuple([&](auto) { return I0; }, Number<nDim>{});
|
||||
constexpr auto dst_scalar_step_in_vector =
|
||||
generate_sequence(detail::lambda_scalar_step_in_vector<DstVectorDim>{}, Number<nDim>{});
|
||||
|
||||
// loop over space-filling curve
|
||||
static_for<0, dst_num_access, 1>{}([&](auto iAccess) {
|
||||
auto dst_vectors = dst_vectors_tuple_[thread_scratch_id][iAccess];
|
||||
|
||||
static_for<0, nDst, 1>{}([&](auto i) {
|
||||
// copy data from buf_vectors into dst_bufs
|
||||
using DstData = remove_cvref_t<decltype(DstDatas{}[i])>;
|
||||
using InterData = remove_cvref_t<decltype(InterDatas{}[i])>;
|
||||
|
||||
typename vector_type_maker<DstData, DstScalarPerVector>::type dst_vector;
|
||||
using dst_vector_t =
|
||||
typename vector_type_maker<DstData, DstScalarPerVector>::type::type;
|
||||
|
||||
static_for<0, DstScalarPerVector, 1>{}([&](auto j) {
|
||||
dst_vector.template AsType<DstData>()(j) =
|
||||
type_convert<DstData>(dst_vectors[i].template AsType<InterData>()[j]);
|
||||
});
|
||||
|
||||
const bool is_dst_valid =
|
||||
coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_descs[i],
|
||||
dst_coords_[i]);
|
||||
|
||||
constexpr InMemoryDataOperationEnum DstInMemOp =
|
||||
static_cast<InMemoryDataOperationEnum>(DstInMemOps::At(i.value));
|
||||
|
||||
dst_bufs(i).template Update<DstInMemOp, dst_vector_t>(
|
||||
dst_coords_[i].GetOffset(),
|
||||
is_dst_valid,
|
||||
dst_vector.template AsType<dst_vector_t>()[I0]);
|
||||
|
||||
// store Vgpr
|
||||
using DstVgprDesc = remove_cvref_t<decltype(DstVgprDescs{}.At(i))>;
|
||||
static_assert(DstVgprDesc::IsKnownAtCompileTime(),
|
||||
"wrong! DstDesc need to known at compile-time");
|
||||
constexpr auto dst_vgpr_desc = DstVgprDesc{};
|
||||
|
||||
constexpr auto src_data_idx = DstSpaceFillingCurve::GetIndex(iAccess);
|
||||
static_for<0, DstScalarPerVector, 1>{}([&](auto j) {
|
||||
constexpr index_t dst_offset =
|
||||
dst_vgpr_desc.CalculateOffset(to_multi_index(dst_slice_origin_idx) +
|
||||
src_data_idx + j * dst_scalar_step_in_vector);
|
||||
|
||||
dst_vgpr_buf(I0)(Number<dst_offset>{}) =
|
||||
is_dst_valid ? dst_vectors[i].template AsType<InterData>()[j]
|
||||
: NumericLimits<InterData>::QuietNaN();
|
||||
});
|
||||
});
|
||||
|
||||
// move coordinate
|
||||
if constexpr(iAccess.value != dst_num_access - 1)
|
||||
{
|
||||
constexpr auto forward_step = DstSpaceFillingCurve::GetForwardStep(iAccess);
|
||||
|
||||
static_for<0, nDst, 1>{}([&](auto i) {
|
||||
move_tensor_coordinate(dst_descs[i],
|
||||
dst_coords_(i),
|
||||
make_tensor_coordinate_step(dst_descs[i], forward_step));
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
static_for<0, nDst, 1>{}([&](auto i) {
|
||||
if constexpr(DstResetCoordinateAfterRunFlags::At(i))
|
||||
{
|
||||
const auto dst_reset_step =
|
||||
make_tensor_coordinate_step(dst_descs[i], GetDstCoordinateResetStep());
|
||||
|
||||
move_tensor_coordinate(dst_descs[i], dst_coords_(i), dst_reset_step);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// DstDescs: Tuple<const DstDesc0&, const DstDesc1&, ...>
|
||||
// DstBuffers: Tuple<const DstBuffer0&, const DstBuffer1&, ...>
|
||||
template <typename DstBuffers,
|
||||
@@ -402,6 +502,9 @@ struct ThreadwiseTensorSliceTransfer_v7r3
|
||||
DstBuffers dst_bufs,
|
||||
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
|
||||
{
|
||||
static_assert(is_same_v<InterDatas, DstDatas>,
|
||||
"RunWrite doesn't support inter data type different from dst data type");
|
||||
|
||||
OOBCheck(thread_scratch_id);
|
||||
TransposeFromElmToDst(thread_scratch_id);
|
||||
|
||||
@@ -630,8 +733,8 @@ struct ThreadwiseTensorSliceTransfer_v7r3
|
||||
|
||||
private:
|
||||
using SrcVectorsType = decltype(generate_vectors<SrcDatas, SrcScalarPerVector>());
|
||||
using ElmVectorsType = decltype(generate_vectors<DstDatas, SrcScalarPerVector>());
|
||||
using DstVectorsType = decltype(generate_vectors<DstDatas, DstScalarPerVector>());
|
||||
using ElmVectorsType = decltype(generate_vectors<InterDatas, SrcScalarPerVector>());
|
||||
using DstVectorsType = decltype(generate_vectors<InterDatas, DstScalarPerVector>());
|
||||
|
||||
static constexpr auto src_num_access = SrcSpaceFillingCurve::GetNumOfAccess();
|
||||
static constexpr auto dst_num_access = DstSpaceFillingCurve::GetNumOfAccess();
|
||||
|
||||
@@ -319,22 +319,67 @@ struct gfx9_t
|
||||
struct gfx950_t
|
||||
{
|
||||
};
|
||||
struct gfx103_t
|
||||
{
|
||||
};
|
||||
struct gfx11_t
|
||||
{
|
||||
};
|
||||
struct gfx12_t
|
||||
{
|
||||
};
|
||||
struct gfx_invalid_t
|
||||
{
|
||||
};
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto get_device_arch()
|
||||
{
|
||||
// FIXME(0): on all devices except gfx11 it returns gfx12_t
|
||||
// FIXME(1): during the host compilation pass it returns gfx12_t
|
||||
#if defined(__gfx11__)
|
||||
return gfx11_t{};
|
||||
#else // if defined(__gfx12__)
|
||||
#else
|
||||
return gfx12_t{};
|
||||
#endif
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto get_n_words_per_128b() { return 4; }
|
||||
|
||||
namespace detail {
|
||||
CK_TILE_DEVICE static constexpr auto get_n_lds_banks(gfx9_t) { return 32; }
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto get_n_lds_banks(gfx103_t) { return 32; }
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto get_n_lds_banks(gfx11_t) { return 32; }
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto get_n_lds_banks(gfx12_t) { return 32; }
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto get_n_lds_banks(gfx950_t) { return 64; }
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto get_n_lds_banks(gfx_invalid_t) { return 0; }
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto arch_tag_dispatch()
|
||||
{
|
||||
#if defined(__gfx103__)
|
||||
return gfx103_t{};
|
||||
#elif defined(__gfx11__)
|
||||
return gfx11_t{};
|
||||
#elif defined(__gfx12__)
|
||||
return gfx12_t{};
|
||||
#elif defined(__gfx950__)
|
||||
return gfx950_t{};
|
||||
#elif defined(__gfx9__)
|
||||
return gfx9_t{};
|
||||
#else
|
||||
return gfx_invalid_t{};
|
||||
#endif
|
||||
}
|
||||
} // namespace detail
|
||||
CK_TILE_DEVICE static constexpr auto get_n_lds_banks()
|
||||
{
|
||||
return detail::get_n_lds_banks(detail::arch_tag_dispatch());
|
||||
}
|
||||
|
||||
enum LLVMSchedGroupMask : int32_t
|
||||
{
|
||||
NONE = 0,
|
||||
|
||||
0
include/ck_tile/host/tensor_shuffle_utils.hpp
Executable file → Normal file
0
include/ck_tile/host/tensor_shuffle_utils.hpp
Executable file → Normal file
@@ -442,7 +442,7 @@ struct BlockFmhaV3PipelineDefaultPolicy
|
||||
|
||||
constexpr index_t SingleVSize = [&]() {
|
||||
using VDataType = remove_cvref_t<typename Problem::VDataType>;
|
||||
constexpr index_t Banks = 32; // TODO: need change based on arch
|
||||
constexpr index_t Banks = get_n_lds_banks();
|
||||
constexpr index_t PixelsPerRow = Banks * 4 / sizeof(VDataType);
|
||||
constexpr index_t kKPack = GetSmemKPackK<Problem>();
|
||||
static_assert(PixelsPerRow % kKPack == 0);
|
||||
|
||||
@@ -140,7 +140,7 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy
|
||||
|
||||
constexpr index_t NumVLdsBuffers = GetNumVLdsBuffers<Problem>();
|
||||
|
||||
constexpr index_t Banks = 32; // TODO: need change based on arch
|
||||
constexpr index_t Banks = get_n_lds_banks();
|
||||
constexpr index_t PixelsPerRow = Banks * 4 / sizeof(VDataType);
|
||||
constexpr index_t kKPack = GetSmemKPackV<Problem>();
|
||||
static_assert(PixelsPerRow % kKPack == 0);
|
||||
|
||||
@@ -465,7 +465,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
|
||||
|
||||
constexpr index_t SingleVSize = [&]() {
|
||||
using VDataType = remove_cvref_t<typename Problem::VDataType>;
|
||||
constexpr index_t Banks = 32; // TODO: need change based on arch
|
||||
constexpr index_t Banks = get_n_lds_banks();
|
||||
constexpr index_t PixelsPerRow = Banks * 4 / sizeof(VDataType);
|
||||
constexpr index_t kKPack = GetSmemKPackK<Problem>();
|
||||
static_assert(PixelsPerRow % kKPack == 0);
|
||||
@@ -620,7 +620,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeVLdsBlockDescriptor()
|
||||
{
|
||||
using VDataType = remove_cvref_t<typename Problem::VDataType>;
|
||||
constexpr index_t Banks = 32; // TODO: need change based on arch
|
||||
constexpr index_t Banks = get_n_lds_banks();
|
||||
constexpr index_t PixelsPerRow = Banks * 4 / sizeof(VDataType);
|
||||
constexpr index_t kKPack = GetSmemKPackV<Problem>();
|
||||
static_assert(PixelsPerRow % kKPack == 0);
|
||||
|
||||
@@ -71,7 +71,7 @@ struct UniversalGemmBasePolicy
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor()
|
||||
CK_TILE_DEVICE static constexpr auto MakeALdsBlockDescriptor()
|
||||
{
|
||||
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
@@ -94,7 +94,7 @@ struct UniversalGemmBasePolicy
|
||||
|
||||
constexpr auto DataTypeSize = sizeof(ADataType);
|
||||
constexpr auto MLdsLayer =
|
||||
(32 * 4 / KPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / KPerBlock / DataTypeSize);
|
||||
max(1UL, get_n_lds_banks() * get_n_words_per_128b() / KPerBlock / DataTypeSize);
|
||||
|
||||
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<KPerBlock / KPack * MLdsLayer>{},
|
||||
@@ -141,7 +141,7 @@ struct UniversalGemmBasePolicy
|
||||
* @return B tensor LDS block descriptor.
|
||||
*/
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor()
|
||||
CK_TILE_DEVICE static constexpr auto MakeBLdsBlockDescriptor()
|
||||
{
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
|
||||
@@ -166,7 +166,7 @@ struct UniversalGemmBasePolicy
|
||||
constexpr auto BK0 = number<KPerBlock / KPack>{};
|
||||
constexpr auto DataTypeSize = sizeof(BDataType);
|
||||
constexpr auto NLdsLayer =
|
||||
(32 * 4 / KPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / KPerBlock / DataTypeSize);
|
||||
max(1UL, get_n_lds_banks() * get_n_words_per_128b() / KPerBlock / DataTypeSize);
|
||||
|
||||
constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(
|
||||
@@ -658,25 +658,27 @@ struct UniversalGemmBasePolicy
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeA()
|
||||
CK_TILE_DEVICE static constexpr index_t GetSmemSizeA()
|
||||
{
|
||||
constexpr auto a_lds_desc = MakeALdsBlockDescriptor<Problem>();
|
||||
constexpr index_t smem_size_a = integer_least_multiple(
|
||||
sizeof(typename Problem::ADataType) * a_lds_desc.get_element_space_size(), 16);
|
||||
constexpr index_t smem_size_a =
|
||||
integer_least_multiple(sizeof(typename Problem::ADataType) *
|
||||
Problem::BlockGemmShape::kM * Problem::BlockGemmShape::kK,
|
||||
16);
|
||||
return smem_size_a;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeB()
|
||||
CK_TILE_DEVICE static constexpr index_t GetSmemSizeB()
|
||||
{
|
||||
constexpr auto b_lds_desc = MakeBLdsBlockDescriptor<Problem>();
|
||||
constexpr index_t smem_size_b = integer_least_multiple(
|
||||
sizeof(typename Problem::BDataType) * b_lds_desc.get_element_space_size(), 16);
|
||||
constexpr index_t smem_size_b =
|
||||
integer_least_multiple(sizeof(typename Problem::BDataType) *
|
||||
Problem::BlockGemmShape::kN * Problem::BlockGemmShape::kK,
|
||||
16);
|
||||
return smem_size_b;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
CK_TILE_DEVICE static constexpr index_t GetSmemSize()
|
||||
{
|
||||
constexpr index_t smem_size_a = GetSmemSizeA<Problem>();
|
||||
constexpr index_t smem_size_b = GetSmemSizeB<Problem>();
|
||||
|
||||
@@ -20,7 +20,7 @@ struct BaseWeightPreshufflePipelineAGmemBGmemCRegV2
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; }
|
||||
|
||||
CK_TILE_HOST static constexpr bool BlockHasHotloop(index_t num_loop)
|
||||
CK_TILE_HOST_DEVICE static constexpr bool BlockHasHotloop(index_t num_loop)
|
||||
{
|
||||
return num_loop > PrefetchStages;
|
||||
}
|
||||
|
||||
@@ -483,6 +483,7 @@ struct QuantGemmKernel
|
||||
const QuantGemmKernelArgs& kargs,
|
||||
const SplitKBatchOffset& splitk_batch_offset)
|
||||
{
|
||||
|
||||
static_assert(!TilePartitioner::BlockGemmShape::PermuteA, "Not implemented!");
|
||||
const auto& a_tensor_view = [&]() {
|
||||
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
|
||||
@@ -790,6 +791,7 @@ struct QuantGemmKernel
|
||||
}();
|
||||
if constexpr(PreshuffleB)
|
||||
{
|
||||
|
||||
return make_tuple(a_pad_view, aq_pad_view, b_flat_view, bq_pad_view, c_pad_view);
|
||||
}
|
||||
else
|
||||
@@ -802,6 +804,7 @@ struct QuantGemmKernel
|
||||
CK_TILE_DEVICE static auto
|
||||
MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n)
|
||||
{
|
||||
|
||||
const auto& a_pad_view = views.at(I0);
|
||||
const auto& aq_pad_view = views.at(I1);
|
||||
const auto& b_pad_view = views.at(I2);
|
||||
@@ -867,6 +870,7 @@ struct QuantGemmKernel
|
||||
const auto& b_block_window = [&]() {
|
||||
if constexpr(PreshuffleB)
|
||||
{
|
||||
|
||||
return make_tile_window(
|
||||
b_pad_view,
|
||||
make_tuple(number<GemmPipeline::flatNPerWarp>{},
|
||||
|
||||
@@ -317,13 +317,88 @@ struct QuantGroupedGemmKernel
|
||||
const BQDataType* bq_ptr = static_cast<const BQDataType*>(kargs.bq_ptr);
|
||||
CDataType* c_ptr = static_cast<CDataType*>(kargs.c_ptr);
|
||||
|
||||
static_assert(GemmPipeline::DoubleSmemBuffer == false,
|
||||
"DoubleSmemBuffer needs to be false");
|
||||
// allocate LDS
|
||||
__shared__ char smem_ptr_0[GetSmemSize()];
|
||||
|
||||
RunGemmWithPipelineSelection(
|
||||
a_ptr, b_ptr, aq_ptr, bq_ptr, c_ptr, smem_ptr_0, kargs, splitk_batch_offset, i_m, i_n);
|
||||
// Only for BQuantGrouped DoubleSmemBuffer is supported
|
||||
if constexpr(GemmPipeline::DoubleSmemBuffer == true &&
|
||||
kQuantType == QuantType::BQuantGrouped)
|
||||
{
|
||||
|
||||
__shared__ char smem_ptr_1[GetSmemSize()];
|
||||
RunGemmWithPipelineSelection2LDS(a_ptr,
|
||||
b_ptr,
|
||||
aq_ptr,
|
||||
bq_ptr,
|
||||
c_ptr,
|
||||
smem_ptr_0,
|
||||
smem_ptr_1,
|
||||
kargs,
|
||||
splitk_batch_offset,
|
||||
i_m,
|
||||
i_n);
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
RunGemmWithPipelineSelection(a_ptr,
|
||||
b_ptr,
|
||||
aq_ptr,
|
||||
bq_ptr,
|
||||
c_ptr,
|
||||
smem_ptr_0,
|
||||
kargs,
|
||||
splitk_batch_offset,
|
||||
i_m,
|
||||
i_n);
|
||||
}
|
||||
}
|
||||
|
||||
template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
|
||||
CK_TILE_DEVICE static void
|
||||
RunGemmWithPipelineSelection2LDS(const ADataType* a_ptr,
|
||||
const BDataType* b_ptr,
|
||||
const AQDataType* aq_ptr,
|
||||
const BQDataType* bq_ptr,
|
||||
CDataType* c_ptr,
|
||||
void* smem_ptr_0,
|
||||
void* smem_ptr_1,
|
||||
const QuantGroupedGemmKernelArgs& kargs,
|
||||
const typename Base::SplitKBatchOffset& splitk_batch_offset,
|
||||
const index_t block_idx_m,
|
||||
const index_t block_idx_n)
|
||||
{
|
||||
static_assert(kQuantType == QuantType::BQuantGrouped, "kQuantType must be BQuantGrouped");
|
||||
// Create Gemm tensor views, pad views and tile windows
|
||||
const auto& gemm_tensor_views_tuple =
|
||||
Base::template MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
|
||||
a_ptr, b_ptr, aq_ptr, bq_ptr, c_ptr, kargs, splitk_batch_offset);
|
||||
|
||||
const auto& gemm_pad_views = Base::MakeGemmPadViews(gemm_tensor_views_tuple);
|
||||
auto gemm_tile_windows =
|
||||
Base::MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
|
||||
|
||||
const index_t num_loop = __builtin_amdgcn_readfirstlane(
|
||||
TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
|
||||
const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
// Run GEMM cooperatively by whole workgroup.
|
||||
const auto& a_block_window = gemm_tile_windows.at(Base::I0);
|
||||
const auto& b_block_window = gemm_tile_windows.at(Base::I2);
|
||||
|
||||
const auto& bq_block_window = gemm_tile_windows.at(Base::I3);
|
||||
const auto& c_block_tile = GemmPipeline{}.template operator()(a_block_window,
|
||||
b_block_window,
|
||||
bq_block_window,
|
||||
num_loop,
|
||||
tail_num,
|
||||
smem_ptr_0,
|
||||
smem_ptr_1);
|
||||
|
||||
// Run Epilogue Pipeline
|
||||
auto& c_block_window = gemm_tile_windows.at(Base::I4);
|
||||
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0);
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -458,6 +458,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
|
||||
void* p_smem_ping,
|
||||
void* p_smem_pong) const
|
||||
{
|
||||
|
||||
return operator()<TailNum>(
|
||||
a_dram_block_window_tmp,
|
||||
[](const ADataType& a) { return a; },
|
||||
@@ -467,5 +468,31 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
|
||||
p_smem_ping,
|
||||
p_smem_pong);
|
||||
}
|
||||
|
||||
template <typename ADramBlockWindowTmp,
|
||||
typename BFlatBlockWindowTmp,
|
||||
typename BQDramBlockWindowTmp>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
|
||||
const BQDramBlockWindowTmp& bq_dram_block_window_tmp,
|
||||
index_t num_loop,
|
||||
TailNumber tail_number,
|
||||
void* p_smem_ping,
|
||||
void* p_smem_pong) const
|
||||
{
|
||||
const auto RunPipeline = [&](auto bool_val, auto tail_num_) {
|
||||
(void)bool_val; // Suppress unused parameter warning
|
||||
constexpr auto tail_num = tail_num_.value;
|
||||
return operator()<tail_num>(
|
||||
a_dram_block_window_tmp,
|
||||
[](const ADataType& a) { return a; },
|
||||
b_flat_dram_block_window_tmp,
|
||||
bq_dram_block_window_tmp,
|
||||
num_loop,
|
||||
p_smem_ping,
|
||||
p_smem_pong);
|
||||
};
|
||||
return Base::TailHandler(RunPipeline, true, tail_number);
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -15,6 +15,7 @@ namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
#if defined(CK_USE_XDL)
|
||||
void add_device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_kn_mn_mn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDLayernorm<Row,
|
||||
Row,
|
||||
@@ -78,6 +79,73 @@ void add_device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_km_nk_mn_mn_mn_ins
|
||||
PassThrough,
|
||||
AddReluAdd,
|
||||
PassThrough>>>&);
|
||||
#endif // CK_USE_XDL
|
||||
|
||||
#if defined(CK_USE_WMMA)
|
||||
void add_device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_mk_kn_mn_mn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDLayernorm<Row,
|
||||
Row,
|
||||
Row_Row_Tuple,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
F16_F16_Tuple,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddReluAdd,
|
||||
PassThrough>>>&);
|
||||
|
||||
void add_device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_mk_nk_mn_mn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDLayernorm<Row,
|
||||
Col,
|
||||
Row_Row_Tuple,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
F16_F16_Tuple,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddReluAdd,
|
||||
PassThrough>>>&);
|
||||
|
||||
void add_device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_km_kn_mn_mn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDLayernorm<Col,
|
||||
Row,
|
||||
Row_Row_Tuple,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
F16_F16_Tuple,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddReluAdd,
|
||||
PassThrough>>>&);
|
||||
|
||||
void add_device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_km_nk_mn_mn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDLayernorm<Col,
|
||||
Col,
|
||||
Row_Row_Tuple,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
F16_F16_Tuple,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddReluAdd,
|
||||
PassThrough>>>&);
|
||||
#endif
|
||||
|
||||
// GEMM + Add + Relu + Add + Layernorm
|
||||
template <typename ALayout,
|
||||
@@ -136,29 +204,53 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
|
||||
is_same_v<D0Layout, Row> && is_same_v<D1Layout, Row> &&
|
||||
is_same_v<HLayout, Row>)
|
||||
{
|
||||
#if defined(CK_USE_XDL)
|
||||
add_device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_kn_mn_mn_mn_instances(
|
||||
op_ptrs);
|
||||
#endif
|
||||
#if defined(CK_USE_WMMA)
|
||||
add_device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_mk_kn_mn_mn_mn_instances(
|
||||
op_ptrs);
|
||||
#endif
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
|
||||
is_same_v<D0Layout, Row> && is_same_v<D1Layout, Row> &&
|
||||
is_same_v<HLayout, Row>)
|
||||
{
|
||||
#if defined(CK_USE_XDL)
|
||||
add_device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_nk_mn_mn_mn_instances(
|
||||
op_ptrs);
|
||||
#endif
|
||||
#if defined(CK_USE_WMMA)
|
||||
add_device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_mk_nk_mn_mn_mn_instances(
|
||||
op_ptrs);
|
||||
#endif
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<D0Layout, Row> && is_same_v<D1Layout, Row> &&
|
||||
is_same_v<HLayout, Row>)
|
||||
{
|
||||
#if defined(CK_USE_XDL)
|
||||
add_device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_km_kn_mn_mn_mn_instances(
|
||||
op_ptrs);
|
||||
#endif
|
||||
#if defined(CK_USE_WMMA)
|
||||
add_device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_km_kn_mn_mn_mn_instances(
|
||||
op_ptrs);
|
||||
#endif
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
|
||||
is_same_v<D0Layout, Row> && is_same_v<D1Layout, Row> &&
|
||||
is_same_v<HLayout, Row>)
|
||||
{
|
||||
#if defined(CK_USE_XDL)
|
||||
add_device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_km_nk_mn_mn_mn_instances(
|
||||
op_ptrs);
|
||||
#endif
|
||||
#if defined(CK_USE_WMMA)
|
||||
add_device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_km_nk_mn_mn_mn_instances(
|
||||
op_ptrs);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,7 +1,12 @@
|
||||
# ONLY XDL_KERNELS
|
||||
# ONLY XDL_AND_WMMA_KERNELS
|
||||
add_instance_library(device_gemm_add_relu_add_layernorm_instance
|
||||
device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_km_kn_mn_mn_mn_instance.cpp
|
||||
device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_km_nk_mn_mn_mn_instance.cpp
|
||||
device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_kn_mn_mn_mn_instance.cpp
|
||||
device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_nk_mn_mn_mn_instance.cpp
|
||||
|
||||
device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_km_kn_mn_mn_mn_instance.cpp
|
||||
device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_km_nk_mn_mn_mn_instance.cpp
|
||||
device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_mk_kn_mn_mn_mn_instance.cpp
|
||||
device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_mk_nk_mn_mn_mn_instance.cpp
|
||||
)
|
||||
|
||||
@@ -0,0 +1,108 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
using F16_F16_Tuple = ck::Tuple<F16, F16>;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
using Row_Row_Tuple = ck::Tuple<Row, Row>;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using AddReluAdd = ck::tensor_operation::element_wise::AddReluAdd;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
|
||||
// e = elementwise((a * b), d0, d1)
|
||||
// h = layernorm(e, gamma, beta)
|
||||
// output: h[m, n]
|
||||
// input: a[k, m], b[k, n], d0[m, n], d1[m, n], gamma[n], beta[n]
|
||||
template <BlockGemmPipelineScheduler GemmLoopScheduler, BlockGemmPipelineVersion GemmPipeline>
|
||||
using device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_km_kn_mn_mn_mn_instances = std::tuple<
|
||||
// clang-format off
|
||||
//##########################################| A| B| Ds| H| AData| BData| DsData| HData| AccData| CShuffleData | EMeanVarData| GammaData| BetaData| A| B| CDE| H| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| Layernorm| Layernorm| LoopScheduler| Pipeline|
|
||||
//##########################################| Layout| Layout| Layout| Layout| Type| Type| Type| Type| Type| Type | Type| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| ThreadClusterLengths| ThreadSliceSize| | |
|
||||
//##########################################| | | | | | | | | | | | | | Operation| Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | _M_N| _M| | |
|
||||
//##########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | | |
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Col, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 256, 256, 128, 32, 2, 2, 16, 16, 8, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>,
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Col, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 256, 256, 128, 32, 8, 8, 16, 16, 8, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>,
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Col, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 256, 128, 256, 32, 2, 2, 16, 16, 4, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>,
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Col, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 256, 128, 256, 32, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>,
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Col, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 128, 128, 128, 32, 2, 2, 16, 16, 8, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8, S<16, 8>, 1, GemmLoopScheduler, GemmPipeline>,
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Col, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 128, 128, 128, 32, 8, 8, 16, 16, 8, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, S<16, 8>, 1, GemmLoopScheduler, GemmPipeline>,
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Col, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 256, 128, 128, 32, 2, 2, 16, 16, 4, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>,
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Col, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>,
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Col, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 128, 128, 64, 32, 2, 2, 16, 16, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 1, GemmLoopScheduler, GemmPipeline>,
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Col, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 128, 128, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 1, GemmLoopScheduler, GemmPipeline>,
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Col, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 128, 64, 128, 32, 2, 2, 16, 16, 4, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8, S<16, 8>, 1, GemmLoopScheduler, GemmPipeline>,
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Col, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 128, 64, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, S<16, 8>, 1, GemmLoopScheduler, GemmPipeline>,
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Col, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 256, 128, 64, 32, 2, 2, 16, 16, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>,
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Col, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 256, 128, 64, 32, 8, 8, 16, 16, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>,
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Col, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 256, 64, 128, 32, 2, 2, 16, 16, 2, 2, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>,
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Col, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 256, 64, 128, 32, 8, 8, 16, 16, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
// irregular tile size
|
||||
template <BlockGemmPipelineScheduler GemmLoopScheduler, BlockGemmPipelineVersion GemmPipeline>
|
||||
using device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_km_kn_mn_mn_mn_irregular_tile_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//##########################################| A| B| Ds| H| AData| BData| DsData| HData| AccData| CShuffleData | EMeanVarData| GammaData| BetaData| A| B| CDE| H| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| Layernorm| Layernorm| LoopScheduler| Pipeline|
|
||||
//##########################################| Layout| Layout| Layout| Layout| Type| Type| Type| Type| Type| Type | Type| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| ThreadClusterLengths| ThreadSliceSize| | |
|
||||
//##########################################| | | | | | | | | | | | | | Operation| Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | _M_N| _M| | |
|
||||
//##########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | | |
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Col, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmMNKPadding, 64, 32, 32, 32, 8, 8, 16, 16, 2, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, S<16, 4>, 1, GemmLoopScheduler, GemmPipeline>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_km_kn_mn_mn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDLayernorm<Col,
|
||||
Row,
|
||||
Row_Row_Tuple,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
F16_F16_Tuple,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddReluAdd,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_km_kn_mn_mn_mn_instances<
|
||||
BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockGemmPipelineVersion::v1>{});
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_km_kn_mn_mn_mn_irregular_tile_instances<
|
||||
BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockGemmPipelineVersion::v1>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,108 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
using F16_F16_Tuple = ck::Tuple<F16, F16>;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
using Row_Row_Tuple = ck::Tuple<Row, Row>;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using AddReluAdd = ck::tensor_operation::element_wise::AddReluAdd;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
|
||||
// e = elementwise((a * b), d0, d1)
|
||||
// h = layernorm(e, gamma, beta)
|
||||
// output: h[m, n]
|
||||
// input: a[k, m], b[k, n], d0[m, n], d1[m, n], gamma[n], beta[n]
|
||||
template <BlockGemmPipelineScheduler GemmLoopScheduler, BlockGemmPipelineVersion GemmPipeline>
|
||||
using device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_km_nk_mn_mn_mn_instances = std::tuple<
|
||||
// clang-format off
|
||||
//##########################################| A| B| Ds| H| AData| BData| DsData| HData| AccData| CShuffleData | EMeanVarData| GammaData| BetaData| A| B| CDE| H| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| Layernorm| Layernorm| LoopScheduler| Pipeline|
|
||||
//##########################################| Layout| Layout| Layout| Layout| Type| Type| Type| Type| Type| Type | Type| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| ThreadClusterLengths| ThreadSliceSize| | |
|
||||
//##########################################| | | | | | | | | | | | | | Operation| Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | _M_N| _M| | |
|
||||
//##########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | | |
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 256, 256, 128, 32, 2, 2, 16, 16, 8, 2, S< 4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>,
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 256, 256, 128, 32, 8, 8, 16, 16, 8, 2, S< 4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>,
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 256, 128, 256, 32, 2, 2, 16, 16, 4, 4, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>,
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 256, 128, 256, 32, 8, 8, 16, 16, 4, 4, S< 4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>,
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 128, 128, 128, 32, 2, 2, 16, 16, 8, 2, S< 4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 16, 1, 8>, 8, S<16, 8>, 1, GemmLoopScheduler, GemmPipeline>,
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 128, 128, 128, 32, 8, 8, 16, 16, 8, 2, S< 4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, S<16, 8>, 1, GemmLoopScheduler, GemmPipeline>,
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 256, 128, 128, 32, 2, 2, 16, 16, 4, 2, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>,
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S< 4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>,
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 128, 128, 64, 32, 2, 2, 16, 16, 4, 2, S< 4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 1, GemmLoopScheduler, GemmPipeline>,
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 128, 128, 64, 32, 8, 8, 16, 16, 4, 2, S< 4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 1, GemmLoopScheduler, GemmPipeline>,
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 128, 64, 128, 32, 2, 2, 16, 16, 4, 2, S< 8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 16, 1, 8>, 8, S<16, 8>, 1, GemmLoopScheduler, GemmPipeline>,
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 128, 64, 128, 32, 8, 8, 16, 16, 4, 2, S< 4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, S<16, 8>, 1, GemmLoopScheduler, GemmPipeline>,
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 256, 128, 64, 32, 2, 2, 16, 16, 4, 1, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>,
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 256, 128, 64, 32, 8, 8, 16, 16, 4, 1, S< 4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>,
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 256, 64, 128, 32, 2, 2, 16, 16, 2, 2, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>,
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 256, 64, 128, 32, 8, 8, 16, 16, 2, 2, S< 4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
// irregular tile size
|
||||
template <BlockGemmPipelineScheduler GemmLoopScheduler, BlockGemmPipelineVersion GemmPipeline>
|
||||
using device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_km_nk_mn_mn_mn_irregular_tile_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//##########################################| A| B| Ds| H| AData| BData| DsData| HData| AccData| CShuffleData | EMeanVarData| GammaData| BetaData| A| B| CDE| H| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| Layernorm| Layernorm| LoopScheduler| Pipeline|
|
||||
//##########################################| Layout| Layout| Layout| Layout| Type| Type| Type| Type| Type| Type | Type| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| ThreadClusterLengths| ThreadSliceSize| | |
|
||||
//##########################################| | | | | | | | | | | | | | Operation| Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | _M_N| _M| | |
|
||||
//##########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | | |
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmMNKPadding, 64, 32, 32, 32, 8, 8, 16, 16, 2, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, S<16, 4>, 1, GemmLoopScheduler, GemmPipeline>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_km_nk_mn_mn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDLayernorm<Col,
|
||||
Col,
|
||||
Row_Row_Tuple,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
F16_F16_Tuple,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddReluAdd,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_km_nk_mn_mn_mn_instances<
|
||||
BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockGemmPipelineVersion::v1>{});
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_km_nk_mn_mn_mn_irregular_tile_instances<
|
||||
BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockGemmPipelineVersion::v1>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,108 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
using F16_F16_Tuple = ck::Tuple<F16, F16>;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
using Row_Row_Tuple = ck::Tuple<Row, Row>;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using AddReluAdd = ck::tensor_operation::element_wise::AddReluAdd;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
|
||||
// e = elementwise((a * b), d0, d1)
|
||||
// h = layernorm(e, gamma, beta)
|
||||
// output: h[m, n]
|
||||
// input: a[k, m], b[k, n], d0[m, n], d1[m, n], gamma[n], beta[n]
|
||||
template <BlockGemmPipelineScheduler GemmLoopScheduler, BlockGemmPipelineVersion GemmPipeline>
|
||||
using device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_mk_kn_mn_mn_mn_instances = std::tuple<
|
||||
// clang-format off
|
||||
//##########################################| A| B| Ds| H| AData| BData| DsData| HData| AccData| CShuffleData | EMeanVarData| GammaData| BetaData| A| B| CDE| H| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| Layernorm| Layernorm| LoopScheduler| Pipeline|
|
||||
//##########################################| Layout| Layout| Layout| Layout| Type| Type| Type| Type| Type| Type | Type| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| ThreadClusterLengths| ThreadSliceSize| | |
|
||||
//##########################################| | | | | | | | | | | | | | Operation| Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | _M_N| _M| | |
|
||||
//##########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | | |
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 256, 256, 128, 32, 2, 2, 16, 16, 8, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>,
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 256, 256, 128, 32, 8, 8, 16, 16, 8, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>,
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 256, 128, 256, 32, 2, 2, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>,
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 256, 128, 256, 32, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>,
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 128, 128, 128, 32, 2, 2, 16, 16, 8, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8, S<16, 8>, 1, GemmLoopScheduler, GemmPipeline>,
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 128, 128, 128, 32, 8, 8, 16, 16, 8, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, S<16, 8>, 1, GemmLoopScheduler, GemmPipeline>,
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 256, 128, 128, 32, 2, 2, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>,
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>,
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 128, 128, 64, 32, 2, 2, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 1, GemmLoopScheduler, GemmPipeline>,
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 128, 128, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 1, GemmLoopScheduler, GemmPipeline>,
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 128, 64, 128, 32, 2, 2, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8, S<16, 8>, 1, GemmLoopScheduler, GemmPipeline>,
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 128, 64, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, S<16, 8>, 1, GemmLoopScheduler, GemmPipeline>,
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 256, 128, 64, 32, 2, 2, 16, 16, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>,
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 256, 128, 64, 32, 8, 8, 16, 16, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>,
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 256, 64, 128, 32, 2, 2, 16, 16, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>,
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 256, 64, 128, 32, 8, 8, 16, 16, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
// irregular tile size
|
||||
template <BlockGemmPipelineScheduler GemmLoopScheduler, BlockGemmPipelineVersion GemmPipeline>
|
||||
using device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_mk_kn_mn_mn_mn_irregular_tile_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//##########################################| A| B| Ds| H| AData| BData| DsData| HData| AccData| CShuffleData | EMeanVarData| GammaData| BetaData| A| B| CDE| H| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| Layernorm| Layernorm| LoopScheduler| Pipeline|
|
||||
//##########################################| Layout| Layout| Layout| Layout| Type| Type| Type| Type| Type| Type | Type| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| ThreadClusterLengths| ThreadSliceSize| | |
|
||||
//##########################################| | | | | | | | | | | | | | Operation| Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | _M_N| _M| | |
|
||||
//##########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | | |
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmMNKPadding, 64, 32, 32, 32, 8, 8, 16, 16, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, S<16, 4>, 1, GemmLoopScheduler, GemmPipeline>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_mk_kn_mn_mn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDLayernorm<Row,
|
||||
Row,
|
||||
Row_Row_Tuple,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
F16_F16_Tuple,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddReluAdd,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_mk_kn_mn_mn_mn_instances<
|
||||
BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockGemmPipelineVersion::v1>{});
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_mk_kn_mn_mn_mn_irregular_tile_instances<
|
||||
BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockGemmPipelineVersion::v1>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,105 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
using F16_F16_Tuple = ck::Tuple<F16, F16>;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
using Row_Row_Tuple = ck::Tuple<Row, Row>;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using AddReluAdd = ck::tensor_operation::element_wise::AddReluAdd;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
|
||||
// e = elementwise((a * b), d0, d1)
|
||||
// h = layernorm(e, gamma, beta)
|
||||
// output: h[m, n]
|
||||
// input: a[k, m], b[k, n], d0[m, n], d1[m, n], gamma[n], beta[n]
|
||||
template <BlockGemmPipelineScheduler GemmLoopScheduler, BlockGemmPipelineVersion GemmPipeline>
|
||||
using device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_mk_nk_mn_mn_mn_instances = std::tuple<
|
||||
// clang-format off
|
||||
//##########################################| A| B| Ds| H| AData| BData| DsData| HData| AccData| CShuffleData | EMeanVarData| GammaData| BetaData| A| B| CDE| H| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| Layernorm| Layernorm| LoopScheduler| Pipeline|
|
||||
//##########################################| Layout| Layout| Layout| Layout| Type| Type| Type| Type| Type| Type | Type| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| ThreadClusterLengths| ThreadSliceSize| | |
|
||||
//##########################################| | | | | | | | | | | | | | Operation| Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | _M_N| _M| | |
|
||||
//##########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | | |
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 256, 256, 128, 32, 8, 8, 16, 16, 8, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>,
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 256, 128, 256, 32, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>,
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 1, GemmLoopScheduler, GemmPipeline>,
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>,
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 128, 128, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 1, GemmLoopScheduler, GemmPipeline>,
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 128, 64, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, S<16, 8>, 1, GemmLoopScheduler, GemmPipeline>,
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 64, 64, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, S<16, 4>, 1, GemmLoopScheduler, GemmPipeline>,
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 256, 128, 64, 32, 8, 8, 16, 16, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>,
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 256, 64, 128, 32, 8, 8, 16, 16, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>,
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 128, 128, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 1, GemmLoopScheduler, GemmPipeline>,
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 128, 32, 128, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, S<16, 8>, 1, GemmLoopScheduler, GemmPipeline>,
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 64, 64, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, S<16, 4>, 1, GemmLoopScheduler, GemmPipeline>,
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 64, 32, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, S<16, 4>, 1, GemmLoopScheduler, GemmPipeline>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <BlockGemmPipelineScheduler GemmLoopScheduler, BlockGemmPipelineVersion GemmPipeline>
|
||||
using device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_mk_nk_mn_mn_mn_irregular_tile_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//##########################################| A| B| Ds| H| AData| BData| DsData| HData| AccData| CShuffleData | EMeanVarData| GammaData| BetaData| A| B| CDE| H| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| Layernorm| Layernorm| LoopScheduler| Pipeline|
|
||||
//##########################################| Layout| Layout| Layout| Layout| Type| Type| Type| Type| Type| Type | Type| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| ThreadClusterLengths| ThreadSliceSize| | |
|
||||
//##########################################| | | | | | | | | | | | | | Operation| Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | _M_N| _M| | |
|
||||
//##########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | | |
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmMNKPadding, 64, 32, 32, 32, 8, 8, 16, 16, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, S<16, 4>, 1, GemmLoopScheduler, GemmPipeline>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_mk_nk_mn_mn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDLayernorm<Row,
|
||||
Col,
|
||||
Row_Row_Tuple,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
F16_F16_Tuple,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddReluAdd,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_mk_nk_mn_mn_mn_instances<
|
||||
BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockGemmPipelineVersion::v1>{});
|
||||
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_mk_nk_mn_mn_mn_irregular_tile_instances<
|
||||
BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockGemmPipelineVersion::v1>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -167,6 +167,12 @@ bool profile_gemm_add_relu_add_layernorm_impl(int do_verification,
|
||||
Tensor<HDataType> h_m_n(f_host_tensor_descriptor2d(M, N, StrideH, HLayout{}));
|
||||
Tensor<HDataType> h_m_n_host(f_host_tensor_descriptor2d(M, N, StrideH, HLayout{}));
|
||||
|
||||
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
|
||||
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
|
||||
std::cout << "d0_m_n: " << d0_m_n.mDesc << std::endl;
|
||||
std::cout << "d1_m_n: " << d1_m_n.mDesc << std::endl;
|
||||
std::cout << "h_m_n: " << h_m_n.mDesc << std::endl;
|
||||
|
||||
switch(init_method)
|
||||
{
|
||||
case 0: break;
|
||||
@@ -312,9 +318,8 @@ bool profile_gemm_add_relu_add_layernorm_impl(int do_verification,
|
||||
|
||||
float gb_per_sec = num_byte / 1.E6 / ave_time;
|
||||
|
||||
if(time_kernel)
|
||||
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << gb_per_sec
|
||||
<< " GB/s, " << op_name << std::endl;
|
||||
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << gb_per_sec << " GB/s, "
|
||||
<< op_name << std::endl;
|
||||
|
||||
if(ave_time < best_ave_time)
|
||||
{
|
||||
@@ -333,8 +338,7 @@ bool profile_gemm_add_relu_add_layernorm_impl(int do_verification,
|
||||
}
|
||||
else
|
||||
{
|
||||
if(time_kernel)
|
||||
std::cout << op_name << " does not support this problem" << std::endl;
|
||||
std::cout << op_name << " does not support this problem" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -4,7 +4,14 @@ if(CK_USE_OCP_FP8)
|
||||
endif()
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx94|gfx95")
|
||||
add_gtest_executable(test_ck_tile_grouped_gemm_quant test_grouped_gemm_quant.cpp)
|
||||
target_compile_options(test_ck_tile_grouped_gemm_quant PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
# Split into three separate test executables for faster parallel compilation
|
||||
add_gtest_executable(test_ck_tile_grouped_gemm_quant_rowcol test_grouped_gemm_quant_rowcol.cpp)
|
||||
target_compile_options(test_ck_tile_grouped_gemm_quant_rowcol PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
add_gtest_executable(test_ck_tile_grouped_gemm_quant_tensor test_grouped_gemm_quant_tensor.cpp)
|
||||
target_compile_options(test_ck_tile_grouped_gemm_quant_tensor PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
add_gtest_executable(test_ck_tile_grouped_gemm_quant_bquant test_grouped_gemm_quant_bquant.cpp)
|
||||
target_compile_options(test_ck_tile_grouped_gemm_quant_bquant PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
endif()
|
||||
|
||||
|
||||
@@ -22,26 +22,28 @@ using BQuant = std::integral_constant<ck_tile::QuantType, ck_tile::QuantTyp
|
||||
|
||||
// clang-format off
|
||||
using KernelTypes = ::testing::Types<
|
||||
// ALayout, BLayout, CLayout, ADataType, AQDataType, BDataType, BQDataType, AccDataType, CDataType, QuantType
|
||||
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant>,
|
||||
std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant>,
|
||||
std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant>,
|
||||
std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant>,
|
||||
// ALayout, BLayout, CLayout, ADataType, AQDataType, BDataType, BQDataType, AccDataType, CDataType, QuantType, PreshuffleB
|
||||
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False>,
|
||||
std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False>,
|
||||
std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False>,
|
||||
std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False>,
|
||||
|
||||
std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, RowColQuant>,
|
||||
std::tuple< Col, Col, Row, BF8, F32, BF8, F32, F32, F16, RowColQuant>,
|
||||
std::tuple< Row, Row, Row, BF8, F32, BF8, F32, F32, F16, RowColQuant>,
|
||||
std::tuple< Col, Row, Row, BF8, F32, BF8, F32, F32, F16, RowColQuant>,
|
||||
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant>,
|
||||
std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant>,
|
||||
std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant>,
|
||||
std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant>,
|
||||
std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, TensorQuant>,
|
||||
std::tuple< Col, Col, Row, BF8, F32, BF8, F32, F32, F16, TensorQuant>,
|
||||
std::tuple< Row, Row, Row, BF8, F32, BF8, F32, F32, F16, TensorQuant>,
|
||||
std::tuple< Col, Row, Row, BF8, F32, BF8, F32, F32, F16, TensorQuant>,
|
||||
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, BQuant>,
|
||||
std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, BQuant>
|
||||
std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, RowColQuant, False>,
|
||||
std::tuple< Col, Col, Row, BF8, F32, BF8, F32, F32, F16, RowColQuant, False>,
|
||||
std::tuple< Row, Row, Row, BF8, F32, BF8, F32, F32, F16, RowColQuant, False>,
|
||||
std::tuple< Col, Row, Row, BF8, F32, BF8, F32, F32, F16, RowColQuant, False>,
|
||||
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False>,
|
||||
std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False>,
|
||||
std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False>,
|
||||
std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False>,
|
||||
std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, TensorQuant, False>,
|
||||
std::tuple< Col, Col, Row, BF8, F32, BF8, F32, F32, F16, TensorQuant, False>,
|
||||
std::tuple< Row, Row, Row, BF8, F32, BF8, F32, F32, F16, TensorQuant, False>,
|
||||
std::tuple< Col, Row, Row, BF8, F32, BF8, F32, F32, F16, TensorQuant, False>,
|
||||
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, BQuant, False>,
|
||||
std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, BQuant, False>,
|
||||
std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, BQuant, True>,
|
||||
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, BQuant, True>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
|
||||
@@ -0,0 +1,33 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <tuple>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "test_grouped_gemm_util_quant.hpp"
|
||||
|
||||
using F16 = ck_tile::half_t;
|
||||
using F32 = float;
|
||||
using FP8 = ck_tile::fp8_t;
|
||||
using BF8 = ck_tile::bf8_t;
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using True = ck_tile::bool_constant<true>;
|
||||
using False = ck_tile::bool_constant<false>;
|
||||
using BQuant = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::BQuantGrouped>;
|
||||
|
||||
// clang-format off
|
||||
using KernelTypes_BQuant = ::testing::Types<
|
||||
// ALayout, BLayout, CLayout, ADataType, AQDataType, BDataType, BQDataType, AccDataType, CDataType, QuantType, PreshuffleB
|
||||
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, BQuant, False>,
|
||||
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, BQuant, True>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileGroupedGemmQuant_BQuant, KernelTypes_BQuant);
|
||||
|
||||
#define TEST_CLASS_NAME TestCkTileGroupedGemmQuant_BQuant
|
||||
#include "test_grouped_gemm_quant_ut_cases.inc"
|
||||
#undef TEST_CLASS_NAME
|
||||
@@ -0,0 +1,35 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <tuple>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "test_grouped_gemm_util_quant.hpp"
|
||||
|
||||
using F16 = ck_tile::half_t;
|
||||
using F32 = float;
|
||||
using FP8 = ck_tile::fp8_t;
|
||||
using BF8 = ck_tile::bf8_t;
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using True = ck_tile::bool_constant<true>;
|
||||
using False = ck_tile::bool_constant<false>;
|
||||
using RowColQuant = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::RowColQuant>;
|
||||
|
||||
// clang-format off
|
||||
using KernelTypes_RowCol = ::testing::Types<
|
||||
// ALayout, BLayout, CLayout, ADataType, AQDataType, BDataType, BQDataType, AccDataType, CDataType, QuantType, PreshuffleB
|
||||
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False>,
|
||||
std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False>,
|
||||
std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False>,
|
||||
std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileGroupedGemmQuant_RowCol, KernelTypes_RowCol);
|
||||
|
||||
#define TEST_CLASS_NAME TestCkTileGroupedGemmQuant_RowCol
|
||||
#include "test_grouped_gemm_quant_ut_cases.inc"
|
||||
#undef TEST_CLASS_NAME
|
||||
@@ -0,0 +1,35 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <tuple>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "test_grouped_gemm_util_quant.hpp"
|
||||
|
||||
using F16 = ck_tile::half_t;
|
||||
using F32 = float;
|
||||
using FP8 = ck_tile::fp8_t;
|
||||
using BF8 = ck_tile::bf8_t;
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using True = ck_tile::bool_constant<true>;
|
||||
using False = ck_tile::bool_constant<false>;
|
||||
using TensorQuant = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::TensorQuant>;
|
||||
|
||||
// clang-format off
|
||||
using KernelTypes_Tensor = ::testing::Types<
|
||||
// ALayout, BLayout, CLayout, ADataType, AQDataType, BDataType, BQDataType, AccDataType, CDataType, QuantType, PreshuffleB
|
||||
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False>,
|
||||
std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False>,
|
||||
std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False>,
|
||||
std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileGroupedGemmQuant_Tensor, KernelTypes_Tensor);
|
||||
|
||||
#define TEST_CLASS_NAME TestCkTileGroupedGemmQuant_Tensor
|
||||
#include "test_grouped_gemm_quant_ut_cases.inc"
|
||||
#undef TEST_CLASS_NAME
|
||||
@@ -1,6 +1,6 @@
|
||||
#pragma once
|
||||
|
||||
TYPED_TEST(TestCkTileGroupedGemmQuant, Basic)
|
||||
TYPED_TEST(TEST_CLASS_NAME, Basic)
|
||||
{
|
||||
const int group_count = 8;
|
||||
std::vector<int> Ms;
|
||||
@@ -29,7 +29,7 @@ TYPED_TEST(TestCkTileGroupedGemmQuant, Basic)
|
||||
|
||||
// No Hot Loop Test Case, this is to test the correctness of the kernel when there is no hot loop
|
||||
// Using 256x256x128 to match the test kernel's tile size (M_Tile=256, N_Tile=256, K_Tile=128)
|
||||
TYPED_TEST(TestCkTileGroupedGemmQuant, SmallUniform) //
|
||||
TYPED_TEST(TEST_CLASS_NAME, SmallUniform) //
|
||||
{
|
||||
const int group_count = 2;
|
||||
std::vector<int> Ms;
|
||||
@@ -55,3 +55,29 @@ TYPED_TEST(TestCkTileGroupedGemmQuant, SmallUniform) //
|
||||
|
||||
this->Run(Ms, Ns, Ks, stride_As, stride_Bs, stride_Cs, stride_AQs, stride_BQs, group_count);
|
||||
}
|
||||
TYPED_TEST(TEST_CLASS_NAME, OddTail) //
|
||||
{
|
||||
const int group_count = 2;
|
||||
std::vector<int> Ms;
|
||||
std::vector<int> Ns;
|
||||
std::vector<int> Ks;
|
||||
std::vector<int> stride_As;
|
||||
std::vector<int> stride_Bs;
|
||||
std::vector<int> stride_Cs;
|
||||
std::vector<int> stride_AQs;
|
||||
std::vector<int> stride_BQs;
|
||||
for(int i = 0; i < group_count; i++)
|
||||
{
|
||||
Ms.push_back(256);
|
||||
Ns.push_back(256);
|
||||
Ks.push_back(128);
|
||||
|
||||
stride_As.push_back(0);
|
||||
stride_Bs.push_back(0);
|
||||
stride_Cs.push_back(0);
|
||||
stride_AQs.push_back(0);
|
||||
stride_BQs.push_back(0);
|
||||
}
|
||||
|
||||
this->Run(Ms, Ns, Ks, stride_As, stride_Bs, stride_Cs, stride_AQs, stride_BQs, group_count);
|
||||
}
|
||||
|
||||
@@ -17,23 +17,40 @@ template <typename Tuple>
|
||||
class TestCkTileGroupedGemmQuant : public ::testing::Test
|
||||
{
|
||||
protected:
|
||||
using ALayout = std::tuple_element_t<0, Tuple>;
|
||||
using BLayout = std::tuple_element_t<1, Tuple>;
|
||||
using CLayout = std::tuple_element_t<2, Tuple>;
|
||||
using ADataType = std::tuple_element_t<3, Tuple>;
|
||||
using AQDataType = std::tuple_element_t<4, Tuple>;
|
||||
using BDataType = std::tuple_element_t<5, Tuple>;
|
||||
using BQDataType = std::tuple_element_t<6, Tuple>;
|
||||
using AccDataType = std::tuple_element_t<7, Tuple>;
|
||||
using CDataType = std::tuple_element_t<8, Tuple>;
|
||||
static constexpr auto QuantType = std::tuple_element_t<9, Tuple>::value;
|
||||
using DsLayout = ck_tile::tuple<>;
|
||||
using DsDataType = ck_tile::tuple<>;
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using AQLayout = Row;
|
||||
using BQLayout = Col;
|
||||
static constexpr bool Persistent = true;
|
||||
using ALayout = std::tuple_element_t<0, Tuple>;
|
||||
using BLayout = std::tuple_element_t<1, Tuple>;
|
||||
using CLayout = std::tuple_element_t<2, Tuple>;
|
||||
using ADataType = std::tuple_element_t<3, Tuple>;
|
||||
using AQDataType = std::tuple_element_t<4, Tuple>;
|
||||
using BDataType = std::tuple_element_t<5, Tuple>;
|
||||
using BQDataType = std::tuple_element_t<6, Tuple>;
|
||||
using AccDataType = std::tuple_element_t<7, Tuple>;
|
||||
using CDataType = std::tuple_element_t<8, Tuple>;
|
||||
static constexpr auto QuantType = std::tuple_element_t<9, Tuple>::value;
|
||||
using DsLayout = ck_tile::tuple<>;
|
||||
using DsDataType = ck_tile::tuple<>;
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using AQLayout = Row;
|
||||
using BQLayout = Col;
|
||||
static constexpr bool Persistent = true;
|
||||
static constexpr bool PreshuffleB = std::tuple_element_t<10, Tuple>::value;
|
||||
|
||||
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
|
||||
static constexpr ck_tile::index_t get_k_from_preshuffled_warp_tile()
|
||||
{
|
||||
#if defined(CK_GFX950_SUPPORT)
|
||||
if constexpr(M_Warp_Tile == 32)
|
||||
return sizeof(PrecType) == 2 ? 16 : 64;
|
||||
else
|
||||
return sizeof(PrecType) == 2 ? 32 : 128;
|
||||
#else
|
||||
if constexpr(M_Warp_Tile == 32)
|
||||
return sizeof(PrecType) == 2 ? 16 : 32;
|
||||
else
|
||||
return sizeof(PrecType) == 2 ? 32 : 64;
|
||||
#endif
|
||||
}
|
||||
|
||||
struct GroupedGemKernelParam_Mfma
|
||||
{
|
||||
@@ -52,7 +69,9 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test
|
||||
|
||||
static const ck_tile::index_t M_Warp_Tile = 32;
|
||||
static const ck_tile::index_t N_Warp_Tile = 32;
|
||||
static const ck_tile::index_t K_Warp_Tile = 16;
|
||||
static const ck_tile::index_t K_Warp_Tile =
|
||||
TestCkTileGroupedGemmQuant::template get_k_from_preshuffled_warp_tile<BDataType,
|
||||
M_Warp_Tile>();
|
||||
};
|
||||
|
||||
using grouped_gemm_kargs = ck_tile::QuantGroupedGemmHostArgs;
|
||||
@@ -66,8 +85,9 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test
|
||||
const ck_tile::index_t num_groups,
|
||||
void* kargs_ptr)
|
||||
{
|
||||
constexpr bool TransposeC = false;
|
||||
constexpr bool DoubleSmemBuffer = false;
|
||||
constexpr bool TransposeC = false;
|
||||
constexpr bool DoubleSmemBuffer =
|
||||
PreshuffleB; // currently DoubleSmemBuffer is only supported for preshuffled B
|
||||
|
||||
constexpr int kBlockPerCu = 1;
|
||||
constexpr ck_tile::index_t TileParitionerGroupNum = 8;
|
||||
@@ -90,7 +110,7 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test
|
||||
GroupedGemKernelParam::kPadN,
|
||||
GroupedGemKernelParam::kPadK,
|
||||
false,
|
||||
false,
|
||||
PreshuffleB,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
@@ -126,11 +146,13 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test
|
||||
BDataType,
|
||||
scheduler>>::type;
|
||||
|
||||
using GemmPipeline = typename std::conditional<
|
||||
QuantType == ck_tile::QuantType::BQuantGrouped,
|
||||
ck_tile::BQuantGemmPipelineAgBgCrCompV3<QuantGemmProblem>,
|
||||
ck_tile::GemmPipelineAgBgCrCompV3<QuantGemmProblem>>::type;
|
||||
|
||||
using GemmPipeline = std::conditional_t<
|
||||
QuantType == ck_tile::QuantType::RowColQuant ||
|
||||
QuantType == ck_tile::QuantType::TensorQuant,
|
||||
ck_tile::GemmPipelineAgBgCrCompV3<QuantGemmProblem>,
|
||||
std::conditional_t<PreshuffleB == true,
|
||||
ck_tile::WPQuantBPipelineAgBgCrV2<QuantGemmProblem>,
|
||||
ck_tile::BQuantGemmPipelineAgBgCrCompV3<QuantGemmProblem>>>;
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
@@ -344,7 +366,18 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test
|
||||
bq_tensors[i].get_element_space_size_in_bytes()));
|
||||
|
||||
a_m_k_dev_buf[i]->ToDevice(a_m_k_tensors[i].data());
|
||||
b_k_n_dev_buf[i]->ToDevice(b_k_n_tensors[i].data());
|
||||
|
||||
if constexpr(PreshuffleB && QuantType == ck_tile::QuantType::BQuantGrouped)
|
||||
{
|
||||
auto b_shuffle_host =
|
||||
ck_tile::shuffle_b<GroupedGemKernelParam_Mfma>(b_k_n_tensors[i]);
|
||||
b_k_n_dev_buf[i]->ToDevice(b_shuffle_host.data());
|
||||
}
|
||||
else
|
||||
{
|
||||
b_k_n_dev_buf[i]->ToDevice(b_k_n_tensors[i].data());
|
||||
}
|
||||
|
||||
aq_dev_buf[i]->ToDevice(aq_tensors[i].data());
|
||||
bq_dev_buf[i]->ToDevice(bq_tensors[i].data());
|
||||
c_m_n_dev_buf[i]->SetZero();
|
||||
@@ -485,3 +518,13 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
};
|
||||
|
||||
// Aliases for split test files
|
||||
template <typename Tuple>
|
||||
using TestCkTileGroupedGemmQuant_RowCol = TestCkTileGroupedGemmQuant<Tuple>;
|
||||
|
||||
template <typename Tuple>
|
||||
using TestCkTileGroupedGemmQuant_Tensor = TestCkTileGroupedGemmQuant<Tuple>;
|
||||
|
||||
template <typename Tuple>
|
||||
using TestCkTileGroupedGemmQuant_BQuant = TestCkTileGroupedGemmQuant<Tuple>;
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
add_gtest_executable(test_gemm_add_relu_add_layernorm_fp16 test_gemm_add_relu_add_layernorm_fp16_xdl.cpp)
|
||||
if(result EQUAL 0)
|
||||
add_custom_target(test_gemm_layernorm)
|
||||
target_link_libraries(test_gemm_add_relu_add_layernorm_fp16 PRIVATE utility device_gemm_add_relu_add_layernorm_instance)
|
||||
add_dependencies(test_gemm_layernorm test_gemm_add_relu_add_layernorm_fp16)
|
||||
if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12")
|
||||
add_gtest_executable(test_gemm_add_relu_add_layernorm_fp16 test_gemm_add_relu_add_layernorm_fp16.cpp)
|
||||
if(result EQUAL 0)
|
||||
add_custom_target(test_gemm_layernorm)
|
||||
target_link_libraries(test_gemm_add_relu_add_layernorm_fp16 PRIVATE utility device_gemm_add_relu_add_layernorm_instance)
|
||||
add_dependencies(test_gemm_layernorm test_gemm_add_relu_add_layernorm_fp16)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
@@ -79,11 +79,6 @@ TYPED_TEST_SUITE(TestGemmAddReluAddLayernorm, KernelTypes);
|
||||
TYPED_TEST(TestGemmAddReluAddLayernorm, Test_FP16) { this->Run(); }
|
||||
int main(int argc, char** argv)
|
||||
{
|
||||
if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
|
||||
{
|
||||
std::cout << "No available instance for gfx11 & gfx12." << std::endl;
|
||||
return 0;
|
||||
}
|
||||
testing::InitGoogleTest(&argc, argv);
|
||||
return RUN_ALL_TESTS();
|
||||
}
|
||||
@@ -125,38 +125,13 @@ WARP_TILE_SUPPORTED_COMBINATIONS = {
|
||||
[32, 32, 64],
|
||||
],
|
||||
},
|
||||
"gfx1201": {
|
||||
"gfx1201": { # Check how to handle for GEMM and Multi D
|
||||
"fp16_fp16_fp16": [
|
||||
[16, 16, 16],
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
# Supported warp tile combinations for different GPU architectures and data types
|
||||
WARP_SUPPORTED_COMBINATIONS = {
|
||||
"gfx90a": [
|
||||
[1, 4, 1],
|
||||
[2, 2, 1],
|
||||
[4, 1, 1],
|
||||
],
|
||||
"gfx942": [
|
||||
[1, 4, 1],
|
||||
[2, 2, 1],
|
||||
[4, 1, 1],
|
||||
],
|
||||
"gfx950": [
|
||||
[1, 4, 1],
|
||||
[2, 2, 1],
|
||||
[4, 1, 1],
|
||||
],
|
||||
"gfx1201": [
|
||||
[2, 4, 1],
|
||||
[1, 8, 1],
|
||||
[8, 1, 1],
|
||||
[4, 2, 1],
|
||||
],
|
||||
}
|
||||
|
||||
# Unsupported trait combinations
|
||||
TRAIT_UNSUPPORTED_COMBINATIONS = {
|
||||
("compv3", "cshuffle", "interwave"),
|
||||
@@ -441,6 +416,20 @@ def get_abc_layouts(layout_code: str) -> Tuple[str, str, str]:
|
||||
return a_layout, b_layout, c_layout
|
||||
|
||||
|
||||
def get_abcd_layouts(layout_code: str) -> Tuple[str, str, str, List[str]]:
|
||||
"""
|
||||
Return (ALayout, BLayout, CLayout) from a 3-letter code like 'rcrr', 'ccrr', 'crrr', 'rrrr'.
|
||||
"""
|
||||
code = str(layout_code).strip().lower()
|
||||
|
||||
a_layout = LAYOUT_MAP[code[0]]
|
||||
b_layout = LAYOUT_MAP[code[1]]
|
||||
c_layout = LAYOUT_MAP[code[2]]
|
||||
d0_layout = LAYOUT_MAP[code[3]]
|
||||
d1_layout = LAYOUT_MAP[code[3]]
|
||||
return a_layout, b_layout, c_layout, [d0_layout, d1_layout]
|
||||
|
||||
|
||||
def validate_whole_wg_cover_configuration(
|
||||
tile_m,
|
||||
tile_n,
|
||||
@@ -464,13 +453,13 @@ def validate_whole_wg_cover_configuration(
|
||||
|
||||
# A matrix validation
|
||||
if layout[0] == "r":
|
||||
XPerTile = tile_k
|
||||
YPerTile = tile_m
|
||||
|
||||
vector_load_size = get_global_vector_load_size(
|
||||
BlockSize, tile_k, a_datatype, tile_m, tile_k
|
||||
)
|
||||
|
||||
XPerTile = tile_k
|
||||
YPerTile = tile_m
|
||||
|
||||
elif layout[0] == "c":
|
||||
vector_load_size = get_global_vector_load_size(
|
||||
BlockSize, tile_k, a_datatype, tile_m, tile_m
|
||||
@@ -485,7 +474,6 @@ def validate_whole_wg_cover_configuration(
|
||||
)
|
||||
|
||||
if not wg_cover_core_valid:
|
||||
print("I am here 1")
|
||||
logging.debug(
|
||||
f"whole workgroup cover failed for Matrix A distribution: {wg_cover_core_error}"
|
||||
)
|
||||
@@ -521,7 +509,7 @@ def validate_whole_wg_cover_configuration(
|
||||
if not wg_cover_core_valid:
|
||||
print("I am here 3")
|
||||
logging.debug(
|
||||
f"whole workgroup cover failed for Matrix A distribution: {wg_cover_core_error}"
|
||||
f"whole workgroup cover failed for Matrix B distribution: {wg_cover_core_error}"
|
||||
)
|
||||
return False, wg_cover_core_error
|
||||
|
||||
@@ -540,7 +528,6 @@ def validate_whole_wg_cover_configuration(
|
||||
XPerTile, YPerTile, BlockSize, vector_load_size, warp_size
|
||||
)
|
||||
if not wg_cover_core_valid:
|
||||
print("I am here 4")
|
||||
logging.debug(
|
||||
f"whole workgroup cover failed for Matrix B: {wg_cover_core_error}"
|
||||
)
|
||||
@@ -557,7 +544,7 @@ def wg_cover_core_validation(
|
||||
warp_size: int,
|
||||
) -> Tuple[bool, str]:
|
||||
if XPerTile % vector_load_size != 0:
|
||||
return False
|
||||
return False, "XPerTile is not divisible by vector_load_size"
|
||||
|
||||
num_warps = BlockSize / warp_size
|
||||
LargestVec = (XPerTile * YPerTile) / (num_warps * warp_size)
|
||||
@@ -567,7 +554,7 @@ def wg_cover_core_validation(
|
||||
Y1 = warp_size // X0
|
||||
|
||||
if X0 * Y1 != warp_size:
|
||||
return False, ""
|
||||
return False, "X0 * Y1 != warp_size"
|
||||
|
||||
return True, ""
|
||||
|
||||
@@ -583,9 +570,9 @@ def get_global_vector_load_size(
|
||||
PackedSize = 1
|
||||
|
||||
if (
|
||||
XPerTile % (PackedSize * 32 / element_size(DataType)) == 0
|
||||
PackedSize == 2
|
||||
and XPerTile % (PackedSize * 32 / element_size(DataType)) == 0
|
||||
and elements_per_thread % (PackedSize * 32 / element_size(DataType)) == 0
|
||||
and PackedSize == 2
|
||||
):
|
||||
return PackedSize * 32 / element_size(DataType)
|
||||
elif (
|
||||
@@ -122,15 +122,15 @@ function(build_individual_gemm_targets datatype layout)
|
||||
if(DEFINED ENV{GEMM_CONFIG_FILE} AND NOT "$ENV{GEMM_CONFIG_FILE}" STREQUAL "")
|
||||
set(config_filename "$ENV{GEMM_CONFIG_FILE}")
|
||||
set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/${config_filename}")
|
||||
message(STATUS " Using config from environment variable: ${config_filename}")
|
||||
message(VERBOSE " Using config from environment variable: ${config_filename}")
|
||||
elseif(NOT "${GEMM_CONFIG_FILE}" STREQUAL "")
|
||||
# Use CMake variable if set
|
||||
set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/${GEMM_CONFIG_FILE}")
|
||||
message(STATUS " Using custom config: ${GEMM_CONFIG_FILE}")
|
||||
message(VERBOSE " Using custom config: ${GEMM_CONFIG_FILE}")
|
||||
else()
|
||||
# Use default config for all layouts
|
||||
set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/default_config.json")
|
||||
message(STATUS " Using default config for layout ${layout}")
|
||||
message(VERBOSE " Using default config for layout ${layout}")
|
||||
endif()
|
||||
|
||||
# Check if config file exists
|
||||
@@ -151,16 +151,16 @@ function(build_individual_gemm_targets datatype layout)
|
||||
endif()
|
||||
|
||||
# Generate individual kernel files using parallel version
|
||||
message(STATUS "Generating individual kernels for ${datatype} ${layout} using ${num_workers} workers...")
|
||||
message(STATUS " Working path: ${working_path}")
|
||||
message(STATUS " Config file: ${json_blob}")
|
||||
message(STATUS " Python executable: ${Python3_EXECUTABLE}")
|
||||
message(STATUS " Script path: ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py")
|
||||
message(VERBOSE "Generating individual kernels for ${datatype} ${layout} using ${num_workers} workers...")
|
||||
message(VERBOSE " Working path: ${working_path}")
|
||||
message(VERBOSE " Config file: ${json_blob}")
|
||||
message(VERBOSE " Python executable: ${Python3_EXECUTABLE}")
|
||||
message(VERBOSE " Script path: ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py")
|
||||
|
||||
# Create working directory first
|
||||
file(MAKE_DIRECTORY ${working_path})
|
||||
|
||||
message(STATUS "COMMAND: ${Python3_EXECUTABLE} -u ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py
|
||||
message(VERBOSE "COMMAND: ${Python3_EXECUTABLE} -u ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py
|
||||
--working_path ${working_path}
|
||||
--datatype ${datatype}
|
||||
--layout ${layout}
|
||||
@@ -169,7 +169,7 @@ function(build_individual_gemm_targets datatype layout)
|
||||
--list_kernels ")
|
||||
|
||||
# First, just list the kernels (fast operation)
|
||||
message(STATUS " Listing kernel configurations...")
|
||||
message(VERBOSE " Listing kernel configurations...")
|
||||
execute_process(
|
||||
COMMAND ${Python3_EXECUTABLE} -u ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py
|
||||
--working_path ${working_path}
|
||||
@@ -192,7 +192,7 @@ function(build_individual_gemm_targets datatype layout)
|
||||
if(EXISTS ${working_path}/gemm_kernel_count.txt)
|
||||
file(READ ${working_path}/gemm_kernel_count.txt kernel_count)
|
||||
string(STRIP "${kernel_count}" kernel_count)
|
||||
message(STATUS " Found ${kernel_count} kernel configurations")
|
||||
message(VERBOSE " Found ${kernel_count} kernel configurations")
|
||||
else()
|
||||
message(FATAL_ERROR "Kernel count file not found")
|
||||
endif()
|
||||
@@ -216,10 +216,10 @@ function(build_individual_gemm_targets datatype layout)
|
||||
endfunction()
|
||||
|
||||
# Main build logic - Only individual builds supported
|
||||
message(STATUS "=== Starting Tile Engine GEMM Configuration ===")
|
||||
message(STATUS "GEMM_DATATYPE: ${GEMM_DATATYPE}")
|
||||
message(STATUS "GEMM_LAYOUT: ${GEMM_LAYOUT}")
|
||||
message(STATUS "SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
|
||||
message(VERBOSE "=== Starting Tile Engine GEMM Configuration ===")
|
||||
message(VERBOSE "GEMM_DATATYPE: ${GEMM_DATATYPE}")
|
||||
message(VERBOSE "GEMM_LAYOUT: ${GEMM_LAYOUT}")
|
||||
message(VERBOSE "SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
|
||||
|
||||
# Filter GPU targets to only gfx90a, gfx942, gfx950, gfx1201
|
||||
set(GEMM_GPU_TARGETS_INDIVIDUAL "")
|
||||
@@ -228,7 +228,7 @@ set(DESIRED_TARGETS "gfx90a;gfx942;gfx950;gfx1201")
|
||||
foreach(target IN LISTS SUPPORTED_GPU_TARGETS)
|
||||
if(target IN_LIST DESIRED_TARGETS)
|
||||
list(APPEND GEMM_GPU_TARGETS_INDIVIDUAL ${target})
|
||||
message(STATUS " Adding GPU target: ${target}")
|
||||
message(VERBOSE " Adding GPU target: ${target}")
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
@@ -236,7 +236,7 @@ endforeach()
|
||||
if(NOT GEMM_GPU_TARGETS_INDIVIDUAL)
|
||||
message(WARNING "Skipping Tile Engine GEMM build: No supported GPU targets (gfx90a, gfx942, gfx950, gfx1201) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
|
||||
else()
|
||||
message(STATUS "Building individual GEMM targets for GPU targets: ${GEMM_GPU_TARGETS_INDIVIDUAL}")
|
||||
message(VERBOSE "Building individual GEMM targets for GPU targets: ${GEMM_GPU_TARGETS_INDIVIDUAL}")
|
||||
|
||||
# Enable parallel compilation optimizations
|
||||
# Set up job pools for better parallel compilation control
|
||||
@@ -251,12 +251,12 @@ else()
|
||||
find_program(CCACHE_PROGRAM ccache)
|
||||
if(CCACHE_PROGRAM)
|
||||
set(CMAKE_CXX_COMPILER_LAUNCHER ${CCACHE_PROGRAM})
|
||||
message(STATUS "Using ccache for faster compilation")
|
||||
message(VERBOSE "Using ccache for faster compilation")
|
||||
else()
|
||||
message(WARNING "ccache requested but not found")
|
||||
endif()
|
||||
else()
|
||||
message(STATUS "ccache disabled for GEMM ops (use -DENABLE_CCACHE_GEMM=ON to enable)")
|
||||
message(VERBOSE "ccache disabled for GEMM ops (use -DENABLE_CCACHE_GEMM=ON to enable)")
|
||||
endif()
|
||||
|
||||
# Create master collection targets
|
||||
|
||||
@@ -8,12 +8,30 @@ import multiprocessing
|
||||
import concurrent.futures
|
||||
from pathlib import Path
|
||||
import logging
|
||||
from commons.validation_utils import (
|
||||
is_tile_config_valid,
|
||||
is_trait_combination_valid,
|
||||
get_dtype_string,
|
||||
get_abc_layouts,
|
||||
)
|
||||
import importlib.util
|
||||
|
||||
|
||||
def _import_validation_utils():
|
||||
"""Import validation utilities from commons directory."""
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
parent_dir = os.path.dirname(current_dir)
|
||||
|
||||
# Load the module dynamically
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
"validation_utils", os.path.join(parent_dir, "commons", "validation_utils.py")
|
||||
)
|
||||
validation_utils = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(validation_utils)
|
||||
|
||||
return validation_utils
|
||||
|
||||
|
||||
# Import validation functions
|
||||
_validation_utils = _import_validation_utils()
|
||||
is_tile_config_valid = _validation_utils.is_tile_config_valid
|
||||
is_trait_combination_valid = _validation_utils.is_trait_combination_valid
|
||||
get_dtype_string = _validation_utils.get_dtype_string
|
||||
get_abc_layouts = _validation_utils.get_abc_layouts
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
@@ -563,6 +581,8 @@ struct SelectedKernel {{
|
||||
tile_configs = self._get_tile_configs()
|
||||
trait_combos = self._generate_trait_combinations()
|
||||
k_block_per_cu = self.config.get("k_block_per_cu")
|
||||
if k_block_per_cu is None:
|
||||
k_block_per_cu = 1
|
||||
|
||||
# Prepare work items for parallel processing
|
||||
work_items = []
|
||||
@@ -574,11 +594,12 @@ struct SelectedKernel {{
|
||||
trait_combo,
|
||||
k_block_per_cu,
|
||||
self.working_path,
|
||||
self.gpu_target,
|
||||
self.datatype,
|
||||
self.layout,
|
||||
self.config_json,
|
||||
)
|
||||
)
|
||||
|
||||
print(
|
||||
f"Generating {len(work_items)} individual kernel files using {num_workers} workers..."
|
||||
)
|
||||
@@ -615,7 +636,6 @@ struct SelectedKernel {{
|
||||
print(
|
||||
f" Progress: {completed}/{len(work_items)} kernels generated"
|
||||
)
|
||||
|
||||
try:
|
||||
result = future.result()
|
||||
if result:
|
||||
@@ -662,10 +682,19 @@ struct SelectedKernel {{
|
||||
|
||||
def _generate_single_kernel_individual(work_item):
|
||||
"""Worker function to generate a single individual kernel file"""
|
||||
tile_config, trait_combo, k_block_per_cu, working_path, datatype, layout = work_item
|
||||
(
|
||||
tile_config,
|
||||
trait_combo,
|
||||
k_block_per_cu,
|
||||
working_path,
|
||||
gpu_target,
|
||||
datatype,
|
||||
layout,
|
||||
config_json,
|
||||
) = work_item
|
||||
|
||||
# Create a temporary builder instance for this worker
|
||||
builder = GemmKernelBuilder(working_path, datatype, layout)
|
||||
builder = GemmKernelBuilder(working_path, gpu_target, datatype, layout, config_json)
|
||||
|
||||
try:
|
||||
kernel_name, instance_code = builder._generate_kernel_instance(
|
||||
@@ -798,6 +827,8 @@ def main():
|
||||
)
|
||||
|
||||
k_block_per_cu = builder.config.get("k_block_per_cu")
|
||||
if k_block_per_cu is None:
|
||||
k_block_per_cu = 1
|
||||
|
||||
# Generate the kernel
|
||||
kernel_name, instance_code = builder._generate_kernel_instance(
|
||||
|
||||
@@ -1,175 +1,311 @@
|
||||
|
||||
set(GEMM_MULTI_D_DATATYPE "fp16" CACHE STRING "List of datatypes for GEMM Multi D (semicolon-separated)")
|
||||
set(GEMM_MULTI_D_LAYOUT "rcrr" CACHE STRING "List of layout for GEMM Multi D(semicolon-separated)")
|
||||
set(GEMM_MULTI_D_LAYOUT "rcrr;rrrr;crrr;ccrr" CACHE STRING "List of layout for GEMM Multi D (semicolon-separated)")
|
||||
set(GEMM_MULTI_D_CONFIG_FILE "" CACHE STRING "Custom config file name (without path, must be in configs/ folder)")
|
||||
set(GEMM_MULTI_D_ELEMENTWISE_FUNCTION "mul" CACHE STRING "Elementwise function")
|
||||
|
||||
function(build_gemm_multi_d_for_datatype_layout datatype layout)
|
||||
# Filter GPU targets to only gfx90a, gfx942, and gfx950
|
||||
set(GEMM_GPU_TARGETS "")
|
||||
set(DESIRED_TARGETS "gfx90a;gfx942;gfx950")
|
||||
|
||||
foreach(target IN LISTS SUPPORTED_GPU_TARGETS)
|
||||
if(target IN_LIST DESIRED_TARGETS)
|
||||
list(APPEND GEMM_GPU_TARGETS ${target})
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
# Skip compilation if no matching targets found
|
||||
if(NOT GEMM_GPU_TARGETS)
|
||||
message(WARNING "Skipping Tile Engine GEMM Multi D compilation: No supported GPU targets (gfx90a, gfx942, gfx950) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
|
||||
option(ENABLE_CCACHE_GEMM_MULTI_D "Enable ccache for GEMM Multi D ops compilation" OFF)
|
||||
|
||||
# Store the directory path for use in functions
|
||||
set(GEMM_MULTI_D_SOURCE_DIR ${CMAKE_CURRENT_LIST_DIR})
|
||||
|
||||
# Function to create individual GEMM Multi D targets
|
||||
function(create_individual_gemm_multi_d_target datatype layout trait tile_config config_json)
|
||||
# Use the parent scope GEMM_MULTI_D_GPU_TARGETS_INDIVIDUAL variable
|
||||
if(NOT GEMM_MULTI_D_GPU_TARGETS_INDIVIDUAL)
|
||||
message(WARNING "Skipping individual GEMM Multi D target ${datatype}_${layout}_${trait}_${tile_config}: No supported GPU targets")
|
||||
return()
|
||||
endif()
|
||||
|
||||
message(STATUS "Building GEMM Multi D for GPU targets: ${GEMM_GPU_TARGETS}")
|
||||
|
||||
|
||||
# Parse tile configuration: format is tile_mxtile_nxtile_k_warp_mxwarp_nxwarp_k_warp_tile_mxwarp_tile_nxwarp_tile_k
|
||||
# First split by underscore to get three groups
|
||||
string(REPLACE "_" ";" config_groups ${tile_config})
|
||||
list(GET config_groups 0 tile_dims) # e.g., 256x256x32
|
||||
list(GET config_groups 1 warp_dims) # e.g., 4x1x1
|
||||
list(GET config_groups 2 warp_tile_dims) # e.g., 16x16x16
|
||||
|
||||
# Parse tile dimensions
|
||||
string(REPLACE "x" ";" tile_parts ${tile_dims})
|
||||
list(GET tile_parts 0 tile_m)
|
||||
list(GET tile_parts 1 tile_n)
|
||||
list(GET tile_parts 2 tile_k)
|
||||
|
||||
# Parse warp dimensions
|
||||
string(REPLACE "x" ";" warp_parts ${warp_dims})
|
||||
list(GET warp_parts 0 warp_m)
|
||||
list(GET warp_parts 1 warp_n)
|
||||
list(GET warp_parts 2 warp_k)
|
||||
|
||||
# Parse warp tile dimensions
|
||||
string(REPLACE "x" ";" warp_tile_parts ${warp_tile_dims})
|
||||
list(GET warp_tile_parts 0 warp_tile_m)
|
||||
list(GET warp_tile_parts 1 warp_tile_n)
|
||||
list(GET warp_tile_parts 2 warp_tile_k)
|
||||
|
||||
set(target_name "benchmark_gemm_multi_d_${datatype}_${layout}_${trait}_${tile_config}")
|
||||
set(working_path "${CMAKE_CURRENT_BINARY_DIR}/${datatype}/${layout}")
|
||||
|
||||
# Comment this if-else block when using user_provided_config
|
||||
if(layout STREQUAL "rcrr")
|
||||
set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/default_config.json")
|
||||
# Generate the single instance header for this kernel
|
||||
set(instance_header "${working_path}/gemm_multi_d_single_${datatype}_${layout}_${trait}_${tile_config}.hpp")
|
||||
|
||||
# Add custom command to generate the header file at build time
|
||||
add_custom_command(
|
||||
OUTPUT ${instance_header}
|
||||
COMMAND ${Python3_EXECUTABLE} ${GEMM_MULTI_D_SOURCE_DIR}/gemm_multi_d_instance_builder.py
|
||||
--working_path ${working_path}
|
||||
--datatype ${datatype}
|
||||
--layout ${layout}
|
||||
--elementwise_function ${GEMM_MULTI_D_ELEMENTWISE_FUNCTION}
|
||||
--config_json ${config_json}
|
||||
--gen_single
|
||||
--kernel_name "gemm_multi_d_${datatype}_${layout}_${trait}_${tile_config}"
|
||||
--tile_config "${tile_config}"
|
||||
--trait_combo "${trait}"
|
||||
--gpu_target "${GEMM_MULTI_D_GPU_TARGETS_INDIVIDUAL}"
|
||||
DEPENDS ${GEMM_MULTI_D_SOURCE_DIR}/gemm_multi_d_instance_builder.py ${config_json}
|
||||
COMMENT "Generating ${instance_header}"
|
||||
)
|
||||
|
||||
# Create the executable
|
||||
add_executable(${target_name}
|
||||
EXCLUDE_FROM_ALL
|
||||
${GEMM_MULTI_D_SOURCE_DIR}/gemm_multi_d_benchmark_single.cpp
|
||||
${instance_header}
|
||||
)
|
||||
|
||||
# Set GPU architectures
|
||||
set_property(TARGET ${target_name} PROPERTY HIP_ARCHITECTURES ${GEMM_MULTI_D_GPU_TARGETS_INDIVIDUAL})
|
||||
|
||||
# Set compile definitions
|
||||
target_compile_definitions(${target_name} PRIVATE
|
||||
GEMM_MULTI_D_SINGLE_INSTANCE_HPP="${instance_header}"
|
||||
)
|
||||
|
||||
# Include directories
|
||||
target_include_directories(${target_name} PRIVATE
|
||||
${GEMM_MULTI_D_SOURCE_DIR}
|
||||
${working_path}
|
||||
)
|
||||
|
||||
# Compile options
|
||||
target_compile_options(${target_name} PRIVATE
|
||||
-Wno-undefined-func-template
|
||||
-Wno-float-equal
|
||||
--offload-compress
|
||||
-include ${instance_header}
|
||||
)
|
||||
|
||||
# Add to collection targets
|
||||
add_dependencies(benchmark_gemm_multi_d_all ${target_name})
|
||||
add_dependencies(benchmark_gemm_multi_d_${datatype} ${target_name})
|
||||
add_dependencies(benchmark_gemm_multi_d_${layout} ${target_name})
|
||||
add_dependencies(benchmark_gemm_multi_d_${datatype}_${layout} ${target_name})
|
||||
|
||||
# Add to trait-specific targets
|
||||
string(REPLACE "_" ";" trait_parts ${trait})
|
||||
list(GET trait_parts 0 pipeline)
|
||||
list(GET trait_parts 1 epilogue)
|
||||
list(GET trait_parts 2 scheduler)
|
||||
|
||||
add_dependencies(benchmark_gemm_multi_d_${pipeline}_pipeline ${target_name})
|
||||
add_dependencies(benchmark_gemm_multi_d_${epilogue}_epilogue ${target_name})
|
||||
add_dependencies(benchmark_gemm_multi_d_${scheduler}_scheduler ${target_name})
|
||||
endfunction()
|
||||
|
||||
# Function to build individual GEMM Multi D targets
|
||||
function(build_individual_gemm_multi_d_targets datatype layout)
|
||||
set(working_path "${CMAKE_CURRENT_BINARY_DIR}/${datatype}/${layout}")
|
||||
|
||||
# Choose config file
|
||||
# Priority order:
|
||||
# 1. Environment variable GEMM_MULTI_D_CONFIG_FILE
|
||||
# 2. CMake variable GEMM_MULTI_D_CONFIG_FILE
|
||||
# 3. Default based on layout
|
||||
|
||||
# Check environment variable first
|
||||
if(DEFINED ENV{GEMM_MULTI_D_CONFIG_FILE} AND NOT "$ENV{GEMM_MULTI_D_CONFIG_FILE}" STREQUAL "")
|
||||
set(config_filename "$ENV{GEMM_MULTI_D_CONFIG_FILE}")
|
||||
set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/${config_filename}")
|
||||
message(VERBOSE " Using config from environment variable: ${config_filename}")
|
||||
elseif(NOT "${GEMM_MULTI_D_CONFIG_FILE}" STREQUAL "")
|
||||
# Use CMake variable if set
|
||||
set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/${GEMM_MULTI_D_CONFIG_FILE}")
|
||||
message(VERBOSE " Using custom config: ${GEMM_MULTI_D_CONFIG_FILE}")
|
||||
else()
|
||||
set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/custom_ci_config.json")
|
||||
# Use default config for all layouts
|
||||
set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/default_config.json")
|
||||
message(VERBOSE " Using default config for layout ${layout}")
|
||||
endif()
|
||||
|
||||
# uncomment this if you want to use user_provided_config.json
|
||||
# set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/user_provided_config.json")
|
||||
|
||||
# Generate kernel list
|
||||
execute_process(
|
||||
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/gemm_multi_d_instance_builder.py
|
||||
# Check if config file exists
|
||||
if(NOT EXISTS ${json_blob})
|
||||
message(FATAL_ERROR "Config file not found: ${json_blob}")
|
||||
endif()
|
||||
|
||||
# Determine number of workers for parallel generation
|
||||
if(DEFINED ENV{CMAKE_BUILD_PARALLEL_LEVEL})
|
||||
set(num_workers $ENV{CMAKE_BUILD_PARALLEL_LEVEL})
|
||||
else()
|
||||
# Use processor count but limit to avoid memory issues
|
||||
cmake_host_system_information(RESULT num_cores QUERY NUMBER_OF_LOGICAL_CORES)
|
||||
math(EXPR num_workers "${num_cores}")
|
||||
if(num_workers GREATER 8)
|
||||
set(num_workers 8)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
# Generate individual kernel files using parallel version
|
||||
message(VERBOSE "Generating individual kernels for ${datatype} ${layout} using ${num_workers} workers...")
|
||||
message(VERBOSE " Working path: ${working_path}")
|
||||
message(VERBOSE " Config file: ${json_blob}")
|
||||
message(VERBOSE " Python executable: ${Python3_EXECUTABLE}")
|
||||
message(VERBOSE " Script path: ${CMAKE_CURRENT_LIST_DIR}/gemm_multi_d_instance_builder.py")
|
||||
|
||||
# Create working directory first
|
||||
file(MAKE_DIRECTORY ${working_path})
|
||||
|
||||
message(VERBOSE "COMMAND: ${Python3_EXECUTABLE} -u ${CMAKE_CURRENT_LIST_DIR}/gemm_multi_d_instance_builder.py
|
||||
--working_path ${working_path}
|
||||
--datatype ${datatype}
|
||||
--layout ${layout}
|
||||
--elementwise_function ${GEMM_MULTI_D_ELEMENTWISE_FUNCTION}
|
||||
--config_json ${json_blob}
|
||||
--list_blobs
|
||||
--gpu_target ${GEMM_GPU_TARGETS}
|
||||
RESULT_VARIABLE ret
|
||||
)
|
||||
if(NOT ret EQUAL 0)
|
||||
message(FATAL_ERROR "Failed to list kernels for ${datatype} ${layout}: ${ret}")
|
||||
endif()
|
||||
--gpu_target ${GEMM_MULTI_D_GPU_TARGETS_INDIVIDUAL}
|
||||
--list_kernels ")
|
||||
|
||||
file(STRINGS "${working_path}/gemm_multi_d_instance_blobs.txt" codegen_blobs)
|
||||
file(STRINGS "${working_path}/gemm_multi_d_instance_blobs_range.txt" codegen_blobs_range)
|
||||
|
||||
# Generate the blobs
|
||||
add_custom_command(
|
||||
OUTPUT ${codegen_blobs}
|
||||
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/gemm_multi_d_instance_builder.py
|
||||
--working_path "${working_path}"
|
||||
# First, just list the kernels (fast operation)
|
||||
message(VERBOSE " Listing kernel configurations...")
|
||||
execute_process(
|
||||
COMMAND ${Python3_EXECUTABLE} -u ${CMAKE_CURRENT_LIST_DIR}/gemm_multi_d_instance_builder.py
|
||||
--working_path ${working_path}
|
||||
--datatype ${datatype}
|
||||
--layout ${layout}
|
||||
--elementwise_function ${GEMM_MULTI_D_ELEMENTWISE_FUNCTION}
|
||||
--config_json "${json_blob}"
|
||||
--gen_blobs
|
||||
--gpu_target ${GEMM_GPU_TARGETS}
|
||||
COMMENT "Generating GEMM Multi D instance sources for ${datatype} ${layout}"
|
||||
--config_json ${json_blob}
|
||||
--gpu_target ${GEMM_MULTI_D_GPU_TARGETS_INDIVIDUAL}
|
||||
--list_kernels
|
||||
WORKING_DIRECTORY ${CMAKE_CURRENT_LIST_DIR}
|
||||
RESULT_VARIABLE ret
|
||||
OUTPUT_VARIABLE list_output
|
||||
ERROR_VARIABLE list_error
|
||||
)
|
||||
add_custom_target(gemm_multi_d_gen_${datatype}_${layout} DEPENDS ${codegen_blobs})
|
||||
|
||||
set(intermediate_libs)
|
||||
list(LENGTH codegen_blobs codegen_blobs_len)
|
||||
if(NOT ret EQUAL 0)
|
||||
message(FATAL_ERROR "Failed to list kernels for ${datatype} ${layout}: ${list_error}")
|
||||
endif()
|
||||
|
||||
foreach(blob IN LISTS codegen_blobs_range)
|
||||
string(STRIP "${blob}" stripped_blob)
|
||||
separate_arguments(spilit_blob UNIX_COMMAND "${stripped_blob}")
|
||||
# Each line is: <trait_name> <first_index_inclusive> <last_index_exclusive>
|
||||
list(GET spilit_blob 0 name)
|
||||
list(GET spilit_blob 1 first)
|
||||
list(GET spilit_blob 2 last)
|
||||
math(EXPR total_files "${last} - ${first}")
|
||||
if(total_files EQUAL 0)
|
||||
continue() # nothing for this trait
|
||||
endif()
|
||||
# Read kernel count
|
||||
if(EXISTS ${working_path}/gemm_multi_d_kernel_count.txt)
|
||||
file(READ ${working_path}/gemm_multi_d_kernel_count.txt kernel_count)
|
||||
string(STRIP "${kernel_count}" kernel_count)
|
||||
message(VERBOSE " Found ${kernel_count} kernel configurations")
|
||||
else()
|
||||
message(FATAL_ERROR "Kernel count file not found")
|
||||
endif()
|
||||
|
||||
# Object libraries (chunked) per trait
|
||||
set(sub_intermediate_libs)
|
||||
set(chunk_size 3)
|
||||
math(EXPR num_chunks "( ${total_files} + ${chunk_size} - 1 ) / ${chunk_size}")
|
||||
math(EXPR num_chunks_minus_1 "${num_chunks} - 1")
|
||||
|
||||
foreach(i RANGE 0 ${num_chunks_minus_1})
|
||||
math(EXPR start "${first} + ${i} * ${chunk_size} ")
|
||||
math(EXPR end "${start} + ${chunk_size} - 1")
|
||||
|
||||
set(chunk_files)
|
||||
foreach(j RANGE ${start} ${end})
|
||||
if(j LESS ${last} AND j LESS ${codegen_blobs_len})
|
||||
list(GET codegen_blobs ${j} f)
|
||||
list(APPEND chunk_files "${f}")
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
#list(LENGTH chunk_files chunk_files_len)
|
||||
#if(chunk_files_len AND chunk_files_len GREATER 1)
|
||||
if(chunk_files)
|
||||
set(sub_intermediate_lib_name "gemm_multi_d_objlib_${name}_${i}_${datatype}_${layout}")
|
||||
add_library(${sub_intermediate_lib_name} OBJECT ${chunk_files})
|
||||
set_property(TARGET ${sub_intermediate_lib_name} PROPERTY HIP_ARCHITECTURES ${GEMM_GPU_TARGETS})
|
||||
list(APPEND sub_intermediate_libs ${sub_intermediate_lib_name})
|
||||
endif()
|
||||
# Read kernel list and create targets
|
||||
if(EXISTS ${working_path}/gemm_multi_d_kernel_list.txt)
|
||||
file(STRINGS ${working_path}/gemm_multi_d_kernel_list.txt kernel_lines)
|
||||
foreach(line IN LISTS kernel_lines)
|
||||
# Parse line: kernel_name|tile_config|trait_combo
|
||||
string(REPLACE "|" ";" parts "${line}")
|
||||
list(GET parts 0 kernel_name)
|
||||
list(GET parts 1 tile_config)
|
||||
list(GET parts 2 trait_combo)
|
||||
|
||||
# Create individual target
|
||||
create_individual_gemm_multi_d_target("${datatype}" "${layout}" "${trait_combo}" "${tile_config}" "${json_blob}")
|
||||
endforeach()
|
||||
|
||||
# ------------------ Bundle the object libs into one static lib ---------
|
||||
#list(LENGTH sub_intermediate_libs sub_intermediate_libs_len)
|
||||
#if(sub_intermediate_libs AND sub_intermediate_libs_len GREATER 1)
|
||||
if(sub_intermediate_libs)
|
||||
set(intermediate_lib_name "gemm_multi_d_staticlib_${name}_${datatype}_${layout}")
|
||||
# Collect the $<TARGET_OBJECTS:...> expressions
|
||||
|
||||
set(obj_exprs)
|
||||
foreach(objlib IN LISTS sub_intermediate_libs)
|
||||
list(APPEND obj_exprs $<TARGET_OBJECTS:${objlib}>)
|
||||
endforeach()
|
||||
|
||||
add_library(${intermediate_lib_name} STATIC ${obj_exprs})
|
||||
add_dependencies(${intermediate_lib_name} gemm_multi_d_gen_${datatype}_${layout})
|
||||
set_property(TARGET ${intermediate_lib_name} PROPERTY HIP_ARCHITECTURES ${GEMM_GPU_TARGETS})
|
||||
#foreach(objlib IN LISTS sub_intermediate_libs)
|
||||
# target_sources(${intermediate_lib_name} PRIVATE $<TARGET_OBJECTS:${objlib}>)
|
||||
#endforeach()
|
||||
list(APPEND intermediate_libs ${intermediate_lib_name})
|
||||
endif()
|
||||
|
||||
endforeach()
|
||||
|
||||
# Interface library for instances
|
||||
add_library(gemm_multi_d_template_instances_${datatype}_${layout} INTERFACE)
|
||||
add_dependencies(gemm_multi_d_template_instances_${datatype}_${layout} gemm_multi_d_gen_${datatype}_${layout})
|
||||
target_link_libraries(gemm_multi_d_template_instances_${datatype}_${layout} INTERFACE ${intermediate_libs})
|
||||
target_include_directories(gemm_multi_d_template_instances_${datatype}_${layout} INTERFACE
|
||||
${CMAKE_CURRENT_LIST_DIR}
|
||||
"${working_path}"
|
||||
)
|
||||
set_target_properties(gemm_multi_d_template_instances_${datatype}_${layout} PROPERTIES LINKER_LANGUAGE CXX)
|
||||
|
||||
# Host API interface library
|
||||
add_library(gemm_multi_d_host_api_${datatype}_${layout} INTERFACE)
|
||||
target_link_libraries(gemm_multi_d_host_api_${datatype}_${layout} INTERFACE gemm_multi_d_template_instances_${datatype}_${layout})
|
||||
target_include_directories(gemm_multi_d_host_api_${datatype}_${layout} INTERFACE
|
||||
${CMAKE_CURRENT_LIST_DIR}
|
||||
"${working_path}"
|
||||
)
|
||||
|
||||
|
||||
|
||||
# Executable per datatype
|
||||
set(exec_name "benchmark_gemm_multi_d_${datatype}_${layout}")
|
||||
add_executable(${exec_name} benchmark_gemm_multi_d.cpp)
|
||||
set_property(TARGET ${exec_name} PROPERTY HIP_ARCHITECTURES ${GEMM_GPU_TARGETS})
|
||||
target_link_libraries(${exec_name} PRIVATE gemm_multi_d_host_api_${datatype}_${layout})
|
||||
target_compile_options(${exec_name} PRIVATE
|
||||
-Wno-undefined-func-template
|
||||
-Wno-float-equal
|
||||
--offload-compress
|
||||
)
|
||||
else()
|
||||
message(FATAL_ERROR "Kernel list file not found")
|
||||
endif()
|
||||
endfunction()
|
||||
|
||||
# Process each datatype in isolation
|
||||
foreach(dt IN LISTS GEMM_MULTI_D_DATATYPE)
|
||||
foreach(l IN LISTS GEMM_MULTI_D_LAYOUT)
|
||||
build_gemm_multi_d_for_datatype_layout(${dt} ${l})
|
||||
endforeach()
|
||||
# Main build logic - Only individual builds supported
|
||||
message(VERBOSE "=== Starting Tile Engine GEMM Multi D Configuration ===")
|
||||
message(VERBOSE "GEMM_MULTI_D_DATATYPE: ${GEMM_MULTI_D_DATATYPE}")
|
||||
message(VERBOSE "GEMM_MULTI_D_LAYOUT: ${GEMM_MULTI_D_LAYOUT}")
|
||||
message(VERBOSE "SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
|
||||
|
||||
# Filter GPU targets to only gfx90a, gfx942, gfx950
|
||||
set(GEMM_MULTI_D_GPU_TARGETS_INDIVIDUAL "")
|
||||
set(DESIRED_TARGETS "gfx90a;gfx942;gfx950")
|
||||
|
||||
foreach(target IN LISTS SUPPORTED_GPU_TARGETS)
|
||||
if(target IN_LIST DESIRED_TARGETS)
|
||||
list(APPEND GEMM_MULTI_D_GPU_TARGETS_INDIVIDUAL ${target})
|
||||
message(VERBOSE " Adding GPU target: ${target}")
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
# Skip build if no matching targets found
|
||||
if(NOT GEMM_MULTI_D_GPU_TARGETS_INDIVIDUAL)
|
||||
message(WARNING "Skipping Tile Engine GEMM Multi D build: No supported GPU targets (gfx90a, gfx942, gfx950) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
|
||||
else()
|
||||
message(VERBOSE "Building individual GEMM Multi D targets for GPU targets: ${GEMM_MULTI_D_GPU_TARGETS_INDIVIDUAL}")
|
||||
|
||||
# Enable parallel compilation optimizations
|
||||
# Set up job pools for better parallel compilation control
|
||||
set_property(GLOBAL PROPERTY JOB_POOLS
|
||||
compile_heavy=4 # Limit heavy compilations to prevent OOM
|
||||
compile_normal=16 # Allow more parallel normal compilations
|
||||
)
|
||||
|
||||
# Enable compiler cache if available and explicitly requested
|
||||
# Disabled by default due to permission issues in CI environments
|
||||
if(ENABLE_CCACHE_GEMM_MULTI_D)
|
||||
find_program(CCACHE_PROGRAM ccache)
|
||||
if(CCACHE_PROGRAM)
|
||||
set(CMAKE_CXX_COMPILER_LAUNCHER ${CCACHE_PROGRAM})
|
||||
message(VERBOSE "Using ccache for faster compilation")
|
||||
else()
|
||||
message(WARNING "ccache requested but not found")
|
||||
endif()
|
||||
else()
|
||||
message(VERBOSE "ccache disabled for GEMM Multi D ops (use -DENABLE_CCACHE_GEMM_MULTI_D=ON to enable)")
|
||||
endif()
|
||||
|
||||
# Create master collection targets
|
||||
add_custom_target(benchmark_gemm_multi_d_all)
|
||||
|
||||
# Create datatype collection targets
|
||||
foreach(dt IN LISTS GEMM_MULTI_D_DATATYPE)
|
||||
add_custom_target(benchmark_gemm_multi_d_${dt})
|
||||
endforeach()
|
||||
|
||||
# Create layout collection targets
|
||||
foreach(l IN LISTS GEMM_MULTI_D_LAYOUT)
|
||||
add_custom_target(benchmark_gemm_multi_d_${l})
|
||||
endforeach()
|
||||
|
||||
# Create combined collection targets
|
||||
foreach(dt IN LISTS GEMM_MULTI_D_DATATYPE)
|
||||
foreach(l IN LISTS GEMM_MULTI_D_LAYOUT)
|
||||
add_custom_target(benchmark_gemm_multi_d_${dt}_${l})
|
||||
endforeach()
|
||||
endforeach()
|
||||
|
||||
# Create trait-based collection targets
|
||||
# These are common trait components used across all GEMM Multi D kernels
|
||||
set(GEMM_MULTI_D_PIPELINES "mem;compv3;compv4")
|
||||
set(GEMM_MULTI_D_EPILOGUES "default;cshuffle")
|
||||
set(GEMM_MULTI_D_SCHEDULERS "intrawave;interwave")
|
||||
|
||||
foreach(pipeline IN LISTS GEMM_MULTI_D_PIPELINES)
|
||||
add_custom_target(benchmark_gemm_multi_d_${pipeline}_pipeline)
|
||||
endforeach()
|
||||
|
||||
foreach(epilogue IN LISTS GEMM_MULTI_D_EPILOGUES)
|
||||
add_custom_target(benchmark_gemm_multi_d_${epilogue}_epilogue)
|
||||
endforeach()
|
||||
|
||||
foreach(scheduler IN LISTS GEMM_MULTI_D_SCHEDULERS)
|
||||
add_custom_target(benchmark_gemm_multi_d_${scheduler}_scheduler)
|
||||
endforeach()
|
||||
|
||||
# Build individual targets for each datatype/layout combination
|
||||
foreach(dt IN LISTS GEMM_MULTI_D_DATATYPE)
|
||||
foreach(l IN LISTS GEMM_MULTI_D_LAYOUT)
|
||||
build_individual_gemm_multi_d_targets(${dt} ${l})
|
||||
endforeach()
|
||||
endforeach()
|
||||
endif()
|
||||
|
||||
@@ -1,110 +0,0 @@
|
||||
|
||||
CK Tile Engine for GEMM Multi D is used to generate and run GEMM kernels with different combinations of BlockTile sizes, WarpTile sizes, WarpTile mapping for all valid pipelines, schedulers and epilogues while able to give custom datatype and Layout selections
|
||||
|
||||
# Kernel Configurations
|
||||
|
||||
# User Specific
|
||||
Users can specify custom kernel configurations such as tile size, warp size, padding, pipeline, scheduler, and epilogue in the config file. This allows building only for selected configurations, significantly reducing build time.
|
||||
For reference please see `./configs/user_provided_config.json`.
|
||||
|
||||
# Default
|
||||
The Tile engine also has a default kernel configuration for providing range of configuration parameter values, which helps users who lack kernel development experience to benchmark. For reference please see in `./configs/default_config.json`
|
||||
|
||||
If user does not provide kernel configuration, the tile engine uses default kernel configuration to generate kernel instances and benchmark.
|
||||
|
||||
## Build Instructions
|
||||
``` bash
|
||||
# in the root of composable kernel create build directory
|
||||
mkdir build && cd build
|
||||
# build composable kernel
|
||||
# replace [Arch] with the appropriate architecture or leave blank and
|
||||
# replace [Datatype] in comma separated datatypes string (possible datatypes are [fp16])
|
||||
# replace [Layout1;Layout2;...] in comma separated datatypes string (possible layouts are [rcr, rrr, crr, ccr])
|
||||
# replace "mul" with either of mul,add,passthrough for Elementwise function as Multiply, Add or Passthrough respectively. If this is not specified it is considered as mul by default.
|
||||
../script/cmake-ck-dev.sh ../ [Arch] -DGEMM_MULTI_D_DATATYPE="[Datatype]" -DGEMM_MULTI_D_LAYOUT="[Layout1;Layout2]" -DGEMM_MULTI_D_ELEMENTWISE_FUNCTION="mul"
|
||||
# generate different executable for each passed datatype
|
||||
make benchmark_gemm_multi_d_[Datatype]_[Layout1] -j
|
||||
make benchmark_gemm_multi_d_[Datatype]_[Layout2] -j
|
||||
```
|
||||
`benchmark_gemm_multi_d_[Datatype]_[Layout]` will be located in the `./bin/` directory.
|
||||
|
||||
`benchmark_gemm_multi_d_[Datatype]_[Layout]` must be rebuilt everytime if configuration file is modified.
|
||||
|
||||
``` bash
|
||||
rm -rf tile_engine/ && make benchmark_gemm_multi_d_[Datatype]_[Layout] -j # rebuild
|
||||
```
|
||||
|
||||
## For eaxmple build for gfx942 for datatype with rcr layout
|
||||
``` bash
|
||||
mkdir build && cd build
|
||||
../script/cmake-ck-dev.sh ../ gfx942 -DGEMM_MULTI_D_DATATYPE="fp16" -DGEMM_MULTI_D_LAYOUT="rcrr"
|
||||
make benchmark_gemm_multi_d_fp16_rcrr -j
|
||||
|
||||
## benchmark_gemm inputs
|
||||
```
|
||||
-m The value for m dimension. Default is 3840.
|
||||
-n The value for n dimension. Default is 4096.
|
||||
-k The value for k dimension. Default is 2048.
|
||||
-stride_a The stride value for tensor A. Default is 0.
|
||||
-stride_b The stride value for tensor B. Default is 0.
|
||||
-stride_ds The stride value for tensor Ds. Default is 0.
|
||||
-stride_e The stride value for tensor E. Default is 0.
|
||||
-split_k The split value for k dimension. Default is 1.
|
||||
-verify The type of validation. Set to 0 for no validation, 1 for validation on CPU, or 2 for validation on GPU. Default is 1, validation on CPU, as validation on GPU is not supported.
|
||||
-log Wether output kernel instance information or not. Possible values are true or false. Default is false.
|
||||
-warmup The number of iterations before benchmark the kernel. Default is 50.
|
||||
-repeat The number of iterations to benchmark the kernel. Default is 100.
|
||||
-timer Whether if the timer is gpu timer or not. Possible values are false or true. Default is true.
|
||||
-init The method of tensor initialization. Set to 0 for random, to 1 for linear, or 2 for constant(1). Default is 0, random.
|
||||
-flush_cache To flush cache, possible values are true or false. Default is false.
|
||||
-rotating_count Number of iterations to rotate the cache. Default is 5.
|
||||
-metric Metric with which to measure kernel performance. Set to 0 for latency, 1 for tflops, or 2 for bandwidth. Default is 0, latency.
|
||||
-csv_filename The filename of benchmark result. Default is gemm_multi_d_kernel.
|
||||
-pipeline The type of pipeline. Possible values are compv3, compv4 or mem. Default is compv3.
|
||||
-scheduler The type of scheduler. Possible values are intrawave. Default is intrawave.
|
||||
-epilogue The type of epilogue. Possible values are cshuffle or default. Default is cshuffle.
|
||||
-pad_m Whether pad or not in m direction. Possible values are true or false. Default is false.
|
||||
-pad_n Whether pad or not in n direction. Possible values are true or false. Default is false.
|
||||
-pad_k Whether pad or not in k direction. Possible values are true or false. Default is false.
|
||||
|
||||
Note: pipeline, scheduler, epilogue, pad_m, pad_n, pad_k should be one of the options specified in user_provided_config.json
|
||||
```
|
||||
Note: In `./configs/user_provided_config.json` pipeline, scheduler, epilogue, pad_m, pad_n, pad_k should be from one of the values specified above.
|
||||
|
||||
## Example
|
||||
|
||||
The following JSON file specifies parameters used to generate and build GEMM kernels across all possible combinations of pipelines, schedulers, epilogues with different tile and warp sizes.
|
||||
|
||||
```json
|
||||
{
|
||||
/// other parameters ///
|
||||
|
||||
"tile_m": {
|
||||
"values": [256]
|
||||
},
|
||||
"tile_n": {
|
||||
"values": [256]
|
||||
},
|
||||
"tile_k": {
|
||||
"values": [64, 32]
|
||||
},
|
||||
|
||||
/// other parameters ///
|
||||
|
||||
"pipeline": {
|
||||
"values": ["compv3", "compv4", "mem"]
|
||||
},
|
||||
"scheduler": {
|
||||
"values": ["intrawave", "interwave"]
|
||||
},
|
||||
"epilogue": {
|
||||
"values": ["cshuffle"]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
At runtime, a specific subset of the generated kernels can be selected using command-line arguments.
|
||||
``` bash
|
||||
./bin/benchmark_gemm_multi_d_[Datatype]_[Layout] -pipeline=compv3 -scheduler=intrawave -epilogue=cshuffle
|
||||
```
|
||||
The above command runs kernels configured with the compv3 pipeline, intrawave scheduler, and cshuffle epilogue, while sweeping over different BlockTile sizes, WarpTile sizes, and WarpTile mappings.
|
||||
@@ -1,73 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <functional>
|
||||
#include <tuple>
|
||||
#include <exception>
|
||||
|
||||
#include "benchmark_gemm_multi_d.hpp"
|
||||
#include "gemm_multi_d_profiler.hpp"
|
||||
|
||||
void benchmark_gemm_multi_d(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
GemmMultiDProblem gemm_multi_d_problem{arg_parser.get_int("split_k"),
|
||||
arg_parser.get_int("m"),
|
||||
arg_parser.get_int("n"),
|
||||
arg_parser.get_int("k"),
|
||||
arg_parser.get_int("stride_a"),
|
||||
arg_parser.get_int("stride_b"),
|
||||
arg_parser.get_int("stride_ds"),
|
||||
arg_parser.get_int("stride_ds"),
|
||||
arg_parser.get_int("stride_e"),
|
||||
DataTypeTraits<ADataType>::name,
|
||||
DataTypeTraits<BDataType>::name,
|
||||
DataTypeTraits<D0DataType>::name,
|
||||
DataTypeTraits<D1DataType>::name,
|
||||
DataTypeTraits<AccDataType>::name,
|
||||
DataTypeTraits<EDataType>::name,
|
||||
ALayout::name,
|
||||
BLayout::name,
|
||||
D0Layout::name,
|
||||
D1Layout::name,
|
||||
ELayout::name};
|
||||
|
||||
Setting setting{arg_parser.get_int("warmup"),
|
||||
arg_parser.get_int("repeat"),
|
||||
arg_parser.get_bool("timer"),
|
||||
arg_parser.get_int("verify"),
|
||||
arg_parser.get_int("init"),
|
||||
arg_parser.get_bool("log"),
|
||||
arg_parser.get_str("csv_filename"),
|
||||
arg_parser.get_bool("flush_cache"),
|
||||
arg_parser.get_int("rotating_count")};
|
||||
|
||||
auto& profiler = GemmMultiDProfiler::instance(setting);
|
||||
|
||||
try
|
||||
{
|
||||
auto kernel_func = get_kernel_func_by_trait(arg_parser);
|
||||
profiler.benchmark(gemm_multi_d_problem, kernel_func);
|
||||
profiler.select_best_instance(static_cast<Metric>(arg_parser.get_int("metric")));
|
||||
}
|
||||
catch(const std::exception& e)
|
||||
{
|
||||
std::cerr << "Benchmark failed: " << e.what() << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
try
|
||||
{
|
||||
auto [result, parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return EXIT_FAILURE;
|
||||
benchmark_gemm_multi_d(parser);
|
||||
return 0;
|
||||
}
|
||||
catch(const std::exception& e)
|
||||
{
|
||||
std::cerr << "Error: " << e.what() << "\n";
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
}
|
||||
@@ -1,80 +0,0 @@
|
||||
{
|
||||
"tile_config": {
|
||||
"tile_m": {
|
||||
"values": [
|
||||
256 ]
|
||||
},
|
||||
"tile_n": {
|
||||
"values": [
|
||||
128
|
||||
]
|
||||
},
|
||||
"tile_k": {
|
||||
"values": [
|
||||
32
|
||||
]
|
||||
},
|
||||
"warp_m": {
|
||||
"values": [
|
||||
2
|
||||
]
|
||||
},
|
||||
"warp_n": {
|
||||
"values": [
|
||||
2
|
||||
]
|
||||
},
|
||||
"warp_k": {
|
||||
"values": [
|
||||
1
|
||||
]
|
||||
},
|
||||
"warp_tile_m": {
|
||||
"values": [
|
||||
16
|
||||
]
|
||||
},
|
||||
"warp_tile_n": {
|
||||
"values": [
|
||||
16
|
||||
]
|
||||
},
|
||||
"warp_tile_k": {
|
||||
"values": [
|
||||
16
|
||||
]
|
||||
}
|
||||
},
|
||||
"trait_config": {
|
||||
"pipeline": {
|
||||
"values": [
|
||||
"compv3"
|
||||
]
|
||||
},
|
||||
"scheduler": {
|
||||
"values": [
|
||||
"intrawave"
|
||||
]
|
||||
},
|
||||
"epilogue": {
|
||||
"values": [
|
||||
"cshuffle"
|
||||
]
|
||||
},
|
||||
"pad_m": {
|
||||
"values": [
|
||||
false
|
||||
]
|
||||
},
|
||||
"pad_n": {
|
||||
"values": [
|
||||
false
|
||||
]
|
||||
},
|
||||
"pad_k": {
|
||||
"values": [
|
||||
false
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,84 +1,104 @@
|
||||
{
|
||||
"tile_config": {
|
||||
"tile_m": {
|
||||
"values": [
|
||||
256
|
||||
]
|
||||
"tile_config": {
|
||||
"tile_m": {
|
||||
"max": 256,
|
||||
"min": 64,
|
||||
"step": 64
|
||||
},
|
||||
"tile_n": {
|
||||
"max": 256,
|
||||
"min": 64,
|
||||
"step": 64
|
||||
},
|
||||
"tile_k": {
|
||||
"max": 256,
|
||||
"min": 64,
|
||||
"step": 64
|
||||
},
|
||||
"warp_m": {
|
||||
"values": [
|
||||
4,
|
||||
2,
|
||||
1
|
||||
]
|
||||
},
|
||||
"warp_n": {
|
||||
"values": [
|
||||
4,
|
||||
2,
|
||||
1
|
||||
]
|
||||
},
|
||||
"warp_k": {
|
||||
"values": [
|
||||
1
|
||||
]
|
||||
},
|
||||
"warp_tile_m": {
|
||||
"values": [
|
||||
4,
|
||||
16,
|
||||
32
|
||||
]
|
||||
},
|
||||
"warp_tile_n": {
|
||||
"values": [
|
||||
16,
|
||||
32,
|
||||
64
|
||||
]
|
||||
},
|
||||
"warp_tile_k": {
|
||||
"values": [
|
||||
8,
|
||||
16,
|
||||
32,
|
||||
64,
|
||||
128
|
||||
]
|
||||
}
|
||||
},
|
||||
"tile_n": {
|
||||
"values": [
|
||||
128
|
||||
]
|
||||
"trait_config": {
|
||||
"pipeline": {
|
||||
"values": [
|
||||
"compv3",
|
||||
"compv4",
|
||||
"mem"
|
||||
]
|
||||
},
|
||||
"scheduler": {
|
||||
"values": [
|
||||
"intrawave",
|
||||
"interwave"
|
||||
]
|
||||
},
|
||||
"epilogue": {
|
||||
"values": [
|
||||
"cshuffle",
|
||||
"default"
|
||||
]
|
||||
},
|
||||
"pad_m": {
|
||||
"values": [
|
||||
false
|
||||
]
|
||||
},
|
||||
"pad_n": {
|
||||
"values": [
|
||||
false
|
||||
]
|
||||
},
|
||||
"pad_k": {
|
||||
"values": [
|
||||
false
|
||||
]
|
||||
},
|
||||
"persistent": {
|
||||
"values": [
|
||||
false,
|
||||
true
|
||||
]
|
||||
}
|
||||
},
|
||||
"tile_k": {
|
||||
"values": [
|
||||
32
|
||||
]
|
||||
},
|
||||
"warp_m": {
|
||||
"values": [
|
||||
2
|
||||
]
|
||||
},
|
||||
"warp_n": {
|
||||
"values": [
|
||||
2
|
||||
]
|
||||
},
|
||||
"warp_k": {
|
||||
"values": [
|
||||
1
|
||||
]
|
||||
},
|
||||
"warp_tile_m": {
|
||||
"values": [
|
||||
16
|
||||
]
|
||||
},
|
||||
"warp_tile_n": {
|
||||
"values": [
|
||||
16
|
||||
]
|
||||
},
|
||||
"warp_tile_k": {
|
||||
"values": [
|
||||
16
|
||||
]
|
||||
}
|
||||
},
|
||||
"trait_config": {
|
||||
"pipeline": {
|
||||
"values": [
|
||||
"compv3",
|
||||
"compv4",
|
||||
"mem"
|
||||
]
|
||||
},
|
||||
"scheduler": {
|
||||
"values": [
|
||||
"intrawave",
|
||||
"interwave"
|
||||
]
|
||||
},
|
||||
"epilogue": {
|
||||
"values": [
|
||||
"cshuffle"
|
||||
]
|
||||
},
|
||||
"pad_m": {
|
||||
"values": [
|
||||
false
|
||||
]
|
||||
},
|
||||
"pad_n": {
|
||||
"values": [
|
||||
false
|
||||
]
|
||||
},
|
||||
"pad_k": {
|
||||
"values": [
|
||||
false
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
"k_block_per_cu": 1
|
||||
}
|
||||
|
||||
@@ -2,12 +2,12 @@
|
||||
"tile_config": {
|
||||
"tile_m": {
|
||||
"values": [
|
||||
256
|
||||
64
|
||||
]
|
||||
},
|
||||
"tile_n": {
|
||||
"values": [
|
||||
256
|
||||
192
|
||||
]
|
||||
},
|
||||
"tile_k": {
|
||||
@@ -42,24 +42,24 @@
|
||||
},
|
||||
"warp_tile_k": {
|
||||
"values": [
|
||||
16
|
||||
8
|
||||
]
|
||||
}
|
||||
},
|
||||
"trait_config": {
|
||||
"pipeline": {
|
||||
"values": [
|
||||
"compv3"
|
||||
"compv4"
|
||||
]
|
||||
},
|
||||
"scheduler": {
|
||||
"values": [
|
||||
"intrawave"
|
||||
"intrawave"
|
||||
]
|
||||
},
|
||||
"epilogue": {
|
||||
"values": [
|
||||
"cshuffle"
|
||||
"cshuffle"
|
||||
]
|
||||
},
|
||||
"pad_m": {
|
||||
@@ -76,6 +76,12 @@
|
||||
"values": [
|
||||
false
|
||||
]
|
||||
},
|
||||
"persistent": {
|
||||
"values": [
|
||||
true
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
"k_block_per_cu": 1
|
||||
}
|
||||
@@ -7,80 +7,14 @@
|
||||
#include <string>
|
||||
#include <fstream>
|
||||
#include <stdexcept>
|
||||
#include <iomanip>
|
||||
|
||||
#include "gemm_multi_d_host_api.hpp"
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "gemm_multi_d_common.hpp"
|
||||
|
||||
struct GemmMultiDProblem
|
||||
{
|
||||
int split_k_;
|
||||
int m_, n_, k_;
|
||||
int stride_a_, stride_b_, stride_d0_, stride_d1_, stride_e_;
|
||||
|
||||
std::string dtype_a_, dtype_b_, dtype_d0_, dtype_d1_, dtype_acc_, dtype_e_;
|
||||
std::string layout_a_, layout_b_, layout_d0_, layout_d1_, layout_e_;
|
||||
|
||||
friend std::ostream& operator<<(std::ostream& os, const GemmMultiDProblem& problem)
|
||||
{
|
||||
os << "{\n"
|
||||
<< " \"split_k\":" << problem.split_k_ << ",\n"
|
||||
<< " \"m\":" << problem.m_ << ",\n"
|
||||
<< " \"n\":" << problem.n_ << ",\n"
|
||||
<< " \"k\":" << problem.k_ << ",\n"
|
||||
<< " \"stride_a\":" << problem.stride_a_ << ",\n"
|
||||
<< " \"stride_b\":" << problem.stride_b_ << ",\n"
|
||||
<< " \"stride_d0\":" << problem.stride_d0_ << ",\n"
|
||||
<< " \"stride_d1\":" << problem.stride_d1_ << ",\n"
|
||||
<< " \"stride_e\":" << problem.stride_e_ << ",\n"
|
||||
<< " \"dtype_a\":\"" << problem.dtype_a_ << "\",\n"
|
||||
<< " \"dtype_b\":\"" << problem.dtype_b_ << "\",\n"
|
||||
<< " \"dtype_d0\":\"" << problem.dtype_d0_ << "\",\n"
|
||||
<< " \"dtype_d1\":\"" << problem.dtype_d1_ << "\",\n"
|
||||
<< " \"dtype_acc\":\"" << problem.dtype_acc_ << "\",\n"
|
||||
<< " \"dtype_e\":\"" << problem.dtype_e_ << "\",\n"
|
||||
<< " \"layout_a\":\"" << problem.layout_a_ << "\",\n"
|
||||
<< " \"layout_b\":\"" << problem.layout_b_ << "\",\n"
|
||||
<< " \"layout_d0\":\"" << problem.layout_d0_ << "\",\n"
|
||||
<< " \"layout_d1\":\"" << problem.layout_d1_ << "\",\n"
|
||||
<< " \"layout_e\":\"" << problem.layout_e_ << "\"\n"
|
||||
<< "}";
|
||||
return os;
|
||||
}
|
||||
};
|
||||
|
||||
struct Setting
|
||||
{
|
||||
int n_warmup_;
|
||||
int n_repeat_;
|
||||
bool is_gpu_timer_;
|
||||
int verify_;
|
||||
int init_method_;
|
||||
bool log_;
|
||||
std::string csv_filename_;
|
||||
bool flush_cache_;
|
||||
int rotating_count_;
|
||||
};
|
||||
|
||||
// @brief Function to get the kernel output with reference implementation on CPU
|
||||
void gemm_multi_d_host_reference(int verify,
|
||||
ck_tile::HostTensor<ADataType>& a_m_k,
|
||||
ck_tile::HostTensor<BDataType>& b_k_n,
|
||||
ck_tile::HostTensor<D0DataType>& d0_m_n,
|
||||
ck_tile::HostTensor<D1DataType>& d1_m_n,
|
||||
ck_tile::HostTensor<EDataType>& e_m_n_host_result)
|
||||
{
|
||||
if(verify > 0)
|
||||
{
|
||||
// Currently supporting on CPU verification for Gemm Multi D
|
||||
// e_m_n_host_result.SetZero();
|
||||
ck_tile::reference_gemm_multiple_d<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
EDataType,
|
||||
ElementWiseFn>(
|
||||
a_m_k, b_k_n, {d0_m_n, d1_m_n}, e_m_n_host_result);
|
||||
}
|
||||
}
|
||||
// Data types and Layouts are defined by the generated kernel headers
|
||||
// No hardcoded type definitions here to avoid conflicts
|
||||
|
||||
enum class Metric
|
||||
{
|
||||
@@ -100,6 +34,43 @@ inline constexpr auto get_metric_name(Metric m)
|
||||
}
|
||||
}
|
||||
|
||||
struct GemmMultiDProblem
|
||||
{
|
||||
int split_k_;
|
||||
int m_, n_, k_;
|
||||
int stride_a_, stride_b_, stride_d0_, stride_d1_, stride_c_;
|
||||
|
||||
std::string dtype_a_, dtype_b_, dtype_d0_, dtype_d1_, dtype_acc_, dtype_c_;
|
||||
std::string layout_a_, layout_b_, layout_d0_, layout_d1_, layout_c_;
|
||||
|
||||
friend std::ostream& operator<<(std::ostream& os, const GemmMultiDProblem& problem)
|
||||
{
|
||||
os << "{\n"
|
||||
<< " \"split_k\":" << problem.split_k_ << ",\n"
|
||||
<< " \"m\":" << problem.m_ << ",\n"
|
||||
<< " \"n\":" << problem.n_ << ",\n"
|
||||
<< " \"k\":" << problem.k_ << ",\n"
|
||||
<< " \"stride_a\":" << problem.stride_a_ << ",\n"
|
||||
<< " \"stride_b\":" << problem.stride_b_ << ",\n"
|
||||
<< " \"stride_d0\":" << problem.stride_d0_ << ",\n"
|
||||
<< " \"stride_d1\":" << problem.stride_d1_ << ",\n"
|
||||
<< " \"stride_c\":" << problem.stride_c_ << ",\n"
|
||||
<< " \"dtype_a\":\"" << problem.dtype_a_ << "\",\n"
|
||||
<< " \"dtype_b\":\"" << problem.dtype_b_ << "\",\n"
|
||||
<< " \"dtype_d0\":\"" << problem.dtype_d0_ << "\",\n"
|
||||
<< " \"dtype_d1\":\"" << problem.dtype_d1_ << "\",\n"
|
||||
<< " \"dtype_acc\":\"" << problem.dtype_acc_ << "\",\n"
|
||||
<< " \"dtype_c\":\"" << problem.dtype_c_ << "\",\n"
|
||||
<< " \"layout_a\":\"" << problem.layout_a_ << "\",\n"
|
||||
<< " \"layout_b\":\"" << problem.layout_b_ << "\",\n"
|
||||
<< " \"layout_d0\":\"" << problem.layout_d0_ << "\",\n"
|
||||
<< " \"layout_d1\":\"" << problem.layout_d1_ << "\",\n"
|
||||
<< " \"layout_c\":\"" << problem.layout_c_ << "\"" << "\n"
|
||||
<< "}";
|
||||
return os;
|
||||
}
|
||||
};
|
||||
|
||||
struct PerformanceResult
|
||||
{
|
||||
double latency_;
|
||||
@@ -143,15 +114,28 @@ struct KernelInstance
|
||||
friend std::ostream& operator<<(std::ostream& os, const KernelInstance& obj)
|
||||
{
|
||||
os << "{\n"
|
||||
<< " \"name\": \"" << "{\n"
|
||||
<< obj.name_ << "\n}" << "\",\n"
|
||||
<< " \"problem\": \"" << obj.problem_ << "\",\n"
|
||||
<< " \"name\": \"" << obj.name_ << "\",\n"
|
||||
<< " \"problem\": " << obj.problem_ << ",\n"
|
||||
<< " \"perf_result\": " << obj.perf_result_ << "\n"
|
||||
<< "}";
|
||||
return os;
|
||||
}
|
||||
};
|
||||
|
||||
struct Setting
|
||||
{
|
||||
int n_warmup_;
|
||||
int n_repeat_;
|
||||
bool is_gpu_timer_;
|
||||
int verify_;
|
||||
int init_method_;
|
||||
bool log_;
|
||||
std::string csv_filename_;
|
||||
bool flush_cache_;
|
||||
int rotating_count_;
|
||||
bool json_output_;
|
||||
};
|
||||
|
||||
inline std::string get_rocm_version()
|
||||
{
|
||||
std::ifstream version_file("/opt/rocm/.info/version");
|
||||
@@ -164,6 +148,11 @@ inline std::string get_rocm_version()
|
||||
return "Unknown";
|
||||
}
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename D0DataType,
|
||||
typename AccDataType,
|
||||
typename CDataType>
|
||||
auto calculate_rtol_atol(const ck_tile::index_t K,
|
||||
const ck_tile::index_t kbatch,
|
||||
const float max_accumulated_value)
|
||||
@@ -175,17 +164,17 @@ auto calculate_rtol_atol(const ck_tile::index_t K,
|
||||
std::conditional_t<sizeof(ComputeTypeAB) < sizeof(D0DataType), ComputeTypeAB, D0DataType>;
|
||||
|
||||
// Calculate thresholds
|
||||
const auto rtol = ck_tile::get_relative_threshold<ComputeType, EDataType, AccDataType>(
|
||||
const auto rtol = ck_tile::get_relative_threshold<ComputeType, CDataType, AccDataType>(
|
||||
ck_tile::integer_divide_ceil(K, kbatch));
|
||||
|
||||
const auto atol = ck_tile::get_absolute_threshold<ComputeType, EDataType, AccDataType>(
|
||||
const auto atol = ck_tile::get_absolute_threshold<ComputeType, CDataType, AccDataType>(
|
||||
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch));
|
||||
|
||||
// Calculate error due to split_k accumulation
|
||||
const auto rtol_split_k =
|
||||
ck_tile::get_relative_threshold<EDataType, EDataType, EDataType>(kbatch);
|
||||
ck_tile::get_relative_threshold<CDataType, CDataType, CDataType>(kbatch);
|
||||
|
||||
const auto atol_split_k = ck_tile::get_absolute_threshold<EDataType, EDataType, EDataType>(
|
||||
const auto atol_split_k = ck_tile::get_absolute_threshold<CDataType, CDataType, CDataType>(
|
||||
max_accumulated_value, kbatch);
|
||||
|
||||
// Use higher threshold
|
||||
@@ -195,16 +184,19 @@ auto calculate_rtol_atol(const ck_tile::index_t K,
|
||||
/// @brief Function to compare the results of the device and host computations
|
||||
bool compare(std::string instanceName,
|
||||
ck_tile::index_t K,
|
||||
ck_tile::HostTensor<EDataType>& e_m_n_dev_result,
|
||||
ck_tile::HostTensor<EDataType>& e_m_n_host_result)
|
||||
ck_tile::index_t kbatch,
|
||||
ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
|
||||
ck_tile::HostTensor<CDataType>& c_m_n_host_result)
|
||||
{
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(e_m_n_host_result.mData.begin(), e_m_n_host_result.mData.end());
|
||||
*std::max_element(c_m_n_host_result.mData.begin(), c_m_n_host_result.mData.end());
|
||||
|
||||
const auto rtol_atol = calculate_rtol_atol(K, 1, max_accumulated_value);
|
||||
const auto rtol_atol =
|
||||
calculate_rtol_atol<ADataType, BDataType, D0DataType, AccDataType, CDataType>(
|
||||
K, kbatch, max_accumulated_value);
|
||||
|
||||
bool pass = ck_tile::check_err(e_m_n_dev_result,
|
||||
e_m_n_host_result,
|
||||
bool pass = ck_tile::check_err(c_m_n_dev_result,
|
||||
c_m_n_host_result,
|
||||
"Error: Incorrect results!",
|
||||
rtol_atol.at(ck_tile::number<0>{}),
|
||||
rtol_atol.at(ck_tile::number<1>{}));
|
||||
@@ -216,3 +208,25 @@ bool compare(std::string instanceName,
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
/// @brief Function to get the kernel output with reference implementation on CPU/GPU
|
||||
void gemm_multi_d_host_reference(int verify,
|
||||
ck_tile::HostTensor<ADataType>& a_m_k,
|
||||
ck_tile::HostTensor<BDataType>& b_k_n,
|
||||
ck_tile::HostTensor<D0DataType>& d0_m_n,
|
||||
ck_tile::HostTensor<D1DataType>& d1_m_n,
|
||||
ck_tile::HostTensor<CDataType>& c_m_n_host_result)
|
||||
{
|
||||
if(verify > 0)
|
||||
{
|
||||
// Currently supporting on CPU verification for Gemm Multi D
|
||||
// e_m_n_host_result.SetZero();
|
||||
ck_tile::reference_gemm_multiple_d<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ElementWiseFn>(
|
||||
a_m_k, b_k_n, {d0_m_n, d1_m_n}, c_m_n_host_result);
|
||||
}
|
||||
}
|
||||
683
tile_engine/ops/gemm_multi_d/gemm_multi_d_benchmark.py
Executable file
683
tile_engine/ops/gemm_multi_d/gemm_multi_d_benchmark.py
Executable file
@@ -0,0 +1,683 @@
|
||||
#!/usr/bin/env python3
|
||||
# SPDX-License-Identifier: MIT
|
||||
# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
import sys
|
||||
import json
|
||||
import subprocess
|
||||
import argparse
|
||||
import csv
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Tuple, Optional
|
||||
|
||||
|
||||
class GemmMultiDBenchmark:
|
||||
def __init__(self, build_dir: str, verbose: bool = False):
|
||||
self.build_dir = Path(build_dir)
|
||||
self.verbose = verbose
|
||||
self.results = []
|
||||
|
||||
def discover_kernels(self) -> List[Path]:
|
||||
"""Find all benchmark_gemm_multi_d_* executables in the build directory"""
|
||||
bin_dir = self.build_dir / "bin"
|
||||
if not bin_dir.exists():
|
||||
print(f"Error: Binary directory {bin_dir} does not exist")
|
||||
return []
|
||||
|
||||
kernels = list(bin_dir.glob("benchmark_gemm_multi_d_*"))
|
||||
if self.verbose:
|
||||
print(f"Found {len(kernels)} kernel executables")
|
||||
for k in kernels:
|
||||
print(f" - {k.name}")
|
||||
return kernels
|
||||
|
||||
def extract_kernel_info(self, kernel_path: Path) -> Dict[str, str]:
|
||||
"""Extract comprehensive kernel information from filename"""
|
||||
name = kernel_path.stem
|
||||
|
||||
# Initialize with basic info
|
||||
info = {
|
||||
"executable": str(kernel_path),
|
||||
"name": name,
|
||||
"data_type": "unknown",
|
||||
"layout": "unknown",
|
||||
"pipeline": "unknown",
|
||||
"scheduler": "unknown",
|
||||
"epilogue": "unknown",
|
||||
}
|
||||
|
||||
# Parse the kernel name pattern:
|
||||
# benchmark_gemm_multi_d_fp16_rcr_mem_default_intrawave_False_False_False_False_False_256x256x32_2x2x1_4x64x16
|
||||
parts = name.split("_")
|
||||
|
||||
if len(parts) >= 5:
|
||||
# Extract data type (3rd part after benchmark_gemm_)
|
||||
info["data_type"] = parts[4] if len(parts) > 4 else "unknown"
|
||||
|
||||
# Extract layout (4th part)
|
||||
info["layout"] = parts[5] if len(parts) > 5 else "unknown"
|
||||
|
||||
# Extract pipeline (5th part)
|
||||
info["pipeline"] = parts[6] if len(parts) > 6 else "unknown"
|
||||
|
||||
# Extract epilogue (6th part)
|
||||
info["epilogue"] = parts[7] if len(parts) > 7 else "unknown"
|
||||
|
||||
# Extract scheduler (7th part)
|
||||
info["scheduler"] = parts[8] if len(parts) > 8 else "unknown"
|
||||
|
||||
# Extract detailed configuration from the end of the name
|
||||
config_info = self.parse_detailed_config(name)
|
||||
info.update(config_info)
|
||||
|
||||
# Generate config ID
|
||||
info["config_id"] = self.generate_config_id(info)
|
||||
|
||||
return info
|
||||
|
||||
def parse_detailed_config(self, kernel_name: str) -> Dict:
|
||||
"""Parse detailed configuration from kernel name"""
|
||||
config = {
|
||||
"tile_sizes": {"tile_m": 0, "tile_n": 0, "tile_k": 0},
|
||||
"warp_config": {"warp_m": 0, "warp_n": 0, "warp_k": 0},
|
||||
"warp_tile": {"warp_tile_m": 0, "warp_tile_n": 0, "warp_tile_k": 0},
|
||||
"optimization_flags": {
|
||||
"pad_m": False,
|
||||
"pad_n": False,
|
||||
"pad_k": False,
|
||||
"persistent": False,
|
||||
},
|
||||
}
|
||||
|
||||
# Split by underscore and look for patterns
|
||||
parts = kernel_name.split("_")
|
||||
|
||||
# Look for boolean flags (sequence of True/False values)
|
||||
bool_sequence = []
|
||||
for i, part in enumerate(parts):
|
||||
if part in ["True", "False"]:
|
||||
bool_sequence.append(part == "True")
|
||||
# Continue collecting consecutive boolean values
|
||||
j = i + 1
|
||||
while j < len(parts) and parts[j] in ["True", "False"]:
|
||||
bool_sequence.append(parts[j] == "True")
|
||||
j += 1
|
||||
break
|
||||
|
||||
# Assign boolean flags if we found them
|
||||
# Order: pad_m, pad_n, pad_k, persistent (4 flags total)
|
||||
if len(bool_sequence) >= 4:
|
||||
config["optimization_flags"]["pad_m"] = bool_sequence[0]
|
||||
config["optimization_flags"]["pad_n"] = bool_sequence[1]
|
||||
config["optimization_flags"]["pad_k"] = bool_sequence[2]
|
||||
config["optimization_flags"]["persistent"] = bool_sequence[3]
|
||||
|
||||
# Look for tile size patterns (e.g., 256x256x32_2x2x1_4x64x16)
|
||||
# The pattern is: tile_sizes_warp_config_warp_tile
|
||||
dimension_groups = []
|
||||
for part in parts:
|
||||
if "x" in part and len(part.split("x")) == 3:
|
||||
try:
|
||||
dims = [int(x) for x in part.split("x")]
|
||||
if all(d > 0 for d in dims):
|
||||
dimension_groups.append(dims)
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
# Assign dimensions based on order and magnitude
|
||||
if len(dimension_groups) >= 3:
|
||||
# Sort by magnitude to identify: largest=tile_sizes, smallest=warp_config, middle=warp_tile
|
||||
sorted_groups = sorted(dimension_groups, key=max, reverse=True)
|
||||
|
||||
# Largest dimensions = tile sizes
|
||||
config["tile_sizes"]["tile_m"] = sorted_groups[0][0]
|
||||
config["tile_sizes"]["tile_n"] = sorted_groups[0][1]
|
||||
config["tile_sizes"]["tile_k"] = sorted_groups[0][2]
|
||||
|
||||
# Smallest dimensions = warp config
|
||||
config["warp_config"]["warp_m"] = sorted_groups[2][0]
|
||||
config["warp_config"]["warp_n"] = sorted_groups[2][1]
|
||||
config["warp_config"]["warp_k"] = sorted_groups[2][2]
|
||||
|
||||
# Middle dimensions = warp tile
|
||||
config["warp_tile"]["warp_tile_m"] = sorted_groups[1][0]
|
||||
config["warp_tile"]["warp_tile_n"] = sorted_groups[1][1]
|
||||
config["warp_tile"]["warp_tile_k"] = sorted_groups[1][2]
|
||||
elif len(dimension_groups) == 2:
|
||||
# If only 2 groups, assign based on magnitude
|
||||
sorted_groups = sorted(dimension_groups, key=max, reverse=True)
|
||||
|
||||
# Larger = tile sizes
|
||||
config["tile_sizes"]["tile_m"] = sorted_groups[0][0]
|
||||
config["tile_sizes"]["tile_n"] = sorted_groups[0][1]
|
||||
config["tile_sizes"]["tile_k"] = sorted_groups[0][2]
|
||||
|
||||
# Smaller = warp config
|
||||
config["warp_config"]["warp_m"] = sorted_groups[1][0]
|
||||
config["warp_config"]["warp_n"] = sorted_groups[1][1]
|
||||
config["warp_config"]["warp_k"] = sorted_groups[1][2]
|
||||
elif len(dimension_groups) == 1:
|
||||
# Only one group - assume it's tile sizes
|
||||
config["tile_sizes"]["tile_m"] = dimension_groups[0][0]
|
||||
config["tile_sizes"]["tile_n"] = dimension_groups[0][1]
|
||||
config["tile_sizes"]["tile_k"] = dimension_groups[0][2]
|
||||
|
||||
return config
|
||||
|
||||
def generate_config_id(self, info: Dict) -> str:
|
||||
"""Generate a compact config ID from kernel info"""
|
||||
# Create a compact identifier
|
||||
parts = [
|
||||
info.get("data_type", "unk"),
|
||||
info.get("layout", "unk"),
|
||||
info.get("pipeline", "unk"),
|
||||
info.get("scheduler", "unk"),
|
||||
]
|
||||
|
||||
# Add tile configuration if available
|
||||
tile_sizes = info.get("tile_sizes", {})
|
||||
if tile_sizes.get("tile_m", 0) > 0:
|
||||
tile_str = (
|
||||
f"{tile_sizes['tile_m']}x{tile_sizes['tile_n']}x{tile_sizes['tile_k']}"
|
||||
)
|
||||
parts.append(tile_str)
|
||||
|
||||
# Add warp config if available
|
||||
warp_config = info.get("warp_config", {})
|
||||
if warp_config.get("warp_m", 0) > 0:
|
||||
warp_str = f"w{warp_config['warp_m']}x{warp_config['warp_n']}x{warp_config['warp_k']}"
|
||||
parts.append(warp_str)
|
||||
|
||||
# Add warp tile if available
|
||||
warp_tile = info.get("warp_tile", {})
|
||||
if warp_tile.get("warp_tile_m", 0) > 0:
|
||||
warp_tile_str = f"wt{warp_tile['warp_tile_m']}x{warp_tile['warp_tile_n']}x{warp_tile['warp_tile_k']}"
|
||||
parts.append(warp_tile_str)
|
||||
|
||||
return "_".join(parts)
|
||||
|
||||
def run_kernel(self, kernel_path: Path, params: Dict[str, str]) -> Optional[Dict]:
|
||||
"""Run a single kernel with given parameters and save output to individual JSON file"""
|
||||
# Create results directory
|
||||
results_dir = self.build_dir / "results"
|
||||
results_dir.mkdir(exist_ok=True)
|
||||
|
||||
# Generate unique JSON filename for this kernel
|
||||
json_file = results_dir / f"{kernel_path.stem}.json"
|
||||
|
||||
cmd = [str(kernel_path)]
|
||||
|
||||
# Add parameters
|
||||
for key, value in params.items():
|
||||
cmd.append(f"-{key}={value}")
|
||||
|
||||
# Add JSON output flag for clean JSON output
|
||||
cmd.append("-json_output=true")
|
||||
|
||||
if self.verbose:
|
||||
print(f"Running: {' '.join(cmd)}")
|
||||
|
||||
try:
|
||||
result = subprocess.run(cmd, capture_output=True, text=True, timeout=60)
|
||||
|
||||
if result.returncode != 0:
|
||||
print(f"Error running {kernel_path.name}: {result.stderr}")
|
||||
return None
|
||||
|
||||
# Save raw output to individual JSON file
|
||||
output = result.stdout.strip()
|
||||
if output:
|
||||
with open(json_file, "w") as f:
|
||||
f.write(output)
|
||||
|
||||
# Parse the JSON file
|
||||
return self.parse_json_file(json_file)
|
||||
else:
|
||||
print(f"No output from {kernel_path.name}")
|
||||
return None
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
print(f"Timeout running {kernel_path.name}")
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"Error running {kernel_path.name}: {e}")
|
||||
return None
|
||||
|
||||
def parse_json_file(self, json_file: Path) -> Optional[Dict]:
|
||||
"""Parse JSON data from individual kernel output file"""
|
||||
try:
|
||||
with open(json_file, "r") as f:
|
||||
content = f.read().strip()
|
||||
|
||||
# Parse the JSON directly since executables produce clean JSON
|
||||
data = json.loads(content)
|
||||
|
||||
# Return the complete JSON data as-is, just add some convenience fields
|
||||
result = data.copy()
|
||||
if "perf_result" in data:
|
||||
perf = data["perf_result"]
|
||||
# Add convenience fields for backward compatibility
|
||||
result["time_ms"] = perf.get("latency(ms)", 0)
|
||||
result["tflops"] = perf.get("tflops(TFlops)", 0)
|
||||
result["bandwidth_gb_s"] = perf.get("bandwidth(GB/s)", 0)
|
||||
|
||||
return result
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
if self.verbose:
|
||||
print(f"Failed to parse JSON from {json_file}: {e}")
|
||||
return None
|
||||
except Exception as e:
|
||||
if self.verbose:
|
||||
print(f"Error reading JSON file {json_file}: {e}")
|
||||
return None
|
||||
|
||||
def benchmark_problem_size(
|
||||
self,
|
||||
kernels: List[Path],
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
split_k: int = 1,
|
||||
verify: int = 0,
|
||||
warmup: int = 50,
|
||||
repeat: int = 100,
|
||||
flush_cache: bool = True,
|
||||
rotating_count: int = 1000,
|
||||
) -> List[Dict]:
|
||||
"""Benchmark all kernels for a specific problem size"""
|
||||
results = []
|
||||
|
||||
params = {
|
||||
"m": m,
|
||||
"n": n,
|
||||
"k": k,
|
||||
"split_k": split_k,
|
||||
"verify": verify,
|
||||
"warmup": warmup,
|
||||
"repeat": repeat,
|
||||
"flush_cache": str(flush_cache).lower(),
|
||||
"rotating_count": rotating_count,
|
||||
}
|
||||
|
||||
print(f"\nBenchmarking M={m}, N={n}, K={k}, split_k={split_k}")
|
||||
|
||||
for kernel_path in kernels:
|
||||
kernel_info = self.extract_kernel_info(kernel_path)
|
||||
result = self.run_kernel(kernel_path, params)
|
||||
|
||||
if result:
|
||||
# Create new structured result format
|
||||
structured_result = {
|
||||
"name": kernel_info["name"], # Add name field for compatibility
|
||||
"config_id": kernel_info["config_id"],
|
||||
"problem": result.get("problem", {}),
|
||||
"perf_result": result.get("perf_result", {}),
|
||||
"config": {
|
||||
"data_type": kernel_info["data_type"],
|
||||
"layout": kernel_info["layout"],
|
||||
"pipeline": kernel_info["pipeline"],
|
||||
"scheduler": kernel_info["scheduler"],
|
||||
"epilogue": kernel_info["epilogue"],
|
||||
"tile_sizes": kernel_info.get("tile_sizes", {}),
|
||||
"warp_config": kernel_info.get("warp_config", {}),
|
||||
"warp_tile": kernel_info.get("warp_tile", {}),
|
||||
"optimization_flags": kernel_info.get("optimization_flags", {}),
|
||||
},
|
||||
"executable": kernel_info["executable"],
|
||||
# Keep backward compatibility fields
|
||||
"time_ms": result.get("time_ms", 0),
|
||||
"tflops": result.get("tflops", 0),
|
||||
"bandwidth_gb_s": result.get("bandwidth_gb_s", 0),
|
||||
}
|
||||
|
||||
results.append(structured_result)
|
||||
|
||||
if self.verbose:
|
||||
print(
|
||||
f" {kernel_info['config_id']}: {structured_result['tflops']:.2f} TFLOPS, {structured_result['bandwidth_gb_s']:.2f} GB/s, {structured_result['time_ms']:.2f}ms"
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
def find_best_kernel(
|
||||
self, results: List[Dict], metric: str = "tflops"
|
||||
) -> Optional[Dict]:
|
||||
"""Find the best performing kernel based on metric"""
|
||||
if not results:
|
||||
return None
|
||||
|
||||
if metric == "tflops":
|
||||
return max(results, key=lambda x: x.get("tflops", 0))
|
||||
elif metric == "time_ms":
|
||||
return min(results, key=lambda x: x.get("time_ms", float("inf")))
|
||||
elif metric == "bandwidth_gb_s":
|
||||
return max(results, key=lambda x: x.get("bandwidth_gb_s", 0))
|
||||
else:
|
||||
raise ValueError(f"Unknown metric: {metric}")
|
||||
|
||||
def benchmark_sweep(
|
||||
self,
|
||||
problem_sizes: List[Tuple[int, int, int]],
|
||||
split_k_values: List[int] = [1],
|
||||
verify: bool = False,
|
||||
warmup: int = 50,
|
||||
repeat: int = 100,
|
||||
flush_cache: bool = True,
|
||||
rotating_count: int = 1000,
|
||||
) -> Dict:
|
||||
"""Run comprehensive benchmark sweep"""
|
||||
kernels = self.discover_kernels()
|
||||
if not kernels:
|
||||
print("No kernels found!")
|
||||
return {}
|
||||
|
||||
all_results = []
|
||||
best_kernels = {}
|
||||
|
||||
for m, n, k in problem_sizes:
|
||||
for split_k in split_k_values:
|
||||
results = self.benchmark_problem_size(
|
||||
kernels,
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
split_k,
|
||||
verify=2 if verify else 0,
|
||||
warmup=warmup,
|
||||
repeat=repeat,
|
||||
flush_cache=flush_cache,
|
||||
rotating_count=rotating_count,
|
||||
)
|
||||
|
||||
all_results.extend(results)
|
||||
|
||||
# Find best kernel for this configuration
|
||||
best = self.find_best_kernel(results)
|
||||
if best:
|
||||
key = f"m{m}_n{n}_k{k}_splitk{split_k}"
|
||||
best_kernels[key] = best
|
||||
print(
|
||||
f"Best for {key}: {best['name']} ({best['tflops']:.2f} TFLOPS, {best['bandwidth_gb_s']:.2f} GB/s, {best['time_ms']:.2f}ms)"
|
||||
)
|
||||
|
||||
self.results = all_results
|
||||
return best_kernels
|
||||
|
||||
def export_csv(self, filename: str):
|
||||
"""Export all results to CSV"""
|
||||
if not self.results:
|
||||
print("No results to export")
|
||||
return
|
||||
|
||||
# Get all unique keys from results
|
||||
all_keys = set()
|
||||
for result in self.results:
|
||||
all_keys.update(result.keys())
|
||||
|
||||
# Sort keys for consistent output
|
||||
fieldnames = sorted(all_keys)
|
||||
|
||||
with open(filename, "w", newline="") as csvfile:
|
||||
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
|
||||
writer.writeheader()
|
||||
writer.writerows(self.results)
|
||||
|
||||
print(f"Results exported to {filename}")
|
||||
|
||||
def export_best_kernels(self, best_kernels: Dict, filename: str):
|
||||
"""Export best kernel selections to file"""
|
||||
with open(filename, "w") as f:
|
||||
f.write("# Best kernel selections\n")
|
||||
f.write(
|
||||
"# Format: problem_size -> kernel_name (TFLOPS, bandwidth, latency)\n\n"
|
||||
)
|
||||
|
||||
for key, kernel in sorted(best_kernels.items()):
|
||||
f.write(
|
||||
f"{key}: {kernel['name']} ({kernel['tflops']:.2f} TFLOPS, {kernel['bandwidth_gb_s']:.2f} GB/s, {kernel['time_ms']:.2f}ms)\n"
|
||||
)
|
||||
|
||||
print(f"Best kernels exported to {filename}")
|
||||
|
||||
def export_json(self, filename: str, best_kernels: Dict = None):
|
||||
"""Export all results and best kernels to JSON with comprehensive metadata"""
|
||||
from datetime import datetime
|
||||
|
||||
# Calculate comprehensive summary statistics for all metrics
|
||||
successful_results = [r for r in self.results if r.get("tflops", 0) > 0]
|
||||
|
||||
tflops_values = [r.get("tflops", 0) for r in successful_results]
|
||||
bandwidth_values = [r.get("bandwidth_gb_s", 0) for r in successful_results]
|
||||
latency_values = [
|
||||
r.get("time_ms", 0) for r in successful_results if r.get("time_ms", 0) > 0
|
||||
]
|
||||
|
||||
# Performance breakdown by kernel type
|
||||
pipeline_stats = {}
|
||||
scheduler_stats = {}
|
||||
data_type_stats = {}
|
||||
|
||||
for result in successful_results:
|
||||
# Get config info from the new structure
|
||||
config = result.get("config", {})
|
||||
|
||||
# Pipeline statistics
|
||||
pipeline = config.get("pipeline", "unknown")
|
||||
if pipeline not in pipeline_stats:
|
||||
pipeline_stats[pipeline] = {
|
||||
"count": 0,
|
||||
"avg_tflops": 0,
|
||||
"best_tflops": 0,
|
||||
}
|
||||
pipeline_stats[pipeline]["count"] += 1
|
||||
pipeline_stats[pipeline]["best_tflops"] = max(
|
||||
pipeline_stats[pipeline]["best_tflops"], result.get("tflops", 0)
|
||||
)
|
||||
|
||||
# Scheduler statistics
|
||||
scheduler = config.get("scheduler", "unknown")
|
||||
if scheduler not in scheduler_stats:
|
||||
scheduler_stats[scheduler] = {
|
||||
"count": 0,
|
||||
"avg_tflops": 0,
|
||||
"best_tflops": 0,
|
||||
}
|
||||
scheduler_stats[scheduler]["count"] += 1
|
||||
scheduler_stats[scheduler]["best_tflops"] = max(
|
||||
scheduler_stats[scheduler]["best_tflops"], result.get("tflops", 0)
|
||||
)
|
||||
|
||||
# Data type statistics
|
||||
data_type = config.get("data_type", "unknown")
|
||||
if data_type not in data_type_stats:
|
||||
data_type_stats[data_type] = {
|
||||
"count": 0,
|
||||
"avg_tflops": 0,
|
||||
"best_tflops": 0,
|
||||
}
|
||||
data_type_stats[data_type]["count"] += 1
|
||||
data_type_stats[data_type]["best_tflops"] = max(
|
||||
data_type_stats[data_type]["best_tflops"], result.get("tflops", 0)
|
||||
)
|
||||
|
||||
# Calculate averages for breakdown stats
|
||||
for stats_dict, field_name in [
|
||||
(pipeline_stats, "pipeline"),
|
||||
(scheduler_stats, "scheduler"),
|
||||
(data_type_stats, "data_type"),
|
||||
]:
|
||||
for key in stats_dict:
|
||||
relevant_results = [
|
||||
r
|
||||
for r in successful_results
|
||||
if r.get("config", {}).get(field_name, "unknown") == key
|
||||
]
|
||||
if relevant_results:
|
||||
stats_dict[key]["avg_tflops"] = sum(
|
||||
r.get("tflops", 0) for r in relevant_results
|
||||
) / len(relevant_results)
|
||||
|
||||
output_data = {
|
||||
"benchmark_metadata": {
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"total_kernels_tested": len(self.results),
|
||||
"unique_kernels": len(
|
||||
set(r.get("name", "unknown") for r in self.results)
|
||||
),
|
||||
"successful_runs": len(successful_results),
|
||||
"failed_runs": len(self.results) - len(successful_results),
|
||||
},
|
||||
"performance_summary": {
|
||||
"tflops_stats": {
|
||||
"best": max(tflops_values, default=0),
|
||||
"average": sum(tflops_values) / len(tflops_values)
|
||||
if tflops_values
|
||||
else 0,
|
||||
"min": min(tflops_values, default=0),
|
||||
"median": sorted(tflops_values)[len(tflops_values) // 2]
|
||||
if tflops_values
|
||||
else 0,
|
||||
},
|
||||
"bandwidth_stats": {
|
||||
"best_gb_s": max(bandwidth_values, default=0),
|
||||
"average_gb_s": sum(bandwidth_values) / len(bandwidth_values)
|
||||
if bandwidth_values
|
||||
else 0,
|
||||
"min_gb_s": min(bandwidth_values, default=0),
|
||||
"median_gb_s": sorted(bandwidth_values)[len(bandwidth_values) // 2]
|
||||
if bandwidth_values
|
||||
else 0,
|
||||
},
|
||||
"latency_stats": {
|
||||
"best_ms": min(latency_values, default=0),
|
||||
"average_ms": sum(latency_values) / len(latency_values)
|
||||
if latency_values
|
||||
else 0,
|
||||
"max_ms": max(latency_values, default=0),
|
||||
"median_ms": sorted(latency_values)[len(latency_values) // 2]
|
||||
if latency_values
|
||||
else 0,
|
||||
},
|
||||
"kernel_type_breakdown": {
|
||||
"by_pipeline": pipeline_stats,
|
||||
"by_scheduler": scheduler_stats,
|
||||
"by_data_type": data_type_stats,
|
||||
},
|
||||
"total_problem_configurations": len(best_kernels)
|
||||
if best_kernels
|
||||
else 0,
|
||||
},
|
||||
"kernel_results": self.results,
|
||||
"best_kernels_by_problem": best_kernels or {},
|
||||
}
|
||||
|
||||
with open(filename, "w") as f:
|
||||
json.dump(output_data, f, indent=2)
|
||||
|
||||
print(f"JSON results exported to {filename}")
|
||||
print(f" - Total kernels: {len(self.results)}")
|
||||
print(f" - Successful runs: {len(successful_results)}")
|
||||
print(f" - Best TFLOPS: {max(tflops_values, default=0):.2f}")
|
||||
print(f" - Best bandwidth: {max(bandwidth_values, default=0):.2f} GB/s")
|
||||
print(f" - Best latency: {min(latency_values, default=0):.2f}ms")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="GEMM Multi D Kernel Benchmarking Tool"
|
||||
)
|
||||
parser.add_argument(
|
||||
"build_dir", help="Build directory containing kernel executables"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--problem-sizes",
|
||||
nargs="+",
|
||||
default=["1024,1024,1024", "2048,2048,2048", "4096,4096,4096"],
|
||||
help="Problem sizes as M,N,K tuples",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--split-k", nargs="+", type=int, default=[1], help="Split-K values to test"
|
||||
)
|
||||
parser.add_argument("--verify", action="store_true", help="Enable verification")
|
||||
parser.add_argument(
|
||||
"--csv",
|
||||
default="gemm_multi_d_benchmark_results.csv",
|
||||
help="CSV output filename",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--best", default="best_kernels.txt", help="Best kernels output filename"
|
||||
)
|
||||
parser.add_argument("--verbose", action="store_true", help="Verbose output")
|
||||
parser.add_argument(
|
||||
"--warmup",
|
||||
type=int,
|
||||
default=50,
|
||||
help="Number of warmup iterations (default: 50)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--repeat",
|
||||
type=int,
|
||||
default=100,
|
||||
help="Number of benchmark iterations (default: 100)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--flush-cache",
|
||||
action="store_true",
|
||||
default=True,
|
||||
help="Enable cache flushing (default: True)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--rotating-count",
|
||||
type=int,
|
||||
default=1000,
|
||||
help="Number of iterations to rotate cache (default: 1000)",
|
||||
)
|
||||
parser.add_argument("--json", help="JSON output filename (optional)")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Parse problem sizes
|
||||
problem_sizes = []
|
||||
for size_str in args.problem_sizes:
|
||||
try:
|
||||
m, n, k = map(int, size_str.split(","))
|
||||
problem_sizes.append((m, n, k))
|
||||
except ValueError:
|
||||
print(f"Invalid problem size: {size_str}")
|
||||
return 1
|
||||
|
||||
# Create benchmark instance
|
||||
benchmark = GemmMultiDBenchmark(args.build_dir, verbose=args.verbose)
|
||||
|
||||
# Run benchmark sweep
|
||||
print("Starting GEMM Multi D kernel benchmark sweep...")
|
||||
start_time = time.time()
|
||||
|
||||
best_kernels = benchmark.benchmark_sweep(
|
||||
problem_sizes=problem_sizes,
|
||||
split_k_values=args.split_k,
|
||||
verify=args.verify,
|
||||
warmup=args.warmup,
|
||||
repeat=args.repeat,
|
||||
flush_cache=args.flush_cache,
|
||||
rotating_count=args.rotating_count,
|
||||
)
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
print(f"\nBenchmark completed in {elapsed_time:.2f} seconds")
|
||||
|
||||
# Export results
|
||||
benchmark.export_csv(args.csv)
|
||||
benchmark.export_best_kernels(best_kernels, args.best)
|
||||
|
||||
# Export JSON if requested
|
||||
if args.json:
|
||||
benchmark.export_json(args.json, best_kernels)
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
170
tile_engine/ops/gemm_multi_d/gemm_multi_d_benchmark_single.cpp
Normal file
170
tile_engine/ops/gemm_multi_d/gemm_multi_d_benchmark_single.cpp
Normal file
@@ -0,0 +1,170 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
#include <functional>
|
||||
#include <tuple>
|
||||
#include <exception>
|
||||
#include <sstream>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "gemm_multi_d_profiler.hpp"
|
||||
#include "gemm_multi_d_common.hpp"
|
||||
|
||||
// The kernel header is included via the compile command line with -include flag
|
||||
// It defines SelectedKernel struct and KERNEL_NAME
|
||||
// DataTypeTraits are now defined in gemm_multi_d_common.hpp
|
||||
|
||||
// Create argument parser
|
||||
inline auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("m", "3840", "The value for m dimension. Default is 3840.")
|
||||
.insert("n", "4096", "The value for n dimension. Default is 4096.")
|
||||
.insert("k", "2048", "The value for k dimension. Default is 2048.")
|
||||
.insert("stride_a", "0", "The stride value for tensor A. Default is 0.")
|
||||
.insert("stride_b", "0", "The stride value for tensor B. Default is 0.")
|
||||
.insert("stride_ds", "0", "The stride value for tensor Ds . Default is 0.")
|
||||
.insert("stride_c", "0", "The stride value for tensor C. Default is 0.")
|
||||
.insert("split_k", "1", "The split value for k dimension. Default is 1.")
|
||||
.insert("verify",
|
||||
"1",
|
||||
"for validation on GPU. Default is 1, validation on CPU, as validation on GPU is "
|
||||
"not supported.")
|
||||
.insert("log",
|
||||
"false",
|
||||
"Whether output kernel instance information or not. Possible values are true or "
|
||||
"false. Default is false")
|
||||
.insert(
|
||||
"warmup", "50", "The number of iterations before benchmark the kernel. Default is 50.")
|
||||
.insert(
|
||||
"repeat", "100", "The number of iterations to benchmark the kernel. Default is 100.")
|
||||
.insert("timer",
|
||||
"true",
|
||||
"Whether if the timer is gpu timer or not. Possible values are false or true. "
|
||||
"Default is true.")
|
||||
.insert("init",
|
||||
"0",
|
||||
"The method of tensor initialization. Set to 0 for random, to 1 for linear, or 2 "
|
||||
"for constant(1). Default is 0, random.")
|
||||
.insert("flush_cache",
|
||||
"true",
|
||||
"To flush cache, possible values are true or false. "
|
||||
"Default is false.")
|
||||
.insert("rotating_count", "1000", "number of iterations to rotate the cache. default is 5.")
|
||||
.insert("metric",
|
||||
"0",
|
||||
"Metric with which to measure kernel performance. Set to 0 for latency, 1 for "
|
||||
"tflops, or 2 for bandwidth. Default is 0, latency.")
|
||||
.insert("csv_filename",
|
||||
"",
|
||||
"The filename of benchmark result. Default is empty (no CSV output).")
|
||||
.insert("structured_sparsity",
|
||||
"false",
|
||||
"Whether use sparsity kernel or not. Possible values are true or false. Default is "
|
||||
"false")
|
||||
.insert("json_output",
|
||||
"false",
|
||||
"Whether to output results in JSON format only. Possible values are true or false. "
|
||||
"Default is "
|
||||
"false");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
void benchmark_single(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
// Use DataTypeTraits to get the actual type names from the generated header
|
||||
// The generated header defines ADataType, BDataType, AccDataType, CDataType
|
||||
std::string dtype_a = DataTypeTraits<ADataType>::name;
|
||||
std::string dtype_b = DataTypeTraits<BDataType>::name;
|
||||
std::string dtype_acc = DataTypeTraits<AccDataType>::name;
|
||||
std::string dtype_c = DataTypeTraits<CDataType>::name;
|
||||
std::string dtype_d0 = DataTypeTraits<D0DataType>::name;
|
||||
std::string dtype_d1 = DataTypeTraits<D1DataType>::name;
|
||||
|
||||
// Layout names from the layout types
|
||||
std::string layout_a = ALayout::name;
|
||||
std::string layout_b = BLayout::name;
|
||||
std::string layout_c = CLayout::name;
|
||||
std::string layout_d0 = D0Layout::name;
|
||||
std::string layout_d1 = D1Layout::name;
|
||||
|
||||
// Create GemmMultiDProblem struct
|
||||
GemmMultiDProblem gemm_multi_d_problem{arg_parser.get_int("split_k"),
|
||||
arg_parser.get_int("m"),
|
||||
arg_parser.get_int("n"),
|
||||
arg_parser.get_int("k"),
|
||||
arg_parser.get_int("stride_a"),
|
||||
arg_parser.get_int("stride_b"),
|
||||
arg_parser.get_int("stride_ds"),
|
||||
arg_parser.get_int("stride_ds"),
|
||||
arg_parser.get_int("stride_c"),
|
||||
dtype_a,
|
||||
dtype_b,
|
||||
dtype_d0,
|
||||
dtype_d1,
|
||||
dtype_acc,
|
||||
dtype_c,
|
||||
layout_a,
|
||||
layout_b,
|
||||
layout_d0,
|
||||
layout_d1,
|
||||
layout_c};
|
||||
|
||||
// Create Setting struct
|
||||
Setting setting{arg_parser.get_int("warmup"),
|
||||
arg_parser.get_int("repeat"),
|
||||
arg_parser.get_bool("timer"),
|
||||
arg_parser.get_int("verify"),
|
||||
arg_parser.get_int("init"),
|
||||
arg_parser.get_bool("log"),
|
||||
arg_parser.get_str("csv_filename"),
|
||||
arg_parser.get_bool("flush_cache"),
|
||||
arg_parser.get_int("rotating_count"),
|
||||
arg_parser.get_bool("json_output")};
|
||||
|
||||
// Get the profiler instance
|
||||
auto& profiler = GemmMultiDProfiler::instance(setting);
|
||||
|
||||
try
|
||||
{
|
||||
// Create a lambda that wraps the kernel launch
|
||||
auto kernel_func = [](const ck_tile::GemmMultiDHostArgs<DsDataType::size()>& args,
|
||||
const ck_tile::stream_config& stream) {
|
||||
return SelectedKernel::launch(args, stream);
|
||||
};
|
||||
|
||||
// Benchmark the kernel
|
||||
profiler.benchmark(gemm_multi_d_problem, kernel_func);
|
||||
|
||||
// Select best instance based on metric
|
||||
profiler.select_best_instance(static_cast<Metric>(arg_parser.get_int("metric")));
|
||||
}
|
||||
catch(const std::exception& e)
|
||||
{
|
||||
std::cerr << "Benchmark failed: " << e.what() << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
try
|
||||
{
|
||||
auto [result, parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return EXIT_FAILURE;
|
||||
|
||||
benchmark_single(parser);
|
||||
return 0;
|
||||
}
|
||||
catch(const std::exception& e)
|
||||
{
|
||||
std::cerr << "Error: " << e.what() << "\n";
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
}
|
||||
@@ -1,196 +0,0 @@
|
||||
# SPDX-License-Identifier: MIT
|
||||
# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""
|
||||
Mappings and utility functions for kernel code generation.
|
||||
"""
|
||||
|
||||
DATA_TYPE_MAP = {
|
||||
"fp32": "float",
|
||||
"fp16": "ck_tile::half_t",
|
||||
"bf16": "ck_tile::bf16_t",
|
||||
"int8": "ck_tile::int8_t",
|
||||
"fp8": "ck_tile::fp8_t",
|
||||
"bf8": "ck_tile::bf8_t",
|
||||
"int4": "ck_tile::pk_int4_t",
|
||||
"int32": "ck_tile::int32_t",
|
||||
}
|
||||
|
||||
LAYOUT_MAP = {
|
||||
"r": "ck_tile::tensor_layout::gemm::RowMajor",
|
||||
"c": "ck_tile::tensor_layout::gemm::ColumnMajor",
|
||||
}
|
||||
|
||||
|
||||
# TODO THIS IS NOT SUPPORTED FOR MULTI D AS OF NOW
|
||||
# DEFAULT_EPILOGUE = """
|
||||
# using GemmEpilogue = ck_tile::DefaultGemm2DEpilogue<
|
||||
# ck_tile::DefaultGemm2DEpilogueProblem<ADataType,
|
||||
# BDataType,
|
||||
# AccDataType,
|
||||
# CDataType,
|
||||
# CLayout,
|
||||
# kPadM,
|
||||
# kPadN,
|
||||
# WarpTileM,
|
||||
# WarpTileN,
|
||||
# WarpTileK,
|
||||
# UniversalGemmProblem::TransposeC,
|
||||
# true,
|
||||
# memory_operation>>;
|
||||
# """
|
||||
|
||||
CSHUFFLE_EPILOGUE = """
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
EDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
WarpM,
|
||||
WarpN,
|
||||
WarpTileM,
|
||||
WarpTileN,
|
||||
WarpTileK,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
memory_operation>>;
|
||||
"""
|
||||
|
||||
PIPELINE_MAP = {
|
||||
"mem": ["ck_tile::BaseGemmPipelineAgBgCrMem", "ck_tile::GemmPipelineAgBgCrMem"],
|
||||
"compv3": [
|
||||
"ck_tile::BaseGemmPipelineAgBgCrCompV3",
|
||||
"ck_tile::GemmPipelineAgBgCrCompV3",
|
||||
],
|
||||
"compv4": [
|
||||
"ck_tile::BaseGemmPipelineAgBgCrCompV4",
|
||||
"ck_tile::GemmPipelineAgBgCrCompV4",
|
||||
],
|
||||
}
|
||||
|
||||
SCHEDULER_MAP = {
|
||||
"interwave": "ck_tile::GemmPipelineScheduler::Interwave",
|
||||
"intrawave": "ck_tile::GemmPipelineScheduler::Intrawave",
|
||||
}
|
||||
|
||||
# EPILOGUE_MAP = {"default": DEFAULT_EPILOGUE, "cshuffle": CSHUFFLE_EPILOGUE}
|
||||
|
||||
EPILOGUE_MAP = {"cshuffle": CSHUFFLE_EPILOGUE}
|
||||
|
||||
|
||||
def BOOL_MAP(b_):
|
||||
return {True: "true", False: "false"}[bool(b_)]
|
||||
|
||||
|
||||
# Can add some more supported combinations
|
||||
warp_tile_supported_combinations = {
|
||||
"gfx90a": {
|
||||
"fp16_fp16_fp16": [
|
||||
[32, 32, 8],
|
||||
[16, 16, 16],
|
||||
[32, 32, 16],
|
||||
[16, 16, 32],
|
||||
[4, 64, 16],
|
||||
[64, 4, 16],
|
||||
],
|
||||
"bf16_bf16_bf16": [
|
||||
[32, 32, 8],
|
||||
[16, 16, 16],
|
||||
[32, 32, 16],
|
||||
[16, 16, 32],
|
||||
[4, 64, 16],
|
||||
[64, 4, 16],
|
||||
],
|
||||
"fp8_fp8_fp16": [[32, 32, 16], [32, 32, 32]],
|
||||
"bf8_bf8_fp16": [[32, 32, 16], [32, 32, 32]],
|
||||
},
|
||||
"gfx942": {
|
||||
"fp16_fp16_fp16": [
|
||||
[32, 32, 8],
|
||||
[16, 16, 16],
|
||||
[32, 32, 16],
|
||||
[16, 16, 32],
|
||||
[4, 64, 16],
|
||||
[64, 4, 16],
|
||||
],
|
||||
"bf16_bf16_bf16": [
|
||||
[32, 32, 8],
|
||||
[16, 16, 16],
|
||||
[32, 32, 16],
|
||||
[16, 16, 32],
|
||||
[4, 64, 16],
|
||||
[64, 4, 16],
|
||||
],
|
||||
"fp8_fp8_fp16": [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64]],
|
||||
"bf8_bf8_fp16": [[32, 32, 16], [32, 32, 32], [16, 16, 64], [16, 16, 32]],
|
||||
"int8_int8_int32": [[16, 16, 32], [32, 32, 16]],
|
||||
},
|
||||
"gfx950": {
|
||||
"fp16_fp16_fp16": [
|
||||
[32, 32, 8],
|
||||
[16, 16, 16],
|
||||
[32, 32, 16],
|
||||
[16, 16, 32],
|
||||
[4, 64, 16],
|
||||
[64, 4, 16],
|
||||
],
|
||||
"bf16_bf16_bf16": [
|
||||
[32, 32, 8],
|
||||
[16, 16, 16],
|
||||
[32, 32, 16],
|
||||
[16, 16, 32],
|
||||
[4, 64, 16],
|
||||
[64, 4, 16],
|
||||
],
|
||||
"fp8_fp8_fp16": [
|
||||
[32, 32, 16],
|
||||
[32, 32, 32],
|
||||
[16, 16, 32],
|
||||
[16, 16, 64],
|
||||
[16, 16, 128],
|
||||
[32, 32, 64],
|
||||
],
|
||||
"bf8_bf8_fp16": [
|
||||
[32, 32, 16],
|
||||
[32, 32, 32],
|
||||
[16, 16, 64],
|
||||
[16, 16, 32],
|
||||
[16, 16, 128],
|
||||
[32, 32, 64],
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
# Remove some unsupported combinations
|
||||
trait_unsupported_combinations = {
|
||||
("compv3", "cshuffle", "interwave"),
|
||||
("compv3", "default", "interwave"),
|
||||
("compv4", "cshuffle", "interwave"),
|
||||
("compv4", "default", "interwave"),
|
||||
}
|
||||
|
||||
|
||||
ELEMENT_SIZE_MAP = {
|
||||
"fp16": 2,
|
||||
"bf16": 2,
|
||||
"int8": 1,
|
||||
"fp8": 1,
|
||||
"bf8": 1,
|
||||
"int4": 0.5,
|
||||
"int32": 4,
|
||||
}
|
||||
|
||||
|
||||
def element_size(data_type: str) -> float:
|
||||
"""Calculate the size (in bytes) of a single element for given data type."""
|
||||
data_type = data_type.lower()
|
||||
if data_type not in ELEMENT_SIZE_MAP:
|
||||
raise ValueError(f"Unsupported data type: {data_type}")
|
||||
return ELEMENT_SIZE_MAP[data_type]
|
||||
100
tile_engine/ops/gemm_multi_d/gemm_multi_d_common.hpp
Normal file
100
tile_engine/ops/gemm_multi_d/gemm_multi_d_common.hpp
Normal file
@@ -0,0 +1,100 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/pk_int4.hpp"
|
||||
|
||||
//[TODO] This can be moved to commons
|
||||
// DataTypeTraits for all supported types
|
||||
template <typename T>
|
||||
struct DataTypeTraits;
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<float>
|
||||
{
|
||||
static constexpr const char* name = "fp32";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<double>
|
||||
{
|
||||
static constexpr const char* name = "fp64";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::half_t>
|
||||
{
|
||||
static constexpr const char* name = "fp16";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::bf16_t>
|
||||
{
|
||||
static constexpr const char* name = "bf16";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::fp8_t>
|
||||
{
|
||||
static constexpr const char* name = "fp8";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::bf8_t>
|
||||
{
|
||||
static constexpr const char* name = "bf8";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::int8_t>
|
||||
{
|
||||
static constexpr const char* name = "int8";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::int32_t>
|
||||
{
|
||||
static constexpr const char* name = "int32";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::pk_int4_t>
|
||||
{
|
||||
static constexpr const char* name = "pk_int4_t";
|
||||
};
|
||||
|
||||
// Helper function to determine if a layout is row-major
|
||||
template <typename Layout>
|
||||
constexpr auto is_row_major(Layout)
|
||||
{
|
||||
return ck_tile::bool_constant<std::is_same_v<Layout, ck_tile::tensor_layout::gemm::RowMajor>>{};
|
||||
}
|
||||
|
||||
// Structure to hold kernel traits for dispatcher
|
||||
struct KernelTraits
|
||||
{
|
||||
std::string pipeline; // compv3, compv4, mem
|
||||
std::string scheduler; // intrawave, interwave
|
||||
std::string epilogue; // cshuffle, default
|
||||
bool pad_m;
|
||||
bool pad_n;
|
||||
bool pad_k;
|
||||
bool persistent;
|
||||
|
||||
// Constructor with defaults
|
||||
KernelTraits()
|
||||
: pipeline("compv3"),
|
||||
scheduler("intrawave"),
|
||||
epilogue("cshuffle"),
|
||||
pad_m(false),
|
||||
pad_n(false),
|
||||
pad_k(false),
|
||||
persistent(false)
|
||||
{
|
||||
}
|
||||
};
|
||||
@@ -1,250 +0,0 @@
|
||||
# SPDX-License-Identifier: MIT
|
||||
# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""
|
||||
Handles loading, parsing, and validation of JSON and Argument configuration parameters.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Union, Type
|
||||
import json
|
||||
|
||||
|
||||
@dataclass
|
||||
class EnumConfigParam:
|
||||
"""Represents an enumeration-type configuration parameter"""
|
||||
|
||||
values: List[Union[int, str, bool]]
|
||||
|
||||
|
||||
@dataclass
|
||||
class RangeConfigParam:
|
||||
"""Represents a numeric range-type configuration parameter"""
|
||||
|
||||
min: int
|
||||
max: int
|
||||
step: int
|
||||
exclude: Optional[List[int]]
|
||||
|
||||
def generate_candidates(self) -> List[int]:
|
||||
"""Generates valid candidates after applying range constraints"""
|
||||
|
||||
if self.min > self.max:
|
||||
raise ValueError(f"Invalid range: min({self.min}) > max({self.max})")
|
||||
if self.step <= 0:
|
||||
raise ValueError(f"Step must be positive, got {self.step}")
|
||||
|
||||
candidates = list(range(self.min, self.max + 1, self.step))
|
||||
|
||||
if hasattr(self, "exclude") and self.exclude:
|
||||
if not isinstance(self.exclude, list):
|
||||
raise TypeError("exclude must be list type")
|
||||
exclude_set = set(self.exclude)
|
||||
candidates = [x for x in candidates if x not in exclude_set]
|
||||
|
||||
if not candidates:
|
||||
raise ValueError(
|
||||
f"No valid candidates for range [{self.min}-{self.max}] "
|
||||
f"with step {self.step} and excludes {self.exclude}"
|
||||
)
|
||||
|
||||
return candidates
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataType:
|
||||
"""Configuration class for data type parameter."""
|
||||
|
||||
a_datatype: str
|
||||
b_datatype: str
|
||||
e_datatype: str
|
||||
d0_datatype: str
|
||||
d1_datatype: str
|
||||
ds_datatype: List[str]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Layout:
|
||||
"""Configuration class for Layout parameter."""
|
||||
|
||||
a_layout: str
|
||||
b_layout: str
|
||||
e_layout: str
|
||||
d0_layout: str
|
||||
d1_layout: str
|
||||
ds_layout: List[str]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ArgumentConfig:
|
||||
"""Configuration class for Argument parameter."""
|
||||
|
||||
datatypes: DataType
|
||||
layouts: Layout
|
||||
function_name: str
|
||||
|
||||
@classmethod
|
||||
def from_args(
|
||||
cls: Type["ArgumentConfig"],
|
||||
datatype: str,
|
||||
layout: str,
|
||||
elementwise_function: str,
|
||||
) -> "ArgumentConfig":
|
||||
"""configuration loader with validation controls"""
|
||||
|
||||
datatypes = DataType(
|
||||
a_datatype=datatype,
|
||||
b_datatype=datatype,
|
||||
e_datatype=datatype,
|
||||
d0_datatype=datatype,
|
||||
d1_datatype=datatype,
|
||||
ds_datatype=[datatype, datatype],
|
||||
)
|
||||
|
||||
layout_parts = layout.lower()
|
||||
assert len(layout_parts) == 4, (
|
||||
f"Invalid layout string: {layout} (must be 4 characters like 'rcrr' where r stands for row major and c stands for column major)"
|
||||
)
|
||||
assert layout_parts[0] in ("r", "c"), (
|
||||
f"Invalid matrix_a layout: {layout_parts[0]} (must be 'r' for row major or or 'c' for column major)"
|
||||
)
|
||||
assert layout_parts[1] in ("r", "c"), (
|
||||
f"Invalid matrix_b layout: {layout_parts[1]} (must be 'r' for row major or or 'c' for column major)"
|
||||
)
|
||||
assert layout_parts[2] == "r", (
|
||||
f"Invalid matrix_e layout: {layout_parts[2]} (must be 'r' only as currently we are supporting only row major)"
|
||||
)
|
||||
assert layout_parts[3] == "r", (
|
||||
f"Invalid D dimension layout: {layout_parts[3]} (must be 'r' only as currently we are supporting only row major)"
|
||||
)
|
||||
|
||||
layouts = Layout(
|
||||
a_layout=layout[0],
|
||||
b_layout=layout[1],
|
||||
e_layout=layout[2],
|
||||
d0_layout=layout[3],
|
||||
d1_layout=layout[3],
|
||||
ds_layout=[layout[3], layout[3]],
|
||||
)
|
||||
# Elementwise function name validation
|
||||
valid_functions = ["mul", "add", "passthrough"]
|
||||
if elementwise_function not in valid_functions:
|
||||
raise ValueError(
|
||||
f"Invalid elementwise function: {elementwise_function}. "
|
||||
f"Valid options are: {', '.join(valid_functions)}"
|
||||
)
|
||||
|
||||
# Set the function name based on the elementwise function
|
||||
if elementwise_function == "mul":
|
||||
function_name = "MultiDMultiply"
|
||||
elif elementwise_function == "add":
|
||||
function_name = "MultiDAdd"
|
||||
elif elementwise_function == "passthrough":
|
||||
function_name = "PassThrough" # TODO Change this
|
||||
|
||||
return cls(datatypes=datatypes, layouts=layouts, function_name=function_name)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TileConfig:
|
||||
"""Configuration class for tile parameter."""
|
||||
|
||||
tile_m: Union[EnumConfigParam, RangeConfigParam]
|
||||
tile_n: Union[EnumConfigParam, RangeConfigParam]
|
||||
tile_k: Union[EnumConfigParam, RangeConfigParam]
|
||||
|
||||
warp_m: Union[EnumConfigParam, RangeConfigParam]
|
||||
warp_n: Union[EnumConfigParam, RangeConfigParam]
|
||||
warp_k: Union[EnumConfigParam, RangeConfigParam]
|
||||
|
||||
warp_tile_m: Union[EnumConfigParam, RangeConfigParam]
|
||||
warp_tile_n: Union[EnumConfigParam, RangeConfigParam]
|
||||
warp_tile_k: Union[EnumConfigParam, RangeConfigParam]
|
||||
|
||||
|
||||
@dataclass
|
||||
class TraitConfig:
|
||||
"""Configuration class for kernel traits."""
|
||||
|
||||
pipeline: EnumConfigParam
|
||||
scheduler: EnumConfigParam
|
||||
epilogue: EnumConfigParam
|
||||
pad_m: EnumConfigParam
|
||||
pad_n: EnumConfigParam
|
||||
pad_k: EnumConfigParam
|
||||
|
||||
|
||||
@dataclass
|
||||
class JsonConfig:
|
||||
"""Configuration class for JSON parameter."""
|
||||
|
||||
tile_config: TileConfig
|
||||
trait_config: TraitConfig
|
||||
|
||||
@classmethod
|
||||
def from_json(cls: Type["JsonConfig"], filepath: str) -> "JsonConfig":
|
||||
"""JSON configuration loader with validation controls"""
|
||||
config_path = Path(filepath)
|
||||
|
||||
try:
|
||||
if not config_path.exists():
|
||||
raise FileNotFoundError(f"Config file {filepath} not found")
|
||||
|
||||
with config_path.open("r") as f:
|
||||
config_dict = json.load(f)
|
||||
|
||||
# Parse tile config
|
||||
def create_param(param_dict):
|
||||
if "values" in param_dict:
|
||||
return EnumConfigParam(values=param_dict["values"])
|
||||
else:
|
||||
return RangeConfigParam(
|
||||
min=param_dict["min"],
|
||||
max=param_dict["max"],
|
||||
step=param_dict["step"],
|
||||
exclude=param_dict.get("exclude", []),
|
||||
)
|
||||
|
||||
tile_config = TileConfig(
|
||||
tile_m=create_param(config_dict["tile_config"]["tile_m"]),
|
||||
tile_n=create_param(config_dict["tile_config"]["tile_n"]),
|
||||
tile_k=create_param(config_dict["tile_config"]["tile_k"]),
|
||||
warp_m=create_param(config_dict["tile_config"]["warp_m"]),
|
||||
warp_n=create_param(config_dict["tile_config"]["warp_n"]),
|
||||
warp_k=create_param(config_dict["tile_config"]["warp_k"]),
|
||||
warp_tile_m=create_param(config_dict["tile_config"]["warp_tile_m"]),
|
||||
warp_tile_n=create_param(config_dict["tile_config"]["warp_tile_n"]),
|
||||
warp_tile_k=create_param(config_dict["tile_config"]["warp_tile_k"]),
|
||||
)
|
||||
|
||||
# Parse trait config
|
||||
trait_config = TraitConfig(
|
||||
pipeline=EnumConfigParam(
|
||||
values=config_dict["trait_config"]["pipeline"]["values"]
|
||||
),
|
||||
scheduler=EnumConfigParam(
|
||||
values=config_dict["trait_config"]["scheduler"]["values"]
|
||||
),
|
||||
epilogue=EnumConfigParam(
|
||||
values=config_dict["trait_config"]["epilogue"]["values"]
|
||||
),
|
||||
pad_m=EnumConfigParam(
|
||||
values=config_dict["trait_config"]["pad_m"]["values"]
|
||||
),
|
||||
pad_n=EnumConfigParam(
|
||||
values=config_dict["trait_config"]["pad_n"]["values"]
|
||||
),
|
||||
pad_k=EnumConfigParam(
|
||||
values=config_dict["trait_config"]["pad_k"]["values"]
|
||||
),
|
||||
)
|
||||
|
||||
return cls(tile_config=tile_config, trait_config=trait_config)
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(f"Invalid JSON format: {str(e)}")
|
||||
except KeyError as e:
|
||||
raise KeyError(f"Missing required configuration field: {str(e)}")
|
||||
@@ -1,164 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstring>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "gemm_multi_d_dispatcher.hpp"
|
||||
#include "gemm_multi_d_common.hpp"
|
||||
|
||||
template <typename T>
|
||||
struct DataTypeTraits;
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<float>
|
||||
{
|
||||
static constexpr const char* name = "fp32";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<double>
|
||||
{
|
||||
static constexpr const char* name = "fp64";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::half_t>
|
||||
{
|
||||
static constexpr const char* name = "fp16";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::bf16_t>
|
||||
{
|
||||
static constexpr const char* name = "bf16";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::fp8_t>
|
||||
{
|
||||
static constexpr const char* name = "fp8";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::bf8_t>
|
||||
{
|
||||
static constexpr const char* name = "bf8";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::int8_t>
|
||||
{
|
||||
static constexpr const char* name = "int8";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::int32_t>
|
||||
{
|
||||
static constexpr const char* name = "int32";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::pk_int4_t>
|
||||
{
|
||||
static constexpr const char* name = "pk_int4_t";
|
||||
};
|
||||
|
||||
template <typename Layout>
|
||||
static constexpr inline auto is_row_major(Layout layout_)
|
||||
{
|
||||
return ck_tile::bool_constant<std::is_same_v<ck_tile::remove_cvref_t<decltype(layout_)>,
|
||||
ck_tile::tensor_layout::gemm::RowMajor>>{};
|
||||
}
|
||||
|
||||
inline auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("m", "3840", "The value for m dimension. Default is 3840.")
|
||||
.insert("n", "4096", "The value for n dimension. Default is 4096.")
|
||||
.insert("k", "2048", "The value for k dimension. Default is 2048.")
|
||||
.insert("stride_a", "0", "The stride value for tensor A. Default is 0.")
|
||||
.insert("stride_b", "0", "The stride value for tensor B. Default is 0.")
|
||||
.insert("stride_ds", "0", "The stride value for tensor Ds Default is 0.")
|
||||
.insert("stride_e", "0", "The stride value for tensor E Default is 0.")
|
||||
.insert("split_k", "1", "The split value for k dimension. Default is 1.")
|
||||
.insert("verify",
|
||||
"1",
|
||||
"The type of validation. Set to 0 for no validation, 1 for validation on CPU, or 2 "
|
||||
"for validation on GPU. Default is 1, validation on CPU, as validation on GPU is "
|
||||
"not supported.")
|
||||
.insert("log",
|
||||
"false",
|
||||
"Wether output kernel instance information or not. Possible values are true or "
|
||||
"false. Default is false")
|
||||
.insert("warmup",
|
||||
"50",
|
||||
"The number of iterations before benchmarking the kernel. Default is 50.")
|
||||
.insert("repeat",
|
||||
"100",
|
||||
"The number of iterations for benchmarking the kernel. Default is 100.")
|
||||
.insert("timer",
|
||||
"true",
|
||||
"Indicates whether the timer is a GPU timer. Possible values are true or false. "
|
||||
"Default is true.")
|
||||
.insert("init",
|
||||
"0",
|
||||
"The method of tensor initialization. Set to 0 for random, to 1 for linear, or 2 "
|
||||
"for constant(1). Default is 0, random.")
|
||||
.insert("flush_cache",
|
||||
"false",
|
||||
"To flush cache, possible values are true or false. "
|
||||
"Default is false.")
|
||||
.insert("rotating_count", "5", "number of iterations to rotate the cache. default is 5.")
|
||||
.insert("metric",
|
||||
"0",
|
||||
"Metric with which to measure kernel performance. Set to 0 for latency, 1 for "
|
||||
"tflops, or 2 for bandwidth. Default is 0, latency.")
|
||||
.insert("csv_filename",
|
||||
"gemm_multi_d_kernel",
|
||||
"The filename of benchmark result. Default is set to gemm_multi_d_kernel.")
|
||||
.insert(
|
||||
"pipeline",
|
||||
"compv3",
|
||||
"The type of pipeline. Possible values are compv3, compv4 or mem. Default is compv3.")
|
||||
.insert("scheduler",
|
||||
"intrawave",
|
||||
"The type of pipeline. Possible values are compv3, compv4 or mem. Default is "
|
||||
"compv3.")
|
||||
.insert(
|
||||
"epilogue",
|
||||
"cshuffle",
|
||||
"The type of epilogue. Possible values are cshuffle or default. Default is cshuffle.")
|
||||
.insert("pad_m",
|
||||
"false",
|
||||
"Whether pad or not in m direction. Possible values are true or false. Default is "
|
||||
"false.")
|
||||
.insert("pad_n",
|
||||
"false",
|
||||
"Whether pad or not in n direction. Possible values are true or false. Default is "
|
||||
"false.")
|
||||
.insert("pad_k",
|
||||
"false",
|
||||
"Whether pad or not in k direction. Possible values are true or false. Default is "
|
||||
"false.");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
auto get_kernel_func_by_trait(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
KernelTraits trait;
|
||||
trait.pipeline = arg_parser.get_str("pipeline");
|
||||
trait.scheduler = arg_parser.get_str("scheduler");
|
||||
trait.epilogue = arg_parser.get_str("epilogue");
|
||||
trait.pad_m = arg_parser.get_bool("pad_m");
|
||||
trait.pad_n = arg_parser.get_bool("pad_n");
|
||||
trait.pad_k = arg_parser.get_bool("pad_k");
|
||||
|
||||
return GemmMultiDDispatcher::dispatch(trait);
|
||||
}
|
||||
1454
tile_engine/ops/gemm_multi_d/gemm_multi_d_instance_builder.py
Executable file → Normal file
1454
tile_engine/ops/gemm_multi_d/gemm_multi_d_instance_builder.py
Executable file → Normal file
File diff suppressed because it is too large
Load Diff
@@ -9,7 +9,7 @@
|
||||
|
||||
#include "ck_tile/host/device_prop.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "benchmark_gemm_multi_d.hpp"
|
||||
#include "gemm_multi_d_benchmark.hpp"
|
||||
|
||||
class GemmMultiDProfiler
|
||||
{
|
||||
@@ -20,6 +20,25 @@ class GemmMultiDProfiler
|
||||
return instance;
|
||||
}
|
||||
|
||||
// Overload for single kernel benchmarking
|
||||
void benchmark(GemmMultiDProblem& gemm_multi_d_problem,
|
||||
std::function<float(const ck_tile::GemmMultiDHostArgs<DsDataType::size()>&,
|
||||
const ck_tile::stream_config&)> kernel_func)
|
||||
{
|
||||
// Create a vector with a single callable that returns both name and time
|
||||
std::vector<std::function<std::tuple<std::string, float>(
|
||||
ck_tile::GemmMultiDHostArgs<DsDataType::size()>&, const ck_tile::stream_config&)>>
|
||||
callables;
|
||||
|
||||
callables.push_back([kernel_func](ck_tile::GemmMultiDHostArgs<DsDataType::size()>& args,
|
||||
const ck_tile::stream_config& stream) {
|
||||
float time = kernel_func(args, stream);
|
||||
return std::make_tuple(std::string(KERNEL_NAME), time);
|
||||
});
|
||||
|
||||
benchmark(gemm_multi_d_problem, callables);
|
||||
}
|
||||
|
||||
void benchmark(
|
||||
GemmMultiDProblem& gemm_multi_d_problem,
|
||||
std::vector<std::function<std::tuple<std::string, float>(
|
||||
@@ -30,7 +49,7 @@ class GemmMultiDProfiler
|
||||
const BLayout layout_b = BLayout{};
|
||||
const D0Layout layout_d0 = D0Layout{};
|
||||
const D1Layout layout_d1 = D1Layout{};
|
||||
const ELayout layout_e = ELayout{};
|
||||
const CLayout layout_c = CLayout{};
|
||||
|
||||
gemm_multi_d_problem.stride_a_ = ck_tile::get_default_stride(gemm_multi_d_problem.m_,
|
||||
gemm_multi_d_problem.k_,
|
||||
@@ -50,10 +69,10 @@ class GemmMultiDProfiler
|
||||
gemm_multi_d_problem.n_,
|
||||
gemm_multi_d_problem.stride_d1_,
|
||||
is_row_major(layout_d1));
|
||||
gemm_multi_d_problem.stride_e_ = ck_tile::get_default_stride(gemm_multi_d_problem.m_,
|
||||
gemm_multi_d_problem.stride_c_ = ck_tile::get_default_stride(gemm_multi_d_problem.m_,
|
||||
gemm_multi_d_problem.n_,
|
||||
gemm_multi_d_problem.stride_e_,
|
||||
is_row_major(layout_e));
|
||||
gemm_multi_d_problem.stride_c_,
|
||||
is_row_major(layout_c));
|
||||
|
||||
ck_tile::HostTensor<ADataType> a_m_k(
|
||||
ck_tile::host_tensor_descriptor(gemm_multi_d_problem.m_,
|
||||
@@ -75,30 +94,30 @@ class GemmMultiDProfiler
|
||||
gemm_multi_d_problem.n_,
|
||||
gemm_multi_d_problem.stride_d1_,
|
||||
is_row_major(layout_d1)));
|
||||
ck_tile::HostTensor<EDataType> e_m_n_device_result(
|
||||
ck_tile::HostTensor<CDataType> c_m_n_dev_result(
|
||||
ck_tile::host_tensor_descriptor(gemm_multi_d_problem.m_,
|
||||
gemm_multi_d_problem.n_,
|
||||
gemm_multi_d_problem.stride_e_,
|
||||
is_row_major(layout_e)));
|
||||
gemm_multi_d_problem.stride_c_,
|
||||
is_row_major(layout_c)));
|
||||
|
||||
ck_tile::FillUniformDistribution<ADataType>{-5.f, 5.f}(a_m_k);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-5.f, 5.f}(b_k_n);
|
||||
ck_tile::FillUniformDistribution<D0DataType>{-1.f, 1.f}(d0_m_n);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-1.f, 1.f}(d1_m_n);
|
||||
ck_tile::FillUniformDistribution<D1DataType>{-1.f, 1.f}(d1_m_n);
|
||||
|
||||
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem d0_m_n_dev_buf(d0_m_n.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem d1_m_n_dev_buf(d1_m_n.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem e_m_n_dev_buf(e_m_n_device_result.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes());
|
||||
|
||||
a_m_k_dev_buf.ToDevice(a_m_k.mData.data());
|
||||
b_k_n_dev_buf.ToDevice(b_k_n.mData.data());
|
||||
d0_m_n_dev_buf.ToDevice(d0_m_n.mData.data());
|
||||
d1_m_n_dev_buf.ToDevice(d1_m_n.mData.data());
|
||||
|
||||
e_m_n_dev_buf.SetZero();
|
||||
e_m_n_device_result.SetZero();
|
||||
c_m_n_dev_buf.SetZero();
|
||||
c_m_n_dev_result.SetZero();
|
||||
|
||||
std::array<const void*, DsDataType::size()> ds_ptr_buf = {d0_m_n_dev_buf.GetDeviceBuffer(),
|
||||
d1_m_n_dev_buf.GetDeviceBuffer()};
|
||||
@@ -110,7 +129,7 @@ class GemmMultiDProfiler
|
||||
a_m_k_dev_buf.GetDeviceBuffer(),
|
||||
b_k_n_dev_buf.GetDeviceBuffer(),
|
||||
ds_ptr_buf,
|
||||
e_m_n_dev_buf.GetDeviceBuffer(),
|
||||
c_m_n_dev_buf.GetDeviceBuffer(),
|
||||
gemm_multi_d_problem.split_k_,
|
||||
gemm_multi_d_problem.m_,
|
||||
gemm_multi_d_problem.n_,
|
||||
@@ -118,19 +137,19 @@ class GemmMultiDProfiler
|
||||
gemm_multi_d_problem.stride_a_,
|
||||
gemm_multi_d_problem.stride_b_,
|
||||
stridesDs,
|
||||
gemm_multi_d_problem.stride_e_,
|
||||
gemm_multi_d_problem.stride_c_,
|
||||
};
|
||||
|
||||
ck_tile::HostTensor<EDataType> e_m_n_host_result(
|
||||
ck_tile::HostTensor<CDataType> c_m_n_host_result(
|
||||
ck_tile::host_tensor_descriptor(gemm_multi_d_problem.m_,
|
||||
gemm_multi_d_problem.n_,
|
||||
gemm_multi_d_problem.stride_e_,
|
||||
is_row_major(layout_e)));
|
||||
gemm_multi_d_problem.stride_c_,
|
||||
is_row_major(layout_c)));
|
||||
|
||||
if(setting_.verify_)
|
||||
{
|
||||
gemm_multi_d_host_reference(
|
||||
setting_.verify_, a_m_k, b_k_n, d0_m_n, d1_m_n, e_m_n_host_result);
|
||||
setting_.verify_, a_m_k, b_k_n, d0_m_n, d1_m_n, c_m_n_host_result);
|
||||
}
|
||||
|
||||
for(auto& callable : callables)
|
||||
@@ -139,54 +158,58 @@ class GemmMultiDProfiler
|
||||
callable(gemm_multi_d_args,
|
||||
ck_tile::stream_config{
|
||||
nullptr, true, setting_.log_, setting_.n_warmup_, setting_.n_repeat_});
|
||||
|
||||
auto [kernel_name, execution_time] = kernel_run_result;
|
||||
|
||||
process_result(gemm_multi_d_problem,
|
||||
e_m_n_dev_buf,
|
||||
e_m_n_host_result,
|
||||
e_m_n_device_result,
|
||||
c_m_n_dev_buf,
|
||||
c_m_n_host_result,
|
||||
c_m_n_dev_result,
|
||||
kernel_run_result);
|
||||
}
|
||||
}
|
||||
|
||||
void process_result(const GemmMultiDProblem& gemm_multi_d_problem,
|
||||
ck_tile::DeviceMem& e_m_n_dev_buf,
|
||||
ck_tile::HostTensor<EDataType>& e_m_n_host_result,
|
||||
ck_tile::HostTensor<EDataType>& e_m_n_dev_result,
|
||||
ck_tile::DeviceMem& c_m_n_dev_buf,
|
||||
ck_tile::HostTensor<CDataType>& c_m_n_host_result,
|
||||
ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
|
||||
const std::tuple<std::string, float>& kernel_run_result)
|
||||
{
|
||||
auto [name, avg_time] = kernel_run_result;
|
||||
|
||||
KernelInstance kernel_instance{name, gemm_multi_d_problem, {-1.0f, -1.0f, -1.0f}};
|
||||
|
||||
static constexpr ck_tile::index_t NumDTensor = DsDataType::size();
|
||||
std::size_t flop = 0, num_byte = 0;
|
||||
flop += std::size_t(2) * gemm_multi_d_problem.m_ * gemm_multi_d_problem.n_ *
|
||||
gemm_multi_d_problem.k_;
|
||||
ck_tile::static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
// compute performance metric
|
||||
std::size_t flop = std::size_t(2) * gemm_multi_d_problem.m_ * gemm_multi_d_problem.n_ *
|
||||
gemm_multi_d_problem.k_;
|
||||
std::size_t num_byte =
|
||||
sizeof(ADataType) * gemm_multi_d_problem.m_ * gemm_multi_d_problem.k_ +
|
||||
sizeof(BDataType) * gemm_multi_d_problem.n_ * gemm_multi_d_problem.k_ +
|
||||
sizeof(CDataType) * gemm_multi_d_problem.m_ * gemm_multi_d_problem.n_;
|
||||
|
||||
// Dth Dimension Updates
|
||||
ck_tile::static_for<0, DsDataType::size(), 1>{}([&](auto i) {
|
||||
num_byte += sizeof(ck_tile::remove_cvref_t<std::tuple_element_t<i, DsDataType>>) *
|
||||
gemm_multi_d_problem.m_ * gemm_multi_d_problem.n_;
|
||||
flop += sizeof(ck_tile::remove_cvref_t<std::tuple_element_t<i, DsDataType>>) *
|
||||
gemm_multi_d_problem.m_ * gemm_multi_d_problem.n_;
|
||||
});
|
||||
num_byte += sizeof(ADataType) * gemm_multi_d_problem.m_ * gemm_multi_d_problem.k_ +
|
||||
sizeof(BDataType) * gemm_multi_d_problem.k_ * gemm_multi_d_problem.n_ +
|
||||
sizeof(EDataType) * gemm_multi_d_problem.m_ * gemm_multi_d_problem.n_;
|
||||
|
||||
// update
|
||||
kernel_instance.perf_result_.latency_ = avg_time;
|
||||
kernel_instance.perf_result_.tflops_ = static_cast<float>(flop) / 1.E9 / avg_time;
|
||||
kernel_instance.perf_result_.bandwidth_ = num_byte / 1.E6 / avg_time;
|
||||
|
||||
if(setting_.log_ > 0)
|
||||
if(setting_.log_ > 0 && !setting_.json_output_)
|
||||
{
|
||||
std::cout << kernel_instance << std::endl;
|
||||
}
|
||||
|
||||
e_m_n_dev_buf.FromDevice(e_m_n_dev_result.data());
|
||||
// verify result
|
||||
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
|
||||
bool verified_correct =
|
||||
!setting_.verify_ ||
|
||||
compare(name, gemm_multi_d_problem.k_, e_m_n_dev_result, e_m_n_host_result);
|
||||
!setting_.verify_ || compare(name,
|
||||
gemm_multi_d_problem.k_,
|
||||
1, // Multi d currently supports only k_batch = 1
|
||||
c_m_n_dev_result,
|
||||
c_m_n_host_result);
|
||||
|
||||
if(verified_correct)
|
||||
{
|
||||
@@ -197,8 +220,9 @@ class GemmMultiDProfiler
|
||||
std::cout << "Verification failed, skip kernel: " << name << std::endl;
|
||||
}
|
||||
|
||||
e_m_n_dev_buf.SetZero();
|
||||
e_m_n_dev_result.SetZero();
|
||||
// clear tensor
|
||||
c_m_n_dev_buf.SetZero();
|
||||
c_m_n_dev_result.SetZero();
|
||||
}
|
||||
|
||||
KernelInstance select_best_instance(Metric metric)
|
||||
@@ -213,10 +237,18 @@ class GemmMultiDProfiler
|
||||
b.perf_result_, a.perf_result_, metric);
|
||||
});
|
||||
|
||||
std::cout << "**********************************" << std::endl;
|
||||
std::cout << "According to given metrics: " << get_metric_name(metric) << "\n"
|
||||
<< "The best kernel instance is: " << kernel_instance << std::endl;
|
||||
std::cout << "**********************************" << std::endl;
|
||||
if(setting_.json_output_)
|
||||
{
|
||||
// Output clean JSON only
|
||||
std::cout << kernel_instance << std::endl;
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "**********************************" << std::endl;
|
||||
std::cout << "According to given metrics: " << get_metric_name(metric) << "\n"
|
||||
<< "Current kernel performance is: " << kernel_instance << std::endl;
|
||||
std::cout << "**********************************" << std::endl;
|
||||
}
|
||||
|
||||
if(!setting_.csv_filename_.empty())
|
||||
{
|
||||
@@ -244,16 +276,13 @@ class GemmMultiDProfiler
|
||||
file << get_rocm_version() << "," << ck_tile::get_device_name() << ","
|
||||
<< problem.split_k_ << "," << problem.m_ << "," << problem.n_ << ","
|
||||
<< problem.k_ << "," << problem.stride_a_ << "," << problem.stride_b_ << ","
|
||||
<< problem.stride_d0_ << "," << problem.stride_d1_ << "," << problem.stride_e_
|
||||
<< "," << problem.dtype_a_ << "," << problem.dtype_b_ << ","
|
||||
<< problem.dtype_d0_ << "," << problem.dtype_d1_ << "," << problem.dtype_acc_
|
||||
<< "," << problem.dtype_e_ << "," << problem.layout_a_ << ","
|
||||
<< problem.layout_b_ << "," << problem.layout_d0_ << "," << problem.layout_d1_
|
||||
<< "," << problem.layout_e_ << "," << "," << name << "," << std::fixed
|
||||
<< std::setprecision(4) << perf.latency_ << "," << std::fixed
|
||||
<< std::setprecision(4) << perf.tflops_ << "," << std::fixed
|
||||
<< std::setprecision(4) << perf.bandwidth_ << "," << get_metric_name(metric)
|
||||
<< "\n";
|
||||
<< problem.stride_c_ << "," << problem.dtype_a_ << "," << problem.dtype_b_
|
||||
<< "," << problem.dtype_acc_ << "," << problem.dtype_c_ << ","
|
||||
<< problem.layout_a_ << "," << problem.layout_b_ << "," << problem.layout_c_
|
||||
<< "," << name << "," << std::fixed << std::setprecision(4) << perf.latency_
|
||||
<< "," << std::fixed << std::setprecision(4) << perf.tflops_ << ","
|
||||
<< std::fixed << std::setprecision(4) << perf.bandwidth_ << ","
|
||||
<< get_metric_name(metric) << "\n";
|
||||
|
||||
if(!file)
|
||||
{
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
set(GEMM_PRESHUFFLE_DATATYPE "fp16;fp8" CACHE STRING "List of datatypes for GEMM Preshuffle (semicolon-separated)")
|
||||
set(GEMM_PRESHUFFLE_DATATYPE "fp16;fp8;bf16;bf8" CACHE STRING "List of datatypes for GEMM Preshuffle (semicolon-separated)")
|
||||
set(GEMM_PRESHUFFLE_LAYOUT "rcr" CACHE STRING "List of layout for GEMM Preshuffle (semicolon-separated)")
|
||||
set(GEMM_PRESHUFFLE_CONFIG_FILE "" CACHE STRING "Custom config file name (without path, must be in configs/ folder)")
|
||||
option(ENABLE_CCACHE_GEMM_PRESHUFFLE "Enable ccache for GEMM Preshuffle ops compilation" OFF)
|
||||
@@ -122,15 +122,15 @@ function(build_individual_gemm_preshuffle_targets datatype layout)
|
||||
if(DEFINED ENV{GEMM_PRESHUFFLE_CONFIG_FILE} AND NOT "$ENV{GEMM_PRESHUFFLE_CONFIG_FILE}" STREQUAL "")
|
||||
set(config_filename "$ENV{GEMM_PRESHUFFLE_CONFIG_FILE}")
|
||||
set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/${config_filename}")
|
||||
message(STATUS " Using config from environment variable: ${config_filename}")
|
||||
message(VERBOSE " Using config from environment variable: ${config_filename}")
|
||||
elseif(NOT "${GEMM_PRESHUFFLE_CONFIG_FILE}" STREQUAL "")
|
||||
# Use CMake variable if set
|
||||
set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/${GEMM_PRESHUFFLE_CONFIG_FILE}")
|
||||
message(STATUS " Using custom config: ${GEMM_PRESHUFFLE_CONFIG_FILE}")
|
||||
message(VERBOSE " Using custom config: ${GEMM_PRESHUFFLE_CONFIG_FILE}")
|
||||
else()
|
||||
# Use default config for all layouts
|
||||
set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/default_config.json")
|
||||
message(STATUS " Using default config for layout ${layout}")
|
||||
message(VERBOSE " Using default config for layout ${layout}")
|
||||
endif()
|
||||
|
||||
# Check if config file exists
|
||||
@@ -151,18 +151,18 @@ function(build_individual_gemm_preshuffle_targets datatype layout)
|
||||
endif()
|
||||
|
||||
# Generate individual kernel files using parallel version
|
||||
message(STATUS "Generating individual kernels for ${datatype} ${layout} using ${num_workers} workers...")
|
||||
message(STATUS " Working path: ${working_path}")
|
||||
message(STATUS " Config file: ${json_blob}")
|
||||
message(STATUS " Python executable: ${Python3_EXECUTABLE}")
|
||||
message(STATUS " Script path: ${CMAKE_CURRENT_LIST_DIR}/gemm_preshuffle_instance_builder.py")
|
||||
message(VERBOSE "Generating individual kernels for ${datatype} ${layout} using ${num_workers} workers...")
|
||||
message(VERBOSE " Working path: ${working_path}")
|
||||
message(VERBOSE " Config file: ${json_blob}")
|
||||
message(VERBOSE " Python executable: ${Python3_EXECUTABLE}")
|
||||
message(VERBOSE " Script path: ${CMAKE_CURRENT_LIST_DIR}/gemm_preshuffle_instance_builder.py")
|
||||
|
||||
# Create working directory first
|
||||
file(MAKE_DIRECTORY ${working_path})
|
||||
|
||||
# First, just list the kernels (fast operation)
|
||||
message(STATUS " Listing kernel configurations...")
|
||||
message(STATUS " GPU Targets: ${GEMM_PRESHUFFLE_GPU_TARGETS_INDIVIDUAL}")
|
||||
message(VERBOSE " Listing kernel configurations...")
|
||||
message(VERBOSE " GPU Targets: ${GEMM_PRESHUFFLE_GPU_TARGETS_INDIVIDUAL}")
|
||||
execute_process(
|
||||
COMMAND ${Python3_EXECUTABLE} -u ${CMAKE_CURRENT_LIST_DIR}/gemm_preshuffle_instance_builder.py
|
||||
--working_path ${working_path}
|
||||
@@ -185,7 +185,7 @@ function(build_individual_gemm_preshuffle_targets datatype layout)
|
||||
if(EXISTS ${working_path}/gemm_preshuffle_kernel_count.txt)
|
||||
file(READ ${working_path}/gemm_preshuffle_kernel_count.txt kernel_count)
|
||||
string(STRIP "${kernel_count}" kernel_count)
|
||||
message(STATUS " Found ${kernel_count} kernel configurations")
|
||||
message(VERBOSE " Found ${kernel_count} kernel configurations")
|
||||
else()
|
||||
message(FATAL_ERROR "Kernel count file not found")
|
||||
endif()
|
||||
@@ -209,10 +209,10 @@ function(build_individual_gemm_preshuffle_targets datatype layout)
|
||||
endfunction()
|
||||
|
||||
# Main build logic - Only individual builds supported
|
||||
message(STATUS "=== Starting Tile Engine GEMM Preshuffle Configuration ===")
|
||||
message(STATUS "GEMM_PRESHUFFLE_DATATYPE: ${GEMM_PRESHUFFLE_DATATYPE}")
|
||||
message(STATUS "GEMM_PRESHUFFLE_LAYOUT: ${GEMM_PRESHUFFLE_LAYOUT}")
|
||||
message(STATUS "SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
|
||||
message(VERBOSE "=== Starting Tile Engine GEMM Preshuffle Configuration ===")
|
||||
message(VERBOSE "GEMM_PRESHUFFLE_DATATYPE: ${GEMM_PRESHUFFLE_DATATYPE}")
|
||||
message(VERBOSE "GEMM_PRESHUFFLE_LAYOUT: ${GEMM_PRESHUFFLE_LAYOUT}")
|
||||
message(VERBOSE "SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
|
||||
|
||||
# Filter GPU targets to only gfx90a, gfx942, and gfx950
|
||||
set(GEMM_PRESHUFFLE_GPU_TARGETS_INDIVIDUAL "")
|
||||
@@ -221,7 +221,7 @@ set(DESIRED_TARGETS "gfx90a;gfx942;gfx950")
|
||||
foreach(target IN LISTS SUPPORTED_GPU_TARGETS)
|
||||
if(target IN_LIST DESIRED_TARGETS)
|
||||
list(APPEND GEMM_PRESHUFFLE_GPU_TARGETS_INDIVIDUAL ${target})
|
||||
message(STATUS " Adding GPU target: ${target}")
|
||||
message(VERBOSE " Adding GPU target: ${target}")
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
@@ -229,7 +229,7 @@ endforeach()
|
||||
if(NOT GEMM_PRESHUFFLE_GPU_TARGETS_INDIVIDUAL)
|
||||
message(WARNING "Skipping Tile Engine GEMM build: No supported GPU targets (gfx90a, gfx942, gfx950) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
|
||||
else()
|
||||
message(STATUS "Building individual GEMM Preshuffle targets for GPU targets: ${GEMM_PRESHUFFLE_GPU_TARGETS_INDIVIDUAL}")
|
||||
message(VERBOSE "Building individual GEMM Preshuffle targets for GPU targets: ${GEMM_PRESHUFFLE_GPU_TARGETS_INDIVIDUAL}")
|
||||
|
||||
# Enable parallel compilation optimizations
|
||||
# Set up job pools for better parallel compilation control
|
||||
@@ -244,12 +244,12 @@ else()
|
||||
find_program(CCACHE_PROGRAM ccache)
|
||||
if(CCACHE_PROGRAM)
|
||||
set(CMAKE_CXX_COMPILER_LAUNCHER ${CCACHE_PROGRAM})
|
||||
message(STATUS "Using ccache for faster compilation")
|
||||
message(VERBOSE "Using ccache for faster compilation")
|
||||
else()
|
||||
message(WARNING "ccache requested but not found")
|
||||
endif()
|
||||
else()
|
||||
message(STATUS "ccache disabled for GEMM Preshuffle ops (use -DENABLE_CCACHE_GEMM_PRESHUFFLE=ON to enable)")
|
||||
message(VERBOSE "ccache disabled for GEMM Preshuffle ops (use -DENABLE_CCACHE_GEMM_PRESHUFFLE=ON to enable)")
|
||||
endif()
|
||||
|
||||
# Create master collection targets
|
||||
|
||||
Reference in New Issue
Block a user