This commit is contained in:
Sami Remes
2025-08-19 08:01:33 +00:00
parent 26d3300930
commit abcf2f3c97
4 changed files with 313 additions and 260 deletions

View File

@@ -121,6 +121,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
KPack,
true>;
using Base::I0;
using Base::I1;
using Base::KRepeat;
using Base::xdlops_gemm;
using typename Base::HotLoopInstList;
@@ -340,6 +341,9 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
auto c_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
c_scale_thread_desc.GetElementSpaceSize());
StaticallyIndexedArray<decltype(a_scale_thread_buf), Number<2>{}> a_scale_thread_bufs;
StaticallyIndexedArray<decltype(b_scale_thread_buf), Number<2>{}> b_scale_thread_bufs;
// Global prefetch 1
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
@@ -352,7 +356,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
a_scale_grid_buf,
a_scale_thread_desc,
make_tuple(m0, I0),
a_scale_thread_buf);
a_scale_thread_bufs(I0));
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
a_scale_thread_copy_step.At(Number<0>{}));
});
@@ -372,12 +376,12 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
b_scale_grid_buf,
b_scale_thread_desc,
make_tuple(I0, I0),
b_scale_thread_buf);
b_scale_thread_bufs(I0));
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step);
static_for<0, MRepeat, 1>{}([&](auto m0) {
c_scale_thread_buf(m0) = a_scale_thread_buf[m0] * b_scale_thread_buf[I0];
c_scale_thread_buf(m0) = a_scale_thread_bufs[I0][m0] * b_scale_thread_bufs[I0][I0];
});
// Local prefill 1
@@ -396,7 +400,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
a_scale_grid_buf,
a_scale_thread_desc,
make_tuple(m0, I0),
a_scale_thread_buf);
a_scale_thread_bufs(I1));
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
a_scale_thread_copy_step.At(Number<0>{}));
});
@@ -416,16 +420,19 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
b_scale_grid_buf,
b_scale_thread_desc,
make_tuple(I0, I0),
b_scale_thread_buf);
b_scale_thread_bufs(I1));
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step);
// Initialize C
c_thread_buf.Clear();
// Double register buffer for non-scaled gemm computation
// 1. Reduce register pressure
// 2. Decouple the dependency between mfma instruction and scale-fma instruction
StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr,
AccDataType,
1,
2,
xdlops_gemm.GetRegSizePerXdlops(),
true>
c_thread_buf_per_scale;
@@ -459,124 +466,161 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
index_t i = 0;
do
{
block_sync_lds();
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
auto LoopFunc = [&](auto mfma_reg_buf, auto local_read_buf) {
block_sync_lds();
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
.template AsType<AccDataType>()(Number<t>{}) = 0;
});
static_for<0, KRepeat, 1>{}([&](auto k0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
static_for<0, MRepeat, 1>{}([&](auto m0) {
vector_type<AccDataType, 2> c_scale_thread_vec;
c_scale_thread_vec.template AsType<AccDataType>()(Number<0>{}) = c_scale_thread_buf[m0];
c_scale_thread_vec.template AsType<AccDataType>()(Number<1>{}) = c_scale_thread_buf[m0];
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
static_for<0, NRepeat, 1>{}([&](auto n0) {
constexpr auto mfma_buf_offset =
((m0 * NRepeat + n0 + 1) % 2) * xdlops_gemm.GetRegSizePerXdlops();
constexpr auto scale_buf_offset =
((m0 * NRepeat + n0) % 2) * xdlops_gemm.GetRegSizePerXdlops();
// Clear buffer for new MFMA computation
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
c_thread_buf_per_scale.GetVectorTypeReference(Number<mfma_buf_offset>{})
.template AsType<AccDataType>()(Number<t>{}) = 0;
});
using mfma_input_type =
typename vector_type<ComputeDataType,
xdlops_gemm.K1PerXdlops>::type;
static_for<0, KRepeat, 1>{}([&](auto k0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}));
});
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType,
xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(Number<mfma_buf_offset>{}));
});
// Apply scaling with packed FMA and accumulate to main buffer
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t));
c_thread_buf(Number<c_offset>{}) +=
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
.template AsType<AccDataType>()[Number<t>{}] *
type_convert<AccDataType>(c_scale_thread_buf[m0]);
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) {
using pk_fma_type = typename vector_type<AccDataType, 2>::type;
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
.template AsType<pk_fma_type>()(t) = __builtin_elementwise_fma(
c_thread_buf_per_scale
.GetVectorTypeReference(Number<scale_buf_offset>{})
.template AsType<pk_fma_type>()[t],
c_scale_thread_vec.template AsType<pk_fma_type>()[Number<0>{}],
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
.template AsType<pk_fma_type>()[t]);
});
});
});
});
static_for<0, MRepeat, 1>{}([&](auto m0) {
c_scale_thread_buf(m0) = a_scale_thread_buf[m0] * b_scale_thread_buf[I0];
});
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, k, I0),
a_thread_buf);
});
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
b_block_buf,
b_thread_desc_,
make_tuple(n0, I0, k, I0),
b_thread_buf);
});
});
HotLoopScheduler();
__builtin_amdgcn_sched_barrier(0);
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, k, I0),
a_thread_buf);
a_scale_thread_copy.Run(a_scale_grid_desc,
a_scale_grid_buf,
a_scale_thread_desc,
make_tuple(m0, I0),
a_scale_thread_bufs(local_read_buf));
a_scale_thread_copy.MoveSrcSliceWindow(
a_scale_grid_desc, a_scale_thread_copy_step.At(Number<0>{}));
});
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
b_block_buf,
b_thread_desc_,
make_tuple(n0, I0, k, I0),
b_thread_buf);
if constexpr(NumKBlockPerScale == 1)
{
a_scale_thread_copy.MoveSrcSliceWindow(
a_scale_grid_desc, a_scale_thread_copy_step.At(Number<2>{}));
}
else
{
a_scale_thread_copy.MoveSrcSliceWindow(
a_scale_grid_desc, a_scale_thread_copy_step.At(Number<1>{}));
}
b_scale_thread_copy.Run(b_scale_grid_desc,
b_scale_grid_buf,
b_scale_thread_desc,
make_tuple(I0, I0),
b_scale_thread_bufs(local_read_buf));
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step);
// Update scales for next iteration using the loaded values
static_for<0, MRepeat, 1>{}([&](auto m0) {
c_scale_thread_buf(m0) = a_scale_thread_bufs[mfma_reg_buf][m0] * b_scale_thread_bufs[mfma_reg_buf][I0];
});
});
};
HotLoopScheduler();
__builtin_amdgcn_sched_barrier(0);
LoopFunc(I0, I1);
LoopFunc(I1, I0);
static_for<0, MRepeat, 1>{}([&](auto m0) {
a_scale_thread_copy.Run(a_scale_grid_desc,
a_scale_grid_buf,
a_scale_thread_desc,
make_tuple(m0, I0),
a_scale_thread_buf);
a_scale_thread_copy.MoveSrcSliceWindow(
a_scale_grid_desc, a_scale_thread_copy_step.At(Number<0>{}));
});
if constexpr(NumKBlockPerScale == 1)
{
a_scale_thread_copy.MoveSrcSliceWindow(
a_scale_grid_desc, a_scale_thread_copy_step.At(Number<2>{}));
}
else
{
a_scale_thread_copy.MoveSrcSliceWindow(
a_scale_grid_desc, a_scale_thread_copy_step.At(Number<1>{}));
}
b_scale_thread_copy.Run(b_scale_grid_desc,
b_scale_grid_buf,
b_scale_thread_desc,
make_tuple(I0, I0),
b_scale_thread_buf);
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step);
i += 1;
} while(i < (num_loop - 1));
i += 2;
} while(i < (num_loop - 2));
}
// tail
if constexpr(TailNum == TailNumber::Full)
{
static_for<0, MRepeat, 1>{}([&](auto m0) {
vector_type<AccDataType, 2> c_scale_thread_vec;
c_scale_thread_vec.template AsType<AccDataType>()(Number<0>{}) = c_scale_thread_buf[m0];
c_scale_thread_vec.template AsType<AccDataType>()(Number<1>{}) = c_scale_thread_buf[m0];
static_for<0, NRepeat, 1>{}([&](auto n0) {
constexpr auto mfma_buf_offset =
((m0 * NRepeat + n0 + 1) % 2) * xdlops_gemm.GetRegSizePerXdlops();
constexpr auto scale_buf_offset =
((m0 * NRepeat + n0) % 2) * xdlops_gemm.GetRegSizePerXdlops();
// Clear buffer for new MFMA computation
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
c_thread_buf_per_scale.GetVectorTypeReference(Number<mfma_buf_offset>{})
.template AsType<AccDataType>()(Number<t>{}) = 0;
});
static_for<0, KRepeat, 1>{}([&](auto k0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
@@ -596,15 +640,24 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}));
c_thread_buf_per_scale.GetVectorTypeReference(Number<mfma_buf_offset>{}));
});
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t));
c_thread_buf(Number<c_offset>{}) +=
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
.template AsType<AccDataType>()[Number<t>{}] *
type_convert<AccDataType>(c_scale_thread_buf[m0]);
// Apply scaling with packed FMA and accumulate to main buffer
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) {
using pk_fma_type = typename vector_type<AccDataType, 2>::type;
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
.template AsType<pk_fma_type>()(t) = __builtin_elementwise_fma(
c_thread_buf_per_scale
.GetVectorTypeReference(Number<scale_buf_offset>{})
.template AsType<pk_fma_type>()[t],
c_scale_thread_vec.template AsType<pk_fma_type>()[Number<0>{}],
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
.template AsType<pk_fma_type>()[t]);
});
});
});

