Add other layouts for FP8 block scaled gemm (#2665)

* Start adding other layouts for gemm_ab_scale

* Add some instances

* Create tensor descriptors for A/B scales depending on A/B layout

* Fix formatting

* Revert some comments

* Revert commented instances in CMakeLists.txt

* Add some more instances for col-row gemm

* enable more row,row instances

* Use occupancy=1 for col,row layout to avoid spills
This commit is contained in:
Sami Remes
2025-08-18 11:46:10 +03:00
committed by GitHub
parent 7310830d14
commit 26d3300930
15 changed files with 758 additions and 13 deletions

View File

@@ -231,11 +231,22 @@ struct DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3
}
};
constexpr index_t minimum_occupancy =
(BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave &&
MPerBlock * NPerBlock / BlockSize > 64)
? 1
: 2;
constexpr index_t minimum_occupancy = [&]() {
if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout> &&
is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
{
// FIXME: many instances have many spills with occupancy > 1, a better solution
// needed to get best performance
return 1;
}
else
{
return (BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave &&
MPerBlock * NPerBlock / BlockSize > 64)
? 1
: 2;
}
}();
if(has_main_k_block_loop)
{

View File

@@ -402,6 +402,34 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
}
}
__host__ __device__ static constexpr auto MakeAScaleGridDesciptor_M_K(index_t M, index_t K)
{
const auto BM = math::integer_divide_ceil(M, ScaleBlockM);
const auto BK = math::integer_divide_ceil(K, ScaleBlockK);
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(BM, BK), make_tuple(BK, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ALayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(BM, BK), make_tuple(I1, BM));
}
}
__host__ __device__ static constexpr auto MakeBScaleGridDesciptor_N_K(index_t N, index_t K)
{
const auto BN = math::integer_divide_ceil(N, ScaleBlockN);
const auto BK = math::integer_divide_ceil(K, ScaleBlockK);
if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(BN, BK), make_tuple(BK, I1));
}
else if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(BN, BK), make_tuple(I1, BN));
}
}
template <typename ABlockDesc_AK0_M_AK1>
__host__ __device__ static constexpr auto
MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&)
@@ -1181,14 +1209,8 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor(
make_tuple(math::integer_divide_ceil(problem.M, ScaleBlockM),
math::integer_divide_ceil(problem.K, ScaleBlockK)),
make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockK), 1));
const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor(
make_tuple(math::integer_divide_ceil(problem.N, ScaleBlockN),
math::integer_divide_ceil(problem.K, ScaleBlockK)),
make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockK), 1));
const auto a_scale_grid_desc_am_ak = MakeAScaleGridDesciptor_M_K(problem.M, problem.K);
const auto b_scale_grid_desc_bn_ak = MakeBScaleGridDesciptor_N_K(problem.N, problem.K);
const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(

View File

@@ -17,6 +17,7 @@ namespace tensor_operation {
namespace device {
namespace instance {
#if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_FP8))
// Row, Col
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_comp_default_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Row,
Col,
@@ -88,6 +89,152 @@ void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_mem_v1_kpadding_
PassThrough,
PassThrough,
PassThrough>>>& instances);
// Row, Row
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_kn_mn_1_128_128_comp_default_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Row,
Row,
Tuple<>,
Row,
F8,
F32,
F8,
F32,
Tuple<>,
BF16,
1,
128,
128,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_kn_mn_1_128_128_comp_kpadding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Row,
Row,
Tuple<>,
Row,
F8,
F32,
F8,
F32,
Tuple<>,
BF16,
1,
128,
128,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_kn_mn_1_128_128_mem_v1_default_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Row,
Row,
Tuple<>,
Row,
F8,
F32,
F8,
F32,
Tuple<>,
BF16,
1,
128,
128,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_kn_mn_1_128_128_mem_v1_kpadding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Row,
Row,
Tuple<>,
Row,
F8,
F32,
F8,
F32,
Tuple<>,
BF16,
1,
128,
128,
PassThrough,
PassThrough,
PassThrough>>>& instances);
// Col, Row
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_km_kn_mn_1_128_128_comp_default_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Col,
Row,
Tuple<>,
Row,
F8,
F32,
F8,
F32,
Tuple<>,
BF16,
1,
128,
128,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_km_kn_mn_1_128_128_comp_kpadding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Col,
Row,
Tuple<>,
Row,
F8,
F32,
F8,
F32,
Tuple<>,
BF16,
1,
128,
128,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_km_kn_mn_1_128_128_mem_v1_default_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Col,
Row,
Tuple<>,
Row,
F8,
F32,
F8,
F32,
Tuple<>,
BF16,
1,
128,
128,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_km_kn_mn_1_128_128_mem_v1_kpadding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Col,
Row,
Tuple<>,
Row,
F8,
F32,
F8,
F32,
Tuple<>,
BF16,
1,
128,
128,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
template <typename A0DataType,
@@ -154,6 +301,32 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_mem_v1_kpadding_instances(
op_ptrs);
}
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_kn_mn_1_128_128_comp_default_instances(
op_ptrs);
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_kn_mn_1_128_128_comp_kpadding_instances(
op_ptrs);
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_kn_mn_1_128_128_mem_v1_default_instances(
op_ptrs);
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_kn_mn_1_128_128_mem_v1_kpadding_instances(
op_ptrs);
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_ab_scale_xdl_f8_f8_bf16_km_kn_mn_1_128_128_comp_default_instances(
op_ptrs);
add_device_gemm_ab_scale_xdl_f8_f8_bf16_km_kn_mn_1_128_128_comp_kpadding_instances(
op_ptrs);
add_device_gemm_ab_scale_xdl_f8_f8_bf16_km_kn_mn_1_128_128_mem_v1_default_instances(
op_ptrs);
add_device_gemm_ab_scale_xdl_f8_f8_bf16_km_kn_mn_1_128_128_mem_v1_kpadding_instances(
op_ptrs);
}
}
#endif
return op_ptrs;

