Wmma support for gemm_ab_scale (#3314)

* Support gemm_ab_scale:

 - Add tests
 - Integrate scaling implementation in multiple D
 - Generalize existing b_scale for ab_scale
 - Add instances
 - Generalize implementation for ScaleBlockM, ScaleBlockN, ScaleBlockK
 - Add support for all layouts supported by xdl
 - Fix splitk xdl

* Fix copyright

* Wmma support for gemm_blockscale_wp (#3315)

* Support for  preshuffle with ab scale

 - add support for b preshuffle in GridwiseGemm_wmma_cshuffle_v3_ab_scale
 - add support for AScaleLayout amnd BScaleLayout (can be different
   from ALayout and BLayout, respectively)
 - add Run method in v1 pipeline to support preshuffle + scaling
 - add support for preshuffle gemms in common invoker
 - Add splitk support

* Fix copyright header
This commit is contained in:
Enrico Degregori
2025-12-11 09:06:20 +01:00
committed by GitHub
parent d66e5f667c
commit ce99cab605
51 changed files with 5144 additions and 552 deletions

View File

@@ -16,7 +16,231 @@ namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
#if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_FP8))
#ifdef CK_USE_WMMA_FP8
// Row, Col
void add_device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_1_128_128_comp_default_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScaleSplitK<Row,
Col,
Tuple<>,
Row,
F8,
F32,
F8,
F32,
Tuple<>,
BF16,
1,
128,
128,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_1_128_128_comp_kpadding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScaleSplitK<Row,
Col,
Tuple<>,
Row,
F8,
F32,
F8,
F32,
Tuple<>,
BF16,
1,
128,
128,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_1_128_128_mem_v1_default_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScaleSplitK<Row,
Col,
Tuple<>,
Row,
F8,
F32,
F8,
F32,
Tuple<>,
BF16,
1,
128,
128,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_1_128_128_mem_v1_kpadding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScaleSplitK<Row,
Col,
Tuple<>,
Row,
F8,
F32,
F8,
F32,
Tuple<>,
BF16,
1,
128,
128,
PassThrough,
PassThrough,
PassThrough>>>& instances);
// Row, Row
void add_device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_1_128_128_comp_default_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScaleSplitK<Row,
Row,
Tuple<>,
Row,
F8,
F32,
F8,
F32,
Tuple<>,
BF16,
1,
128,
128,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_1_128_128_comp_kpadding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScaleSplitK<Row,
Row,
Tuple<>,
Row,
F8,
F32,
F8,
F32,
Tuple<>,
BF16,
1,
128,
128,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_1_128_128_mem_v1_default_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScaleSplitK<Row,
Row,
Tuple<>,
Row,
F8,
F32,
F8,
F32,
Tuple<>,
BF16,
1,
128,
128,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_1_128_128_mem_v1_kpadding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScaleSplitK<Row,
Row,
Tuple<>,
Row,
F8,
F32,
F8,
F32,
Tuple<>,
BF16,
1,
128,
128,
PassThrough,
PassThrough,
PassThrough>>>& instances);
// Col, Row
void add_device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_1_128_128_comp_default_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScaleSplitK<Col,
Row,
Tuple<>,
Row,
F8,
F32,
F8,
F32,
Tuple<>,
BF16,
1,
128,
128,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_1_128_128_comp_kpadding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScaleSplitK<Col,
Row,
Tuple<>,
Row,
F8,
F32,
F8,
F32,
Tuple<>,
BF16,
1,
128,
128,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_1_128_128_mem_v1_default_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScaleSplitK<Col,
Row,
Tuple<>,
Row,
F8,
F32,
F8,
F32,
Tuple<>,
BF16,
1,
128,
128,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_1_128_128_mem_v1_kpadding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScaleSplitK<Col,
Row,
Tuple<>,
Row,
F8,
F32,
F8,
F32,
Tuple<>,
BF16,
1,
128,
128,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_USE_XDL
// Row, Col
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_comp_default_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Row,
@@ -236,6 +460,7 @@ void add_device_gemm_ab_scale_xdl_f8_f8_bf16_km_kn_mn_1_128_128_mem_v1_kpadding_
PassThrough,
PassThrough>>>& instances);
#endif
#endif
template <typename A0DataType,
typename A1DataType,
@@ -245,23 +470,124 @@ template <typename A0DataType,
typename ALayout,
typename BLayout,
typename CLayout>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMultipleD_ABScale<
ALayout,
BLayout,
Tuple<>,
CLayout,
A0DataType,
A1DataType,
B0DataType,
B1DataType,
Tuple<>,
CDataType,
1,
128,
128,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>>
struct DeviceOperationInstanceFactory<
ck::tensor_operation::device::DeviceGemmMultipleD_ABScaleSplitK<ALayout,
BLayout,
Tuple<>,
CLayout,
A0DataType,
A1DataType,
B0DataType,
B1DataType,
Tuple<>,
CDataType,
1,
128,
128,
PassThrough,
PassThrough,
PassThrough>>
{
using DeviceOp = DeviceGemmMultipleD_ABScaleSplitK<ALayout,
BLayout,
Tuple<>,
CLayout,
A0DataType,
A1DataType,
B0DataType,
B1DataType,
Tuple<>,
CDataType,
1,
128,
128,
PassThrough,
PassThrough,
PassThrough>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_FP8))
#ifdef CK_USE_XDL
// No XDL instances for DeviceGemmMultipleABDSplitK with Add at the moment
#endif
#ifdef CK_USE_WMMA_FP8
if constexpr(is_same_v<A0DataType, f8_t> && is_same_v<B0DataType, f8_t> &&
is_same_v<CDataType, bhalf_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_1_128_128_comp_default_instances(
op_ptrs);
add_device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_1_128_128_comp_kpadding_instances(
op_ptrs);
add_device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_1_128_128_mem_v1_default_instances(
op_ptrs);
add_device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_1_128_128_mem_v1_kpadding_instances(
op_ptrs);
}
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_1_128_128_comp_default_instances(
op_ptrs);
add_device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_1_128_128_comp_kpadding_instances(
op_ptrs);
add_device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_1_128_128_mem_v1_default_instances(
op_ptrs);
add_device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_1_128_128_mem_v1_kpadding_instances(
op_ptrs);
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_1_128_128_comp_default_instances(
op_ptrs);
add_device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_1_128_128_comp_kpadding_instances(
op_ptrs);
add_device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_1_128_128_mem_v1_default_instances(
op_ptrs);
add_device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_1_128_128_mem_v1_kpadding_instances(
op_ptrs);
}
}
#endif
#endif
return op_ptrs;
}
};
template <typename A0DataType,
typename A1DataType,
typename B0DataType,
typename B1DataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename CLayout>
struct DeviceOperationInstanceFactory<
ck::tensor_operation::device::DeviceGemmMultipleD_ABScale<ALayout,
BLayout,
Tuple<>,
CLayout,
A0DataType,
A1DataType,
B0DataType,
B1DataType,
Tuple<>,
CDataType,
1,
128,
128,
PassThrough,
PassThrough,
PassThrough>>
{
using DeviceOp = DeviceGemmMultipleD_ABScale<ALayout,
BLayout,
@@ -276,15 +602,16 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
1,
128,
128,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>;
PassThrough,
PassThrough,
PassThrough>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_FP8))
#ifdef CK_USE_XDL
if constexpr(is_same_v<A0DataType, f8_t> && is_same_v<B0DataType, f8_t> &&
is_same_v<CDataType, bhalf_t>)
{
@@ -328,6 +655,33 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
op_ptrs);
}
}
#endif
#ifdef CK_USE_WMMA_FP8
using Wrapper = DeviceGemmMultipleD_ABScaleSplitKWrapper<ALayout,
BLayout,
Tuple<>,
CLayout,
A0DataType,
A1DataType,
B0DataType,
B1DataType,
Tuple<>,
CDataType,
1,
128,
128,
PassThrough,
PassThrough,
PassThrough>;
auto new_op_ptrs =
DeviceOperationInstanceFactory<typename Wrapper::DeviceOp>::GetInstances();
for(auto& op_ptr : new_op_ptrs)
{
op_ptrs.emplace_back(std::make_unique<Wrapper>(std::move(op_ptr)));
}
#endif // CK_USE_WMMA_FP8
#endif
return op_ptrs;
}