View File

@@ -232,14 +232,14 @@ struct DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3
};
constexpr index_t minimum_occupancy = [&]() {
if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout> &&
is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
{
// FIXME: many instances have many spills with occupancy > 1, a better solution
// needed to get best performance
return 1;
}
else
// if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout> &&
// is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
// {
// // FIXME: many instances have many spills with occupancy > 1, a better solution
// // needed to get best performance
// return 1;
// }
// else
{
return (BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave &&
MPerBlock * NPerBlock / BlockSize > 64)

View File

@@ -20,7 +20,7 @@ list(APPEND GEMM_AB_SCALE_INSTANCES
)
# Row, Col
set_source_files_properties(device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
set_source_files_properties(device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1;-Rpass-analysis=kernel-resource-usage;-save-temps;-g;-fverbose-asm;-Wno-gnu-line-marker")
set_source_files_properties(device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
set_source_files_properties(device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
set_source_files_properties(device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
@@ -30,7 +30,7 @@ set_source_files_properties(device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_s
set_source_files_properties(device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_kn_mn_128_128_128_mem_v1_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
set_source_files_properties(device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_kn_mn_128_128_128_mem_v1_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
# Col, Row
set_source_files_properties(device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_km_kn_mn_128_128_128_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
set_source_files_properties(device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_km_kn_mn_128_128_128_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1;-Rpass-analysis=kernel-resource-usage;-save-temps;-g;-fverbose-asm;-Wno-gnu-line-marker")
set_source_files_properties(device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_km_kn_mn_128_128_128_comp_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
set_source_files_properties(device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_km_kn_mn_128_128_128_mem_v1_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
set_source_files_properties(device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_km_kn_mn_128_128_128_mem_v1_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")

View File

@@ -11,98 +11,98 @@ message(STATUS "CK_PROFILER_OP_FILTER: ${CK_PROFILER_OP_FILTER}")
message(STATUS "CK_PROFILER_INSTANCE_FILTER: ${CK_PROFILER_INSTANCE_FILTER}")
set(PROFILER_OPS
profile_gemm.cpp
profile_reduce.cpp
profile_groupnorm_bwd_data.cpp
profile_groupnorm_fwd.cpp
profile_layernorm_bwd_data.cpp
profile_layernorm_bwd_gamma_beta.cpp
profile_groupnorm_bwd_gamma_beta.cpp
profile_layernorm_fwd.cpp
profile_max_pool2d_fwd.cpp
profile_pool3d_fwd.cpp
profile_avg_pool3d_bwd.cpp
profile_max_pool3d_bwd.cpp
profile_avg_pool2d_bwd.cpp
profile_max_pool2d_bwd.cpp
profile_softmax.cpp
profile_batchnorm_fwd.cpp
profile_batchnorm_bwd.cpp
profile_batchnorm_infer.cpp
profile_conv_tensor_rearrange.cpp
profile_transpose.cpp
profile_permute_scale.cpp
# profile_gemm.cpp
# profile_reduce.cpp
# profile_groupnorm_bwd_data.cpp
# profile_groupnorm_fwd.cpp
# profile_layernorm_bwd_data.cpp
# profile_layernorm_bwd_gamma_beta.cpp
# profile_groupnorm_bwd_gamma_beta.cpp
# profile_layernorm_fwd.cpp
# profile_max_pool2d_fwd.cpp
# profile_pool3d_fwd.cpp
# profile_avg_pool3d_bwd.cpp
# profile_max_pool3d_bwd.cpp
# profile_avg_pool2d_bwd.cpp
# profile_max_pool2d_bwd.cpp
# profile_softmax.cpp
# profile_batchnorm_fwd.cpp
# profile_batchnorm_bwd.cpp
# profile_batchnorm_infer.cpp
# profile_conv_tensor_rearrange.cpp
# profile_transpose.cpp
# profile_permute_scale.cpp
)
if(SUPPORTED_GPU_TARGETS MATCHES "gfx9")
if(DTYPES MATCHES "fp32" OR DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES)
list(APPEND PROFILER_OPS profile_contraction_bilinear.cpp)
list(APPEND PROFILER_OPS profile_contraction_scale.cpp)
# list(APPEND PROFILER_OPS profile_contraction_bilinear.cpp)
# list(APPEND PROFILER_OPS profile_contraction_scale.cpp)
endif()
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
list(APPEND PROFILER_OPS profile_gemm_reduce.cpp)
list(APPEND PROFILER_OPS profile_batched_gemm_gemm.cpp)
list(APPEND PROFILER_OPS profile_batched_gemm_add_relu_gemm_add.cpp)
list(APPEND PROFILER_OPS profile_gemm_add.cpp)
list(APPEND PROFILER_OPS profile_gemm_add_add_fastgelu.cpp)
list(APPEND PROFILER_OPS profile_gemm_add_fastgelu.cpp)
list(APPEND PROFILER_OPS profile_grouped_gemm.cpp)
list(APPEND PROFILER_OPS profile_gemm_streamk.cpp)
list(APPEND PROFILER_OPS profile_gemm_fastgelu.cpp)
list(APPEND PROFILER_OPS profile_gemm_add_relu.cpp)
list(APPEND PROFILER_OPS profile_gemm_add_silu.cpp)
list(APPEND PROFILER_OPS profile_gemm_add_relu_add_layernorm.cpp)
list(APPEND PROFILER_OPS profile_grouped_gemm_fixed_nk.cpp)
list(APPEND PROFILER_OPS profile_grouped_gemm_fastgelu.cpp)
list(APPEND PROFILER_OPS profile_grouped_gemm_tile_loop.cpp)
list(APPEND PROFILER_OPS profile_grouped_gemm_multiply_tile_loop.cpp)
# list(APPEND PROFILER_OPS profile_gemm_reduce.cpp)
# list(APPEND PROFILER_OPS profile_batched_gemm_gemm.cpp)
# list(APPEND PROFILER_OPS profile_batched_gemm_add_relu_gemm_add.cpp)
# list(APPEND PROFILER_OPS profile_gemm_add.cpp)
# list(APPEND PROFILER_OPS profile_gemm_add_add_fastgelu.cpp)
# list(APPEND PROFILER_OPS profile_gemm_add_fastgelu.cpp)
# list(APPEND PROFILER_OPS profile_grouped_gemm.cpp)
# list(APPEND PROFILER_OPS profile_gemm_streamk.cpp)
# list(APPEND PROFILER_OPS profile_gemm_fastgelu.cpp)
# list(APPEND PROFILER_OPS profile_gemm_add_relu.cpp)
# list(APPEND PROFILER_OPS profile_gemm_add_silu.cpp)
# list(APPEND PROFILER_OPS profile_gemm_add_relu_add_layernorm.cpp)
# list(APPEND PROFILER_OPS profile_grouped_gemm_fixed_nk.cpp)
# list(APPEND PROFILER_OPS profile_grouped_gemm_fastgelu.cpp)
# list(APPEND PROFILER_OPS profile_grouped_gemm_tile_loop.cpp)
# list(APPEND PROFILER_OPS profile_grouped_gemm_multiply_tile_loop.cpp)
endif()
list(APPEND PROFILER_OPS profile_gemm_multiply_add.cpp)
# list(APPEND PROFILER_OPS profile_gemm_multiply_add.cpp)
if(SUPPORTED_GPU_TARGETS MATCHES "gfx9[45]")
list(APPEND PROFILER_OPS profile_gemm_multiply_multiply.cpp)
list(APPEND PROFILER_OPS profile_gemm_multiply_multiply_wp.cpp)
# list(APPEND PROFILER_OPS profile_gemm_multiply_multiply.cpp)
# list(APPEND PROFILER_OPS profile_gemm_multiply_multiply_wp.cpp)
list(APPEND PROFILER_OPS profile_gemm_ab_scale.cpp)
list(APPEND PROFILER_OPS profile_gemm_blockscale_wp.cpp)
list(APPEND PROFILER_OPS profile_gemm_universal_preshuffle.cpp)
# list(APPEND PROFILER_OPS profile_gemm_blockscale_wp.cpp)
# list(APPEND PROFILER_OPS profile_gemm_universal_preshuffle.cpp)
endif()
if(SUPPORTED_GPU_TARGETS MATCHES "gfx95")
list(APPEND PROFILER_OPS profile_gemm_mx.cpp)
# list(APPEND PROFILER_OPS profile_gemm_mx.cpp)
endif()
list(APPEND PROFILER_OPS profile_batched_gemm_reduce.cpp)
list(APPEND PROFILER_OPS profile_gemm_add_multiply.cpp)
list(APPEND PROFILER_OPS profile_gemm_bias_add_reduce.cpp)
list(APPEND PROFILER_OPS profile_gemm_splitk.cpp)
list(APPEND PROFILER_OPS profile_batched_gemm_b_scale.cpp)
list(APPEND PROFILER_OPS profile_gemm_universal_batched.cpp)
list(APPEND PROFILER_OPS profile_gemm_universal_reduce.cpp)
list(APPEND PROFILER_OPS profile_gemm_universal_streamk.cpp)
list(APPEND PROFILER_OPS profile_conv_fwd_bias_relu.cpp)
list(APPEND PROFILER_OPS profile_conv_fwd_bias_relu_add.cpp)
list(APPEND PROFILER_OPS profile_conv_bwd_data.cpp)
list(APPEND PROFILER_OPS profile_conv_fwd.cpp)
list(APPEND PROFILER_OPS profile_grouped_conv_fwd_outelementop.cpp)
# list(APPEND PROFILER_OPS profile_batched_gemm_reduce.cpp)
# list(APPEND PROFILER_OPS profile_gemm_add_multiply.cpp)
# list(APPEND PROFILER_OPS profile_gemm_bias_add_reduce.cpp)
# list(APPEND PROFILER_OPS profile_gemm_splitk.cpp)
# list(APPEND PROFILER_OPS profile_batched_gemm_b_scale.cpp)
# list(APPEND PROFILER_OPS profile_gemm_universal_batched.cpp)
# list(APPEND PROFILER_OPS profile_gemm_universal_reduce.cpp)
# list(APPEND PROFILER_OPS profile_gemm_universal_streamk.cpp)
# list(APPEND PROFILER_OPS profile_conv_fwd_bias_relu.cpp)
# list(APPEND PROFILER_OPS profile_conv_fwd_bias_relu_add.cpp)
# list(APPEND PROFILER_OPS profile_conv_bwd_data.cpp)
# list(APPEND PROFILER_OPS profile_conv_fwd.cpp)
# list(APPEND PROFILER_OPS profile_grouped_conv_fwd_outelementop.cpp)
endif()
if((SUPPORTED_GPU_TARGETS MATCHES "gfx9" AND (DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)) OR
(SUPPORTED_GPU_TARGETS MATCHES "gfx1[12]" AND (DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)))
list(APPEND PROFILER_OPS profile_gemm_bilinear.cpp)
# list(APPEND PROFILER_OPS profile_gemm_bilinear.cpp)
endif()
if(SUPPORTED_GPU_TARGETS MATCHES "gfx11" OR SUPPORTED_GPU_TARGETS MATCHES "gfx12" OR SUPPORTED_GPU_TARGETS MATCHES "gfx9")
list(APPEND PROFILER_OPS profile_gemm_universal.cpp)
list(APPEND PROFILER_OPS profile_batched_gemm.cpp)
list(APPEND PROFILER_OPS profile_gemm_b_scale.cpp)
list(APPEND PROFILER_OPS profile_grouped_conv_fwd.cpp)
list(APPEND PROFILER_OPS profile_grouped_conv_fwd_bias_clamp.cpp)
list(APPEND PROFILER_OPS profile_grouped_conv_fwd_clamp.cpp)
list(APPEND PROFILER_OPS profile_grouped_conv_bwd_data.cpp)
list(APPEND PROFILER_OPS profile_grouped_conv_bwd_weight.cpp)
# list(APPEND PROFILER_OPS profile_gemm_universal.cpp)
# list(APPEND PROFILER_OPS profile_batched_gemm.cpp)
# list(APPEND PROFILER_OPS profile_gemm_b_scale.cpp)
# list(APPEND PROFILER_OPS profile_grouped_conv_fwd.cpp)
# list(APPEND PROFILER_OPS profile_grouped_conv_fwd_bias_clamp.cpp)
# list(APPEND PROFILER_OPS profile_grouped_conv_fwd_clamp.cpp)
# list(APPEND PROFILER_OPS profile_grouped_conv_bwd_data.cpp)
# list(APPEND PROFILER_OPS profile_grouped_conv_bwd_weight.cpp)
endif()
if(DL_KERNELS)
list(APPEND PROFILER_OPS profile_batched_gemm_multi_d.cpp)
list(APPEND PROFILER_OPS profile_grouped_conv_bwd_weight.cpp)
# list(APPEND PROFILER_OPS profile_batched_gemm_multi_d.cpp)
# list(APPEND PROFILER_OPS profile_grouped_conv_bwd_weight.cpp)
endif()
set(PROFILER_SOURCES profiler.cpp)
@@ -129,103 +129,103 @@ endif()
set(DEVICE_INSTANCES "")
list(APPEND DEVICE_INSTANCES device_gemm_instance)
list(APPEND DEVICE_INSTANCES device_normalization_fwd_instance)
list(APPEND DEVICE_INSTANCES device_normalization_bwd_data_instance)
list(APPEND DEVICE_INSTANCES device_normalization_bwd_gamma_beta_instance)
list(APPEND DEVICE_INSTANCES device_softmax_instance)
list(APPEND DEVICE_INSTANCES device_reduce_instance)
list(APPEND DEVICE_INSTANCES device_batchnorm_instance)
list(APPEND DEVICE_INSTANCES device_pool2d_fwd_instance)
list(APPEND DEVICE_INSTANCES device_pool3d_fwd_instance)
list(APPEND DEVICE_INSTANCES device_avg_pool2d_bwd_instance)
list(APPEND DEVICE_INSTANCES device_avg_pool3d_bwd_instance)
list(APPEND DEVICE_INSTANCES device_max_pool_bwd_instance)
list(APPEND DEVICE_INSTANCES device_image_to_column_instance)
list(APPEND DEVICE_INSTANCES device_column_to_image_instance)
list(APPEND DEVICE_INSTANCES device_transpose_instance)
list(APPEND DEVICE_INSTANCES device_permute_scale_instance)
# list(APPEND DEVICE_INSTANCES device_gemm_instance)
# list(APPEND DEVICE_INSTANCES device_normalization_fwd_instance)
# list(APPEND DEVICE_INSTANCES device_normalization_bwd_data_instance)
# list(APPEND DEVICE_INSTANCES device_normalization_bwd_gamma_beta_instance)
# list(APPEND DEVICE_INSTANCES device_softmax_instance)
# list(APPEND DEVICE_INSTANCES device_reduce_instance)
# list(APPEND DEVICE_INSTANCES device_batchnorm_instance)
# list(APPEND DEVICE_INSTANCES device_pool2d_fwd_instance)
# list(APPEND DEVICE_INSTANCES device_pool3d_fwd_instance)
# list(APPEND DEVICE_INSTANCES device_avg_pool2d_bwd_instance)
# list(APPEND DEVICE_INSTANCES device_avg_pool3d_bwd_instance)
# list(APPEND DEVICE_INSTANCES device_max_pool_bwd_instance)
# list(APPEND DEVICE_INSTANCES device_image_to_column_instance)
# list(APPEND DEVICE_INSTANCES device_column_to_image_instance)
# list(APPEND DEVICE_INSTANCES device_transpose_instance)
# list(APPEND DEVICE_INSTANCES device_permute_scale_instance)
if(SUPPORTED_GPU_TARGETS MATCHES "gfx9")
if(DTYPES MATCHES "fp32" OR DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES)
list(APPEND DEVICE_INSTANCES device_contraction_bilinear_instance)
list(APPEND DEVICE_INSTANCES device_contraction_scale_instance)
# list(APPEND DEVICE_INSTANCES device_contraction_bilinear_instance)
# list(APPEND DEVICE_INSTANCES device_contraction_scale_instance)
endif()
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
list(APPEND DEVICE_INSTANCES device_gemm_add_instance)
list(APPEND DEVICE_INSTANCES device_gemm_add_add_fastgelu_instance)
list(APPEND DEVICE_INSTANCES device_gemm_fastgelu_instance)
list(APPEND DEVICE_INSTANCES device_batched_gemm_gemm_instance)
list(APPEND DEVICE_INSTANCES device_batched_gemm_add_relu_gemm_add_instance)
list(APPEND DEVICE_INSTANCES device_grouped_gemm_instance)
list(APPEND DEVICE_INSTANCES device_gemm_streamk_instance)
list(APPEND DEVICE_INSTANCES device_gemm_add_fastgelu_instance)
list(APPEND DEVICE_INSTANCES device_gemm_add_relu_instance)
list(APPEND DEVICE_INSTANCES device_gemm_add_silu_instance)
list(APPEND DEVICE_INSTANCES device_gemm_add_relu_add_layernorm_instance)
list(APPEND DEVICE_INSTANCES device_grouped_gemm_fixed_nk_instance)
list(APPEND DEVICE_INSTANCES device_grouped_gemm_fastgelu_instance)
list(APPEND DEVICE_INSTANCES device_grouped_gemm_tile_loop_instance)
# list(APPEND DEVICE_INSTANCES device_gemm_add_instance)
# list(APPEND DEVICE_INSTANCES device_gemm_add_add_fastgelu_instance)
# list(APPEND DEVICE_INSTANCES device_gemm_fastgelu_instance)
# list(APPEND DEVICE_INSTANCES device_batched_gemm_gemm_instance)
# list(APPEND DEVICE_INSTANCES device_batched_gemm_add_relu_gemm_add_instance)
# list(APPEND DEVICE_INSTANCES device_grouped_gemm_instance)
# list(APPEND DEVICE_INSTANCES device_gemm_streamk_instance)
# list(APPEND DEVICE_INSTANCES device_gemm_add_fastgelu_instance)
# list(APPEND DEVICE_INSTANCES device_gemm_add_relu_instance)
# list(APPEND DEVICE_INSTANCES device_gemm_add_silu_instance)
# list(APPEND DEVICE_INSTANCES device_gemm_add_relu_add_layernorm_instance)
# list(APPEND DEVICE_INSTANCES device_grouped_gemm_fixed_nk_instance)
# list(APPEND DEVICE_INSTANCES device_grouped_gemm_fastgelu_instance)
# list(APPEND DEVICE_INSTANCES device_grouped_gemm_tile_loop_instance)
endif()
list(APPEND DEVICE_INSTANCES device_batched_gemm_reduce_instance)
list(APPEND DEVICE_INSTANCES device_gemm_multiply_add_instance)
# list(APPEND DEVICE_INSTANCES device_batched_gemm_reduce_instance)
# list(APPEND DEVICE_INSTANCES device_gemm_multiply_add_instance)
if(SUPPORTED_GPU_TARGETS MATCHES "gfx9[45]")
list(APPEND DEVICE_INSTANCES device_gemm_multiply_multiply_instance)
list(APPEND DEVICE_INSTANCES device_gemm_multiply_multiply_wp_instance)
# list(APPEND DEVICE_INSTANCES device_gemm_multiply_multiply_instance)
# list(APPEND DEVICE_INSTANCES device_gemm_multiply_multiply_wp_instance)
list(APPEND DEVICE_INSTANCES device_gemm_ab_scale_instance)
list(APPEND DEVICE_INSTANCES device_gemm_blockscale_wp_instance)
list(APPEND DEVICE_INSTANCES device_gemm_universal_preshuffle_instance)
# list(APPEND DEVICE_INSTANCES device_gemm_blockscale_wp_instance)
# list(APPEND DEVICE_INSTANCES device_gemm_universal_preshuffle_instance)
endif()
if(SUPPORTED_GPU_TARGETS MATCHES "gfx95")
list(APPEND DEVICE_INSTANCES device_gemm_mx_instance)
# list(APPEND DEVICE_INSTANCES device_gemm_mx_instance)
endif()
list(APPEND DEVICE_INSTANCES device_gemm_splitk_instance)
list(APPEND DEVICE_INSTANCES device_batched_gemm_b_scale_instance)
list(APPEND DEVICE_INSTANCES device_gemm_universal_batched_instance)
list(APPEND DEVICE_INSTANCES device_gemm_universal_reduce_instance)
list(APPEND DEVICE_INSTANCES device_gemm_universal_streamk_instance)
list(APPEND DEVICE_INSTANCES device_gemm_add_multiply_instance)
list(APPEND DEVICE_INSTANCES device_gemm_reduce_instance)
list(APPEND DEVICE_INSTANCES device_gemm_bias_add_reduce_instance)
list(APPEND DEVICE_INSTANCES device_conv2d_fwd_instance)
list(APPEND DEVICE_INSTANCES device_conv2d_fwd_bias_relu_instance)
list(APPEND DEVICE_INSTANCES device_conv2d_fwd_bias_relu_add_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv1d_fwd_instance)
list(APPEND DEVICE_INSTANCES device_conv1d_bwd_data_instance)
list(APPEND DEVICE_INSTANCES device_conv3d_bwd_data_instance)
list(APPEND DEVICE_INSTANCES device_conv2d_bwd_data_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv1d_bwd_weight_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv2d_bwd_weight_instance)
list(APPEND DEVICE_INSTANCES device_grouped_convnd_bwd_weight_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_convscale_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_convinvscale_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv2d_fwd_clamp_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_clamp_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv2d_fwd_bias_clamp_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_bias_clamp_instance)
# list(APPEND DEVICE_INSTANCES device_gemm_splitk_instance)
# list(APPEND DEVICE_INSTANCES device_batched_gemm_b_scale_instance)
# list(APPEND DEVICE_INSTANCES device_gemm_universal_batched_instance)
# list(APPEND DEVICE_INSTANCES device_gemm_universal_reduce_instance)
# list(APPEND DEVICE_INSTANCES device_gemm_universal_streamk_instance)
# list(APPEND DEVICE_INSTANCES device_gemm_add_multiply_instance)
# list(APPEND DEVICE_INSTANCES device_gemm_reduce_instance)
# list(APPEND DEVICE_INSTANCES device_gemm_bias_add_reduce_instance)
# list(APPEND DEVICE_INSTANCES device_conv2d_fwd_instance)
# list(APPEND DEVICE_INSTANCES device_conv2d_fwd_bias_relu_instance)
# list(APPEND DEVICE_INSTANCES device_conv2d_fwd_bias_relu_add_instance)
# list(APPEND DEVICE_INSTANCES device_grouped_conv1d_fwd_instance)
# list(APPEND DEVICE_INSTANCES device_conv1d_bwd_data_instance)
# list(APPEND DEVICE_INSTANCES device_conv3d_bwd_data_instance)
# list(APPEND DEVICE_INSTANCES device_conv2d_bwd_data_instance)
# list(APPEND DEVICE_INSTANCES device_grouped_conv1d_bwd_weight_instance)
# list(APPEND DEVICE_INSTANCES device_grouped_conv2d_bwd_weight_instance)
# list(APPEND DEVICE_INSTANCES device_grouped_convnd_bwd_weight_instance)
# list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_convscale_instance)
# list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_convinvscale_instance)
# list(APPEND DEVICE_INSTANCES device_grouped_conv2d_fwd_clamp_instance)
# list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_clamp_instance)
# list(APPEND DEVICE_INSTANCES device_grouped_conv2d_fwd_bias_clamp_instance)
# list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_bias_clamp_instance)
endif()
if((SUPPORTED_GPU_TARGETS MATCHES "gfx9" AND (DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)) OR
(SUPPORTED_GPU_TARGETS MATCHES "gfx1[12]" AND (DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)))
list(APPEND DEVICE_INSTANCES device_gemm_bilinear_instance)
# list(APPEND DEVICE_INSTANCES device_gemm_bilinear_instance)
endif()
if(SUPPORTED_GPU_TARGETS MATCHES "gfx9" OR SUPPORTED_GPU_TARGETS MATCHES "gfx1[12]")
list(APPEND DEVICE_INSTANCES device_gemm_universal_instance)
list(APPEND DEVICE_INSTANCES device_batched_gemm_instance)
list(APPEND DEVICE_INSTANCES device_gemm_b_scale_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv2d_bwd_data_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_bwd_data_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv2d_fwd_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_bwd_weight_instance)
# list(APPEND DEVICE_INSTANCES device_gemm_universal_instance)
# list(APPEND DEVICE_INSTANCES device_batched_gemm_instance)
# list(APPEND DEVICE_INSTANCES device_gemm_b_scale_instance)
# list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_instance)
# list(APPEND DEVICE_INSTANCES device_grouped_conv2d_bwd_data_instance)
# list(APPEND DEVICE_INSTANCES device_grouped_conv3d_bwd_data_instance)
# list(APPEND DEVICE_INSTANCES device_grouped_conv2d_fwd_instance)
# list(APPEND DEVICE_INSTANCES device_grouped_conv3d_bwd_weight_instance)
endif()
if(DL_KERNELS)
list(APPEND DEVICE_INSTANCES device_batched_gemm_multi_d_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv1d_bwd_weight_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv2d_bwd_weight_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_bwd_weight_instance)
# list(APPEND DEVICE_INSTANCES device_batched_gemm_multi_d_instance)
# list(APPEND DEVICE_INSTANCES device_grouped_conv1d_bwd_weight_instance)
# list(APPEND DEVICE_INSTANCES device_grouped_conv2d_bwd_weight_instance)
# list(APPEND DEVICE_INSTANCES device_grouped_conv3d_bwd_weight_instance)
endif()
set(PROFILER_LIBS utility getopt::getopt)