Wmma support for gemm_multiply_multiply_wp (#3278)

* Initial implementation with splitK support

* Add gfx11 support

* Fix compilation error

* Add instances

* Add irregular instances

* Fix GetBuffer arguments

* Minor changes

* Address review comments

* Fix compilation errors

* Fix copyright header
This commit is contained in:
Enrico Degregori
2025-12-03 16:38:23 +01:00
committed by GitHub
parent f29b67cf9b
commit 161835533b
30 changed files with 2482 additions and 86 deletions

View File

@@ -31,6 +31,7 @@ using TGemmMulMulF8F8F16Instances =
PassThrough,
MultiplyMultiply>>>;
#ifdef CK_USE_XDL
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_compute_default_instances_p1(
TGemmMulMulF8F8F16Instances& instances);
@@ -86,6 +87,21 @@ void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma16
TGemmMulMulF8F8F16Instances& instances);
#endif
#ifdef CK_USE_WMMA
void add_device_gemm_multiply_multiply_weight_preshuffle_wmma_f8_f8_f16_mk_wmma_mn_default_instances_p1(
TGemmMulMulF8F8F16Instances& instances);
void add_device_gemm_multiply_multiply_weight_preshuffle_wmma_f8_f8_f16_mk_wmma_mn_default_instances_p2(
TGemmMulMulF8F8F16Instances& instances);
void add_device_gemm_multiply_multiply_weight_preshuffle_wmma_f8_f8_f16_mk_wmma_mn_default_instances_p3(
TGemmMulMulF8F8F16Instances& instances);
void add_device_gemm_multiply_multiply_weight_preshuffle_wmma_f8_f8_f16_mk_wmma_mn_default_instances_p4(
TGemmMulMulF8F8F16Instances& instances);
#endif
#endif
#if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_FP8))
using TGemmMulMulF8F8BF16Instances =
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
@@ -100,6 +116,7 @@ using TGemmMulMulF8F8BF16Instances =
PassThrough,
MultiplyMultiply>>>;
#ifdef CK_USE_XDL
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_compute_default_instances_p1(
TGemmMulMulF8F8BF16Instances& instances);
@@ -153,7 +170,21 @@ void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma1
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma16x16_mn_compute_default_instances_p6(
TGemmMulMulF8F8BF16Instances& instances);
#endif
#ifdef CK_USE_WMMA
void add_device_gemm_multiply_multiply_weight_preshuffle_wmma_f8_f8_bf16_mk_wmma_mn_default_instances_p1(
TGemmMulMulF8F8BF16Instances& instances);
void add_device_gemm_multiply_multiply_weight_preshuffle_wmma_f8_f8_bf16_mk_wmma_mn_default_instances_p2(
TGemmMulMulF8F8BF16Instances& instances);
void add_device_gemm_multiply_multiply_weight_preshuffle_wmma_f8_f8_bf16_mk_wmma_mn_default_instances_p3(
TGemmMulMulF8F8BF16Instances& instances);
void add_device_gemm_multiply_multiply_weight_preshuffle_wmma_f8_f8_bf16_mk_wmma_mn_default_instances_p4(
TGemmMulMulF8F8BF16Instances& instances);
#endif
#endif
template <typename ADataType,
@@ -200,6 +231,7 @@ struct DeviceOperationInstanceFactory<
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<CLayout, Row>)
{
#ifdef CK_USE_XDL
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instances_p1(
op_ptrs);
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instances_p2(
@@ -237,6 +269,17 @@ struct DeviceOperationInstanceFactory<
op_ptrs);
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p5_default_instances_v2(
op_ptrs);
#endif
#ifdef CK_USE_WMMA
add_device_gemm_multiply_multiply_weight_preshuffle_wmma_f8_f8_f16_mk_wmma_mn_default_instances_p1(
op_ptrs);
add_device_gemm_multiply_multiply_weight_preshuffle_wmma_f8_f8_f16_mk_wmma_mn_default_instances_p2(
op_ptrs);
add_device_gemm_multiply_multiply_weight_preshuffle_wmma_f8_f8_f16_mk_wmma_mn_default_instances_p3(
op_ptrs);
add_device_gemm_multiply_multiply_weight_preshuffle_wmma_f8_f8_f16_mk_wmma_mn_default_instances_p4(
op_ptrs);
#endif
}
}
#endif
@@ -248,6 +291,7 @@ struct DeviceOperationInstanceFactory<
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<CLayout, Row>)
{
#ifdef CK_USE_XDL
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma16x16_mn_compute_default_instances_p1(
op_ptrs);
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma16x16_mn_compute_default_instances_p2(
@@ -285,6 +329,17 @@ struct DeviceOperationInstanceFactory<
op_ptrs);
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p5_default_instances_v2(
op_ptrs);
#endif
#ifdef CK_USE_WMMA
add_device_gemm_multiply_multiply_weight_preshuffle_wmma_f8_f8_bf16_mk_wmma_mn_default_instances_p1(
op_ptrs);
add_device_gemm_multiply_multiply_weight_preshuffle_wmma_f8_f8_bf16_mk_wmma_mn_default_instances_p2(
op_ptrs);
add_device_gemm_multiply_multiply_weight_preshuffle_wmma_f8_f8_bf16_mk_wmma_mn_default_instances_p3(
op_ptrs);
add_device_gemm_multiply_multiply_weight_preshuffle_wmma_f8_f8_bf16_mk_wmma_mn_default_instances_p4(
op_ptrs);
#endif
}
}
#endif