View File

@@ -17,6 +17,47 @@ namespace tensor_operation {
namespace device {
namespace instance {
#if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_FP8))
#if(defined(CK_USE_WMMA) && defined(CK_USE_WMMA_FP8))
void add_device_gemm_blockscale_wp_wmma_f8_f8_bf16_mk_nk_mn_1_128_128_comp_default_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD_BlockScale_BPreshuffleSplitK<Row,
Col,
Tuple<>,
Row,
F8,
F32,
F8,
F32,
Tuple<>,
BF16,
1,
128,
128,
PassThrough,
PassThrough,
PassThrough>>>&
instances);
void add_device_gemm_blockscale_wp_wmma_f8_f8_bf16_mk_nk_mn_1_128_128_mem_default_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD_BlockScale_BPreshuffleSplitK<Row,
Col,
Tuple<>,
Row,
F8,
F32,
F8,
F32,
Tuple<>,
BF16,
1,
128,
128,
PassThrough,
PassThrough,
PassThrough>>>&
instances);
#endif // CK_USE_WMMA && CK_USE_WMMA_FP8
#ifdef CK_USE_XDL
void add_device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_comp_default_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD_BlockScale_BPreshuffle<Row,
Col,
@@ -93,6 +134,82 @@ void add_device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_mem_v1_kpad
PassThrough>>>&
instances);
#endif
#endif
template <typename A0DataType,
typename A1DataType,
typename B0DataType,
typename B1DataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename CLayout>
struct DeviceOperationInstanceFactory<
ck::tensor_operation::device::DeviceGemmMultipleD_BlockScale_BPreshuffleSplitK<
ALayout,
BLayout,
Tuple<>,
CLayout,
A0DataType,
A1DataType,
B0DataType,
B1DataType,
Tuple<>,
CDataType,
1,
128,
128,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>>
{
using DeviceOp = DeviceGemmMultipleD_BlockScale_BPreshuffleSplitK<
ALayout,
BLayout,
Tuple<>,
CLayout,
A0DataType,
A1DataType,
B0DataType,
B1DataType,
Tuple<>,
CDataType,
1,
128,
128,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#if defined(CK_USE_XDL)
// No XDL instances for DeviceGemmMultipleD_BlockScale_BPreshuffleSplitK at the moment
#endif // CK_USE_XDL
#if(defined(CK_USE_WMMA) && defined(CK_USE_WMMA_FP8))
#if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_FP8))
if constexpr(is_same_v<A0DataType, f8_t> && is_same_v<B0DataType, f8_t> &&
is_same_v<CDataType, bhalf_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_blockscale_wp_wmma_f8_f8_bf16_mk_nk_mn_1_128_128_comp_default_instances(
op_ptrs);
add_device_gemm_blockscale_wp_wmma_f8_f8_bf16_mk_nk_mn_1_128_128_mem_default_instances(
op_ptrs);
}
}
#endif
#endif // CK_USE_WMMA && CK_USE_WMMA_FP8
return op_ptrs;
}
};
template <typename A0DataType,
typename A1DataType,
@@ -143,6 +260,7 @@ struct DeviceOperationInstanceFactory<
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#ifdef CK_USE_XDL
#if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_FP8))
if constexpr(is_same_v<A0DataType, f8_t> && is_same_v<B0DataType, f8_t> &&
is_same_v<CDataType, bhalf_t>)
@@ -162,6 +280,35 @@ struct DeviceOperationInstanceFactory<
}
}
#endif
#endif
#if defined(CK_USE_WMMA)
// Reuse DeviceGemmMultipleD_BlockScale_BPreshuffleSplitK instances
using Wrapper = DeviceGemmMultipleD_BlockScale_BPreshuffleWrapper<
ALayout,
BLayout,
Tuple<>,
CLayout,
A0DataType,
A1DataType,
B0DataType,
B1DataType,
Tuple<>,
CDataType,
1,
128,
128,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>;
auto new_op_ptrs =
DeviceOperationInstanceFactory<typename Wrapper::DeviceOp>::GetInstances();
for(auto& op_ptr : new_op_ptrs)
{
op_ptrs.emplace_back(std::make_unique<Wrapper>(std::move(op_ptr)));
}
#endif // CK_USE_WMMA
return op_ptrs;
}
};