mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-13 01:36:06 +00:00
Wmma support for gemm_ab_scale (#3314)
* Support gemm_ab_scale: - Add tests - Integrate scaling implementation in multiple D - Generalize existing b_scale for ab_scale - Add instances - Generalize implementation for ScaleBlockM, ScaleBlockN, ScaleBlockK - Add support for all layouts supported by xdl - Fix splitk xdl * Fix copyright * Wmma support for gemm_blockscale_wp (#3315) * Support for preshuffle with ab scale - add support for b preshuffle in GridwiseGemm_wmma_cshuffle_v3_ab_scale - add support for AScaleLayout amnd BScaleLayout (can be different from ALayout and BLayout, respectively) - add Run method in v1 pipeline to support preshuffle + scaling - add support for preshuffle gemms in common invoker - Add splitk support * Fix copyright header
This commit is contained in:
@@ -16,7 +16,231 @@ namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
#if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_FP8))
|
||||
#ifdef CK_USE_WMMA_FP8
|
||||
// Row, Col
|
||||
void add_device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_1_128_128_comp_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScaleSplitK<Row,
|
||||
Col,
|
||||
Tuple<>,
|
||||
Row,
|
||||
F8,
|
||||
F32,
|
||||
F8,
|
||||
F32,
|
||||
Tuple<>,
|
||||
BF16,
|
||||
1,
|
||||
128,
|
||||
128,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_1_128_128_comp_kpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScaleSplitK<Row,
|
||||
Col,
|
||||
Tuple<>,
|
||||
Row,
|
||||
F8,
|
||||
F32,
|
||||
F8,
|
||||
F32,
|
||||
Tuple<>,
|
||||
BF16,
|
||||
1,
|
||||
128,
|
||||
128,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_1_128_128_mem_v1_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScaleSplitK<Row,
|
||||
Col,
|
||||
Tuple<>,
|
||||
Row,
|
||||
F8,
|
||||
F32,
|
||||
F8,
|
||||
F32,
|
||||
Tuple<>,
|
||||
BF16,
|
||||
1,
|
||||
128,
|
||||
128,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_1_128_128_mem_v1_kpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScaleSplitK<Row,
|
||||
Col,
|
||||
Tuple<>,
|
||||
Row,
|
||||
F8,
|
||||
F32,
|
||||
F8,
|
||||
F32,
|
||||
Tuple<>,
|
||||
BF16,
|
||||
1,
|
||||
128,
|
||||
128,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
// Row, Row
|
||||
void add_device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_1_128_128_comp_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScaleSplitK<Row,
|
||||
Row,
|
||||
Tuple<>,
|
||||
Row,
|
||||
F8,
|
||||
F32,
|
||||
F8,
|
||||
F32,
|
||||
Tuple<>,
|
||||
BF16,
|
||||
1,
|
||||
128,
|
||||
128,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_1_128_128_comp_kpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScaleSplitK<Row,
|
||||
Row,
|
||||
Tuple<>,
|
||||
Row,
|
||||
F8,
|
||||
F32,
|
||||
F8,
|
||||
F32,
|
||||
Tuple<>,
|
||||
BF16,
|
||||
1,
|
||||
128,
|
||||
128,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_1_128_128_mem_v1_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScaleSplitK<Row,
|
||||
Row,
|
||||
Tuple<>,
|
||||
Row,
|
||||
F8,
|
||||
F32,
|
||||
F8,
|
||||
F32,
|
||||
Tuple<>,
|
||||
BF16,
|
||||
1,
|
||||
128,
|
||||
128,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_1_128_128_mem_v1_kpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScaleSplitK<Row,
|
||||
Row,
|
||||
Tuple<>,
|
||||
Row,
|
||||
F8,
|
||||
F32,
|
||||
F8,
|
||||
F32,
|
||||
Tuple<>,
|
||||
BF16,
|
||||
1,
|
||||
128,
|
||||
128,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
// Col, Row
|
||||
void add_device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_1_128_128_comp_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScaleSplitK<Col,
|
||||
Row,
|
||||
Tuple<>,
|
||||
Row,
|
||||
F8,
|
||||
F32,
|
||||
F8,
|
||||
F32,
|
||||
Tuple<>,
|
||||
BF16,
|
||||
1,
|
||||
128,
|
||||
128,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_1_128_128_comp_kpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScaleSplitK<Col,
|
||||
Row,
|
||||
Tuple<>,
|
||||
Row,
|
||||
F8,
|
||||
F32,
|
||||
F8,
|
||||
F32,
|
||||
Tuple<>,
|
||||
BF16,
|
||||
1,
|
||||
128,
|
||||
128,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_1_128_128_mem_v1_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScaleSplitK<Col,
|
||||
Row,
|
||||
Tuple<>,
|
||||
Row,
|
||||
F8,
|
||||
F32,
|
||||
F8,
|
||||
F32,
|
||||
Tuple<>,
|
||||
BF16,
|
||||
1,
|
||||
128,
|
||||
128,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_1_128_128_mem_v1_kpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScaleSplitK<Col,
|
||||
Row,
|
||||
Tuple<>,
|
||||
Row,
|
||||
F8,
|
||||
F32,
|
||||
F8,
|
||||
F32,
|
||||
Tuple<>,
|
||||
BF16,
|
||||
1,
|
||||
128,
|
||||
128,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_USE_XDL
|
||||
// Row, Col
|
||||
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_comp_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Row,
|
||||
@@ -236,6 +460,7 @@ void add_device_gemm_ab_scale_xdl_f8_f8_bf16_km_kn_mn_1_128_128_mem_v1_kpadding_
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#endif
|
||||
|
||||
template <typename A0DataType,
|
||||
typename A1DataType,
|
||||
@@ -245,23 +470,124 @@ template <typename A0DataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout>
|
||||
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMultipleD_ABScale<
|
||||
ALayout,
|
||||
BLayout,
|
||||
Tuple<>,
|
||||
CLayout,
|
||||
A0DataType,
|
||||
A1DataType,
|
||||
B0DataType,
|
||||
B1DataType,
|
||||
Tuple<>,
|
||||
CDataType,
|
||||
1,
|
||||
128,
|
||||
128,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough>>
|
||||
struct DeviceOperationInstanceFactory<
|
||||
ck::tensor_operation::device::DeviceGemmMultipleD_ABScaleSplitK<ALayout,
|
||||
BLayout,
|
||||
Tuple<>,
|
||||
CLayout,
|
||||
A0DataType,
|
||||
A1DataType,
|
||||
B0DataType,
|
||||
B1DataType,
|
||||
Tuple<>,
|
||||
CDataType,
|
||||
1,
|
||||
128,
|
||||
128,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>
|
||||
{
|
||||
using DeviceOp = DeviceGemmMultipleD_ABScaleSplitK<ALayout,
|
||||
BLayout,
|
||||
Tuple<>,
|
||||
CLayout,
|
||||
A0DataType,
|
||||
A1DataType,
|
||||
B0DataType,
|
||||
B1DataType,
|
||||
Tuple<>,
|
||||
CDataType,
|
||||
1,
|
||||
128,
|
||||
128,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>;
|
||||
|
||||
static auto GetInstances()
|
||||
{
|
||||
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
|
||||
|
||||
#if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_FP8))
|
||||
#ifdef CK_USE_XDL
|
||||
// No XDL instances for DeviceGemmMultipleABDSplitK with Add at the moment
|
||||
#endif
|
||||
#ifdef CK_USE_WMMA_FP8
|
||||
if constexpr(is_same_v<A0DataType, f8_t> && is_same_v<B0DataType, f8_t> &&
|
||||
is_same_v<CDataType, bhalf_t>)
|
||||
{
|
||||
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_1_128_128_comp_default_instances(
|
||||
op_ptrs);
|
||||
add_device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_1_128_128_comp_kpadding_instances(
|
||||
op_ptrs);
|
||||
|
||||
add_device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_1_128_128_mem_v1_default_instances(
|
||||
op_ptrs);
|
||||
add_device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_1_128_128_mem_v1_kpadding_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_1_128_128_comp_default_instances(
|
||||
op_ptrs);
|
||||
add_device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_1_128_128_comp_kpadding_instances(
|
||||
op_ptrs);
|
||||
|
||||
add_device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_1_128_128_mem_v1_default_instances(
|
||||
op_ptrs);
|
||||
add_device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_1_128_128_mem_v1_kpadding_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_1_128_128_comp_default_instances(
|
||||
op_ptrs);
|
||||
add_device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_1_128_128_comp_kpadding_instances(
|
||||
op_ptrs);
|
||||
|
||||
add_device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_1_128_128_mem_v1_default_instances(
|
||||
op_ptrs);
|
||||
add_device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_1_128_128_mem_v1_kpadding_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
return op_ptrs;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename A0DataType,
|
||||
typename A1DataType,
|
||||
typename B0DataType,
|
||||
typename B1DataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout>
|
||||
struct DeviceOperationInstanceFactory<
|
||||
ck::tensor_operation::device::DeviceGemmMultipleD_ABScale<ALayout,
|
||||
BLayout,
|
||||
Tuple<>,
|
||||
CLayout,
|
||||
A0DataType,
|
||||
A1DataType,
|
||||
B0DataType,
|
||||
B1DataType,
|
||||
Tuple<>,
|
||||
CDataType,
|
||||
1,
|
||||
128,
|
||||
128,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>
|
||||
{
|
||||
using DeviceOp = DeviceGemmMultipleD_ABScale<ALayout,
|
||||
BLayout,
|
||||
@@ -276,15 +602,16 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
|
||||
1,
|
||||
128,
|
||||
128,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough>;
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>;
|
||||
|
||||
static auto GetInstances()
|
||||
{
|
||||
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
|
||||
|
||||
#if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_FP8))
|
||||
#ifdef CK_USE_XDL
|
||||
if constexpr(is_same_v<A0DataType, f8_t> && is_same_v<B0DataType, f8_t> &&
|
||||
is_same_v<CDataType, bhalf_t>)
|
||||
{
|
||||
@@ -328,6 +655,33 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef CK_USE_WMMA_FP8
|
||||
using Wrapper = DeviceGemmMultipleD_ABScaleSplitKWrapper<ALayout,
|
||||
BLayout,
|
||||
Tuple<>,
|
||||
CLayout,
|
||||
A0DataType,
|
||||
A1DataType,
|
||||
B0DataType,
|
||||
B1DataType,
|
||||
Tuple<>,
|
||||
CDataType,
|
||||
1,
|
||||
128,
|
||||
128,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>;
|
||||
|
||||
auto new_op_ptrs =
|
||||
DeviceOperationInstanceFactory<typename Wrapper::DeviceOp>::GetInstances();
|
||||
for(auto& op_ptr : new_op_ptrs)
|
||||
{
|
||||
op_ptrs.emplace_back(std::make_unique<Wrapper>(std::move(op_ptr)));
|
||||
}
|
||||
#endif // CK_USE_WMMA_FP8
|
||||
#endif
|
||||
return op_ptrs;
|
||||
}
|
||||
|
||||
@@ -17,6 +17,47 @@ namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
#if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_FP8))
|
||||
#if(defined(CK_USE_WMMA) && defined(CK_USE_WMMA_FP8))
|
||||
void add_device_gemm_blockscale_wp_wmma_f8_f8_bf16_mk_nk_mn_1_128_128_comp_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleD_BlockScale_BPreshuffleSplitK<Row,
|
||||
Col,
|
||||
Tuple<>,
|
||||
Row,
|
||||
F8,
|
||||
F32,
|
||||
F8,
|
||||
F32,
|
||||
Tuple<>,
|
||||
BF16,
|
||||
1,
|
||||
128,
|
||||
128,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_blockscale_wp_wmma_f8_f8_bf16_mk_nk_mn_1_128_128_mem_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleD_BlockScale_BPreshuffleSplitK<Row,
|
||||
Col,
|
||||
Tuple<>,
|
||||
Row,
|
||||
F8,
|
||||
F32,
|
||||
F8,
|
||||
F32,
|
||||
Tuple<>,
|
||||
BF16,
|
||||
1,
|
||||
128,
|
||||
128,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>&
|
||||
instances);
|
||||
#endif // CK_USE_WMMA && CK_USE_WMMA_FP8
|
||||
|
||||
#ifdef CK_USE_XDL
|
||||
void add_device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_comp_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleD_BlockScale_BPreshuffle<Row,
|
||||
Col,
|
||||
@@ -93,6 +134,82 @@ void add_device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_mem_v1_kpad
|
||||
PassThrough>>>&
|
||||
instances);
|
||||
#endif
|
||||
#endif
|
||||
|
||||
template <typename A0DataType,
|
||||
typename A1DataType,
|
||||
typename B0DataType,
|
||||
typename B1DataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout>
|
||||
struct DeviceOperationInstanceFactory<
|
||||
ck::tensor_operation::device::DeviceGemmMultipleD_BlockScale_BPreshuffleSplitK<
|
||||
ALayout,
|
||||
BLayout,
|
||||
Tuple<>,
|
||||
CLayout,
|
||||
A0DataType,
|
||||
A1DataType,
|
||||
B0DataType,
|
||||
B1DataType,
|
||||
Tuple<>,
|
||||
CDataType,
|
||||
1,
|
||||
128,
|
||||
128,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough>>
|
||||
{
|
||||
using DeviceOp = DeviceGemmMultipleD_BlockScale_BPreshuffleSplitK<
|
||||
ALayout,
|
||||
BLayout,
|
||||
Tuple<>,
|
||||
CLayout,
|
||||
A0DataType,
|
||||
A1DataType,
|
||||
B0DataType,
|
||||
B1DataType,
|
||||
Tuple<>,
|
||||
CDataType,
|
||||
1,
|
||||
128,
|
||||
128,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough>;
|
||||
|
||||
static auto GetInstances()
|
||||
{
|
||||
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
|
||||
|
||||
#if defined(CK_USE_XDL)
|
||||
// No XDL instances for DeviceGemmMultipleD_BlockScale_BPreshuffleSplitK at the moment
|
||||
#endif // CK_USE_XDL
|
||||
|
||||
#if(defined(CK_USE_WMMA) && defined(CK_USE_WMMA_FP8))
|
||||
#if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_FP8))
|
||||
if constexpr(is_same_v<A0DataType, f8_t> && is_same_v<B0DataType, f8_t> &&
|
||||
is_same_v<CDataType, bhalf_t>)
|
||||
{
|
||||
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_gemm_blockscale_wp_wmma_f8_f8_bf16_mk_nk_mn_1_128_128_comp_default_instances(
|
||||
op_ptrs);
|
||||
|
||||
add_device_gemm_blockscale_wp_wmma_f8_f8_bf16_mk_nk_mn_1_128_128_mem_default_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#endif // CK_USE_WMMA && CK_USE_WMMA_FP8
|
||||
|
||||
return op_ptrs;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename A0DataType,
|
||||
typename A1DataType,
|
||||
@@ -143,6 +260,7 @@ struct DeviceOperationInstanceFactory<
|
||||
{
|
||||
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
|
||||
|
||||
#ifdef CK_USE_XDL
|
||||
#if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_FP8))
|
||||
if constexpr(is_same_v<A0DataType, f8_t> && is_same_v<B0DataType, f8_t> &&
|
||||
is_same_v<CDataType, bhalf_t>)
|
||||
@@ -162,6 +280,35 @@ struct DeviceOperationInstanceFactory<
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#if defined(CK_USE_WMMA)
|
||||
// Reuse DeviceGemmMultipleD_BlockScale_BPreshuffleSplitK instances
|
||||
using Wrapper = DeviceGemmMultipleD_BlockScale_BPreshuffleWrapper<
|
||||
ALayout,
|
||||
BLayout,
|
||||
Tuple<>,
|
||||
CLayout,
|
||||
A0DataType,
|
||||
A1DataType,
|
||||
B0DataType,
|
||||
B1DataType,
|
||||
Tuple<>,
|
||||
CDataType,
|
||||
1,
|
||||
128,
|
||||
128,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough>;
|
||||
auto new_op_ptrs =
|
||||
DeviceOperationInstanceFactory<typename Wrapper::DeviceOp>::GetInstances();
|
||||
for(auto& op_ptr : new_op_ptrs)
|
||||
{
|
||||
op_ptrs.emplace_back(std::make_unique<Wrapper>(std::move(op_ptr)));
|
||||
}
|
||||
#endif // CK_USE_WMMA
|
||||
|
||||
return op_ptrs;
|
||||
}
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user