mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 19:28:33 +00:00
WIP
This commit is contained in:
@@ -121,6 +121,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
KPack,
|
||||
true>;
|
||||
using Base::I0;
|
||||
using Base::I1;
|
||||
using Base::KRepeat;
|
||||
using Base::xdlops_gemm;
|
||||
using typename Base::HotLoopInstList;
|
||||
@@ -340,6 +341,9 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
auto c_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
|
||||
c_scale_thread_desc.GetElementSpaceSize());
|
||||
|
||||
StaticallyIndexedArray<decltype(a_scale_thread_buf), Number<2>{}> a_scale_thread_bufs;
|
||||
StaticallyIndexedArray<decltype(b_scale_thread_buf), Number<2>{}> b_scale_thread_bufs;
|
||||
|
||||
// Global prefetch 1
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
|
||||
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
|
||||
@@ -352,7 +356,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
a_scale_grid_buf,
|
||||
a_scale_thread_desc,
|
||||
make_tuple(m0, I0),
|
||||
a_scale_thread_buf);
|
||||
a_scale_thread_bufs(I0));
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
|
||||
a_scale_thread_copy_step.At(Number<0>{}));
|
||||
});
|
||||
@@ -372,12 +376,12 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
b_scale_grid_buf,
|
||||
b_scale_thread_desc,
|
||||
make_tuple(I0, I0),
|
||||
b_scale_thread_buf);
|
||||
b_scale_thread_bufs(I0));
|
||||
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step);
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
c_scale_thread_buf(m0) = a_scale_thread_buf[m0] * b_scale_thread_buf[I0];
|
||||
c_scale_thread_buf(m0) = a_scale_thread_bufs[I0][m0] * b_scale_thread_bufs[I0][I0];
|
||||
});
|
||||
|
||||
// Local prefill 1
|
||||
@@ -396,7 +400,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
a_scale_grid_buf,
|
||||
a_scale_thread_desc,
|
||||
make_tuple(m0, I0),
|
||||
a_scale_thread_buf);
|
||||
a_scale_thread_bufs(I1));
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
|
||||
a_scale_thread_copy_step.At(Number<0>{}));
|
||||
});
|
||||
@@ -416,16 +420,19 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
b_scale_grid_buf,
|
||||
b_scale_thread_desc,
|
||||
make_tuple(I0, I0),
|
||||
b_scale_thread_buf);
|
||||
b_scale_thread_bufs(I1));
|
||||
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step);
|
||||
|
||||
// Initialize C
|
||||
c_thread_buf.Clear();
|
||||
|
||||
// Double register buffer for non-scaled gemm computation
|
||||
// 1. Reduce register pressure
|
||||
// 2. Decouple the dependency between mfma instruction and scale-fma instruction
|
||||
StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr,
|
||||
AccDataType,
|
||||
1,
|
||||
2,
|
||||
xdlops_gemm.GetRegSizePerXdlops(),
|
||||
true>
|
||||
c_thread_buf_per_scale;
|
||||
@@ -459,124 +466,161 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
index_t i = 0;
|
||||
do
|
||||
{
|
||||
block_sync_lds();
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
|
||||
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
|
||||
auto LoopFunc = [&](auto mfma_reg_buf, auto local_read_buf) {
|
||||
block_sync_lds();
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
|
||||
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
|
||||
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
|
||||
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
|
||||
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
|
||||
.template AsType<AccDataType>()(Number<t>{}) = 0;
|
||||
});
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
vector_type<AccDataType, 2> c_scale_thread_vec;
|
||||
c_scale_thread_vec.template AsType<AccDataType>()(Number<0>{}) = c_scale_thread_buf[m0];
|
||||
c_scale_thread_vec.template AsType<AccDataType>()(Number<1>{}) = c_scale_thread_buf[m0];
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, k0, ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
constexpr auto mfma_buf_offset =
|
||||
((m0 * NRepeat + n0 + 1) % 2) * xdlops_gemm.GetRegSizePerXdlops();
|
||||
constexpr auto scale_buf_offset =
|
||||
((m0 * NRepeat + n0) % 2) * xdlops_gemm.GetRegSizePerXdlops();
|
||||
|
||||
// Clear buffer for new MFMA computation
|
||||
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(Number<mfma_buf_offset>{})
|
||||
.template AsType<AccDataType>()(Number<t>{}) = 0;
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType,
|
||||
xdlops_gemm.K1PerXdlops>::type;
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
|
||||
xdlops_gemm.template Run<>(
|
||||
a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec.template AsType<mfma_input_type>(),
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}));
|
||||
});
|
||||
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, k0, ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType,
|
||||
xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
xdlops_gemm.template Run<>(
|
||||
a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec.template AsType<mfma_input_type>(),
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(Number<mfma_buf_offset>{}));
|
||||
});
|
||||
|
||||
// Apply scaling with packed FMA and accumulate to main buffer
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t));
|
||||
c_thread_buf(Number<c_offset>{}) +=
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
|
||||
.template AsType<AccDataType>()[Number<t>{}] *
|
||||
type_convert<AccDataType>(c_scale_thread_buf[m0]);
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
|
||||
static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) {
|
||||
using pk_fma_type = typename vector_type<AccDataType, 2>::type;
|
||||
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
|
||||
.template AsType<pk_fma_type>()(t) = __builtin_elementwise_fma(
|
||||
c_thread_buf_per_scale
|
||||
.GetVectorTypeReference(Number<scale_buf_offset>{})
|
||||
.template AsType<pk_fma_type>()[t],
|
||||
c_scale_thread_vec.template AsType<pk_fma_type>()[Number<0>{}],
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
|
||||
.template AsType<pk_fma_type>()[t]);
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
c_scale_thread_buf(m0) = a_scale_thread_buf[m0] * b_scale_thread_buf[I0];
|
||||
});
|
||||
block_sync_lds();
|
||||
static_for<0, KRepeat, 1>{}([&](auto k) {
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
|
||||
make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(m0, I0, k, I0),
|
||||
a_thread_buf);
|
||||
});
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
|
||||
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
|
||||
b_block_buf,
|
||||
b_thread_desc_,
|
||||
make_tuple(n0, I0, k, I0),
|
||||
b_thread_buf);
|
||||
});
|
||||
});
|
||||
|
||||
HotLoopScheduler();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
block_sync_lds();
|
||||
static_for<0, KRepeat, 1>{}([&](auto k) {
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
|
||||
make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(m0, I0, k, I0),
|
||||
a_thread_buf);
|
||||
a_scale_thread_copy.Run(a_scale_grid_desc,
|
||||
a_scale_grid_buf,
|
||||
a_scale_thread_desc,
|
||||
make_tuple(m0, I0),
|
||||
a_scale_thread_bufs(local_read_buf));
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(
|
||||
a_scale_grid_desc, a_scale_thread_copy_step.At(Number<0>{}));
|
||||
});
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
|
||||
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
|
||||
b_block_buf,
|
||||
b_thread_desc_,
|
||||
make_tuple(n0, I0, k, I0),
|
||||
b_thread_buf);
|
||||
|
||||
if constexpr(NumKBlockPerScale == 1)
|
||||
{
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(
|
||||
a_scale_grid_desc, a_scale_thread_copy_step.At(Number<2>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(
|
||||
a_scale_grid_desc, a_scale_thread_copy_step.At(Number<1>{}));
|
||||
}
|
||||
|
||||
b_scale_thread_copy.Run(b_scale_grid_desc,
|
||||
b_scale_grid_buf,
|
||||
b_scale_thread_desc,
|
||||
make_tuple(I0, I0),
|
||||
b_scale_thread_bufs(local_read_buf));
|
||||
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step);
|
||||
|
||||
// Update scales for next iteration using the loaded values
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
c_scale_thread_buf(m0) = a_scale_thread_bufs[mfma_reg_buf][m0] * b_scale_thread_bufs[mfma_reg_buf][I0];
|
||||
});
|
||||
});
|
||||
};
|
||||
|
||||
HotLoopScheduler();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
LoopFunc(I0, I1);
|
||||
LoopFunc(I1, I0);
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
a_scale_thread_copy.Run(a_scale_grid_desc,
|
||||
a_scale_grid_buf,
|
||||
a_scale_thread_desc,
|
||||
make_tuple(m0, I0),
|
||||
a_scale_thread_buf);
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(
|
||||
a_scale_grid_desc, a_scale_thread_copy_step.At(Number<0>{}));
|
||||
});
|
||||
|
||||
if constexpr(NumKBlockPerScale == 1)
|
||||
{
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(
|
||||
a_scale_grid_desc, a_scale_thread_copy_step.At(Number<2>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(
|
||||
a_scale_grid_desc, a_scale_thread_copy_step.At(Number<1>{}));
|
||||
}
|
||||
|
||||
b_scale_thread_copy.Run(b_scale_grid_desc,
|
||||
b_scale_grid_buf,
|
||||
b_scale_thread_desc,
|
||||
make_tuple(I0, I0),
|
||||
b_scale_thread_buf);
|
||||
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step);
|
||||
|
||||
i += 1;
|
||||
} while(i < (num_loop - 1));
|
||||
i += 2;
|
||||
} while(i < (num_loop - 2));
|
||||
}
|
||||
|
||||
// tail
|
||||
if constexpr(TailNum == TailNumber::Full)
|
||||
{
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
vector_type<AccDataType, 2> c_scale_thread_vec;
|
||||
c_scale_thread_vec.template AsType<AccDataType>()(Number<0>{}) = c_scale_thread_buf[m0];
|
||||
c_scale_thread_vec.template AsType<AccDataType>()(Number<1>{}) = c_scale_thread_buf[m0];
|
||||
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
constexpr auto mfma_buf_offset =
|
||||
((m0 * NRepeat + n0 + 1) % 2) * xdlops_gemm.GetRegSizePerXdlops();
|
||||
constexpr auto scale_buf_offset =
|
||||
((m0 * NRepeat + n0) % 2) * xdlops_gemm.GetRegSizePerXdlops();
|
||||
|
||||
// Clear buffer for new MFMA computation
|
||||
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(Number<mfma_buf_offset>{})
|
||||
.template AsType<AccDataType>()(Number<t>{}) = 0;
|
||||
});
|
||||
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
@@ -596,15 +640,24 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
xdlops_gemm.template Run<>(
|
||||
a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec.template AsType<mfma_input_type>(),
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}));
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(Number<mfma_buf_offset>{}));
|
||||
});
|
||||
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t));
|
||||
c_thread_buf(Number<c_offset>{}) +=
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
|
||||
.template AsType<AccDataType>()[Number<t>{}] *
|
||||
type_convert<AccDataType>(c_scale_thread_buf[m0]);
|
||||
|
||||
// Apply scaling with packed FMA and accumulate to main buffer
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
|
||||
static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) {
|
||||
using pk_fma_type = typename vector_type<AccDataType, 2>::type;
|
||||
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
|
||||
.template AsType<pk_fma_type>()(t) = __builtin_elementwise_fma(
|
||||
c_thread_buf_per_scale
|
||||
.GetVectorTypeReference(Number<scale_buf_offset>{})
|
||||
.template AsType<pk_fma_type>()[t],
|
||||
c_scale_thread_vec.template AsType<pk_fma_type>()[Number<0>{}],
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
|
||||
.template AsType<pk_fma_type>()[t]);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -232,14 +232,14 @@ struct DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3
|
||||
};
|
||||
|
||||
constexpr index_t minimum_occupancy = [&]() {
|
||||
if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout> &&
|
||||
is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
|
||||
{
|
||||
// FIXME: many instances have many spills with occupancy > 1, a better solution
|
||||
// needed to get best performance
|
||||
return 1;
|
||||
}
|
||||
else
|
||||
// if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout> &&
|
||||
// is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
|
||||
// {
|
||||
// // FIXME: many instances have many spills with occupancy > 1, a better solution
|
||||
// // needed to get best performance
|
||||
// return 1;
|
||||
// }
|
||||
// else
|
||||
{
|
||||
return (BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave &&
|
||||
MPerBlock * NPerBlock / BlockSize > 64)
|
||||
|
||||
@@ -20,7 +20,7 @@ list(APPEND GEMM_AB_SCALE_INSTANCES
|
||||
)
|
||||
|
||||
# Row, Col
|
||||
set_source_files_properties(device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
|
||||
set_source_files_properties(device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1;-Rpass-analysis=kernel-resource-usage;-save-temps;-g;-fverbose-asm;-Wno-gnu-line-marker")
|
||||
set_source_files_properties(device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
|
||||
set_source_files_properties(device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
|
||||
set_source_files_properties(device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
|
||||
@@ -30,7 +30,7 @@ set_source_files_properties(device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_s
|
||||
set_source_files_properties(device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_kn_mn_128_128_128_mem_v1_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
|
||||
set_source_files_properties(device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_kn_mn_128_128_128_mem_v1_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
|
||||
# Col, Row
|
||||
set_source_files_properties(device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_km_kn_mn_128_128_128_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
|
||||
set_source_files_properties(device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_km_kn_mn_128_128_128_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1;-Rpass-analysis=kernel-resource-usage;-save-temps;-g;-fverbose-asm;-Wno-gnu-line-marker")
|
||||
set_source_files_properties(device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_km_kn_mn_128_128_128_comp_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
|
||||
set_source_files_properties(device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_km_kn_mn_128_128_128_mem_v1_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
|
||||
set_source_files_properties(device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_km_kn_mn_128_128_128_mem_v1_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
|
||||
|
||||
@@ -11,98 +11,98 @@ message(STATUS "CK_PROFILER_OP_FILTER: ${CK_PROFILER_OP_FILTER}")
|
||||
message(STATUS "CK_PROFILER_INSTANCE_FILTER: ${CK_PROFILER_INSTANCE_FILTER}")
|
||||
|
||||
set(PROFILER_OPS
|
||||
profile_gemm.cpp
|
||||
profile_reduce.cpp
|
||||
profile_groupnorm_bwd_data.cpp
|
||||
profile_groupnorm_fwd.cpp
|
||||
profile_layernorm_bwd_data.cpp
|
||||
profile_layernorm_bwd_gamma_beta.cpp
|
||||
profile_groupnorm_bwd_gamma_beta.cpp
|
||||
profile_layernorm_fwd.cpp
|
||||
profile_max_pool2d_fwd.cpp
|
||||
profile_pool3d_fwd.cpp
|
||||
profile_avg_pool3d_bwd.cpp
|
||||
profile_max_pool3d_bwd.cpp
|
||||
profile_avg_pool2d_bwd.cpp
|
||||
profile_max_pool2d_bwd.cpp
|
||||
profile_softmax.cpp
|
||||
profile_batchnorm_fwd.cpp
|
||||
profile_batchnorm_bwd.cpp
|
||||
profile_batchnorm_infer.cpp
|
||||
profile_conv_tensor_rearrange.cpp
|
||||
profile_transpose.cpp
|
||||
profile_permute_scale.cpp
|
||||
# profile_gemm.cpp
|
||||
# profile_reduce.cpp
|
||||
# profile_groupnorm_bwd_data.cpp
|
||||
# profile_groupnorm_fwd.cpp
|
||||
# profile_layernorm_bwd_data.cpp
|
||||
# profile_layernorm_bwd_gamma_beta.cpp
|
||||
# profile_groupnorm_bwd_gamma_beta.cpp
|
||||
# profile_layernorm_fwd.cpp
|
||||
# profile_max_pool2d_fwd.cpp
|
||||
# profile_pool3d_fwd.cpp
|
||||
# profile_avg_pool3d_bwd.cpp
|
||||
# profile_max_pool3d_bwd.cpp
|
||||
# profile_avg_pool2d_bwd.cpp
|
||||
# profile_max_pool2d_bwd.cpp
|
||||
# profile_softmax.cpp
|
||||
# profile_batchnorm_fwd.cpp
|
||||
# profile_batchnorm_bwd.cpp
|
||||
# profile_batchnorm_infer.cpp
|
||||
# profile_conv_tensor_rearrange.cpp
|
||||
# profile_transpose.cpp
|
||||
# profile_permute_scale.cpp
|
||||
)
|
||||
|
||||
if(SUPPORTED_GPU_TARGETS MATCHES "gfx9")
|
||||
if(DTYPES MATCHES "fp32" OR DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES)
|
||||
list(APPEND PROFILER_OPS profile_contraction_bilinear.cpp)
|
||||
list(APPEND PROFILER_OPS profile_contraction_scale.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_contraction_bilinear.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_contraction_scale.cpp)
|
||||
endif()
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
list(APPEND PROFILER_OPS profile_gemm_reduce.cpp)
|
||||
list(APPEND PROFILER_OPS profile_batched_gemm_gemm.cpp)
|
||||
list(APPEND PROFILER_OPS profile_batched_gemm_add_relu_gemm_add.cpp)
|
||||
list(APPEND PROFILER_OPS profile_gemm_add.cpp)
|
||||
list(APPEND PROFILER_OPS profile_gemm_add_add_fastgelu.cpp)
|
||||
list(APPEND PROFILER_OPS profile_gemm_add_fastgelu.cpp)
|
||||
list(APPEND PROFILER_OPS profile_grouped_gemm.cpp)
|
||||
list(APPEND PROFILER_OPS profile_gemm_streamk.cpp)
|
||||
list(APPEND PROFILER_OPS profile_gemm_fastgelu.cpp)
|
||||
list(APPEND PROFILER_OPS profile_gemm_add_relu.cpp)
|
||||
list(APPEND PROFILER_OPS profile_gemm_add_silu.cpp)
|
||||
list(APPEND PROFILER_OPS profile_gemm_add_relu_add_layernorm.cpp)
|
||||
list(APPEND PROFILER_OPS profile_grouped_gemm_fixed_nk.cpp)
|
||||
list(APPEND PROFILER_OPS profile_grouped_gemm_fastgelu.cpp)
|
||||
list(APPEND PROFILER_OPS profile_grouped_gemm_tile_loop.cpp)
|
||||
list(APPEND PROFILER_OPS profile_grouped_gemm_multiply_tile_loop.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_gemm_reduce.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_batched_gemm_gemm.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_batched_gemm_add_relu_gemm_add.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_gemm_add.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_gemm_add_add_fastgelu.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_gemm_add_fastgelu.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_grouped_gemm.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_gemm_streamk.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_gemm_fastgelu.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_gemm_add_relu.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_gemm_add_silu.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_gemm_add_relu_add_layernorm.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_grouped_gemm_fixed_nk.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_grouped_gemm_fastgelu.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_grouped_gemm_tile_loop.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_grouped_gemm_multiply_tile_loop.cpp)
|
||||
endif()
|
||||
list(APPEND PROFILER_OPS profile_gemm_multiply_add.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_gemm_multiply_add.cpp)
|
||||
if(SUPPORTED_GPU_TARGETS MATCHES "gfx9[45]")
|
||||
list(APPEND PROFILER_OPS profile_gemm_multiply_multiply.cpp)
|
||||
list(APPEND PROFILER_OPS profile_gemm_multiply_multiply_wp.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_gemm_multiply_multiply.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_gemm_multiply_multiply_wp.cpp)
|
||||
list(APPEND PROFILER_OPS profile_gemm_ab_scale.cpp)
|
||||
list(APPEND PROFILER_OPS profile_gemm_blockscale_wp.cpp)
|
||||
list(APPEND PROFILER_OPS profile_gemm_universal_preshuffle.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_gemm_blockscale_wp.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_gemm_universal_preshuffle.cpp)
|
||||
endif()
|
||||
if(SUPPORTED_GPU_TARGETS MATCHES "gfx95")
|
||||
list(APPEND PROFILER_OPS profile_gemm_mx.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_gemm_mx.cpp)
|
||||
endif()
|
||||
list(APPEND PROFILER_OPS profile_batched_gemm_reduce.cpp)
|
||||
list(APPEND PROFILER_OPS profile_gemm_add_multiply.cpp)
|
||||
list(APPEND PROFILER_OPS profile_gemm_bias_add_reduce.cpp)
|
||||
list(APPEND PROFILER_OPS profile_gemm_splitk.cpp)
|
||||
list(APPEND PROFILER_OPS profile_batched_gemm_b_scale.cpp)
|
||||
list(APPEND PROFILER_OPS profile_gemm_universal_batched.cpp)
|
||||
list(APPEND PROFILER_OPS profile_gemm_universal_reduce.cpp)
|
||||
list(APPEND PROFILER_OPS profile_gemm_universal_streamk.cpp)
|
||||
list(APPEND PROFILER_OPS profile_conv_fwd_bias_relu.cpp)
|
||||
list(APPEND PROFILER_OPS profile_conv_fwd_bias_relu_add.cpp)
|
||||
list(APPEND PROFILER_OPS profile_conv_bwd_data.cpp)
|
||||
list(APPEND PROFILER_OPS profile_conv_fwd.cpp)
|
||||
list(APPEND PROFILER_OPS profile_grouped_conv_fwd_outelementop.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_batched_gemm_reduce.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_gemm_add_multiply.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_gemm_bias_add_reduce.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_gemm_splitk.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_batched_gemm_b_scale.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_gemm_universal_batched.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_gemm_universal_reduce.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_gemm_universal_streamk.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_conv_fwd_bias_relu.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_conv_fwd_bias_relu_add.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_conv_bwd_data.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_conv_fwd.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_grouped_conv_fwd_outelementop.cpp)
|
||||
|
||||
endif()
|
||||
|
||||
if((SUPPORTED_GPU_TARGETS MATCHES "gfx9" AND (DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)) OR
|
||||
(SUPPORTED_GPU_TARGETS MATCHES "gfx1[12]" AND (DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)))
|
||||
list(APPEND PROFILER_OPS profile_gemm_bilinear.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_gemm_bilinear.cpp)
|
||||
endif()
|
||||
|
||||
if(SUPPORTED_GPU_TARGETS MATCHES "gfx11" OR SUPPORTED_GPU_TARGETS MATCHES "gfx12" OR SUPPORTED_GPU_TARGETS MATCHES "gfx9")
|
||||
list(APPEND PROFILER_OPS profile_gemm_universal.cpp)
|
||||
list(APPEND PROFILER_OPS profile_batched_gemm.cpp)
|
||||
list(APPEND PROFILER_OPS profile_gemm_b_scale.cpp)
|
||||
list(APPEND PROFILER_OPS profile_grouped_conv_fwd.cpp)
|
||||
list(APPEND PROFILER_OPS profile_grouped_conv_fwd_bias_clamp.cpp)
|
||||
list(APPEND PROFILER_OPS profile_grouped_conv_fwd_clamp.cpp)
|
||||
list(APPEND PROFILER_OPS profile_grouped_conv_bwd_data.cpp)
|
||||
list(APPEND PROFILER_OPS profile_grouped_conv_bwd_weight.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_gemm_universal.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_batched_gemm.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_gemm_b_scale.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_grouped_conv_fwd.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_grouped_conv_fwd_bias_clamp.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_grouped_conv_fwd_clamp.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_grouped_conv_bwd_data.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_grouped_conv_bwd_weight.cpp)
|
||||
endif()
|
||||
|
||||
if(DL_KERNELS)
|
||||
list(APPEND PROFILER_OPS profile_batched_gemm_multi_d.cpp)
|
||||
list(APPEND PROFILER_OPS profile_grouped_conv_bwd_weight.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_batched_gemm_multi_d.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_grouped_conv_bwd_weight.cpp)
|
||||
endif()
|
||||
|
||||
set(PROFILER_SOURCES profiler.cpp)
|
||||
@@ -129,103 +129,103 @@ endif()
|
||||
|
||||
|
||||
set(DEVICE_INSTANCES "")
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_normalization_fwd_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_normalization_bwd_data_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_normalization_bwd_gamma_beta_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_softmax_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_reduce_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_batchnorm_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_pool2d_fwd_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_pool3d_fwd_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_avg_pool2d_bwd_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_avg_pool3d_bwd_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_max_pool_bwd_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_image_to_column_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_column_to_image_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_transpose_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_permute_scale_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_gemm_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_normalization_fwd_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_normalization_bwd_data_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_normalization_bwd_gamma_beta_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_softmax_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_reduce_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_batchnorm_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_pool2d_fwd_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_pool3d_fwd_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_avg_pool2d_bwd_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_avg_pool3d_bwd_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_max_pool_bwd_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_image_to_column_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_column_to_image_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_transpose_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_permute_scale_instance)
|
||||
|
||||
if(SUPPORTED_GPU_TARGETS MATCHES "gfx9")
|
||||
if(DTYPES MATCHES "fp32" OR DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES)
|
||||
list(APPEND DEVICE_INSTANCES device_contraction_bilinear_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_contraction_scale_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_contraction_bilinear_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_contraction_scale_instance)
|
||||
endif()
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_add_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_add_add_fastgelu_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_fastgelu_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_batched_gemm_gemm_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_batched_gemm_add_relu_gemm_add_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_gemm_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_streamk_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_add_fastgelu_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_add_relu_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_add_silu_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_add_relu_add_layernorm_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_gemm_fixed_nk_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_gemm_fastgelu_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_gemm_tile_loop_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_gemm_add_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_gemm_add_add_fastgelu_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_gemm_fastgelu_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_batched_gemm_gemm_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_batched_gemm_add_relu_gemm_add_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_grouped_gemm_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_gemm_streamk_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_gemm_add_fastgelu_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_gemm_add_relu_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_gemm_add_silu_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_gemm_add_relu_add_layernorm_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_grouped_gemm_fixed_nk_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_grouped_gemm_fastgelu_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_grouped_gemm_tile_loop_instance)
|
||||
endif()
|
||||
list(APPEND DEVICE_INSTANCES device_batched_gemm_reduce_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_multiply_add_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_batched_gemm_reduce_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_gemm_multiply_add_instance)
|
||||
if(SUPPORTED_GPU_TARGETS MATCHES "gfx9[45]")
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_multiply_multiply_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_multiply_multiply_wp_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_gemm_multiply_multiply_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_gemm_multiply_multiply_wp_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_ab_scale_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_blockscale_wp_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_universal_preshuffle_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_gemm_blockscale_wp_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_gemm_universal_preshuffle_instance)
|
||||
endif()
|
||||
if(SUPPORTED_GPU_TARGETS MATCHES "gfx95")
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_mx_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_gemm_mx_instance)
|
||||
endif()
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_splitk_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_batched_gemm_b_scale_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_universal_batched_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_universal_reduce_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_universal_streamk_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_add_multiply_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_reduce_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_bias_add_reduce_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_conv2d_fwd_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_conv2d_fwd_bias_relu_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_conv2d_fwd_bias_relu_add_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_conv1d_fwd_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_conv1d_bwd_data_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_conv3d_bwd_data_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_conv2d_bwd_data_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_conv1d_bwd_weight_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_conv2d_bwd_weight_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_convnd_bwd_weight_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_convscale_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_convinvscale_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_conv2d_fwd_clamp_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_clamp_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_conv2d_fwd_bias_clamp_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_bias_clamp_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_gemm_splitk_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_batched_gemm_b_scale_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_gemm_universal_batched_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_gemm_universal_reduce_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_gemm_universal_streamk_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_gemm_add_multiply_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_gemm_reduce_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_gemm_bias_add_reduce_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_conv2d_fwd_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_conv2d_fwd_bias_relu_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_conv2d_fwd_bias_relu_add_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_grouped_conv1d_fwd_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_conv1d_bwd_data_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_conv3d_bwd_data_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_conv2d_bwd_data_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_grouped_conv1d_bwd_weight_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_grouped_conv2d_bwd_weight_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_grouped_convnd_bwd_weight_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_convscale_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_convinvscale_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_grouped_conv2d_fwd_clamp_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_clamp_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_grouped_conv2d_fwd_bias_clamp_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_bias_clamp_instance)
|
||||
endif()
|
||||
|
||||
if((SUPPORTED_GPU_TARGETS MATCHES "gfx9" AND (DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)) OR
|
||||
(SUPPORTED_GPU_TARGETS MATCHES "gfx1[12]" AND (DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)))
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_bilinear_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_gemm_bilinear_instance)
|
||||
endif()
|
||||
|
||||
if(SUPPORTED_GPU_TARGETS MATCHES "gfx9" OR SUPPORTED_GPU_TARGETS MATCHES "gfx1[12]")
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_universal_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_batched_gemm_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_b_scale_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_conv2d_bwd_data_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_bwd_data_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_conv2d_fwd_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_bwd_weight_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_gemm_universal_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_batched_gemm_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_gemm_b_scale_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_grouped_conv2d_bwd_data_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_grouped_conv3d_bwd_data_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_grouped_conv2d_fwd_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_grouped_conv3d_bwd_weight_instance)
|
||||
endif()
|
||||
|
||||
if(DL_KERNELS)
|
||||
list(APPEND DEVICE_INSTANCES device_batched_gemm_multi_d_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_conv1d_bwd_weight_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_conv2d_bwd_weight_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_bwd_weight_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_batched_gemm_multi_d_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_grouped_conv1d_bwd_weight_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_grouped_conv2d_bwd_weight_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_grouped_conv3d_bwd_weight_instance)
|
||||
endif()
|
||||
|
||||
set(PROFILER_LIBS utility getopt::getopt)
|
||||
|
||||
Reference in New Issue
Block a user