View File

@@ -2,15 +2,37 @@
set(GEMM_AB_SCALE_INSTANCES)
list(APPEND GEMM_AB_SCALE_INSTANCES
# Row, Col
device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_instance.cpp
device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_instance.cpp
device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_default_instance.cpp
device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_kpadding_instance.cpp
# Row, Row
device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_kn_mn_128_128_128_comp_default_instance.cpp
device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_kn_mn_128_128_128_comp_kpadding_instance.cpp
device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_kn_mn_128_128_128_mem_v1_default_instance.cpp
device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_kn_mn_128_128_128_mem_v1_kpadding_instance.cpp
# Col, Row
device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_km_kn_mn_128_128_128_comp_default_instance.cpp
device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_km_kn_mn_128_128_128_comp_kpadding_instance.cpp
device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_km_kn_mn_128_128_128_mem_v1_default_instance.cpp
device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_km_kn_mn_128_128_128_mem_v1_kpadding_instance.cpp
)
# Row, Col
set_source_files_properties(device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
set_source_files_properties(device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
set_source_files_properties(device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
set_source_files_properties(device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
# Row, Row
set_source_files_properties(device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_kn_mn_128_128_128_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
set_source_files_properties(device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_kn_mn_128_128_128_comp_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
set_source_files_properties(device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_kn_mn_128_128_128_mem_v1_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
set_source_files_properties(device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_kn_mn_128_128_128_mem_v1_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
# Col, Row
set_source_files_properties(device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_km_kn_mn_128_128_128_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
set_source_files_properties(device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_km_kn_mn_128_128_128_comp_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
set_source_files_properties(device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_km_kn_mn_128_128_128_mem_v1_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
set_source_files_properties(device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_km_kn_mn_128_128_128_mem_v1_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
add_instance_library(device_gemm_ab_scale_instance ${GEMM_AB_SCALE_INSTANCES})

View File

@@ -0,0 +1,96 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_ab_scale.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using F8 = f8_t;
using BF16 = bhalf_t;
using F32 = float;
using Row = tensor_layout::gemm::RowMajor;
using Col = tensor_layout::gemm::ColumnMajor;
template <index_t... Is>
using S = Sequence<Is...>;
using PassThrough = element_wise::PassThrough;
using PassThrough = element_wise::PassThrough;
static constexpr auto GemmDefault = GemmSpecialization::Default;
static constexpr auto GemmKPadding = GemmSpecialization::KPadding;
static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding;
static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding;
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
template <GemmSpecialization GemmSpec>
using device_gemm_ab_scale_xdl_f8_f8_bf16_km_kn_mn_1_128_128_comp_instances =
std::tuple<
// clang-format off
//################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| Scale| Scale| Scale| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
//################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
//################################| | | | | | | | | | | Operation| Operation| Operation| | | M| N| K| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// Compute friendly
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 128, 128, 4, 4, 32, 32, 2, 2, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 128, 128, 4, 4, 16, 16, 4, 4, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 2, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 128, 128, 4, 4, 32, 32, 2, 2, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 128, 128, 4, 4, 32, 32, 2, 2, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 64, 128, 4, 4, 32, 32, 2, 1, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 128, 128, 4, 4, 32, 32, 1, 2, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 64, 128, 4, 4, 32, 32, 1, 1, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>
// clang-format on
>;
template <BlockGemmPipelineScheduler BlkGemmPipeSched, GemmSpecialization GemmSpec>
using device_gemm_ab_scale_xdl_f8_f8_bf16_km_kn_mn_1_128_128_mem_instances = std::tuple<
// clang-format off
//################################| ALayout| BLayout| DsLayout| ELayout|AData | BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| Scale| Scale| Scale| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
//################################| | | | | Type | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
//################################| | | | | | | | | | | Operation| Operation| Operation| | | M| N| K| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// Memory friendly
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 16, 256, 128, 4, 4, 16, 16, 1, 4, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 2, S<1, 16, 1, 16>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>,
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 16, 128, 128, 4, 4, 16, 16, 1, 2, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 2, S<1, 16, 1, 16>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>,
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 16, 64, 128, 4, 4, 16, 16, 1, 1, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, 1, 1, S<1, 16, 1, 16>, S<4>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>,
// FIXME: KPerBlock=256 give numerically bad results:
// DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 16, 128, 256, 4, 4, 16, 16, 1, 2, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 2, S<1, 16, 1, 16>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>,
// DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 16, 64, 256, 4, 4, 16, 16, 1, 1, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 16>, S<4>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>,
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 32, 256, 128, 4, 4, 32, 32, 1, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>,
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 32, 128, 128, 4, 4, 32, 32, 1, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>,
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 32, 64, 128, 4, 4, 16, 16, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, 2, 1, S<1, 32, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>,
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 32, 256, 128, 4, 4, 32, 32, 1, 2, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>,
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 32, 128, 128, 4, 4, 32, 32, 1, 1, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>,
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 32, 64, 128, 4, 4, 16, 16, 2, 1, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 2, 0, 2, 1, S<1, 32, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>,
// FIXME: KPerBlock=256 give numerically bad results:
// DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 32, 128, 256, 4, 4, 32, 32, 1, 1, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>,
// DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 32, 64, 256, 4, 4, 16, 16, 2, 1, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>,
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 256, 128, 4, 4, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>,
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 128, 128, 4, 4, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>,
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 64, 128, 4, 4, 32, 32, 1, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>
// DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 128, 256, 4, 4, 32, 32, 2, 1, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>,
// DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 64, 256, 4, 4, 32, 32, 1, 1, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>
// clang-format on
>;
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,37 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_ab_scale_xdl_f8_f8_bf16_km_kn_mn_128_128_128.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_km_kn_mn_1_128_128_comp_default_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Col,
Row,
Tuple<>,
Row,
F8,
F32,
F8,
F32,
Tuple<>,
BF16,
1,
128,
128,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_ab_scale_xdl_f8_f8_bf16_km_kn_mn_1_128_128_comp_instances<GemmDefault>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,37 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_ab_scale_xdl_f8_f8_bf16_km_kn_mn_128_128_128.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_km_kn_mn_1_128_128_comp_kpadding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Col,
Row,
Tuple<>,
Row,
F8,
F32,
F8,
F32,
Tuple<>,
BF16,
1,
128,
128,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_ab_scale_xdl_f8_f8_bf16_km_kn_mn_1_128_128_comp_instances<GemmKPadding>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,38 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_ab_scale_xdl_f8_f8_bf16_km_kn_mn_128_128_128.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_km_kn_mn_1_128_128_mem_v1_default_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Col,
Row,
Tuple<>,
Row,
F8,
F32,
F8,
F32,
Tuple<>,
BF16,
1,
128,
128,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_ab_scale_xdl_f8_f8_bf16_km_kn_mn_1_128_128_mem_instances<Intrawave,
GemmDefault>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,38 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_ab_scale_xdl_f8_f8_bf16_km_kn_mn_128_128_128.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_km_kn_mn_1_128_128_mem_v1_kpadding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Col,
Row,
Tuple<>,
Row,
F8,
F32,
F8,
F32,
Tuple<>,
BF16,
1,
128,
128,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_ab_scale_xdl_f8_f8_bf16_km_kn_mn_1_128_128_mem_instances<Intrawave,
GemmKPadding>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,87 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_ab_scale.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using F8 = f8_t;
using BF16 = bhalf_t;
using F32 = float;
using Row = tensor_layout::gemm::RowMajor;
using Col = tensor_layout::gemm::ColumnMajor;
template <index_t... Is>
using S = Sequence<Is...>;
using PassThrough = element_wise::PassThrough;
using PassThrough = element_wise::PassThrough;
static constexpr auto GemmDefault = GemmSpecialization::Default;
static constexpr auto GemmKPadding = GemmSpecialization::KPadding;
static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding;
static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding;
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
template <GemmSpecialization GemmSpec>
using device_gemm_ab_scale_xdl_f8_f8_bf16_mk_kn_mn_1_128_128_comp_instances =
std::tuple<
// clang-format off
//################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| Scale| Scale| Scale| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
//################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
//################################| | | | | | | | | | | Operation| Operation| Operation| | | M| N| K| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// Compute friendly
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Row, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 128, 128, 16, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Row, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 64, 128, 16, 4, 32, 32, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Row, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 128, 128, 16, 4, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Row, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 64, 128, 16, 4, 32, 32, 1, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>
// clang-format on
>;
template <BlockGemmPipelineScheduler BlkGemmPipeSched, GemmSpecialization GemmSpec>
using device_gemm_ab_scale_xdl_f8_f8_bf16_mk_kn_mn_1_128_128_mem_instances = std::tuple<
// clang-format off
//################################| ALayout| BLayout| DsLayout| ELayout|AData | BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| Scale| Scale| Scale| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
//################################| | | | | Type | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
//################################| | | | | | | | | | | Operation| Operation| Operation| | | M| N| K| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// Memory friendly
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Row, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 16, 256, 128, 8, 4, 16, 16, 1, 4, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 2, S<1, 16, 1, 16>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>,
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Row, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 16, 128, 128, 8, 4, 16, 16, 1, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 2, S<1, 16, 1, 16>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>,
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Row, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 16, 64, 128, 8, 4, 16, 16, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, 1, 1, S<1, 16, 1, 16>, S<4>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>,
// DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Row, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 16, 128, 256, 16, 16, 16, 16, 1, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 16, 1, 16>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>,
// DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Row, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 16, 64, 256, 16, 16, 16, 16, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<4>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>,
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Row, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 32, 256, 128, 16, 4, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<0, 2, 1>, S<1, 0, 2>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>,
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Row, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 32, 128, 128, 16, 4, 32, 32, 1, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<0, 2, 1>, S<1, 0, 2>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>,
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Row, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 32, 64, 128, 16, 4, 16, 16, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<0, 2, 1>, S<1, 0, 2>, 1, 2, 2, 0, 2, 1, S<1, 32, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>,
// DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Row, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 32, 128, 256, 16, 16, 32, 32, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>,
// DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Row, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 32, 64, 256, 16, 16, 16, 16, 2, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>,
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Row, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 256, 128, 16, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>,
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Row, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 128, 128, 16, 4, 32, 32, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>,
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Row, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 64, 128, 16, 4, 32, 32, 1, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>
// DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Row, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 128, 256, 16, 16, 32, 32, 2, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>,
// DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Row, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 64, 256, 16, 16, 32, 32, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>
// clang-format on
>;
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,37 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_ab_scale_xdl_f8_f8_bf16_mk_kn_mn_128_128_128.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_kn_mn_1_128_128_comp_default_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Row,
Row,
Tuple<>,
Row,
F8,
F32,
F8,
F32,
Tuple<>,
BF16,
1,
128,
128,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_ab_scale_xdl_f8_f8_bf16_mk_kn_mn_1_128_128_comp_instances<GemmDefault>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,37 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_ab_scale_xdl_f8_f8_bf16_mk_kn_mn_128_128_128.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_kn_mn_1_128_128_comp_kpadding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Row,
Row,
Tuple<>,
Row,
F8,
F32,
F8,
F32,
Tuple<>,
BF16,
1,
128,
128,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_ab_scale_xdl_f8_f8_bf16_mk_kn_mn_1_128_128_comp_instances<GemmKPadding>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,38 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_ab_scale_xdl_f8_f8_bf16_mk_kn_mn_128_128_128.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_kn_mn_1_128_128_mem_v1_default_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Row,
Row,
Tuple<>,
Row,
F8,
F32,
F8,
F32,
Tuple<>,
BF16,
1,
128,
128,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_ab_scale_xdl_f8_f8_bf16_mk_kn_mn_1_128_128_mem_instances<Intrawave,
GemmDefault>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,38 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_ab_scale_xdl_f8_f8_bf16_mk_kn_mn_128_128_128.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_kn_mn_1_128_128_mem_v1_kpadding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Row,
Row,
Tuple<>,
Row,
F8,
F32,
F8,
F32,
Tuple<>,
BF16,
1,
128,
128,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_ab_scale_xdl_f8_f8_bf16_mk_kn_mn_1_128_128_mem_instances<Intrawave,
GemmKPadding>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -173,6 +173,40 @@ int profile_gemm_ab_scale(int argc, char* argv[])
Col{},
Row{});
}
else if(data_type == GemmDataType::F8_F8_BF16 && layout == GemmMatrixLayout::MK_KN_MN &&
scale_block_tile == ScaleBlockTile::Tile_1_128_128)
{
return profile(F8{},
F32{},
F8{},
F32{},
F8{},
F32{},
BF16{},
ck::Number<1>{},
ck::Number<128>{},
ck::Number<128>{},
Row{},
Row{},
Row{});
}
else if(data_type == GemmDataType::F8_F8_BF16 && layout == GemmMatrixLayout::KM_KN_MN &&
scale_block_tile == ScaleBlockTile::Tile_1_128_128)
{
return profile(F8{},
F32{},
F8{},
F32{},
F8{},
F32{},
BF16{},
ck::Number<1>{},
ck::Number<128>{},
ck::Number<128>{},
Col{},
Row{},
Row{});
}
else
{
std::cout << "this data_type & layout is not implemented" << std::endl;