View File

@@ -1,7 +1,7 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
# ONLY XDL_KERNELS
# ONLY XDL_AND_WMMA_KERNELS
set(GEMM_MULTIPLY_MULTIPLY_WEIGHT_PRESHUFFLE_INSTANCES)
list(APPEND GEMM_MULTIPLY_MULTIPLY_WEIGHT_PRESHUFFLE_INSTANCES
@@ -42,6 +42,15 @@ list(APPEND GEMM_MULTIPLY_MULTIPLY_WEIGHT_PRESHUFFLE_INSTANCES
f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instance_p4.cpp
f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instance_p5.cpp
f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instance_p6.cpp
device_gemm_multiply_multiply_wp_wmma_f8_f8_bf16_mk_wmma_mn_default_instance_p1.cpp
device_gemm_multiply_multiply_wp_wmma_f8_f8_bf16_mk_wmma_mn_default_instance_p2.cpp
device_gemm_multiply_multiply_wp_wmma_f8_f8_bf16_mk_wmma_mn_default_instance_p3.cpp
device_gemm_multiply_multiply_wp_wmma_f8_f8_bf16_mk_wmma_mn_default_instance_p4.cpp
device_gemm_multiply_multiply_wp_wmma_f8_f8_f16_mk_wmma_mn_default_instance_p1.cpp
device_gemm_multiply_multiply_wp_wmma_f8_f8_f16_mk_wmma_mn_default_instance_p2.cpp
device_gemm_multiply_multiply_wp_wmma_f8_f8_f16_mk_wmma_mn_default_instance_p3.cpp
device_gemm_multiply_multiply_wp_wmma_f8_f8_f16_mk_wmma_mn_default_instance_p4.cpp
)
set_source_files_properties(f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn_p1_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")

View File

@@ -0,0 +1,105 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3_b_preshuffle.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using F8 = f8_t;
using BF16 = bhalf_t;
using F32 = float;
using Row = tensor_layout::gemm::RowMajor;
using Col = tensor_layout::gemm::ColumnMajor;
template <index_t... Is>
using S = Sequence<Is...>;
using PassThrough = element_wise::PassThrough;
using MultiplyMultiply = element_wise::MultiplyMultiply;
static constexpr auto GemmDefault = GemmSpecialization::Default;
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
static constexpr auto v1 = BlockGemmPipelineVersion::v1;
template <GemmSpecialization GemmSpec>
using device_gemm_multiply_multiply_weight_preshuffle_wmma_f8_f8_bf16_mk_wmma_mn_instances_p1 =
std::tuple<
// clang-format off
//###########################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CDEBlockTransferClusterLengths| CDEBlockTransfer| BlkGemmPipeSched| Block-wiseGemm|
//###########################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock| ScalarPerVector| | Pipeline|
//###########################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NWaveNPerXdl| | Verision|
//###########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmMultiD_Wmma_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 32, 128, 128, 16, 16, 16, 16, 2, 1, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 8, 1, 32>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, v1, F8>,
DeviceGemmMultiD_Wmma_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 64, 128, 128, 16, 16, 16, 16, 4, 1, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, v1, F8>,
DeviceGemmMultiD_Wmma_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 64, 256, 128, 16, 16, 16, 16, 4, 2, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, v1, F8>,
DeviceGemmMultiD_Wmma_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 32, 256, 128, 16, 16, 16, 16, 2, 2, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, v1, F8>,
DeviceGemmMultiD_Wmma_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 128, 16, 16, 16, 16, 8, 1, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, v1, F8>
// clang-format on
>;
template <GemmSpecialization GemmSpec>
using device_gemm_multiply_multiply_weight_preshuffle_wmma_f8_f8_bf16_mk_wmma_mn_instances_p2 =
std::tuple<
// clang-format off
//###########################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CDEBlockTransferClusterLengths| CDEBlockTransfer| BlkGemmPipeSched| Block-wiseGemm|
//###########################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock| ScalarPerVector| | Pipeline|
//###########################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NWaveNPerXdl| | Verision|
//###########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmMultiD_Wmma_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 256, 128, 16, 16, 16, 16, 8, 2, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, v1, F8>,
DeviceGemmMultiD_Wmma_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 128, 128, 16, 16, 16, 16, 8, 2, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, v1, F8>,
DeviceGemmMultiD_Wmma_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 128, 128, 16, 16, 16, 16, 16, 1, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, v1, F8>,
DeviceGemmMultiD_Wmma_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 256, 64, 128, 16, 16, 16, 16, 16, 1, S< 8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S< 8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, v1, F8>,
DeviceGemmMultiD_Wmma_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 128, 64, 128, 16, 16, 16, 16, 8, 1, S< 8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S< 8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, v1, F8>
// clang-format on
>;
template <GemmSpecialization GemmSpec>
using device_gemm_multiply_multiply_weight_preshuffle_wmma_f8_f8_bf16_mk_wmma_mn_instances_p3 =
std::tuple<
// clang-format off
//###########################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CDEBlockTransferClusterLengths| CDEBlockTransfer| BlkGemmPipeSched| Block-wiseGemm|
//###########################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock| ScalarPerVector| | Pipeline|
//###########################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NWaveNPerXdl| | Verision|
//###########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmMultiD_Wmma_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 32, 128, 256, 16, 16, 16, 16, 2, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, v1, F8>,
DeviceGemmMultiD_Wmma_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 64, 128, 256, 16, 16, 16, 16, 4, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, v1, F8>,
DeviceGemmMultiD_Wmma_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 16, 128, 256, 16, 16, 16, 16, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, v1, F8>,
DeviceGemmMultiD_Wmma_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 16, 256, 256, 16, 16, 16, 16, 1, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, v1, F8>,
DeviceGemmMultiD_Wmma_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 16, 512, 256, 16, 16, 16, 16, 1, 4, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, v1, F8>,
DeviceGemmMultiD_Wmma_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 256, 16, 16, 16, 16, 8, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, v1, F8>,
DeviceGemmMultiD_Wmma_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 128, 64, 256, 16, 16, 16, 16, 8, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, v1, F8>
// clang-format on
>;
template <GemmSpecialization GemmSpec>
using device_gemm_multiply_multiply_weight_preshuffle_wmma_f8_f8_bf16_mk_wmma_mn_instances_p4 =
std::tuple<
// clang-format off
//###########################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CDEBlockTransferClusterLengths| CDEBlockTransfer| BlkGemmPipeSched| Block-wiseGemm|
//###########################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock| ScalarPerVector| | Pipeline|
//###########################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NWaveNPerXdl| | Verision|
//###########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmMultiD_Wmma_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 224, 256, 128, 16, 16, 16, 16, 7, 4, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, v1, F8>,
DeviceGemmMultiD_Wmma_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 224, 128, 128, 16, 16, 16, 16, 7, 2, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, v1, F8>,
DeviceGemmMultiD_Wmma_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 224, 64, 128, 16, 16, 16, 16, 7, 1, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, v1, F8>,
DeviceGemmMultiD_Wmma_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 224, 128, 16, 16, 16, 16, 4, 7, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, v1, F8>,
DeviceGemmMultiD_Wmma_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 224, 128, 16, 16, 16, 16, 2, 7, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, v1, F8>,
DeviceGemmMultiD_Wmma_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 64, 224, 128, 16, 16, 16, 16, 1, 7, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, v1, F8>
// clang-format on
>;
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,33 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "device_gemm_multiply_multiply_wp_wmma_f8_f8_bf16_mk_wmma_mn.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_multiply_multiply_weight_preshuffle_wmma_f8_f8_bf16_mk_wmma_mn_default_instances_p1(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
Col,
Tuple<Row, Col>,
Row,
F8,
F8,
Tuple<F32, F32>,
BF16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_multiply_multiply_weight_preshuffle_wmma_f8_f8_bf16_mk_wmma_mn_instances_p1<
GemmDefault>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,33 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "device_gemm_multiply_multiply_wp_wmma_f8_f8_bf16_mk_wmma_mn.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_multiply_multiply_weight_preshuffle_wmma_f8_f8_bf16_mk_wmma_mn_default_instances_p2(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
Col,
Tuple<Row, Col>,
Row,
F8,
F8,
Tuple<F32, F32>,
BF16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_multiply_multiply_weight_preshuffle_wmma_f8_f8_bf16_mk_wmma_mn_instances_p2<
GemmDefault>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,33 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "device_gemm_multiply_multiply_wp_wmma_f8_f8_bf16_mk_wmma_mn.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_multiply_multiply_weight_preshuffle_wmma_f8_f8_bf16_mk_wmma_mn_default_instances_p3(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
Col,
Tuple<Row, Col>,
Row,
F8,
F8,
Tuple<F32, F32>,
BF16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_multiply_multiply_weight_preshuffle_wmma_f8_f8_bf16_mk_wmma_mn_instances_p3<
GemmDefault>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,33 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "device_gemm_multiply_multiply_wp_wmma_f8_f8_bf16_mk_wmma_mn.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_multiply_multiply_weight_preshuffle_wmma_f8_f8_bf16_mk_wmma_mn_default_instances_p4(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
Col,
Tuple<Row, Col>,
Row,
F8,
F8,
Tuple<F32, F32>,
BF16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_multiply_multiply_weight_preshuffle_wmma_f8_f8_bf16_mk_wmma_mn_instances_p4<
GemmDefault>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,105 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3_b_preshuffle.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using F8 = f8_t;
using F16 = half_t;
using F32 = float;
using Row = tensor_layout::gemm::RowMajor;
using Col = tensor_layout::gemm::ColumnMajor;
template <index_t... Is>
using S = Sequence<Is...>;
using PassThrough = element_wise::PassThrough;
using MultiplyMultiply = element_wise::MultiplyMultiply;
static constexpr auto GemmDefault = GemmSpecialization::Default;
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
static constexpr auto v1 = BlockGemmPipelineVersion::v1;
template <GemmSpecialization GemmSpec>
using device_gemm_multiply_multiply_weight_preshuffle_wmma_f8_f8_f16_mk_wmma_mn_instances_p1 =
std::tuple<
// clang-format off
//###########################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CDEBlockTransferClusterLengths| CDEBlockTransfer| BlkGemmPipeSched| Block-wiseGemm|
//###########################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock| ScalarPerVector| | Pipeline|
//###########################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NWaveNPerXdl| | Verision|
//###########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmMultiD_Wmma_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 32, 128, 128, 16, 16, 16, 16, 2, 1, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 8, 1, 32>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, v1, F8>,
DeviceGemmMultiD_Wmma_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 64, 128, 128, 16, 16, 16, 16, 4, 1, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, v1, F8>,
DeviceGemmMultiD_Wmma_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 64, 256, 128, 16, 16, 16, 16, 4, 2, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, v1, F8>,
DeviceGemmMultiD_Wmma_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 32, 256, 128, 16, 16, 16, 16, 2, 2, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, v1, F8>,
DeviceGemmMultiD_Wmma_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 128, 16, 16, 16, 16, 8, 1, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, v1, F8>
// clang-format on
>;
template <GemmSpecialization GemmSpec>
using device_gemm_multiply_multiply_weight_preshuffle_wmma_f8_f8_f16_mk_wmma_mn_instances_p2 =
std::tuple<
// clang-format off
//###########################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CDEBlockTransferClusterLengths| CDEBlockTransfer| BlkGemmPipeSched| Block-wiseGemm|
//###########################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock| ScalarPerVector| | Pipeline|
//###########################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NWaveNPerXdl| | Verision|
//###########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmMultiD_Wmma_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 256, 128, 16, 16, 16, 16, 8, 2, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, v1, F8>,
DeviceGemmMultiD_Wmma_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 128, 128, 16, 16, 16, 16, 8, 2, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, v1, F8>,
DeviceGemmMultiD_Wmma_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 128, 128, 16, 16, 16, 16, 16, 1, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, v1, F8>,
DeviceGemmMultiD_Wmma_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 256, 64, 128, 16, 16, 16, 16, 16, 1, S< 8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S< 8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, v1, F8>,
DeviceGemmMultiD_Wmma_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 128, 64, 128, 16, 16, 16, 16, 8, 1, S< 8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S< 8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, v1, F8>
// clang-format on
>;
template <GemmSpecialization GemmSpec>
using device_gemm_multiply_multiply_weight_preshuffle_wmma_f8_f8_f16_mk_wmma_mn_instances_p3 =
std::tuple<
// clang-format off
//###########################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CDEBlockTransferClusterLengths| CDEBlockTransfer| BlkGemmPipeSched| Block-wiseGemm|
//###########################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock| ScalarPerVector| | Pipeline|
//###########################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NWaveNPerXdl| | Verision|
//###########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmMultiD_Wmma_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 32, 128, 256, 16, 16, 16, 16, 2, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, v1, F8>,
DeviceGemmMultiD_Wmma_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 64, 128, 256, 16, 16, 16, 16, 4, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, v1, F8>,
DeviceGemmMultiD_Wmma_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 16, 128, 256, 16, 16, 16, 16, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, v1, F8>,
DeviceGemmMultiD_Wmma_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 16, 256, 256, 16, 16, 16, 16, 1, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, v1, F8>,
DeviceGemmMultiD_Wmma_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 16, 512, 256, 16, 16, 16, 16, 1, 4, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, v1, F8>,
DeviceGemmMultiD_Wmma_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 256, 16, 16, 16, 16, 8, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, v1, F8>,
DeviceGemmMultiD_Wmma_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 128, 64, 256, 16, 16, 16, 16, 8, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, v1, F8>
// clang-format on
>;
template <GemmSpecialization GemmSpec>
using device_gemm_multiply_multiply_weight_preshuffle_wmma_f8_f8_f16_mk_wmma_mn_instances_p4 =
std::tuple<
// clang-format off
//###########################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CDEBlockTransferClusterLengths| CDEBlockTransfer| BlkGemmPipeSched| Block-wiseGemm|
//###########################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock| ScalarPerVector| | Pipeline|
//###########################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NWaveNPerXdl| | Verision|
//###########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmMultiD_Wmma_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 224, 256, 128, 16, 16, 16, 16, 7, 4, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, v1, F8>,
DeviceGemmMultiD_Wmma_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 224, 128, 128, 16, 16, 16, 16, 7, 2, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, v1, F8>,
DeviceGemmMultiD_Wmma_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 224, 64, 128, 16, 16, 16, 16, 7, 1, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, v1, F8>,
DeviceGemmMultiD_Wmma_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 224, 128, 16, 16, 16, 16, 4, 7, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, v1, F8>,
DeviceGemmMultiD_Wmma_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 224, 128, 16, 16, 16, 16, 2, 7, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, v1, F8>,
DeviceGemmMultiD_Wmma_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 64, 224, 128, 16, 16, 16, 16, 1, 7, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, v1, F8>
// clang-format on
>;
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,33 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "device_gemm_multiply_multiply_wp_wmma_f8_f8_f16_mk_wmma_mn.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_multiply_multiply_weight_preshuffle_wmma_f8_f8_f16_mk_wmma_mn_default_instances_p1(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
Col,
Tuple<Row, Col>,
Row,
F8,
F8,
Tuple<F32, F32>,
F16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_multiply_multiply_weight_preshuffle_wmma_f8_f8_f16_mk_wmma_mn_instances_p1<
GemmDefault>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,33 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "device_gemm_multiply_multiply_wp_wmma_f8_f8_f16_mk_wmma_mn.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_multiply_multiply_weight_preshuffle_wmma_f8_f8_f16_mk_wmma_mn_default_instances_p2(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
Col,
Tuple<Row, Col>,
Row,
F8,
F8,
Tuple<F32, F32>,
F16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_multiply_multiply_weight_preshuffle_wmma_f8_f8_f16_mk_wmma_mn_instances_p2<
GemmDefault>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,33 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "device_gemm_multiply_multiply_wp_wmma_f8_f8_f16_mk_wmma_mn.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_multiply_multiply_weight_preshuffle_wmma_f8_f8_f16_mk_wmma_mn_default_instances_p3(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
Col,
Tuple<Row, Col>,
Row,
F8,
F8,
Tuple<F32, F32>,
F16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_multiply_multiply_weight_preshuffle_wmma_f8_f8_f16_mk_wmma_mn_instances_p3<
GemmDefault>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,33 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "device_gemm_multiply_multiply_wp_wmma_f8_f8_f16_mk_wmma_mn.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_multiply_multiply_weight_preshuffle_wmma_f8_f8_f16_mk_wmma_mn_default_instances_p4(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
Col,
Tuple<Row, Col>,
Row,
F8,
F8,
Tuple<F32, F32>,
F16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_multiply_multiply_weight_preshuffle_wmma_f8_f8_f16_mk_wmma_mn_instances_p4<
GemmDefault>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck