From 7c0e29cc0f6f60ab66b48e324b2481d167722dd9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Thu, 15 May 2025 16:21:34 +0200 Subject: [PATCH 001/147] Extend 64x64 with 4 waves instances for grouped conv bwd wei (#2187) * Extend 64x64 with 4 waves instnaces for grouped conv bwd wei * Fix * fix * fix --- ...conv_bwd_weight_two_stage_xdl_instance.hpp | 29 ++++++++++++++++--- ...e_grouped_conv_bwd_weight_xdl_instance.hpp | 7 ++++- 2 files changed, 31 insertions(+), 5 deletions(-) diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp index 1c4dc8a445..0ed12b984b 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp @@ -72,7 +72,14 @@ using device_grouped_conv_bwd_weight_two_stage_nhwgc_xdl_c_shuffle_f16_instances DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 32, 32, 8, 32, 32, 1, 1, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 2>, DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 32, 32, 8, 32, 32, 2, 1, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 4>, - DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 128, 32, 32, 8, 32, 32, 4, 1, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 8> + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 128, 32, 32, 8, 32, 32, 4, 1, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 8>, + + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 64, 64, 8, 32, 32, 1, 1, S<8, 8, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 8, 8, false, S<8, 8, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 8, 8, false, 1, 1, S<1, 16, 1, 16>, 4, Scheduler, PipelineVersion, 1>, + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 64, 64, 8, 32, 32, 1, 1, S<8, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 4, 8, false, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 1, 8, false, 1, 1, S<1, 4, 1, 64>, 1, Scheduler, PipelineVersion, 1>, + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 64, 64, 8, 32, 32, 1, 1, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 1, 8, false, S<8, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 4, 8, false, 1, 1, S<1, 16, 1, 16>, 4, Scheduler, PipelineVersion, 1>, + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 64, 64, 8, 32, 32, 1, 1, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 2, 8, false, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 2, 8, false, 1, 1, S<1, 8, 1, 32>, 2, Scheduler, PipelineVersion, 1>, + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 64, 64, 8, 32, 32, 1, 1, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 1, 8, false, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 1, 8, false, 1, 1, S<1, 4, 1, 64>, 1, Scheduler, PipelineVersion, 1> + // clang-format on >; @@ -138,7 +145,13 @@ using device_grouped_conv_bwd_weight_two_stage_nhwgc_xdl_c_shuffle_bf16_instance DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 32, 32, 8, 32, 32, 1, 1, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 2>, DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 32, 32, 8, 32, 32, 2, 1, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 4>, - DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 128, 32, 32, 8, 32, 32, 4, 1, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 8> + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 128, 32, 32, 8, 32, 32, 4, 1, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 8>, + + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 64, 64, 8, 32, 32, 1, 1, S<8, 4, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 8, 8, false, S<8, 4, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 8, 8, false, 1, 1, S<1, 16, 1, 16>, 4, Scheduler, PipelineVersion, 1>, + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 64, 64, 8, 32, 32, 1, 1, S<8, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 4, 8, false, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 1, 8, false, 1, 1, S<1, 4, 1, 64>, 1, Scheduler, PipelineVersion, 1>, + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 64, 64, 8, 32, 32, 1, 1, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 1, 8, false, S<8, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 4, 8, false, 1, 1, S<1, 16, 1, 16>, 4, Scheduler, PipelineVersion, 1>, + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 64, 64, 8, 32, 32, 1, 1, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 2, 8, false, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 2, 8, false, 1, 1, S<1, 8, 1, 32>, 2, Scheduler, PipelineVersion, 1>, + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 64, 64, 8, 32, 32, 1, 1, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 1, 8, false, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 1, 8, false, 1, 1, S<1, 4, 1, 64>, 1, Scheduler, PipelineVersion, 1> // clang-format on >; @@ -218,7 +231,11 @@ using device_grouped_conv_bwd_weight_two_stage_ngchw_xdl_c_shuffle_f16_instances DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 128, 32, 8, 32, 32, 1, 4, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, 1, 1, S<1, 4, 1, 8>, 1, Scheduler, PipelineVersion, 8, F16, F16, 8 ,1>, DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 32, 32, 8, 32, 32, 2, 1, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 4, F16, F16, 4, 1>, - DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 128, 32, 32, 8, 32, 32, 4, 1, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, 1, 1, S<1, 8, 1, 4>, 1, Scheduler, PipelineVersion, 8, F16, F16, 8, 1> + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 128, 32, 32, 8, 32, 32, 4, 1, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, 1, 1, S<1, 8, 1, 4>, 1, Scheduler, PipelineVersion, 8, F16, F16, 8, 1>, + + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 64, 64, 8, 32, 32, 1, 1, S<8, 8, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 8, 8, false, S<8, 8, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 8, 8, false, 1, 1, S<1, 16, 1, 16>, 4, Scheduler, PipelineVersion, 1, F16, F16, 4, 4>, + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 64, 64, 8, 32, 32, 1, 1, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 2, 8, false, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 2, 8, false, 1, 1, S<1, 8, 1, 32>, 2, Scheduler, PipelineVersion, 1, F16, F16, 2, 2>, + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 64, 64, 8, 32, 32, 1, 1, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 1, 8, false, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 1, 8, false, 1, 1, S<1, 4, 1, 64>, 1, Scheduler, PipelineVersion, 1, F16, F16, 1, 1> // clang-format on >; @@ -275,7 +292,11 @@ using device_grouped_conv_bwd_weight_two_stage_ngchw_xdl_c_shuffle_bf16_instance DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 128, 32, 8, 32, 32, 1, 4, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, 1, 1, S<1, 4, 1, 8>, 1, Scheduler, PipelineVersion, 8, BF16, BF16, 8 ,1>, DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 32, 32, 8, 32, 32, 2, 1, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 4, BF16, BF16, 4, 1>, - DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 128, 32, 32, 8, 32, 32, 4, 1, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, 1, 1, S<1, 8, 1, 4>, 1, Scheduler, PipelineVersion, 8, BF16, BF16, 8, 1> + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 128, 32, 32, 8, 32, 32, 4, 1, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, 1, 1, S<1, 8, 1, 4>, 1, Scheduler, PipelineVersion, 8, BF16, BF16, 8, 1>, + + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 64, 64, 8, 32, 32, 1, 1, S<8, 8, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 8, 8, false, S<8, 8, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 8, 8, false, 1, 1, S<1, 16, 1, 16>, 4, Scheduler, PipelineVersion, 1, BF16, BF16, 4, 4>, + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 64, 64, 8, 32, 32, 1, 1, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 2, 8, false, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 2, 8, false, 1, 1, S<1, 8, 1, 32>, 2, Scheduler, PipelineVersion, 1, BF16, BF16, 2, 2>, + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 64, 64, 8, 32, 32, 1, 1, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 1, 8, false, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 1, 8, false, 1, 1, S<1, 4, 1, 64>, 1, Scheduler, PipelineVersion, 1, BF16, BF16, 1, 1> // clang-format on >; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp index a493719637..3587570e42 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp @@ -87,7 +87,12 @@ using device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_instances = std::tuple< DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, 1, 1, S<1, 32, 1, 4>, 4, F32, F32, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4, F32, F32, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 16, 1, 4>, 4, F32, F32, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 4>, 4, F32, F32, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector> + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 4>, 4, F32, F32, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 64, 8, 8, 32, 32, 1, 1, S<1, 8, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, 2, 4, 4, true, S<1, 8, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 16>, 4, F32, F32, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 64, 8, 8, 32, 32, 1, 1, S<1, 8, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, 2, 4, 4, true, S<1, 8, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, 2, 1, 4, true, 1, 1, S<1, 4, 1, 64>, 1, F32, F32, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 64, 8, 8, 32, 32, 1, 1, S<1, 8, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, 2, 1, 4, true, S<1, 8, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 16>, 4, F32, F32, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 64, 8, 8, 32, 32, 1, 1, S<1, 8, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, 2, 1, 4, true, S<1, 8, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, 2, 1, 4, true, 1, 1, S<1, 4, 1, 64>, 1, F32, F32, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector> // clang-format on >; From 3d8d6e75e485f5811df0ca37272f119392727726 Mon Sep 17 00:00:00 2001 From: Khushbu Agarwal Date: Thu, 15 May 2025 10:28:31 -0700 Subject: [PATCH 002/147] Adding validation for tile sizes in Tile Engine (#2189) * Adding validation for tile sizes * Add architecture in config, and shuffle lines of code in warp_gemm.hpp * Enable MFMA for gfx950, and invalid tile handling --- include/ck_tile/ops/gemm/warp/warp_gemm.hpp | 26 ++--- .../warp/warp_gemm_attribute_mfma_impl.hpp | 8 +- .../gemm/configs/instance_combination.json | 4 +- tile_engine/ops/gemm/gemm_instance_builder.py | 96 +++++++++++++++---- 4 files changed, 96 insertions(+), 38 deletions(-) diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp index f050a8e382..be5d5690ff 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp @@ -204,14 +204,6 @@ using WarpGemmMfmaBf16Bf16F32M64N4K16 = WarpGemmImpl>>; -using WarpGemmMfma_f32_32x32x32_fp8_fp8 = WarpGemmImpl, - 2>>; - -using WarpGemmMfma_f32_32x32x32_bf8_bf8 = WarpGemmImpl, - 2>>; - using WarpGemmMfma_f32_32x32x16_fp8_bf8 = WarpGemmImpl< WarpGemmAtrributeMfma>>; @@ -221,20 +213,28 @@ using WarpGemmMfma_f32_32x32x16_bf8_fp8 = WarpGemmImpl< using WarpGemmMfma_f32_32x32x16_bf8_bf8 = WarpGemmImpl< WarpGemmAtrributeMfma>>; -using WarpGemmMfma_f32_16x16x64_fp8_fp8 = WarpGemmImpl, +using WarpGemmMfma_f32_32x32x32_fp8_fp8 = WarpGemmImpl, + 2>>; + +using WarpGemmMfma_f32_32x32x32_bf8_bf8 = WarpGemmImpl, 2>>; using WarpGemmMfma_f32_16x16x32_fp8_fp8 = WarpGemmImpl< WarpGemmAtrributeMfma>>; +using WarpGemmMfma_f32_16x16x32_bf8_bf8 = WarpGemmImpl< + WarpGemmAtrributeMfma>>; + +using WarpGemmMfma_f32_16x16x64_fp8_fp8 = WarpGemmImpl, + 2>>; + using WarpGemmMfma_f32_16x16x64_bf8_bf8 = WarpGemmImpl, 2>>; -using WarpGemmMfma_f32_16x16x32_bf8_bf8 = WarpGemmImpl< - WarpGemmAtrributeMfma>>; - using WarpGemmMfma_f32_16x16x128_fp8_fp8 = WarpGemmImpl>>; diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp index 69d22496f1..4bc4884beb 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp @@ -1092,7 +1092,7 @@ struct WarpGemmAttributeMfmaImpl_f32_16x16x32_f8_base } else { -#if defined(__gfx94__) +#if defined(__gfx94__) or defined(__gfx95__) if constexpr(std::is_same_v && std::is_same_v) c_vec = __builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8( bit_cast(a_vec), bit_cast(b_vec), c_vec, 0, 0, 0); @@ -1116,7 +1116,7 @@ struct WarpGemmAttributeMfmaImpl_f32_16x16x32_f8_base // c_vec = a_vec * b_vec CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const { -#if defined(__gfx94__) +#if defined(__gfx94__) or defined(__gfx95__) if constexpr(std::is_same_v && std::is_same_v) return bit_cast(__builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8( bit_cast(a_vec), bit_cast(b_vec), CVecType{0.f}, 0, 0, 0)); @@ -1251,7 +1251,7 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base } else { -#if defined(__gfx94__) +#if defined(__gfx94__) or defined(__gfx95__) if constexpr(std::is_same_v && std::is_same_v) c_vec = __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8( bit_cast(a_vec), bit_cast(b_vec), c_vec, 0, 0, 0); @@ -1286,7 +1286,7 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base // c_vec = a_vec * b_vec CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const { -#if defined(__gfx94__) +#if defined(__gfx94__) or defined(__gfx95__) if constexpr(std::is_same_v && std::is_same_v) return bit_cast(__builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8( bit_cast(a_vec), bit_cast(b_vec), CVecType{0.f}, 0, 0, 0)); diff --git a/tile_engine/ops/gemm/configs/instance_combination.json b/tile_engine/ops/gemm/configs/instance_combination.json index 53197ada6c..b497513efa 100644 --- a/tile_engine/ops/gemm/configs/instance_combination.json +++ b/tile_engine/ops/gemm/configs/instance_combination.json @@ -1,5 +1,7 @@ { - + "architecture": { + "values": ["gfx90a"] + }, "layout_a": { "values": ["r"] }, diff --git a/tile_engine/ops/gemm/gemm_instance_builder.py b/tile_engine/ops/gemm/gemm_instance_builder.py index 3839523e3d..dd8b4d1157 100755 --- a/tile_engine/ops/gemm/gemm_instance_builder.py +++ b/tile_engine/ops/gemm/gemm_instance_builder.py @@ -23,7 +23,39 @@ DATA_TYPE_MAP = {'fp32' : 'float', } LAYOUT_MAP = {'r' : 'ck_tile::tensor_layout::gemm::RowMajor', - 'c' : 'ck_tile::tensor_layout::gemm::ColumnMajor'} + 'c' : 'ck_tile::tensor_layout::gemm::ColumnMajor'} + + +warp_tile_combinations_map = { + "gfx90a": { + 'fp16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]], + 'bf16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]], + 'fp8': [[32, 32, 16], [32, 32, 32]], + 'bf8': [[32, 32, 16], [32, 32, 32]] + }, + "gfx942": { + 'fp16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]], + 'bf16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]], + 'fp8': [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64]], + 'bf8': [[32, 32, 16], [32, 32, 32], [16, 16, 64], [16, 16, 32]] + }, + "gfx950": { + 'fp16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]], + 'bf16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]], + 'fp8': [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64], [16, 16, 128], [32, 32, 64]], + 'bf8': [[32, 32, 16], [32, 32, 32], [16, 16, 64], [16, 16, 32], [16, 16, 128], [32, 32, 64]] + } + } + +def sizeOf(data_type): + if data_type == 'fp16' or data_type == 'bf16': + return 2 + elif data_type == 'int8' or data_type == 'fp8' or data_type == 'bf8': + return 1 + elif data_type == 'int4': ## TODO:: needs to confirm + return 0.5 + else: + return 4 DEFAULT_EPILOGUE = """ using GemmEpilogue = ck_tile::DefaultGemm2DEpilogue< @@ -168,11 +200,15 @@ class GemmConfig: self.matrix_cfg : Dict[str, Any] = {} self.impl_cfg : Dict[str, Any] = {} for key, value in config_data.items(): - if key in ["datatype", "layout_a", "layout_b", "layout_c"]: + if key in ["architecture", "datatype", "layout_a", "layout_b", "layout_c"]: self.matrix_cfg[key] = value else: self.impl_cfg[key] = value + @property + def architecture(self) -> str: + return self.matrix_cfg["architecture"]["values"][0] + @property def datatype(self) -> str: return self.matrix_cfg["datatype"]["values"][0] @@ -201,7 +237,7 @@ class GemmCodeGenerator: def _validate_config(self): """Validate matrix and implementation configurations""" # Matrix config validation - for param in ["datatype", "layout_a", "layout_b", "layout_c"]: + for param in ["architecture", "datatype", "layout_a", "layout_b", "layout_c"]: if len(self.config.matrix_cfg[param]["values"]) != 1: raise ValueError(f"Matrix config {param} must have exactly one value") @@ -327,7 +363,7 @@ namespace {group_name} {{ return f""" template void try_run(ck_tile::TailNumber tn) {{ - if constexpr (Pipeline::PrefetchStages > static_cast(TN)) {{ + if constexpr (Pipeline::PrefetchStages > static_cast(TN) - 1) {{ if (tn == TN) {{ RunSplitk(ck_tile::bool_constant{{}}, ck_tile::integral_constant{{}}); @@ -477,6 +513,30 @@ struct GemmKernel {{ content += f"#include \"gemm_{group}.hpp\"\n" (self.output_dir / "gemm_instances.hpp").write_text(content) + def is_tile_valid(self, tile: tuple, group: str) -> bool: + """Check if the tile configuration is valid for the given group""" + # Extract tile parameters + tile_m, tile_n, tile_k, warp_m, warp_n, warp_k, warp_tile_m, warp_tile_n, warp_tile_k = tile + + # Extract the pipeline and epilogue from the group name + _, pipeline, epilogue, scheduler, *_ = group.split("_") + + if tile_m % (warp_m * warp_tile_m) == 0 and \ + tile_n % (warp_n * warp_tile_n) == 0 and \ + tile_k % (warp_k * warp_tile_k) == 0: + total_tile_in_lds = (tile_m * tile_k + tile_n * tile_k ) * sizeOf(self.config.datatype) + # Validate and append valid tile parameters + is_compv4 = pipeline == "compv4" + max_tile_size = pow(2, 16) if is_compv4 else pow(2, 15) + + if total_tile_in_lds > max_tile_size: + raise ValueError(f'Total tile size should not exceed {max_tile_size / 1024}KB of LDS. ' + f'{tile_m} * {tile_n} * {tile_k} > {max_tile_size / 1024}KB') + arch = self.config.architecture + if [warp_tile_m, warp_tile_n, warp_tile_k] in warp_tile_combinations_map[arch][self.config.datatype]: + return True + return False + def _generate_dispatcher(self): """Generate dispatch mechanism""" content = """// SPDX-License-Identifier: MIT @@ -517,7 +577,7 @@ struct GemmDispatcher { self.config.impl_cfg["warp_tile_k"]["values"] )) - + for group in self.all_kernels: content += f""" kernel_map["{group}"] = [=](ck_tile::DeviceMem& c_m_n_dev_buf, ck_tile::HostTensor& c_m_n_host_result, @@ -526,26 +586,22 @@ struct GemmDispatcher { const ck_tile::stream_config& stream) {{ if(structured_sparsity){{ // SMFMA""" for tile in tile_params: - # Check if we have valid tile/warp combinations - # (tile_m/(warp_m*warp_tile_m)) * warp_m * warp_tile_m == tile_m - if ((tile[0]/(tile[3] * tile[7]) * tile[3] * tile[7]) != tile[0]) or \ - ((tile[1]/(tile[4] * tile[8]) * tile[4] * tile[8]) != tile[1]): - continue - sparse = self.atype == 'fp16' and \ - ((tile[6] == 32 and tile[7] == 32 and tile[8] == 16) or - (tile[6] == 16 and tile[7] == 16 and tile[8] == 32)) - content += f""" + if self.is_tile_valid(tile, group): + sparse = self.atype == 'fp16' and \ + ((tile[6] == 32 and tile[7] == 32 and tile[8] == 16) or + (tile[6] == 16 and tile[7] == 16 and tile[8] == 32)) + content += f""" run_kernel<{group}::GemmKernel<{tile[0]}, {tile[1]}, {tile[2]}, {tile[3]}, {tile[4]}, {tile[5]}, {tile[6]}, {tile[7]}, {tile[8]}, {BOOL_MAP(sparse)}>>(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, args, stream);""" + else: + raise ValueError(f"Invalid tile configuration for group {group}: {tile}") content += f""" }} else {{""" for tile in tile_params: - # Check if we have valid tile/warp combinations - # (tile_m/(warp_m*warp_tile_m)) * warp_m * warp_tile_m == tile_m - if ((tile[0]/(tile[3] * tile[7]) * tile[3] * tile[7]) != tile[0]) or \ - ((tile[1]/(tile[4] * tile[8]) * tile[4] * tile[8]) != tile[1]): - continue - content += f""" + if self.is_tile_valid(tile, group): + content += f""" run_kernel<{group}::GemmKernel<{tile[0]}, {tile[1]}, {tile[2]}, {tile[3]}, {tile[4]}, {tile[5]}, {tile[6]}, {tile[7]}, {tile[8]}, {BOOL_MAP(False)}>>(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, args, stream);""" + else: + raise ValueError(f"Invalid tile configuration for group {group}: {tile}") content += f""" }} }};\n""" From 8cb0474b3d880abe55bca977856a4be104aac337 Mon Sep 17 00:00:00 2001 From: Po Yen Chen Date: Fri, 16 May 2025 02:47:29 +0800 Subject: [PATCH 003/147] Use only qr_async pipeline for batch_prefill (#2195) --- .../ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py b/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py index 30b9299963..76b9429b2e 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py @@ -470,11 +470,10 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask)) else: if bias == "bias": - # TODO: rocm 6.2 compiler problem if using qr_async for bias case - pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr_async', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr_async', 'col', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask)) else: pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask)) pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask)) From 791802b381c99e47966cbf4a987b91ab3d56bcfc Mon Sep 17 00:00:00 2001 From: Po Yen Chen Date: Fri, 16 May 2025 15:14:46 +0800 Subject: [PATCH 004/147] [CK_TILE] fMHA batch_prefill block index & logits soft-capping optimizations (#2198) * Write soft-sign in inline asm * Change tile idx computation * Add macro to turn off soft-sign asm opt * Use simple for loop to avoid register spill * Only do block id transform for masking cases --- include/ck_tile/ops/fmha/block/variants.hpp | 38 ++++++++++++++++--- .../fmha/kernel/fmha_batch_prefill_kernel.hpp | 21 ++++++++-- ..._batch_prefill_pipeline_qr_ks_vs_async.hpp | 13 ++++++- 3 files changed, 63 insertions(+), 9 deletions(-) diff --git a/include/ck_tile/ops/fmha/block/variants.hpp b/include/ck_tile/ops/fmha/block/variants.hpp index 90fc5656fc..d8b0cdbb86 100644 --- a/include/ck_tile/ops/fmha/block/variants.hpp +++ b/include/ck_tile/ops/fmha/block/variants.hpp @@ -15,7 +15,36 @@ #define CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT CK_TILE_ATTENTION_LOGITS_SOFT_CAP_TANH #endif +#ifndef CK_TILE_ATTENTION_USE_SOFTSIGN_ASM +#define CK_TILE_ATTENTION_USE_SOFTSIGN_ASM 0 +#endif + namespace ck_tile { +namespace internal { +__device__ inline float +exp2_soft_sign_impl(float softmax_scale, float logits, float logits_soft_cap_rcp) +{ +#if(defined(__gfx90a__) || defined(__gfx94__)) && \ + (CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_SOFTSIGN && \ + CK_TILE_ATTENTION_USE_SOFTSIGN_ASM) + /// NOTICE: Make sure softmax_scale is stored in SGPR + float result, numerator, denominator; + asm volatile( + "v_mul_f32_e32 %[denominator], %[logits], %[logits_soft_cap_rcp]\n" + "v_add_f32_e64 %[denominator], |%[denominator]|, 1.0\n" + "v_rcp_f32_e32 %[denominator], %[denominator]\n" + "v_mul_f32_e32 %[numerator], %[softmax_scale], %[logits]\n" + "v_mul_f32_e32 %[result], %[numerator], %[denominator]" + : [numerator] "=&v"(numerator), [denominator] "=&v"(denominator), [result] "=v"(result) + : [softmax_scale] "s"(softmax_scale), + [logits] "v"(logits), + [logits_soft_cap_rcp] "v"(logits_soft_cap_rcp)); + return result; +#else + return softmax_scale * logits * rcp(1.f + abs(logits * logits_soft_cap_rcp)); +#endif +} +} // namespace internal template struct StandardAttentionParams @@ -169,8 +198,8 @@ struct LogitsSoftCap return params.logits_soft_cap * tanh_fast(type_convert(logits) * params.logits_soft_cap_rcp); #elif CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_SOFTSIGN - return params.sm_scale * type_convert(logits) * - rcp(1.f + abs(type_convert(logits) * params.logits_soft_cap_rcp)); + return internal::exp2_soft_sign_impl( + params.sm_scale, type_convert(logits), params.logits_soft_cap_rcp); #endif } else @@ -239,9 +268,8 @@ struct ComposedAttention return params.logits_soft_cap * tanh_fast(type_convert(logits) * params.logits_soft_cap_rcp); #elif CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_SOFTSIGN - return params.sm_scale * type_convert(logits) * - rcp(1.f + - abs(type_convert(logits) * params.logits_soft_cap_rcp)); + return internal::exp2_soft_sign_impl( + params.sm_scale, type_convert(logits), params.logits_soft_cap_rcp); #endif } else diff --git a/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp index ba327ee511..7472c82114 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp @@ -651,8 +651,15 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel }; const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1); - - return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch); + if constexpr(kHasMask) + { + // assume that num_tile_n1 is always 1 + return ck_tile::make_tuple(gridDim.z - 1 - i_tile_m, i_tile_n, i_nhead, i_batch); + } + else + { + return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch); + } } else { @@ -672,7 +679,15 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1); - return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch); + if constexpr(kHasMask) + { + // assume that num_tile_n1 is always 1 + return ck_tile::make_tuple(gridDim.x - 1 - i_tile_m, i_tile_n, i_nhead, i_batch); + } + else + { + return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch); + } } } diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp index e07cf1c94e..8691622bb0 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp @@ -6,8 +6,9 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" -#include "ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async_default_policy.hpp" #include "ck_tile/ops/fmha/block/block_dropout.hpp" +#include "ck_tile/ops/fmha/block/variants.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async_default_policy.hpp" #include "ck_tile/ops/reduce/block/block_reduce.hpp" namespace ck_tile { @@ -498,6 +499,16 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync #else for(index_t i = 0; i < s_acc.thread_buf_.size(); ++i) { +#if(defined(__gfx90a__) || defined(__gfx94__)) && \ + (CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_SOFTSIGN && \ + CK_TILE_ATTENTION_USE_SOFTSIGN_ASM) + // Avoid data hazard if v_mfma is followed by inline asm consumer + // instructions. In this case, compiler won't add s_nop for us + if(i == s_acc.thread_buf_.size() / 2) + { + __builtin_amdgcn_sched_barrier(0); + } +#endif apply_logits_transform(s_acc.thread_buf_[i]); } #endif From fa3c6811d8e81096f52779bf0877777bf405d241 Mon Sep 17 00:00:00 2001 From: Mateusz Ozga <110818320+mozga-amd@users.noreply.github.com> Date: Fri, 16 May 2025 10:18:47 +0200 Subject: [PATCH 005/147] Disable conv for Filter1x1Stride1Pad0 when K or C is even (#2186) --- include/ck/ck.hpp | 3 +++ .../device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp | 7 +++++++ .../test_grouped_convnd_bwd_weight.cpp | 1 + 3 files changed, 11 insertions(+) diff --git a/include/ck/ck.hpp b/include/ck/ck.hpp index e38f166c1a..26e4787949 100644 --- a/include/ck/ck.hpp +++ b/include/ck/ck.hpp @@ -222,6 +222,9 @@ // TODO: separate index calculation into "compile-time", "global", "block", "wave", "thread" #define CK_HACK_MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE 0 +// workaround: conv crash when K, C is even +#define CK_WORKAROUND_DISABLE_FILTER1x1STRIDE1PAD0_WHEN_K_C_IS_EVEN 1 + // workaround: compiler crash when compiling recursive lambda #define CK_WORKAROUND_SWDEV_275126 1 diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp index dd5b97096d..869457a99e 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp @@ -1206,6 +1206,13 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 if constexpr(ConvBackwardWeightSpecialization == ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) { +// workaround: disable when K, C is even +#if CK_WORKAROUND_DISABLE_FILTER1x1STRIDE1PAD0_WHEN_K_C_IS_EVEN + if(arg.Conv_C_ % 2 == 0 || arg.Conv_K_ % 2 == 0) + { + return false; + } +#endif // check if it's 1x1, stride=1 pad = 0 conv for(int i = 0; i < NDimSpatial; i++) { diff --git a/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp b/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp index 21f2cb5ce6..95a0a09414 100644 --- a/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp +++ b/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp @@ -188,6 +188,7 @@ TYPED_TEST(TestGroupedConvndBwdWeight1d, Test1D) TYPED_TEST(TestGroupedConvndBwdWeight2d, Test2D) { this->conv_params.clear(); + this->conv_params.push_back({2, 2, 64, 4, 4, {1, 1}, {7, 7}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}); this->conv_params.push_back( {2, 2, 64, 128, 256, {1, 1}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0}}); this->conv_params.push_back({2, 2, 64, 3, 3, {1, 1}, {7, 7}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}); From 40668c9a993ca9391eb628bbb4be3ca3fb4e7e56 Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Fri, 16 May 2025 07:40:53 -0700 Subject: [PATCH 006/147] Build and store CK library deb package for all targets daily. (#2196) * generate and store library package for all targets * use ninja to build packages for all targets * make sure to use ftime-trace when using ninja * make sure build trace only runs on gfx9 * archive lib package and stash only library package --- Jenkinsfile | 135 +++++++++--------- .../gpu/CMakeLists.txt | 2 +- 2 files changed, 67 insertions(+), 70 deletions(-) diff --git a/Jenkinsfile b/Jenkinsfile index 68e0fa1246..c26350f120 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -93,6 +93,30 @@ def build_compiler(){ return compiler } +def check_arch(){ + def arch_type = 0 + sh 'rocminfo | tee rocminfo.log' + if ( runShell('grep -n "gfx90a" rocminfo.log') ){ + arch_type = 1 + } + else if ( runShell('grep -n "gfx942" rocminfo.log') ) { + arch_type = 2 + } + else if ( runShell('grep -n "gfx10" rocminfo.log') ) { + arch_type = 3 + } + else if ( runShell('grep -n "gfx11" rocminfo.log') ) { + arch_type = 4 + } + else if ( runShell('grep -n "gfx12" rocminfo.log') ) { + arch_type = 5 + } + else if ( runShell('grep -n "gfx908" rocminfo.log') ) { + arch_type = 6 + } + return arch_type +} + def getDockerImage(Map conf=[:]){ env.DOCKER_BUILDKIT=1 def prefixpath = conf.get("prefixpath", "/opt/rocm") @@ -287,7 +311,7 @@ def cmake_build(Map conf=[:]){ def build_cmd def execute_cmd = conf.get("execute_cmd", "") if(!setup_args.contains("NO_CK_BUILD")){ - if (setup_args.contains("gfx90a") && params.NINJA_BUILD_TRACE){ + if (setup_args.contains("gfx9") && params.NINJA_BUILD_TRACE){ echo "running ninja build trace" setup_cmd = conf.get("setup_cmd", """${cmake_envs} cmake -G Ninja ${setup_args} -DCMAKE_CXX_FLAGS=" -O3 -ftime-trace " .. """) build_cmd = conf.get("build_cmd", "${build_envs} ninja -j${nt} ${config_targets}") @@ -315,7 +339,7 @@ def cmake_build(Map conf=[:]){ sh cmd //run tests except when NO_CK_BUILD or BUILD_LEGACY_OS are set if(!setup_args.contains("NO_CK_BUILD") && !params.BUILD_LEGACY_OS){ - if (setup_args.contains("gfx90a") && params.NINJA_BUILD_TRACE){ + if ((setup_args.contains("gfx9") && params.NINJA_BUILD_TRACE) || params.BUILD_INSTANCES_ONLY){ sh "/ninjatracing/ninjatracing .ninja_log > ck_build_trace.json" sh "/ClangBuildAnalyzer/build/ClangBuildAnalyzer --all . clang_build.log" sh "/ClangBuildAnalyzer/build/ClangBuildAnalyzer --analyze clang_build.log > clang_build_analysis.log" @@ -323,7 +347,15 @@ def cmake_build(Map conf=[:]){ archiveArtifacts "clang_build_analysis.log" // do not run unit tests when building instances only if(!params.BUILD_INSTANCES_ONLY){ - sh "ninja test" + sh "ninja check" + } + if(params.BUILD_INSTANCES_ONLY){ + // build deb packages + echo "Build packages" + sh 'ninja -j64 package' + archiveArtifacts artifacts: 'composablekernel-dev*.deb' + sh 'mv composablekernel-dev_*.deb composablekernel-dev_all_targets_1.1.0_amd64.deb' + stash includes: "composablekernel-dev_all_targets_1.1.0_amd64.deb", name: "packages" } } else{ @@ -340,21 +372,14 @@ def cmake_build(Map conf=[:]){ archiveArtifacts artifacts: "build/*.deb", allowEmptyArchive: true, fingerprint: true } //check the node gpu architecture - def arch_type = 0 - sh 'rocminfo | tee rocminfo.log' - if ( runShell('grep -n "gfx90a" rocminfo.log') ){ - arch_type = 1 - } - else if ( runShell('grep -n "gfx942" rocminfo.log') ) { - arch_type = 2 - } + def arch = check_arch() if (params.RUN_CK_TILE_FMHA_TESTS){ try{ archiveArtifacts "perf_fmha_*.log" - if (arch_type == 1){ + if (arch == 1){ stash includes: "perf_fmha_**_gfx90a.log", name: "perf_fmha_log_gfx90a" } - else if (arch_type == 2){ + else if (arch == 2){ stash includes: "perf_fmha_**_gfx942.log", name: "perf_fmha_log_gfx942" } } @@ -379,10 +404,10 @@ def cmake_build(Map conf=[:]){ if (params.RUN_CK_TILE_GEMM_TESTS){ try{ archiveArtifacts "perf_tile_gemm_**.log" - if (arch_type == 1){ + if (arch == 1){ stash includes: "perf_tile_gemm_**_gfx90a.log", name: "perf_tile_gemm_log_gfx90a" } - else if (arch_type == 2){ + else if (arch == 2){ stash includes: "perf_tile_gemm_**_gfx942.log", name: "perf_tile_gemm_log_gfx942" } } @@ -410,7 +435,13 @@ def buildHipClangJob(Map conf=[:]){ def prefixpath = conf.get("prefixpath", "/opt/rocm") // Jenkins is complaining about the render group - def dockerOpts="--device=/dev/kfd --device=/dev/dri --group-add video --group-add render --cap-add=SYS_PTRACE --security-opt seccomp=unconfined" + def dockerOpts + if ( params.BUILD_INSTANCES_ONLY ){ + dockerOpts = "--group-add video --group-add render --cap-add=SYS_PTRACE --security-opt seccomp=unconfined" + } + else{ + dockerOpts = "--device=/dev/kfd --device=/dev/dri --group-add video --group-add render --cap-add=SYS_PTRACE --security-opt seccomp=unconfined" + } if (conf.get("enforce_xnack_on", false)) { dockerOpts = dockerOpts + " --env HSA_XNACK=1 " } @@ -521,28 +552,9 @@ def Build_CK(Map conf=[:]){ timeout(time: 20, unit: 'HOURS') { //check whether to run performance tests on this node - def arch_type = 0 - sh 'rocminfo | tee rocminfo.log' - if ( runShell('grep -n "gfx90a" rocminfo.log') ){ - arch_type = 1 - } - else if ( runShell('grep -n "gfx942" rocminfo.log') ) { - arch_type = 2 - } - else if ( runShell('grep -n "gfx10" rocminfo.log') ) { - arch_type = 3 - } - else if ( runShell('grep -n "gfx11" rocminfo.log') ) { - arch_type = 4 - } - else if ( runShell('grep -n "gfx12" rocminfo.log') ) { - arch_type = 5 - } - else if ( runShell('grep -n "gfx908" rocminfo.log') ) { - arch_type = 6 - } + def arch = check_arch() cmake_build(conf) - if ( params.RUN_INDUCTOR_TESTS && !params.BUILD_LEGACY_OS && arch_type == 1 ){ + if ( params.RUN_INDUCTOR_TESTS && !params.BUILD_LEGACY_OS && arch == 1 ){ echo "Run inductor codegen tests" sh """ python3 -m venv ${env.WORKSPACE} @@ -553,9 +565,9 @@ def Build_CK(Map conf=[:]){ """ } dir("build"){ - if (params.RUN_FULL_QA && arch_type == 2 ){ - // build deb packages for all gfx9 targets on gfx90a system and prepare to export - echo "Build ckProfiler package" + if (params.RUN_FULL_QA && arch == 2 ){ + // build deb packages + echo "Build packages" sh 'make -j package' archiveArtifacts artifacts: 'composablekernel*.deb' sh 'mv composablekernel-ckprofiler_*.deb composablekernel-ckprofiler_1.1.0_amd64.deb' @@ -568,7 +580,7 @@ def Build_CK(Map conf=[:]){ // run performance tests, stash the logs, results will be processed on the master node dir("script"){ if (params.RUN_PERFORMANCE_TESTS){ - if (params.RUN_FULL_QA && arch_type == 1){ + if (params.RUN_FULL_QA && arch == 1){ // run full tests on gfx90a echo "Run full performance tests" sh "./run_full_performance_tests.sh 0 QA_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME}" @@ -587,7 +599,7 @@ def Build_CK(Map conf=[:]){ archiveArtifacts "perf_mixed_gemm.log" stash includes: "perf_**.log", name: "perf_log" } - else if ( arch_type == 1 ){ + else if ( arch == 1 ){ // run standard tests on gfx90a echo "Run performance tests" sh "./run_performance_tests.sh 0 CI_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME}" @@ -598,28 +610,28 @@ def Build_CK(Map conf=[:]){ stash includes: "perf_**.log", name: "perf_log" } // disable performance tests on gfx1030 for now. - //else if ( arch_type == 3){ + //else if ( arch == 3){ // run basic tests on gfx1030 // echo "Run gemm performance tests" // sh "./run_gemm_performance_tests.sh 0 CI_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME} gfx10" // archiveArtifacts "perf_onnx_gemm_gfx10.log" // stash includes: "perf_onnx_gemm_gfx10.log", name: "perf_log_gfx10" //} - else if ( arch_type == 4){ + else if ( arch == 4){ // run basic tests on gfx11 echo "Run gemm performance tests" sh "./run_gemm_performance_tests.sh 0 CI_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME} gfx11" archiveArtifacts "perf_onnx_gemm_gfx11.log" stash includes: "perf_onnx_gemm_gfx11.log", name: "perf_log_gfx11" } - else if ( arch_type == 5 ){ + else if ( arch == 5 ){ // run basic tests on gfx12 echo "Run gemm performance tests" sh "./run_gemm_performance_tests.sh 0 CI_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME} gfx12" archiveArtifacts "perf_onnx_gemm_gfx12.log" stash includes: "perf_onnx_gemm_gfx12.log", name: "perf_log_gfx12" } - else if ( arch_type == 6 ){ + else if ( arch == 6 ){ // run basic tests on gfx908 echo "Run performance tests" sh "./run_gemm_performance_tests.sh 0 CI_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME} gfx908" @@ -628,7 +640,7 @@ def Build_CK(Map conf=[:]){ } } } - if (params.hipTensor_test && arch_type == 1 ){ + if (params.hipTensor_test && arch == 1 ){ // build and test hipTensor on gfx90a node sh """#!/bin/bash rm -rf "${params.hipTensor_branch}".zip @@ -730,24 +742,10 @@ def process_results(Map conf=[:]){ echo "could not locate the GEMM performance logs: ${err.getMessage()}." } } - if (params.RUN_FULL_QA){ - // unstash perf files to master + if (params.RUN_FULL_QA || params.BUILD_INSTANCES_ONLY){ + // unstash deb packages unstash "packages" sh "sshpass -p ${env.ck_deb_pw} scp -o StrictHostKeyChecking=no composablekernel-*.deb ${env.ck_deb_user}@${env.ck_deb_ip}:/var/www/html/composable_kernel/" - try{ - unstash "perf_log" - } - catch(Exception err){ - echo "could not locate perf_log: ${err.getMessage()}." - } - try{ - unstash "perf_log_gfx11" - unstash "perf_log_gfx12" - } - catch(Exception err){ - echo "could not locate the GEMM gfx11/gfx12 performance logs: ${err.getMessage()}." - } - sh "./process_qa_data.sh" } else{ // unstash perf files to master @@ -775,12 +773,12 @@ def process_results(Map conf=[:]){ } } -//launch develop branch daily at 23:00 UT in FULL_QA mode and at 19:00 UT with latest staging compiler version -CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;DISABLE_DL_KERNELS=true;ROCMVERSION=6.4;RUN_CK_TILE_FMHA_TESTS=true;RUN_CK_TILE_TRANSPOSE_TESTS=true;RUN_CK_TILE_GEMM_TESTS=true - 0 21 * * * % ROCMVERSION=6.4;hipTensor_test=true;RUN_CODEGEN_TESTS=true;BUILD_GFX908=true +//launch develop branch daily jobs +CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;DISABLE_DL_KERNELS=true;RUN_CK_TILE_FMHA_TESTS=true;RUN_CK_TILE_TRANSPOSE_TESTS=true;RUN_CK_TILE_GEMM_TESTS=true + 0 21 * * * % RUN_GROUPED_CONV_LARGE_CASES_TESTS=true;hipTensor_test=true;RUN_CODEGEN_TESTS=true;BUILD_GFX908=true 0 19 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-staging;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true 0 17 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-mainline;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true - 0 15 * * * % BUILD_INSTANCES_ONLY=true;RUN_PERFORMANCE_TESTS=false;USE_SCCACHE=false + 0 15 * * * % BUILD_INSTANCES_ONLY=true;USE_SCCACHE=false;NINJA_BUILD_TRACE=true 0 13 * * * % BUILD_LEGACY_OS=true;USE_SCCACHE=false;RUN_PERFORMANCE_TESTS=false''' : "" pipeline { @@ -1263,8 +1261,7 @@ pipeline { execute_args = """ cmake -G Ninja -D CMAKE_PREFIX_PATH=/opt/rocm \ -D CMAKE_CXX_COMPILER="${build_compiler()}" \ -D CMAKE_BUILD_TYPE=Release \ - -D GPU_ARCHS="gfx908;gfx90a;gfx942;gfx950;gfx1030;gfx1100;gfx1151;gfx1201" \ - -D CMAKE_CXX_FLAGS=" -O3 " .. && ninja -j64 """ + -D CMAKE_CXX_FLAGS=" -O3 -ftime-trace" .. && ninja -j64 """ } steps{ buildHipClangJobAndReboot(setup_cmd: "", build_cmd: "", no_reboot:true, build_type: 'Release', execute_cmd: execute_args) diff --git a/library/src/tensor_operation_instance/gpu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/CMakeLists.txt index 25ea3b2ae4..97946207a1 100755 --- a/library/src/tensor_operation_instance/gpu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/CMakeLists.txt @@ -103,7 +103,7 @@ function(add_instance_library INSTANCE_NAME) list(REMOVE_ITEM ARGN "${source}") endif() endforeach() - message("remaining instances: ${ARGN}") + #message("remaining instances: ${ARGN}") #only continue if there are some source files left on the list if(ARGN) set(INST_OBJ) From 5b3430b868766068dabcc92394f0da65d9206099 Mon Sep 17 00:00:00 2001 From: arai713 <67439843+arai713@users.noreply.github.com> Date: Fri, 16 May 2025 11:11:54 -0700 Subject: [PATCH 007/147] Narrowing error fix for codegen compilation (#2194) * removed comment with special characters * fix for arg/template change after merge from develop --------- Co-authored-by: Thomas Ning --- ...e_gemm_pipeline_xdlops_b_preshuffle_v3.hpp | 1 - .../device_gemm_multiple_d_xdl_cshuffle.hpp | 54 ++++++++++--------- 2 files changed, 28 insertions(+), 27 deletions(-) diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v3.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v3.hpp index 6f3a7e6357..6f0404a1ca 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v3.hpp @@ -381,7 +381,6 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3{}([&](auto m0) { static_for<0, KRepeat, 1>{}([&](auto k0) { static_for<0, KGroup, 1>{}([&](auto kg0) { - // K = k0 × KGroup × k1 = k0 × kg0 × A_K1 a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, make_tuple(m0, I0, I0, Number{}, I0, I0), a_block_buf.At(I0), diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp index 6c4195e75d..f193b093d1 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp @@ -860,35 +860,37 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD(p_a_grid, - p_b_grid, - p_ds_grid, - p_e_grid, - p_shared_block, - desc.a_element_op, - desc.b_element_op, - desc.cde_element_op, - desc.a_grid_desc_ak0_m_ak1, - desc.b_grid_desc_bk0_n_bk1, - desc.ds_grid_desc_mblock_mperblock_nblock_nperblock, - desc.e_grid_desc_mblock_mperblock_nblock_nperblock, - desc.block_2_etile_map); + GridwiseGemm::template Run( + p_a_grid, + p_b_grid, + p_ds_grid, + p_e_grid, + p_shared_block, + desc.a_element_op, + desc.b_element_op, + desc.cde_element_op, + desc.a_grid_desc_ak0_m_ak1, + desc.b_grid_desc_bk0_n_bk1, + desc.ds_grid_desc_mblock_mperblock_nblock_nperblock, + desc.e_grid_desc_mblock_mperblock_nblock_nperblock, + desc.block_2_etile_map); } else { - GridwiseGemm::template Run(p_a_grid, - p_b_grid, - p_ds_grid, - p_e_grid, - p_shared_block, - desc.a_element_op, - desc.b_element_op, - desc.cde_element_op, - desc.a_grid_desc_ak0_m_ak1, - desc.b_grid_desc_bk0_n_bk1, - desc.ds_grid_desc_mblock_mperblock_nblock_nperblock, - desc.e_grid_desc_mblock_mperblock_nblock_nperblock, - desc.block_2_etile_map); + GridwiseGemm::template Run( + p_a_grid, + p_b_grid, + p_ds_grid, + p_e_grid, + p_shared_block, + desc.a_element_op, + desc.b_element_op, + desc.cde_element_op, + desc.a_grid_desc_ak0_m_ak1, + desc.b_grid_desc_bk0_n_bk1, + desc.ds_grid_desc_mblock_mperblock_nblock_nperblock, + desc.e_grid_desc_mblock_mperblock_nblock_nperblock, + desc.block_2_etile_map); } } }; From 6342f6b5e8bbb9f2b4cefa33d2a863a8bb35329b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Sat, 17 May 2025 03:42:02 +0200 Subject: [PATCH 008/147] Restore oddc instances (#2201) --- .../gpu/grouped_convolution_forward.hpp | 8 ++ .../gpu/grouped_convolution_forward_wmma.inc | 111 ++++++++++++++++++ .../gpu/grouped_conv2d_fwd/CMakeLists.txt | 4 + ...ma_gnhwc_gkyxc_gnhwk_f16_oddc_instance.cpp | 40 +++++++ ...mma_gnhwc_gkyxc_gnhwk_i8_oddc_instance.cpp | 40 +++++++ ...ma_nhwgc_gkyxc_nhwgk_f16_oddc_instance.cpp | 40 +++++++ ...mma_nhwgc_gkyxc_nhwgk_i8_oddc_instance.cpp | 40 +++++++ ...hwgc_gkyxc_nhwgk_bf16_comp_2x_instance.cpp | 9 ++ ...l_nhwgc_gkyxc_nhwgk_bf16_comp_instance.cpp | 8 ++ ...c_gkyxc_nhwgk_bf16_comp_part2_instance.cpp | 9 ++ ...nhwgc_gkyxc_nhwgk_f16_comp_2x_instance.cpp | 9 ++ ...dl_nhwgc_gkyxc_nhwgk_f16_comp_instance.cpp | 8 ++ ...gc_gkyxc_nhwgk_f16_comp_part2_instance.cpp | 9 ++ ...dl_nhwgc_gkyxc_nhwgk_f32_comp_instance.cpp | 10 +- ...l_nhwgc_gkyxc_nhwgk_int8_comp_instance.cpp | 28 ++++- ...wd_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp | 10 +- ...fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp | 10 +- ...fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp | 10 +- ...wd_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp | 10 +- ...fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp | 10 +- ...fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp | 10 +- ...wd_xdl_nhwgc_gkyxc_nhwgk_int8_instance.cpp | 10 +- ...gc_gkyxc_nhwgk_bf16_mem_inter_instance.cpp | 11 +- ...gc_gkyxc_nhwgk_bf16_mem_intra_instance.cpp | 11 +- ...wgc_gkyxc_nhwgk_f16_mem_inter_instance.cpp | 11 +- ...wgc_gkyxc_nhwgk_f16_mem_intra_instance.cpp | 11 +- ...wgc_gkyxc_nhwgk_f32_mem_inter_instance.cpp | 11 +- ...wgc_gkyxc_nhwgk_f32_mem_intra_instance.cpp | 11 +- ...gc_gkyxc_nhwgk_int8_mem_inter_instance.cpp | 11 +- ...gc_gkyxc_nhwgk_int8_mem_intra_instance.cpp | 11 +- .../gpu/grouped_conv3d_fwd/CMakeLists.txt | 4 + ...gndhwc_gkzyxc_gndhwk_f16_oddc_instance.cpp | 41 +++++++ ..._gndhwc_gkzyxc_gndhwk_i8_oddc_instance.cpp | 41 +++++++ ...ndhwgc_gkzyxc_ndhwgk_f16_oddc_instance.cpp | 41 +++++++ ..._ndhwgc_gkzyxc_ndhwgk_i8_oddc_instance.cpp | 41 +++++++ 35 files changed, 682 insertions(+), 17 deletions(-) create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_oddc_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_oddc_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_oddc_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_oddc_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_oddc_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_oddc_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_oddc_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_oddc_instance.cpp diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp index cf5dbaa323..545826650c 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp @@ -613,6 +613,7 @@ struct DeviceOperationInstanceFactory>>& instances); +void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_oddc_instances( + std::vector>>& instances); + void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_instances( std::vector>>& instances); +void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_oddc_instances( + std::vector>>& instances); + void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_instances( std::vector>>& instances); +void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_oddc_instances( + std::vector>>& instances); + void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_instances( std::vector>>& instances); +void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_oddc_instances( + std::vector>>& instances); #endif #ifdef CK_ENABLE_FP16 @@ -236,6 +291,20 @@ void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_1x1s1p0_instances( PassThrough, PassThrough>>>& instances); +void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_oddc_instances( + std::vector>>& instances); + void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_instances( std::vector>>& instances); +void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_oddc_instances( + std::vector>>& instances); + void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instances( std::vector>>& instances); +void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_oddc_instances( + std::vector>>& instances); + void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_instances( std::vector>>& instances); + +void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_oddc_instances( + std::vector>>& instances); #endif } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt index eba6fd789e..22e9d726b0 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt @@ -93,6 +93,8 @@ add_instance_library(device_grouped_conv2d_fwd_instance wmma/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_1x1p0_instance.cpp wmma/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_1x1s1p0_instance.cpp wmma/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_1x1s1p0_instance.cpp + wmma/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_oddc_instance.cpp + wmma/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_oddc_instance.cpp ## NHWGC, GKYXC, NHWGK wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_instance.cpp wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_instance.cpp @@ -100,4 +102,6 @@ add_instance_library(device_grouped_conv2d_fwd_instance wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_1x1p0_instance.cpp wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_1x1s1p0_instance.cpp wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_1x1s1p0_instance.cpp + wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_oddc_instance.cpp + wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_oddc_instance.cpp ) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_oddc_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_oddc_instance.cpp new file mode 100644 index 0000000000..a8f723dfec --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_oddc_instance.cpp @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[g, n, hi, wi, c] * wei[g, k, y, x, c] = out[g, n, ho, wo, k] +void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_oddc_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_wmma_f16_instances<2, + GNHWC, + GKYXC, + Empty_Tuple, + GNHWK, + Empty_Tuple, + PassThrough, + ConvFwdOddC>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_oddc_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_oddc_instance.cpp new file mode 100644 index 0000000000..784a118897 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_oddc_instance.cpp @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[g, n, hi, wi, c] * wei[g, k, y, x, c] = out[g, n, ho, wo, k] +void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_oddc_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_wmma_i8_instances<2, + GNHWC, + GKYXC, + Empty_Tuple, + GNHWK, + Empty_Tuple, + PassThrough, + ConvFwdOddC>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_oddc_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_oddc_instance.cpp new file mode 100644 index 0000000000..8c621543a9 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_oddc_instance.cpp @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_oddc_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_wmma_f16_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + Empty_Tuple, + PassThrough, + ConvFwdOddC>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_oddc_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_oddc_instance.cpp new file mode 100644 index 0000000000..5cb313b3ca --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_oddc_instance.cpp @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_oddc_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_wmma_i8_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + Empty_Tuple, + PassThrough, + ConvFwdOddC>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_2x_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_2x_instance.cpp index f5df7278d0..c078f8ed04 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_2x_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_2x_instance.cpp @@ -52,6 +52,15 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_2x_instances( Empty_Tuple, NHWGK, ConvFwd1x1S1P0>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances_2x<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdOddC>{}); } } diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instance.cpp index db048679bd..a67b11f1cf 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instance.cpp @@ -49,6 +49,14 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instances( Empty_Tuple, NHWGK, ConvFwd1x1S1P0>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdOddC>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_part2_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_part2_instance.cpp index ee9507a80a..5c0391a25f 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_part2_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_part2_instance.cpp @@ -52,6 +52,15 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_part2_instanc Empty_Tuple, NHWGK, ConvFwd1x1S1P0>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances_part2<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdOddC>{}); } } diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_2x_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_2x_instance.cpp index 132d3c8411..726276c461 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_2x_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_2x_instance.cpp @@ -52,6 +52,15 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_2x_instances( Empty_Tuple, NHWGK, ConvFwd1x1S1P0>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f16_comp_instances_2x<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdOddC>{}); } } diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_instance.cpp index a7deb969ba..8b7bdec2a8 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_instance.cpp @@ -49,6 +49,14 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_instances( Empty_Tuple, NHWGK, ConvFwd1x1S1P0>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f16_comp_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdOddC>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_part2_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_part2_instance.cpp index d2732547fa..c66114b9a3 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_part2_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_part2_instance.cpp @@ -52,6 +52,15 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_part2_instance Empty_Tuple, NHWGK, ConvFwd1x1S1P0>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f16_comp_instances_part2<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdOddC>{}); } } diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instance.cpp index 8a0caebc9f..93e07e08fb 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" @@ -48,6 +48,14 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instances( Empty_Tuple, NHWGK, ConvFwd1x1S1P0>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_comp_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdOddC>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_comp_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_comp_instance.cpp index e45df1e107..6acbb7475c 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_comp_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_comp_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" @@ -50,6 +50,14 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_comp_instances( NHWGK, ConvFwd1x1S1P0>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_int8_comp_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdOddC>{}); + if(ck::get_device_name() != "gfx950") { add_device_operation_instances( @@ -78,6 +86,15 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_comp_instances( Empty_Tuple, NHWGK, ConvFwd1x1S1P0>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_int8_comp_instances_part2<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdOddC>{}); } if(ck::get_device_name() == "gfx950") @@ -108,6 +125,15 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_comp_instances( Empty_Tuple, NHWGK, ConvFwd1x1S1P0>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_int8_comp_instances_2x<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdOddC>{}); } } diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp index 078221f89f..2afbfdc386 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" @@ -46,6 +46,14 @@ void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instances( Empty_Tuple, GNHWK, ConvFwd1x1S1P0>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_instances<2, + GNHWC, + GKYXC, + Empty_Tuple, + GNHWK, + ConvFwdOddC>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp index 3a481dd204..822ef51e00 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" @@ -46,6 +46,14 @@ void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instances( Empty_Tuple, GNHWK, ConvFwd1x1S1P0>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f16_instances<2, + GNHWC, + GKYXC, + Empty_Tuple, + GNHWK, + ConvFwdOddC>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp index 5add0f8add..79a1fb99a8 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" @@ -46,6 +46,14 @@ void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instances( Empty_Tuple, GNHWK, ConvFwd1x1S1P0>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_instances<2, + GNHWC, + GKYXC, + Empty_Tuple, + GNHWK, + ConvFwdOddC>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp index 0257c7d315..e567c0df75 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" @@ -46,6 +46,14 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instances( Empty_Tuple, NHWGK, ConvFwd1x1S1P0>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdOddC>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp index 2715506fe2..3e42184996 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" @@ -46,6 +46,14 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instances( Empty_Tuple, NHWGK, ConvFwd1x1S1P0>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f16_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdOddC>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp index 8d3e4d91b1..c035d4c3da 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" @@ -46,6 +46,14 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instances( Empty_Tuple, NHWGK, ConvFwd1x1S1P0>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdOddC>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_instance.cpp index 465fa927a5..5c425effd8 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" @@ -46,6 +46,14 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_instances( Empty_Tuple, NHWGK, ConvFwd1x1S1P0>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_int8_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdOddC>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instance.cpp index 87423801cb..e8a763c527 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" @@ -49,6 +49,15 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instance NHWGK, ConvFwd1x1S1P0, Interwave>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_mem_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdOddC, + Interwave>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instance.cpp index ebb213461a..3ae3fb5186 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" @@ -49,6 +49,15 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instance NHWGK, ConvFwd1x1S1P0, Intrawave>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_mem_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdOddC, + Intrawave>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_mem_inter_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_mem_inter_instance.cpp index c2c8a099b2..cb7e912936 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_mem_inter_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_mem_inter_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" @@ -49,6 +49,15 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_mem_inter_instances NHWGK, ConvFwd1x1S1P0, Interwave>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f16_mem_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdOddC, + Interwave>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_mem_intra_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_mem_intra_instance.cpp index 11cb853f0d..d787f4b048 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_mem_intra_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_mem_intra_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" @@ -49,6 +49,15 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_mem_intra_instances NHWGK, ConvFwd1x1S1P0, Intrawave>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f16_mem_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdOddC, + Intrawave>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instance.cpp index 1992d7f7c1..5644289790 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" @@ -49,6 +49,15 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instances NHWGK, ConvFwd1x1S1P0, Interwave>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_mem_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdOddC, + Interwave>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instance.cpp index 2b8fd3d9db..5b12dad5a3 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" @@ -49,6 +49,15 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instances NHWGK, ConvFwd1x1S1P0, Intrawave>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_mem_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdOddC, + Intrawave>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_inter_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_inter_instance.cpp index 5579ec62cc..f667481fa4 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_inter_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_inter_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" @@ -49,6 +49,15 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_inter_instance NHWGK, ConvFwd1x1S1P0, Interwave>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_int8_mem_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdOddC, + Interwave>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_intra_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_intra_instance.cpp index 77f3df2c11..2ff2c7f51f 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_intra_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_intra_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" @@ -49,6 +49,15 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_intra_instance NHWGK, ConvFwd1x1S1P0, Intrawave>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_int8_mem_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdOddC, + Intrawave>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt index f55bdd45c9..f8efa5a7c1 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt @@ -66,6 +66,10 @@ set(GROUPED_CONV3D_FWD wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1s1p0_instance.cpp wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1s1p0_instance.cpp wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1s1p0_instance.cpp + wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_oddc_instance.cpp + wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_oddc_instance.cpp + wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_oddc_instance.cpp + wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_oddc_instance.cpp ) if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "fp16") OR NOT DEFINED DTYPES) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_oddc_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_oddc_instance.cpp new file mode 100644 index 0000000000..fa378af1ee --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_oddc_instance.cpp @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[g, n, di, hi, wi, c] * wei[g, k, z, y, x, c] = out[g, n, do, ho, +// wo, k] +void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_oddc_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_wmma_f16_instances<3, + GNDHWC, + GKZYXC, + Empty_Tuple, + GNDHWK, + Empty_Tuple, + PassThrough, + ConvFwdOddC>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_oddc_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_oddc_instance.cpp new file mode 100644 index 0000000000..d41416fd4a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_oddc_instance.cpp @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[g, n, di, hi, wi, c] * wei[g, k, z, y, x, c] = out[g, n, do, ho, +// wo, k] +void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_oddc_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_wmma_i8_instances<3, + GNDHWC, + GKZYXC, + Empty_Tuple, + GNDHWK, + Empty_Tuple, + PassThrough, + ConvFwdOddC>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_oddc_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_oddc_instance.cpp new file mode 100644 index 0000000000..8a7bc26178 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_oddc_instance.cpp @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, di, hi, wi, g, c] * wei[g, k, z, y, x, c] = out[n, do, ho, wo, +// g, k] +void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_oddc_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_wmma_f16_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + Empty_Tuple, + PassThrough, + ConvFwdOddC>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_oddc_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_oddc_instance.cpp new file mode 100644 index 0000000000..7649f86971 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_oddc_instance.cpp @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, di, hi, wi, g, c] * wei[g, k, z, y, x, c] = out[n, do, ho, wo, +// g, k] +void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_oddc_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_wmma_i8_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + Empty_Tuple, + PassThrough, + ConvFwdOddC>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck From b8b12bb81e1b370d39ab7b17b0c13654a6e54721 Mon Sep 17 00:00:00 2001 From: jefyang1 <146495389+jefyang1@users.noreply.github.com> Date: Mon, 19 May 2025 14:25:50 -0700 Subject: [PATCH 009/147] Fix example_grouped_gemm_multiple_d_xdl_fp16 on gfx950 (#2203) * Fix example_grouped_gemm_multiple_d_xdl_fp16 on gfx950 * Run clang format --- example/15_grouped_gemm/grouped_gemm_multiple_d_xdl_fp16.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/example/15_grouped_gemm/grouped_gemm_multiple_d_xdl_fp16.cpp b/example/15_grouped_gemm/grouped_gemm_multiple_d_xdl_fp16.cpp index db162fe444..63a2aea0b3 100644 --- a/example/15_grouped_gemm/grouped_gemm_multiple_d_xdl_fp16.cpp +++ b/example/15_grouped_gemm/grouped_gemm_multiple_d_xdl_fp16.cpp @@ -141,8 +141,8 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co a_tensors_device.reserve(group_count); b_tensors_device.reserve(group_count); - d_tensors_device.reserve(group_count); c_tensors_device.reserve(group_count); + d_tensors_device.resize(group_count); // reserve and update vector size std::size_t flop = 0, num_btype = 0; From 57e0f5df29abefd919c334c994628a994ba2868c Mon Sep 17 00:00:00 2001 From: Andriy Roshchenko <107577548+andriy-ca@users.noreply.github.com> Date: Mon, 19 May 2025 15:52:51 -0600 Subject: [PATCH 010/147] MX GEMM - Expand MX MFMA Testing to BF8, FP6, and BF6 Data Types (#2199) * Unify test interface for different layouts. * WIP: Introducing FP4/FP6/FP8 abstractions * WIP: Introducing packed storage abstraction * WIP: Introducing packed storage abstraction * WIP: Improved support for FP6 data type * Refactor packed storage for f6_t * WIP: FP6 MFMA test * Test if we correctly represent all FP6/FP4 numbers * Additional output for failed FP4 test. * More failing conversion tests * Even more failing conversion tests * Working FP6 MFMA tests * Expand MX MFMA testing to BF8/6 * Update and verify MX MFMA test for packed types * Fix fp4 and fp6 conversions on host * Working MX MFMA tests for FP8/6/4 * Cleanup * Add missing type * Cleanup * Final cleanup * Restrict FP6/4 values output to CK_LOGGING=1 * Use CHAR_BIT instead of number 8 * Fix typo * Remove FP6 and FP4 from the list of native types --------- Co-authored-by: Rostyslav Geyyer --- include/ck/library/utility/host_tensor.hpp | 57 +-- .../library/utility/host_tensor_generator.hpp | 232 ++++++++++ include/ck/utility/amd_xdlops.hpp | 390 ++++++++++++++-- include/ck/utility/data_type.hpp | 428 +++++++----------- include/ck/utility/dtype_vector.hpp | 104 ++++- include/ck/utility/mxf4_utils.hpp | 12 +- include/ck/utility/mxf6_utils.hpp | 8 +- .../cpu/reference_gemm.hpp | 16 + .../cpu/reference_mx_gemm.hpp | 20 + test/data_type/test_bf6.cpp | 111 ++++- test/data_type/test_fp4.cpp | 57 +++ test/data_type/test_fp6.cpp | 106 ++++- test/mx_mfma_op/mx_mfma_op.cpp | 365 ++++++++++++--- test/mx_mfma_op/mx_mfma_op.hpp | 282 ++++++------ 14 files changed, 1601 insertions(+), 587 deletions(-) diff --git a/include/ck/library/utility/host_tensor.hpp b/include/ck/library/utility/host_tensor.hpp index 71417ce7bf..257636d956 100644 --- a/include/ck/library/utility/host_tensor.hpp +++ b/include/ck/library/utility/host_tensor.hpp @@ -360,10 +360,9 @@ struct Tensor std::size_t GetElementSpaceSize() const { - if constexpr(ck::is_same_v, ck::pk_i4_t> || - ck::is_same_v, ck::f4x2_pk_t>) + if constexpr(ck::is_packed_type_v>) { - return (mDesc.GetElementSpaceSize() + 1) / 2; + return (mDesc.GetElementSpaceSize() + 1) / ck::packed_size_v>; } else { @@ -516,69 +515,31 @@ struct Tensor template std::size_t GetOffsetFromMultiIndex(Is... is) const { - if constexpr(ck::is_same_v, ck::pk_i4_t> || - ck::is_same_v, ck::f4x2_pk_t>) - { - return mDesc.GetOffsetFromMultiIndex(is...) / 2; - } - else - { - return mDesc.GetOffsetFromMultiIndex(is...); - } + return mDesc.GetOffsetFromMultiIndex(is...) / ck::packed_size_v>; } template T& operator()(Is... is) { - if constexpr(ck::is_same_v, ck::pk_i4_t> || - ck::is_same_v, ck::f4x2_pk_t>) - { - return mData[mDesc.GetOffsetFromMultiIndex(is...) / 2]; - } - else - { - return mData[mDesc.GetOffsetFromMultiIndex(is...)]; - } + return mData[mDesc.GetOffsetFromMultiIndex(is...) / + ck::packed_size_v>]; } template const T& operator()(Is... is) const { - if constexpr(ck::is_same_v, ck::pk_i4_t> || - ck::is_same_v, ck::f4x2_pk_t>) - { - return mData[mDesc.GetOffsetFromMultiIndex(is...) / 2]; - } - else - { - return mData[mDesc.GetOffsetFromMultiIndex(is...)]; - } + return mData[mDesc.GetOffsetFromMultiIndex(is...) / + ck::packed_size_v>]; } T& operator()(std::vector idx) { - if constexpr(ck::is_same_v, ck::pk_i4_t> || - ck::is_same_v, ck::f4x2_pk_t>) - { - return mData[mDesc.GetOffsetFromMultiIndex(idx) / 2]; - } - else - { - return mData[mDesc.GetOffsetFromMultiIndex(idx)]; - } + return mData[mDesc.GetOffsetFromMultiIndex(idx) / ck::packed_size_v>]; } const T& operator()(std::vector idx) const { - if constexpr(ck::is_same_v, ck::pk_i4_t> || - ck::is_same_v, ck::f4x2_pk_t>) - { - return mData[mDesc.GetOffsetFromMultiIndex(idx) / 2]; - } - else - { - return mData[mDesc.GetOffsetFromMultiIndex(idx)]; - } + return mData[mDesc.GetOffsetFromMultiIndex(idx) / ck::packed_size_v>]; } typename Data::iterator begin() { return mData.begin(); } diff --git a/include/ck/library/utility/host_tensor_generator.hpp b/include/ck/library/utility/host_tensor_generator.hpp index 785f74a3c0..f48ba49bbf 100644 --- a/include/ck/library/utility/host_tensor_generator.hpp +++ b/include/ck/library/utility/host_tensor_generator.hpp @@ -67,6 +67,18 @@ struct GeneratorTensor_1 return ck::type_convert(value); } }; + +template <> +struct GeneratorTensor_1 +{ + float value = 1.0; + + template + ck::bf8_t operator()(Is...) + { + return ck::type_convert(value); + } +}; #endif template <> @@ -93,6 +105,38 @@ struct GeneratorTensor_1 } }; +template <> +struct GeneratorTensor_1 +{ + float value = 1.0; + + template + ck::f6x32_pk_t operator()(Is...) + { + ck::f6x32_pk_t r; + ck::static_for<0, 32, 1>{}([&](auto i) { + r.pack(ck::type_convert(value), static_cast(i)); + }); + return r; + } +}; + +template <> +struct GeneratorTensor_1 +{ + float value = 1.0; + + template + ck::bf6x32_pk_t operator()(Is...) + { + ck::bf6x32_pk_t r; + ck::static_for<0, 32, 1>{}([&](auto i) { + r.pack(ck::type_convert(value), static_cast(i)); + }); + return r; + } +}; + template <> struct GeneratorTensor_1 { @@ -132,6 +176,44 @@ struct GeneratorTensor_2 } }; +template <> +struct GeneratorTensor_2 +{ + int min_value = 0; + int max_value = 1; + + template + ck::f6x32_pk_t operator()(Is...) + { + ck::f6x32_pk_t r; + ck::static_for<0, 32, 1>{}([&](auto i) { + float tmp = (std::rand() % (max_value - min_value)) + min_value; + r.pack(ck::type_convert(tmp), static_cast(i)); + }); + + return r; + } +}; + +template <> +struct GeneratorTensor_2 +{ + int min_value = 0; + int max_value = 1; + + template + ck::bf6x32_pk_t operator()(Is...) + { + ck::bf6x32_pk_t r; + ck::static_for<0, 32, 1>{}([&](auto i) { + float tmp = (std::rand() % (max_value - min_value)) + min_value; + r.pack(ck::type_convert(tmp), static_cast(i)); + }); + + return r; + } +}; + template <> struct GeneratorTensor_2 { @@ -342,6 +424,46 @@ struct GeneratorTensor_3 } }; +template <> +struct GeneratorTensor_3 +{ + float min_value = 0; + float max_value = 1; + + template + ck::f6x32_pk_t operator()(Is...) + { + ck::f6x32_pk_t r; + ck::static_for<0, 32, 1>{}([&](auto i) { + float rnd = float(std::rand()) / float(RAND_MAX); + float fp32 = min_value + rnd * (max_value - min_value); + r.pack(ck::type_convert(fp32), static_cast(i)); + }); + + return r; + } +}; + +template <> +struct GeneratorTensor_3 +{ + float min_value = 0; + float max_value = 1; + + template + ck::bf6x32_pk_t operator()(Is...) + { + ck::bf6x32_pk_t r; + ck::static_for<0, 32, 1>{}([&](auto i) { + float rnd = float(std::rand()) / float(RAND_MAX); + float fp32 = min_value + rnd * (max_value - min_value); + r.pack(ck::type_convert(fp32), static_cast(i)); + }); + + return r; + } +}; + template struct GeneratorTensor_4 { @@ -360,6 +482,69 @@ struct GeneratorTensor_4 } }; +template <> +struct GeneratorTensor_4 +{ + std::mt19937 generator; + std::normal_distribution distribution; + + GeneratorTensor_4(float mean, float stddev, unsigned int seed = 1) + : generator(seed), distribution(mean, stddev){}; + + template + ck::f4x2_pk_t operator()(Is...) + { + float fp32_tmp0 = distribution(generator); + float fp32_tmp1 = distribution(generator); + + return ck::f4x2_pk_t{ck::type_convert(ck::float2_t{fp32_tmp0, fp32_tmp1})}; + } +}; + +template <> +struct GeneratorTensor_4 +{ + std::mt19937 generator; + std::normal_distribution distribution; + + GeneratorTensor_4(float mean, float stddev, unsigned int seed = 1) + : generator(seed), distribution(mean, stddev){}; + + template + ck::f6x32_pk_t operator()(Is...) + { + ck::f6x32_pk_t r; + ck::static_for<0, 32, 1>{}([&](auto i) { + r.pack(ck::type_convert(distribution(generator)), + static_cast(i)); + }); + + return r; + } +}; + +template <> +struct GeneratorTensor_4 +{ + std::mt19937 generator; + std::normal_distribution distribution; + + GeneratorTensor_4(float mean, float stddev, unsigned int seed = 1) + : generator(seed), distribution(mean, stddev){}; + + template + ck::bf6x32_pk_t operator()(Is...) + { + ck::bf6x32_pk_t r; + ck::static_for<0, 32, 1>{}([&](auto i) { + r.pack(ck::type_convert(distribution(generator)), + static_cast(i)); + }); + + return r; + } +}; + struct GeneratorTensor_Checkboard { template @@ -405,6 +590,53 @@ struct GeneratorTensor_Sequential } }; +template +struct GeneratorTensor_Sequential +{ + template + ck::f4x2_pk_t operator()(Ts... Xs) const + { + std::array dims = {{static_cast(Xs)...}}; + + float tmp = dims[Dim]; + return ck::type_convert(ck::float2_t(tmp)); + } +}; + +template +struct GeneratorTensor_Sequential +{ + template + ck::f6x32_pk_t operator()(Ts... Xs) const + { + std::array dims = {{static_cast(Xs)...}}; + + float tmp = dims[Dim]; + + ck::f6x32_pk_t r; + ck::static_for<0, 32, 1>{}( + [&](auto i) { r.pack(ck::type_convert(tmp), static_cast(i)); }); + return r; + } +}; + +template +struct GeneratorTensor_Sequential +{ + template + ck::bf6x32_pk_t operator()(Ts... Xs) const + { + std::array dims = {{static_cast(Xs)...}}; + + float tmp = dims[Dim]; + + ck::bf6x32_pk_t r; + ck::static_for<0, 32, 1>{}( + [&](auto i) { r.pack(ck::type_convert(tmp), static_cast(i)); }); + return r; + } +}; + template struct GeneratorTensor_Diagonal { diff --git a/include/ck/utility/amd_xdlops.hpp b/include/ck/utility/amd_xdlops.hpp index 66c4958e1d..ad48389625 100644 --- a/include/ck/utility/amd_xdlops.hpp +++ b/include/ck/utility/amd_xdlops.hpp @@ -498,7 +498,7 @@ struct intrin_mfma_f32_32x32x64f8f6f4<32, 32> reg_a, reg_b, reg_c.template AsType()[Number<0>{}], - 0, // cbsz + 0, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} 0, // blgp 0, 0, @@ -511,6 +511,28 @@ struct intrin_mfma_f32_32x32x64f8f6f4<32, 32> #endif } + template + __device__ static void Run(const bf8x32_t& reg_a, const bf8x32_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx950__) + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4( + reg_a, + reg_b, + reg_c.template AsType()[Number<0>{}], + 1, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} + 1, // blgp + 0, + 0, + 0, + 0); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif + } + template __device__ static void Run(const f4x32_t& reg_a, const f4x32_t& reg_b, FloatC& reg_c) { @@ -536,6 +558,62 @@ struct intrin_mfma_f32_32x32x64f8f6f4<32, 32> ignore = reg_a; ignore = reg_b; ignore = reg_c; +#endif + } + + template + __device__ static void Run(const f6x32_t& reg_a, const f6x32_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx950__) + + int32x6_t arg_a = bit_cast(reg_a); + int32x6_t arg_b = bit_cast(reg_b); + + using arg_type = int32x8_t; + + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4( + arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0}, + arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0}, + reg_c.template AsType()[Number<0>{}], + 2, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} + 2, // blgp + 0, // OPSEL + 0, + 0, // OPSEL + 0); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif + } + + template + __device__ static void Run(const bf6x32_t& reg_a, const bf6x32_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx950__) + + int32x6_t arg_a = bit_cast(reg_a); + int32x6_t arg_b = bit_cast(reg_b); + + using arg_type = int32x8_t; + + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4( + arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0}, + arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0}, + reg_c.template AsType()[Number<0>{}], + 3, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} + 3, // blgp + 0, // OPSEL + 0, + 0, // OPSEL + 0); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; #endif } }; @@ -583,6 +661,43 @@ struct intrin_mfma_scale_f32_32x32x64f8f6f4<32, 32> #endif } + template + __device__ static void Run(const bf8x32_t& reg_a, + const int32_t& scale_a, + const bf8x32_t& reg_b, + const int32_t& scale_b, + FloatC& reg_c) + { +#if defined(__gfx950__) + // https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10 + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4( + reg_a, + reg_b, + reg_c.template AsType()[Number<0>{}], + 1, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} + 1, // blgp + 0, // OPSEL + scale_a, + 0, // OPSEL + scale_b); + // XXX: Note on the scale_a and scale_b parameters: + // If compiler detects that one or both scales are constant values, it will treat that + // constant as F32 constant. I.e., if scale_a at some point was declared as + // `e8m0_bexp_t a_scale{1.0f}`, the instruction would only work if scale_a parameter is + // assigned value `bit_cast(static_cast(a_scale))`. + + // XXX: Note on the OPSEL parameters: Instruction always takes byte0 as a scale value even + // when OPSEL is set otherwise. +#else + ignore = reg_a; + ignore = scale_a; + ignore = reg_b; + ignore = scale_b; + ignore = reg_c; +#endif + } + template __device__ static void Run(const bf8x32_t& reg_a, const int32_t& scale_a, @@ -620,6 +735,74 @@ struct intrin_mfma_scale_f32_32x32x64f8f6f4<32, 32> #endif } + template + __device__ static void Run(const f6x32_t& reg_a, + const int32_t scale_a, + const f6x32_t& reg_b, + const int32_t scale_b, + FloatC& reg_c) + { +#if defined(__gfx950__) + + int32x6_t arg_a = bit_cast(reg_a); + int32x6_t arg_b = bit_cast(reg_b); + + using arg_type = int32x8_t; + + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4( + arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0}, + arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0}, + reg_c.template AsType()[Number<0>{}], + 2, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} + 2, // blgp + 0, // OPSEL + scale_a, + 0, // OPSEL + scale_b); +#else + ignore = reg_a; + ignore = scale_a; + ignore = reg_b; + ignore = scale_b; + ignore = reg_c; +#endif + } + + template + __device__ static void Run(const bf6x32_t& reg_a, + const int32_t scale_a, + const bf6x32_t& reg_b, + const int32_t scale_b, + FloatC& reg_c) + { +#if defined(__gfx950__) + + int32x6_t arg_a = bit_cast(reg_a); + int32x6_t arg_b = bit_cast(reg_b); + + using arg_type = int32x8_t; + + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4( + arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0}, + arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0}, + reg_c.template AsType()[Number<0>{}], + 3, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} + 3, // blgp + 0, // OPSEL + scale_a, + 0, // OPSEL + scale_b); +#else + ignore = reg_a; + ignore = scale_a; + ignore = reg_b; + ignore = scale_b; + ignore = reg_c; +#endif + } + template __device__ static void Run(const f4x32_t& reg_a, const int32_t scale_a, @@ -639,7 +822,7 @@ struct intrin_mfma_scale_f32_32x32x64f8f6f4<32, 32> arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], 0, 0, 0, 0}, arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], 0, 0, 0, 0}, reg_c.template AsType()[Number<0>{}], - 4, // cbsz + 4, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} 4, // blgp 0, // OPSEL scale_a, @@ -748,6 +931,101 @@ struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16> #endif } + template + __device__ static void Run(const bf8x32_t& reg_a, + const int32_t& scale_a, + const f8x32_t& reg_b, + const int32_t& scale_b, + FloatC& reg_c) + { +#if defined(__gfx950__) + // https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10 + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( + reg_a, + reg_b, + reg_c.template AsType()[Number<0>{}], + 1, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} + 0, // blgp + 0, // OPSEL + scale_a, + 0, // OPSEL + scale_b); +#else + ignore = reg_a; + ignore = scale_a; + ignore = reg_b; + ignore = scale_b; + ignore = reg_c; +#endif + } + + template + __device__ static void Run(const f6x32_t& reg_a, + const int32_t scale_a, + const f6x32_t& reg_b, + const int32_t scale_b, + FloatC& reg_c) + { +#if defined(__gfx950__) + int32x6_t arg_a = bit_cast(reg_a); + int32x6_t arg_b = bit_cast(reg_b); + + using arg_type = int32x8_t; + + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( + arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0}, + arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0}, + reg_c.template AsType()[Number<0>{}], + 2, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} + 2, // blgp + 0, // OPSEL + scale_a, + 0, // OPSEL + scale_b); +#else + ignore = reg_a; + ignore = scale_a; + ignore = reg_b; + ignore = scale_b; + ignore = reg_c; +#endif + } + + template + __device__ static void Run(const bf6x32_t& reg_a, + const int32_t scale_a, + const bf6x32_t& reg_b, + const int32_t scale_b, + FloatC& reg_c) + { +#if defined(__gfx950__) + int32x6_t arg_a = bit_cast(reg_a); + int32x6_t arg_b = bit_cast(reg_b); + + using arg_type = int32x8_t; + + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( + arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0}, + arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0}, + reg_c.template AsType()[Number<0>{}], + 3, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} + 3, // blgp + 0, // OPSEL + scale_a, + 0, // OPSEL + scale_b); +#else + ignore = reg_a; + ignore = scale_a; + ignore = reg_b; + ignore = scale_b; + ignore = reg_c; +#endif + } + template __device__ static void Run(const f4x32_t& reg_a, const int32_t scale_a, @@ -778,35 +1056,6 @@ struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16> ignore = reg_b; ignore = scale_b; ignore = reg_c; -#endif - } - - template - __device__ static void Run(const bf8x32_t& reg_a, - const int32_t& scale_a, - const f8x32_t& reg_b, - const int32_t& scale_b, - FloatC& reg_c) - { -#if defined(__gfx950__) - // https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10 - reg_c.template AsType()(Number<0>{}) = - __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( - reg_a, - reg_b, - reg_c.template AsType()[Number<0>{}], - 1, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} - 0, // blgp - 0, // OPSEL - scale_a, - 0, // OPSEL - scale_b); -#else - ignore = reg_a; - ignore = scale_a; - ignore = reg_b; - ignore = scale_b; - ignore = reg_c; #endif } }; @@ -833,7 +1082,7 @@ struct intrin_mfma_f32_16x16x128f8f6f4<16, 16> reg_a, reg_b, reg_c.template AsType()[Number<0>{}], - 0, // cbsz + 0, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} 0, // blgp 0, 0, @@ -846,6 +1095,29 @@ struct intrin_mfma_f32_16x16x128f8f6f4<16, 16> #endif } + template + __device__ static void Run(const bf8x32_t& reg_a, const bf8x32_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx950__) + // https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10 + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( + reg_a, + reg_b, + reg_c.template AsType()[Number<0>{}], + 1, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} + 1, // blgp + 0, + 0, + 0, + 0); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif + } + template __device__ static void Run(const f4x32_t& reg_a, const f4x32_t& reg_b, FloatC& reg_c) { @@ -870,6 +1142,60 @@ struct intrin_mfma_f32_16x16x128f8f6f4<16, 16> ignore = reg_a; ignore = reg_b; ignore = reg_c; +#endif + } + + template + __device__ static void Run(const f6x32_t& reg_a, const f6x32_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx950__) + int32x6_t arg_a = bit_cast(reg_a); + int32x6_t arg_b = bit_cast(reg_b); + + using arg_type = int32x8_t; + + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( + arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0}, + arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0}, + reg_c.template AsType()[Number<0>{}], + 2, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} + 2, // blgp + 0, // OPSEL + 0, + 0, // OPSEL + 0); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif + } + + template + __device__ static void Run(const bf6x32_t& reg_a, const bf6x32_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx950__) + int32x6_t arg_a = bit_cast(reg_a); + int32x6_t arg_b = bit_cast(reg_b); + + using arg_type = int32x8_t; + + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( + arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0}, + arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0}, + reg_c.template AsType()[Number<0>{}], + 3, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} + 3, // blgp + 0, // OPSEL + 0, + 0, // OPSEL + 0); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; #endif } }; diff --git a/include/ck/utility/data_type.hpp b/include/ck/utility/data_type.hpp index a6106bb146..c11b9c0272 100644 --- a/include/ck/utility/data_type.hpp +++ b/include/ck/utility/data_type.hpp @@ -32,8 +32,14 @@ using f4_t = unsigned _BitInt(4); using f6_t = _BitInt(6); // e2m3 format using bf6_t = unsigned _BitInt(6); // e3m2 format +// scalar_type +template +struct scalar_type; + struct f4x2_pk_t { + static constexpr int packed_size = 2; + using type = uint8_t; type data; __host__ __device__ f4x2_pk_t() : data{type{}} {} @@ -55,269 +61,82 @@ struct f4x2_pk_t } }; -struct f6x16_pk_t +template +struct f6_pk_t { - // store 16 elements of f6_t in an array of 3 uint32_t - using element_type = uint32_t; - using type = StaticallyIndexedArray_v2; - type data; - typedef int8_t test_vec_t __attribute__((ext_vector_type(16))); - f6x16_pk_t() : data{type{}} {} - f6x16_pk_t(type init) : data{init} {} + using element_type = uint32_t; // element storage fundamental type - template - __host__ __device__ inline f6_t unpack(Number) + static constexpr index_t packed_size = pk_size; + static constexpr index_t num_bits_elem = 6; + static constexpr index_t num_bits_vec_elem = sizeof(element_type) * CHAR_BIT; + static_assert((packed_size * num_bits_elem) % num_bits_vec_elem == 0, + "Packed elements must fit exactly into the element storage."); + static constexpr index_t vector_size = (packed_size * num_bits_elem) / num_bits_vec_elem; + + using storage_type = StaticallyIndexedArray_v2; + storage_type data; // packed data + + using type = f6_pk_t; + + __host__ __device__ constexpr f6_pk_t() : data{} {} + __host__ __device__ constexpr f6_pk_t(storage_type init) : data{init} {} + template ::vector_size == packed_size>> + __host__ __device__ f6_pk_t(const T& v) : data{} { - static_assert(I < 16, "Index out of range for 16 f6_t elements."); + static_for<0, packed_size, 1>{}( + [&](auto i) { pack(v[static_cast(i)], static_cast(i)); }); + } - constexpr int num_bits_elem = 6; - constexpr int num_bits_vec_elem = 32; - constexpr int vector_size = 3; - constexpr int bit_pos = I * num_bits_elem; - constexpr int arr_idx = bit_pos / num_bits_vec_elem; - constexpr int bit_offset = bit_pos % num_bits_vec_elem; - uint32_t bits = data.At(Number{}) >> bit_offset; - constexpr int overhang = bit_offset + num_bits_elem - num_bits_vec_elem; + template + __host__ __device__ void pack(const T x, const index_t i) + { + static_assert(is_integral::value || is_same_v, + "T must be an integral type."); - if constexpr(overhang > 0 && (arr_idx + 1) < vector_size) + uint32_t bits = static_cast(x) & 0x3F; + const int bit_pos = i * num_bits_elem; + const int arr_index = bit_pos / num_bits_vec_elem; + const int bit_offset = bit_pos % num_bits_vec_elem; + const int overhang = bit_offset + num_bits_elem - num_bits_vec_elem; + uint32_t old_value = data.data_[arr_index]; + + // insert bits into the current 32-bit block + old_value |= (bits << bit_offset); + data.data_[arr_index] = old_value; + + // if it crosses into the next block, shift the remainder + if(overhang > 0 && (arr_index + 1) < vector_size) { - bits |= (data.At(Number{}) & ((1u << overhang) - 1)) + uint32_t next_value = data.data_[arr_index + 1]; + next_value |= (bits >> (num_bits_elem - overhang)); + data.data_[arr_index + 1] = next_value; + } + } + + __host__ __device__ static inline BitType unpack(const type& pk, const index_t i) + { + const int bit_pos = i * num_bits_elem; + const int arr_idx = bit_pos / num_bits_vec_elem; + const int bit_offset = bit_pos % num_bits_vec_elem; + const int overhang = bit_offset + num_bits_elem - num_bits_vec_elem; + + uint32_t bits = pk.data.data_[arr_idx] >> bit_offset; + if(overhang > 0 && (arr_idx + 1) < vector_size) + { + bits |= (pk.data.data_[arr_idx + 1] & ((1u << overhang) - 1)) << (num_bits_elem - overhang); } - return static_cast(bits & 0x3F); + return static_cast(bits & 0x3F); } - __host__ __device__ inline type pack(const test_vec_t& x) - { - type packed{}; - - // for each of the 16 f6_t values, place its 6 bits in the correct position - ck::static_for<0, 16, 1>{}([&](auto i) { - uint32_t bits = static_cast(x[static_cast(i)]) & 0x3F; - constexpr int num_bits_elem = 6; - constexpr int num_bits_vec_elem = 32; - constexpr int vector_size = 3; - constexpr int bit_pos = i * num_bits_elem; - constexpr int arr_index = bit_pos / num_bits_vec_elem; - constexpr int bit_offset = bit_pos % num_bits_vec_elem; - constexpr int overhang = bit_offset + num_bits_elem - num_bits_vec_elem; - uint32_t old_value = packed.At(Number{}); - - // insert bits into the current 32-bit block - old_value |= (bits << bit_offset); - packed.At(Number{}) = old_value; - - // if it crosses into the next block, shift the remainder - if constexpr(overhang > 0 && (arr_index + 1) < vector_size) - { - uint32_t next_value = packed.At(Number{}); - next_value |= (bits >> (num_bits_elem - overhang)); - packed.At(Number{}) = next_value; - } - }); - - return packed; - } + __host__ __device__ inline BitType unpack(const index_t i) const { return unpack(*this, i); } }; -struct f6x32_pk_t -{ - // store 32 elements of f6_t in an array of 6 uint32_t - using element_type = uint32_t; - using type = StaticallyIndexedArray_v2; - type data; - typedef int8_t test_vec_t __attribute__((ext_vector_type(32))); - f6x32_pk_t() : data{type{}} {} - f6x32_pk_t(type init) : data{init} {} - - template - __host__ __device__ inline f6_t unpack(Number) - { - static_assert(I < 32, "Index out of range for 32 f6_t elements."); - - constexpr int num_bits_elem = 6; - constexpr int num_bits_vec_elem = 32; - constexpr int vector_size = 6; - constexpr int bit_pos = I * num_bits_elem; - constexpr int arr_idx = bit_pos / num_bits_vec_elem; - constexpr int bit_offset = bit_pos % num_bits_vec_elem; - uint32_t bits = data.At(Number{}) >> bit_offset; - constexpr int overhang = bit_offset + num_bits_elem - num_bits_vec_elem; - - if constexpr(overhang > 0 && (arr_idx + 1) < vector_size) - { - bits |= (data.At(Number{}) & ((1u << overhang) - 1)) - << (num_bits_elem - overhang); - } - - return static_cast(bits & 0x3F); - } - - __host__ __device__ inline type pack(const test_vec_t& x) - { - type packed{}; - - // for each of the 32 f6_t values, place its 6 bits in the correct position - ck::static_for<0, 32, 1>{}([&](auto i) { - uint32_t bits = static_cast(x[static_cast(i)]) & 0x3F; - constexpr int num_bits_elem = 6; - constexpr int num_bits_vec_elem = 32; - constexpr int vector_size = 6; - constexpr int bit_pos = i * num_bits_elem; - constexpr int arr_index = bit_pos / num_bits_vec_elem; - constexpr int bit_offset = bit_pos % num_bits_vec_elem; - constexpr int overhang = bit_offset + num_bits_elem - num_bits_vec_elem; - uint32_t old_value = packed.At(Number{}); - - // insert bits into the current 32-bit block - old_value |= (bits << bit_offset); - packed.At(Number{}) = old_value; - - // if it crosses into the next block, shift the remainder - if constexpr(overhang > 0 && (arr_index + 1) < vector_size) - { - uint32_t next_value = packed.At(Number{}); - next_value |= (bits >> (num_bits_elem - overhang)); - packed.At(Number{}) = next_value; - } - }); - - return packed; - } -}; - -struct bf6x16_pk_t -{ - // store 16 elements of bf6_t in an array of 3 uint32_t - using element_type = uint32_t; - using type = StaticallyIndexedArray_v2; - type data; - typedef int8_t test_vec_t __attribute__((ext_vector_type(16))); - bf6x16_pk_t() : data{type{}} {} - bf6x16_pk_t(type init) : data{init} {} - - template - __host__ __device__ inline bf6_t unpack(Number) - { - static_assert(I < 16, "Index out of range for 16 f6_t elements."); - - constexpr int num_bits_elem = 6; - constexpr int num_bits_vec_elem = 32; - constexpr int vector_size = 3; - constexpr int bit_pos = I * num_bits_elem; - constexpr int arr_idx = bit_pos / num_bits_vec_elem; - constexpr int bit_offset = bit_pos % num_bits_vec_elem; - uint32_t bits = data.At(Number{}) >> bit_offset; - constexpr int overhang = bit_offset + num_bits_elem - num_bits_vec_elem; - - if constexpr(overhang > 0 && (arr_idx + 1) < vector_size) - { - bits |= (data.At(Number{}) & ((1u << overhang) - 1)) - << (num_bits_elem - overhang); - } - - return static_cast(bits & 0x3F); - } - - __host__ __device__ inline type pack(const test_vec_t& x) - { - type packed{}; - - // for each of the 16 bf6_t values, place its 6 bits in the correct position - ck::static_for<0, 16, 1>{}([&](auto i) { - uint32_t bits = static_cast(x[static_cast(i)]) & 0x3F; - constexpr int num_bits_elem = 6; - constexpr int num_bits_vec_elem = 32; - constexpr int vector_size = 3; - constexpr int bit_pos = i * num_bits_elem; - constexpr int arr_index = bit_pos / num_bits_vec_elem; - constexpr int bit_offset = bit_pos % num_bits_vec_elem; - constexpr int overhang = bit_offset + num_bits_elem - num_bits_vec_elem; - uint32_t old_value = packed.At(Number{}); - - // insert bits into the current 32-bit block - old_value |= (bits << bit_offset); - packed.At(Number{}) = old_value; - - // if it crosses into the next block, shift the remainder - if constexpr(overhang > 0 && (arr_index + 1) < vector_size) - { - uint32_t next_value = packed.At(Number{}); - next_value |= (bits >> (num_bits_elem - overhang)); - packed.At(Number{}) = next_value; - } - }); - - return packed; - } -}; - -struct bf6x32_pk_t -{ - // store 32 elements of bf6_t in an array of 6 uint32_t - using element_type = uint32_t; - using type = StaticallyIndexedArray_v2; - type data; - typedef int8_t test_vec_t __attribute__((ext_vector_type(32))); - bf6x32_pk_t() : data{type{}} {} - bf6x32_pk_t(type init) : data{init} {} - - template - __host__ __device__ inline bf6_t unpack(Number) - { - static_assert(I < 32, "Index out of range for 32 f6_t elements."); - - constexpr int num_bits_elem = 6; - constexpr int num_bits_vec_elem = 32; - constexpr int vector_size = 6; - constexpr int bit_pos = I * num_bits_elem; - constexpr int arr_idx = bit_pos / num_bits_vec_elem; - constexpr int bit_offset = bit_pos % num_bits_vec_elem; - uint32_t bits = data.At(Number{}) >> bit_offset; - constexpr int overhang = bit_offset + num_bits_elem - num_bits_vec_elem; - - if constexpr(overhang > 0 && (arr_idx + 1) < vector_size) - { - bits |= (data.At(Number{}) & ((1u << overhang) - 1)) - << (num_bits_elem - overhang); - } - - return static_cast(bits & 0x3F); - } - - __host__ __device__ inline type pack(const test_vec_t& x) - { - type packed{}; - - // for each of the 32 bf6_t values, place its 6 bits in the correct position - ck::static_for<0, 32, 1>{}([&](auto i) { - uint32_t bits = static_cast(x[static_cast(i)]) & 0x3F; - constexpr int num_bits_elem = 6; - constexpr int num_bits_vec_elem = 32; - constexpr int vector_size = 6; - constexpr int bit_pos = i * num_bits_elem; - constexpr int arr_index = bit_pos / num_bits_vec_elem; - constexpr int bit_offset = bit_pos % num_bits_vec_elem; - constexpr int overhang = bit_offset + num_bits_elem - num_bits_vec_elem; - uint32_t old_value = packed.At(Number{}); - - // insert bits into the current 32-bit block - old_value |= (bits << bit_offset); - packed.At(Number{}) = old_value; - - // if it crosses into the next block, shift the remainder - if constexpr(overhang > 0 && (arr_index + 1) < vector_size) - { - uint32_t next_value = packed.At(Number{}); - next_value |= (bits >> (num_bits_elem - overhang)); - packed.At(Number{}) = next_value; - } - }); - - return packed; - } -}; +using f6x16_pk_t = f6_pk_t; +using f6x32_pk_t = f6_pk_t; +using bf6x16_pk_t = f6_pk_t; +using bf6x32_pk_t = f6_pk_t; // custom data type - pack int4 data struct pk_i4_t @@ -335,15 +154,14 @@ inline constexpr auto next_pow2(uint32_t x) } // native types: double, float, _Float16, ushort, int32_t, int8_t, uint8_t, f8_fnuz_t, bf8_fnuz_t, -// native types: bool, f4_t, f6_t, bf6_t +// native types: bool template inline constexpr bool is_native_type() { return is_same::value || is_same::value || is_same::value || - is_same::value || is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value || is_same::value || is_same::value || - is_same::value || is_same::value; + is_same::value || is_same::value || + is_same::value || is_same::value || is_same::value || + is_same::value || is_same::value || is_same::value; } // scalar_type @@ -484,6 +302,106 @@ struct scalar_type static constexpr index_t vector_size = 1; }; +// Default behavior for types that do not need special handling +template +struct packed_type +{ + using type = T; + static constexpr index_t packed_size = 1; // number of packed elements +}; + +template <> +struct packed_type +{ + using type = pk_i4_t; + static constexpr index_t packed_size = 2; // number of packed elements +}; + +template <> +struct packed_type +{ + using type = f4x2_pk_t; + static constexpr index_t packed_size = 2; // number of packed elements +}; + +template <> +struct packed_type +{ + using type = f6x32_pk_t; + static constexpr index_t packed_size = f6x32_pk_t::packed_size; // number of packed elements +}; + +template <> +struct packed_type +{ + using type = bf6x32_pk_t; + static constexpr index_t packed_size = bf6x32_pk_t::packed_size; // number of packed elements +}; + +template +using packed_type_t = typename packed_type::type; + +// Check if the type has packed type specialization +template +inline constexpr bool has_packed_type_v = !is_same_v, T>; + +template +struct element_type +{ + private: + static constexpr auto get_element_type() + { + using U = remove_cvref_t; + if constexpr(is_same_v) + return int4_t{}; + else if constexpr(is_same_v) + return f4_t{}; + else if constexpr(is_same_v) + return f6_t{}; + else if constexpr(is_same_v) + return bf6_t{}; + else if constexpr(is_same_v) + return f6_t{}; + else if constexpr(is_same_v) + return bf6_t{}; + else + return T{}; + } + + public: + using type = decltype(get_element_type()); +}; +template +using element_type_t = typename element_type::type; + +template +inline constexpr bool is_packed_type_v = + has_packed_type_v>&& is_same_v>>; + +template +struct packed_size +{ + private: + static constexpr auto get_packed_size() + { + using U = remove_cvref_t; + if constexpr(is_packed_type_v) + return Number>::packed_size>{}; + else + return Number::packed_size>{}; + } + + public: + using type = decltype(get_packed_size()); + static constexpr auto value = get_packed_size(); +}; + +template +using packed_size_t = typename packed_size::type; + +template +inline constexpr index_t packed_size_v = packed_size::value; + #if defined(_WIN32) using int64_t = long long; #else diff --git a/include/ck/utility/dtype_vector.hpp b/include/ck/utility/dtype_vector.hpp index 9c40d923d3..65eed0624c 100644 --- a/include/ck/utility/dtype_vector.hpp +++ b/include/ck/utility/dtype_vector.hpp @@ -365,6 +365,88 @@ struct vector_type()>> } }; +template +struct vector_type()>> +{ + using d1_t = T; + typedef T d2_t __attribute__((ext_vector_type(2))); + typedef T d3_t __attribute__((ext_vector_type(3))); + typedef T d6_t __attribute__((ext_vector_type(6))); + + using type = d6_t; + + union + { + d6_t d6_; + StaticallyIndexedArray d1x6_; + StaticallyIndexedArray d2x3_; + StaticallyIndexedArray d3x2_; + StaticallyIndexedArray d6x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{0}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x6_; + } + else if constexpr(is_same::value) + { + return data_.d2x3_; + } + else if constexpr(is_same::value) + { + return data_.d3x2_; + } + else if constexpr(is_same::value) + { + return data_.d6x1_; + } + else + { + return err; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x6_; + } + else if constexpr(is_same::value) + { + return data_.d2x3_; + } + else if constexpr(is_same::value) + { + return data_.d3x2_; + } + else if constexpr(is_same::value) + { + return data_.d6x1_; + } + else + { + return err; + } + } +}; + template struct vector_type()>> { @@ -1221,25 +1303,25 @@ struct nnvb_data_t_selector template <> struct nnvb_data_t_selector { - using type = f6x16_pk_t::type; + using type = f6x16_pk_t::storage_type; }; template <> struct nnvb_data_t_selector { - using type = f6x32_pk_t::type; + using type = f6x32_pk_t::storage_type; }; template <> struct nnvb_data_t_selector { - using type = bf6x16_pk_t::type; + using type = bf6x16_pk_t::storage_type; }; template <> struct nnvb_data_t_selector { - using type = bf6x32_pk_t::type; + using type = bf6x32_pk_t::storage_type; }; template <> @@ -1406,12 +1488,23 @@ struct non_native_vector_base -struct scalar_type> +struct scalar_type>> { using type = typename non_native_vector_base::data_t; static constexpr index_t vector_size = N; }; +template +struct scalar_type< + non_native_vector_base>> +{ + using type = typename non_native_vector_base::element_t; + static constexpr index_t vector_size = N * non_native_vector_base::size_factor; +}; + // non-native vector_type implementation template struct vector_type()>> @@ -2025,6 +2118,7 @@ using bhalf32_t = typename vector_type::type; // i32 using int32x2_t = typename vector_type::type; using int32x4_t = typename vector_type::type; +using int32x6_t = typename vector_type::type; using int32x8_t = typename vector_type::type; using int32x16_t = typename vector_type::type; using int32x32_t = typename vector_type::type; diff --git a/include/ck/utility/mxf4_utils.hpp b/include/ck/utility/mxf4_utils.hpp index b0b5297f77..53edb6e182 100644 --- a/include/ck/utility/mxf4_utils.hpp +++ b/include/ck/utility/mxf4_utils.hpp @@ -66,7 +66,7 @@ __host__ __device__ inline f4_t sat_convert_to_type(float value) : NumericUtils::data_max_positive_normal_mask; } - if(std::abs(value) > NumericLimits::Max()) // covers inf case as well + if(std::abs(value) > NumericLimits::DataMaxNorm()) // covers inf case as well return sign ? NumericUtils::data_max_negative_normal_mask : NumericUtils::data_max_positive_normal_mask; @@ -74,8 +74,8 @@ __host__ __device__ inline f4_t sat_convert_to_type(float value) if(std::abs(to_float(NumericLimits::Binary_1(), res)) < NumericLimits::DataMinSubnorm()) - return value < 0 ? NumericUtils::negative_zero_mask - : NumericUtils::positive_zero_mask; + return sign ? NumericUtils::negative_zero_mask + : NumericUtils::positive_zero_mask; return res; } @@ -91,7 +91,7 @@ __host__ __device__ inline f4_t sat_convert_to_type_sr(float value, uint32 return sign ? NumericUtils::data_max_negative_normal_mask : NumericUtils::data_max_positive_normal_mask; - if(std::abs(value) > NumericLimits::Max()) // covers inf case as well + if(std::abs(value) > NumericLimits::DataMaxNorm()) // covers inf case as well return sign ? NumericUtils::data_max_negative_normal_mask : NumericUtils::data_max_positive_normal_mask; @@ -99,8 +99,8 @@ __host__ __device__ inline f4_t sat_convert_to_type_sr(float value, uint32 if(std::abs(to_float(NumericLimits::Binary_1(), res)) < NumericLimits::DataMinSubnorm()) - return value < 0 ? NumericUtils::negative_zero_mask - : NumericUtils::positive_zero_mask; + return sign ? NumericUtils::negative_zero_mask + : NumericUtils::positive_zero_mask; return res; } diff --git a/include/ck/utility/mxf6_utils.hpp b/include/ck/utility/mxf6_utils.hpp index cf68188b3e..a840c520a9 100644 --- a/include/ck/utility/mxf6_utils.hpp +++ b/include/ck/utility/mxf6_utils.hpp @@ -201,7 +201,7 @@ __host__ __device__ inline f6_t sat_convert_to_type(float value) : NumericUtils::data_max_positive_normal_mask; } - if(std::abs(value) > NumericLimits::Max()) // covers inf case as well + if(std::abs(value) > NumericLimits::DataMaxNorm()) // covers inf case as well return sign ? NumericUtils::data_max_negative_normal_mask : NumericUtils::data_max_positive_normal_mask; @@ -239,7 +239,7 @@ __host__ __device__ inline bf6_t sat_convert_to_type(float value) : NumericUtils::data_max_positive_normal_mask; } - if(std::abs(value) > NumericLimits::Max()) // covers inf case as well + if(std::abs(value) > NumericLimits::DataMaxNorm()) // covers inf case as well return sign ? NumericUtils::data_max_negative_normal_mask : NumericUtils::data_max_positive_normal_mask; @@ -274,7 +274,7 @@ __host__ __device__ inline f6_t sat_convert_to_type_sr(float value, uint32 return sign ? NumericUtils::data_max_negative_normal_mask : NumericUtils::data_max_positive_normal_mask; - if(std::abs(value) > NumericLimits::Max()) // covers inf case as well + if(std::abs(value) > NumericLimits::DataMaxNorm()) // covers inf case as well return sign ? NumericUtils::data_max_negative_normal_mask : NumericUtils::data_max_positive_normal_mask; @@ -308,7 +308,7 @@ __host__ __device__ inline bf6_t sat_convert_to_type_sr(float value, uint if(std::isnan(value)) return sign ? NumericUtils::data_max_negative_normal_mask : NumericUtils::data_max_positive_normal_mask; - if(std::abs(value) > NumericLimits::Max()) // covers inf case as well + if(std::abs(value) > NumericLimits::DataMaxNorm()) // covers inf case as well return sign ? NumericUtils::data_max_negative_normal_mask : NumericUtils::data_max_positive_normal_mask; diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp index c8d284a1d7..ed07e53e6d 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp @@ -89,6 +89,14 @@ struct ReferenceGemm : public device::BaseOperator v_a = type_convert( f4_t(arg.a_m_k_(m, k).template unpack<>(Number<0>{}))); } + else if constexpr(is_same_v || + is_same_v || + is_same_v || + is_same_v) + { + v_a = type_convert( + arg.a_m_k_(m, k).unpack(k % ADataType::packed_size)); + } else { arg.a_element_op_(v_a, arg.a_m_k_(m, k)); @@ -115,6 +123,14 @@ struct ReferenceGemm : public device::BaseOperator v_b = type_convert( f4_t(arg.b_k_n_(k, n).template unpack<>(Number<0>{}))); } + else if constexpr(is_same_v || + is_same_v || + is_same_v || + is_same_v) + { + v_b = type_convert( + arg.b_k_n_(k, n).unpack(k % BDataType::packed_size)); + } else { arg.b_element_op_(v_b, arg.b_k_n_(k, n)); diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_mx_gemm.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_mx_gemm.hpp index e8fdcf1acd..3fc39911dd 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_mx_gemm.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_mx_gemm.hpp @@ -105,6 +105,16 @@ struct ReferenceMXGemm : public device::BaseOperator type_convert( arg.a_m_kblock_scales_(m, k / SCALE_BLOCK)); } + else if constexpr(is_same_v || + is_same_v || + is_same_v || + is_same_v) + { + a_m_k_scaled(m, k) = + type_convert( + arg.a_m_k_(m, k).unpack(k % ADataType::packed_size)) * + type_convert(arg.a_m_kblock_scales_(m, k / SCALE_BLOCK)); + } else { a_m_k_scaled(m, k) = @@ -134,6 +144,16 @@ struct ReferenceMXGemm : public device::BaseOperator type_convert( arg.b_kblock_n_scales_(k / SCALE_BLOCK, n)); } + else if constexpr(is_same_v || + is_same_v || + is_same_v || + is_same_v) + { + b_k_n_scaled(k, n) = + type_convert( + arg.b_k_n_(k, n).unpack(k % BDataType::packed_size)) * + type_convert(arg.b_kblock_n_scales_(k / SCALE_BLOCK, n)); + } else { b_k_n_scaled(k, n) = diff --git a/test/data_type/test_bf6.cpp b/test/data_type/test_bf6.cpp index a260f81d16..9dbb77454c 100644 --- a/test/data_type/test_bf6.cpp +++ b/test/data_type/test_bf6.cpp @@ -4,6 +4,7 @@ #include "gtest/gtest.h" #include "ck/utility/data_type.hpp" #include "ck/utility/type_convert.hpp" +#include "ck/utility/env.hpp" #include "ck/utility/scaled_type_convert.hpp" using ck::bf6_convert_rne; @@ -41,6 +42,11 @@ TEST(BF6, ConvertFP32Nearest) ASSERT_NEAR(max_bf6, type_convert(bf6_convert_rne(std::numeric_limits::infinity())), 0.0f); + + // convert float +/-30 to bf6 and back, check if clipped to +/-max_bf6 + ASSERT_NEAR(-max_bf6, type_convert(bf6_convert_rne(-30.0f)), 0.0f); + ASSERT_NEAR(max_bf6, type_convert(bf6_convert_rne(30.0f)), 0.0f); + // convert float value less than bf6 subnorm to bf6 and back, check if equal to 0.0 float less_than_subnorm = 0.03125f; ASSERT_NEAR(0.0f, type_convert(bf6_convert_rne(less_than_subnorm)), 0.0f); @@ -266,21 +272,18 @@ TEST(BF6, TestAsType16x1) vector_type right_vec; // check default CTOR ck::static_for<0, packed_size, 1>{}([&](auto i) { - ASSERT_EQ( - right_vec.template AsType()(Number<0>{}).template unpack<>(Number{}), - 0); + ASSERT_EQ(right_vec.template AsType()(Number<0>{}).unpack(i), 0); }); // assign test values to the vector ck::static_for<0, vector_size, 1>{}([&](auto i) { - right_vec.template AsType()(Number{}) = bf6x16_pk_t{}.pack(test_vec); + right_vec.template AsType()(Number{}) = bf6x16_pk_t{test_vec}; }); // copy the vector vector_type left_vec{right_vec}; // check if values were copied correctly ck::static_for<0, packed_size, 1>{}([&](auto i) { - ASSERT_EQ( - left_vec.template AsType()(Number<0>{}).template unpack<>(Number{}), - static_cast(test_vec[static_cast(i)])); + ASSERT_EQ(left_vec.template AsType()(Number<0>{}).unpack(i), + static_cast(test_vec[static_cast(i)])); }); } @@ -329,23 +332,23 @@ TEST(BF6, TestAsType16x2) // check default CTOR ck::static_for<0, vector_size, 1>{}([&](auto idx_vector) { ck::static_for<0, packed_size, 1>{}([&](auto idx_element) { - ASSERT_EQ(right_vec.template AsType()(Number{}) - .template unpack<>(Number{}), - 0); + ASSERT_EQ( + right_vec.template AsType()(Number{}).unpack(idx_element), + 0); }); }); // assign test values to the vector ck::static_for<0, vector_size, 1>{}([&](auto i) { - right_vec.template AsType()(Number{}) = bf6x16_pk_t{}.pack(test_vec[i]); + right_vec.template AsType()(Number{}) = bf6x16_pk_t{test_vec[i]}; }); // copy the vector vector_type left_vec{right_vec}; // check if values were copied correctly ck::static_for<0, vector_size, 1>{}([&](auto idx_vector) { ck::static_for<0, packed_size, 1>{}([&](auto idx_element) { - ASSERT_EQ(left_vec.template AsType()(Number{}) - .template unpack<>(Number{}), - static_cast(test_vec[idx_vector][static_cast(idx_element)])); + ASSERT_EQ( + left_vec.template AsType()(Number{}).unpack(idx_element), + static_cast(test_vec[idx_vector][static_cast(idx_element)])); }); }); } @@ -369,20 +372,86 @@ TEST(BF6, TestAsType32x1) vector_type right_vec; // check default CTOR ck::static_for<0, packed_size, 1>{}([&](auto i) { - ASSERT_EQ( - right_vec.template AsType()(Number<0>{}).template unpack<>(Number{}), - 0); + ASSERT_EQ(right_vec.template AsType()(Number<0>{}).unpack(i), 0); }); // assign test values to the vector ck::static_for<0, vector_size, 1>{}([&](auto i) { - right_vec.template AsType()(Number{}) = bf6x32_pk_t{}.pack(test_vec); + right_vec.template AsType()(Number{}) = bf6x32_pk_t{test_vec}; }); // copy the vector vector_type left_vec{right_vec}; // check if values were copied correctly ck::static_for<0, packed_size, 1>{}([&](auto i) { - ASSERT_EQ( - left_vec.template AsType()(Number<0>{}).template unpack<>(Number{}), - static_cast(test_vec[static_cast(i)])); + ASSERT_EQ(left_vec.template AsType()(Number<0>{}).unpack(i), + static_cast(test_vec[static_cast(i)])); + }); +} + +TEST(BF6, TestAllValues) +{ + + constexpr std::array e3m2ValuesOCP = { + // clang-format off + 0.0000000000, 0.0625000000, 0.1250000000, 0.1875000000, + 0.2500000000, 0.3125000000, 0.3750000000, 0.4375000000, + 0.5000000000, 0.6250000000, 0.7500000000, 0.8750000000, + 1.0000000000, 1.2500000000, 1.5000000000, 1.7500000000, + 2.0000000000, 2.5000000000, 3.0000000000, 3.5000000000, + 4.0000000000, 5.0000000000, 6.0000000000, 7.0000000000, + 8.0000000000, 10.0000000000, 12.0000000000, 14.0000000000, + 16.0000000000, 20.0000000000, 24.0000000000, 28.0000000000, + -0.0000000000, -0.0625000000, -0.1250000000, -0.1875000000, + -0.2500000000, -0.3125000000, -0.3750000000, -0.4375000000, + -0.5000000000, -0.6250000000, -0.7500000000, -0.8750000000, + -1.0000000000, -1.2500000000, -1.5000000000, -1.7500000000, + -2.0000000000, -2.5000000000, -3.0000000000, -3.5000000000, + -4.0000000000, -5.0000000000, -6.0000000000, -7.0000000000, + -8.0000000000, -10.0000000000, -12.0000000000, -14.0000000000, + -16.0000000000, -20.0000000000, -24.0000000000, -28.0000000000 + // clang-format on + }; + + constexpr uint8_t e3m2BitsOCP[] = { + // clang-format off + 0b000000, 0b000001, 0b000010, 0b000011, + 0b000100, 0b000101, 0b000110, 0b000111, + 0b001000, 0b001001, 0b001010, 0b001011, + 0b001100, 0b001101, 0b001110, 0b001111, + 0b010000, 0b010001, 0b010010, 0b010011, + 0b010100, 0b010101, 0b010110, 0b010111, + 0b011000, 0b011001, 0b011010, 0b011011, + 0b011100, 0b011101, 0b011110, 0b011111, + 0b100000, 0b100001, 0b100010, 0b100011, + 0b100100, 0b100101, 0b100110, 0b100111, + 0b101000, 0b101001, 0b101010, 0b101011, + 0b101100, 0b101101, 0b101110, 0b101111, + 0b110000, 0b110001, 0b110010, 0b110011, + 0b110100, 0b110101, 0b110110, 0b110111, + 0b111000, 0b111001, 0b111010, 0b111011, + 0b111100, 0b111101, 0b111110, 0b111111 + // clang-format on + }; + + const bool ck_logging = ck::EnvIsEnabled(CK_ENV(CK_LOGGING)); + + if(ck_logging) + printf("BF6 Table\n"); + ck::static_for<0, 64, 1>{}([&](auto i) { + float fp = type_convert(bf6_t(e3m2BitsOCP[i])); + ASSERT_EQ(fp, e3m2ValuesOCP[i]); + + bf6_t bf6 = type_convert(e3m2ValuesOCP[i]); + ASSERT_EQ(bf6 & 0x3F, e3m2BitsOCP[i] & 0x3F); + + if(ck_logging) + { + // Print the binary representation + printf("Bits: 0b"); + for(int j = 5; j >= 0; --j) + { + printf("%c", (e3m2BitsOCP[i] & (1 << j)) ? '1' : '0'); + } + printf(", 0x%02X, Value: %f\n", e3m2BitsOCP[i], e3m2ValuesOCP[i]); + } }); } diff --git a/test/data_type/test_fp4.cpp b/test/data_type/test_fp4.cpp index f4b2bf3358..3fc74a2ef3 100644 --- a/test/data_type/test_fp4.cpp +++ b/test/data_type/test_fp4.cpp @@ -5,6 +5,7 @@ #include "ck/utility/data_type.hpp" #include "ck/utility/type_convert.hpp" #include "ck/utility/scaled_type_convert.hpp" +#include "ck/utility/env.hpp" using ck::e8m0_bexp_t; using ck::f4_convert_rne; @@ -38,6 +39,11 @@ TEST(FP4, ConvertFP32Nearest) // convert maximal float to fp4 and back, check if clipped to 6.0 ASSERT_NEAR( max_fp4, type_convert(f4_convert_rne(std::numeric_limits::max())), abs_tol); + + // convert +/-7.0 to fp4 and back, check if clipped to +/-6.0 + ASSERT_NEAR(-max_fp4, type_convert(f4_convert_rne(-7.0f)), 0.0); + ASSERT_NEAR(max_fp4, type_convert(f4_convert_rne(7.0f)), 0.0); + // positive norm float value to fp4 and back, check if holds float pos_float = 1.0f; ASSERT_NEAR(pos_float, type_convert(f4_convert_rne(pos_float)), abs_tol); @@ -468,3 +474,54 @@ TEST(FP4, TestAsType32) test_vec.at(i + 1)); }); } + +TEST(FP4, TestAllValues) +{ + constexpr std::array e2m1ValuesOCP = { + // clang-format off + 0.0000000000, 0.5000000000, + 1.0000000000, 1.5000000000, + 2.0000000000, 3.0000000000, + 4.0000000000, 6.0000000000, + -0.0000000000, -0.5000000000, + -1.0000000000, -1.5000000000, + -2.0000000000, -3.0000000000, + -4.0000000000, -6.0000000000 + // clang-format on + }; + + constexpr uint8_t e2m1BitsOCP[] = { + // clang-format off + 0b0000, 0b0001, + 0b0010, 0b0011, + 0b0100, 0b0101, + 0b0110, 0b0111, + 0b1000, 0b1001, + 0b1010, 0b1011, + 0b1100, 0b1101, + 0b1110, 0b1111 + // clang-format on + }; + + const bool ck_logging = ck::EnvIsEnabled(CK_ENV(CK_LOGGING)); + + if(ck_logging) + printf("FP4 Table\n"); + ck::static_for<0, 16, 1>{}([&](auto i) { + float fp = type_convert(f4_t(e2m1BitsOCP[i])); + ASSERT_EQ(fp, e2m1ValuesOCP[i]); + + f4_t fp4 = type_convert(e2m1ValuesOCP[i]); + ASSERT_EQ(fp4 & 0xF, e2m1BitsOCP[i] & 0xF); + if(ck_logging) + { + // Print the binary representation + printf("Bits: 0b"); + for(int j = 3; j >= 0; --j) + { + printf("%c", (e2m1BitsOCP[i] & (1 << j)) ? '1' : '0'); + } + printf(", 0x%02X, Value: %f\n", e2m1BitsOCP[i], e2m1ValuesOCP[i]); + } + }); +} diff --git a/test/data_type/test_fp6.cpp b/test/data_type/test_fp6.cpp index cf91e69db3..6d4aec1d9a 100644 --- a/test/data_type/test_fp6.cpp +++ b/test/data_type/test_fp6.cpp @@ -4,6 +4,7 @@ #include "gtest/gtest.h" #include "ck/utility/data_type.hpp" #include "ck/utility/type_convert.hpp" +#include "ck/utility/env.hpp" #include "ck/utility/scaled_type_convert.hpp" using ck::e8m0_bexp_t; @@ -34,6 +35,11 @@ TEST(FP6, ConvertFP32Nearest) ASSERT_NEAR(0.0f, type_convert(f6_convert_rne(0.0f)), 0.0f); // convert maximal f6_t to float and check if equal to max_fp6 ASSERT_NEAR(max_fp6, type_convert(f6_convert_rne(max_fp6)), 0.0f); + + // convert maximal +/-8.0 to fp6 and check if equal to +/-max_fp6 + ASSERT_NEAR(-max_fp6, type_convert(f6_convert_rne(-8.0f)), 0.0f); + ASSERT_NEAR(max_fp6, type_convert(f6_convert_rne(8.0f)), 0.0f); + // convert maximal float to fp6 and back, check if clipped to max_fp6 ASSERT_NEAR( max_fp6, type_convert(f6_convert_rne(std::numeric_limits::max())), 0.0f); @@ -265,20 +271,24 @@ TEST(FP6, TestAsType16x1) vector_type right_vec; // check default CTOR ck::static_for<0, packed_size, 1>{}([&](auto i) { - ASSERT_EQ( - right_vec.template AsType()(Number<0>{}).template unpack<>(Number{}), 0); + ASSERT_EQ(right_vec.template AsType()(Number<0>{}).unpack(i), 0); }); // assign test values to the vector ck::static_for<0, vector_size, 1>{}([&](auto i) { - right_vec.template AsType()(Number{}) = f6x16_pk_t{}.pack(test_vec); + right_vec.template AsType()(Number{}) = f6x16_pk_t{test_vec}; }); + // copy the vector vector_type left_vec{right_vec}; // check if values were copied correctly ck::static_for<0, packed_size, 1>{}([&](auto i) { - ASSERT_EQ( - left_vec.template AsType()(Number<0>{}).template unpack<>(Number{}), - static_cast(test_vec[static_cast(i)])); + ASSERT_EQ(left_vec.template AsType()(Number<0>{}).unpack(i), + static_cast(test_vec[static_cast(i)])) + << " i = " << i << "; left = " + << type_convert(left_vec.template AsType()(Number<0>{}).unpack(i)) + << " -- right = " + << type_convert(static_cast(test_vec[static_cast(i)])) << " (" + << static_cast(test_vec[static_cast(i)]) << ")" << std::endl; }); } @@ -327,23 +337,23 @@ TEST(FP6, TestAsType16x2) // check default CTOR ck::static_for<0, vector_size, 1>{}([&](auto idx_vector) { ck::static_for<0, packed_size, 1>{}([&](auto idx_element) { - ASSERT_EQ(right_vec.template AsType()(Number{}) - .template unpack<>(Number{}), - 0); + ASSERT_EQ( + right_vec.template AsType()(Number{}).unpack(idx_element), + 0); }); }); // assign test values to the vector ck::static_for<0, vector_size, 1>{}([&](auto i) { - right_vec.template AsType()(Number{}) = f6x16_pk_t{}.pack(test_vec[i]); + right_vec.template AsType()(Number{}) = f6x16_pk_t{test_vec[i]}; }); // copy the vector vector_type left_vec{right_vec}; // check if values were copied correctly ck::static_for<0, vector_size, 1>{}([&](auto idx_vector) { ck::static_for<0, packed_size, 1>{}([&](auto idx_element) { - ASSERT_EQ(left_vec.template AsType()(Number{}) - .template unpack<>(Number{}), - static_cast(test_vec[idx_vector][static_cast(idx_element)])); + ASSERT_EQ( + left_vec.template AsType()(Number{}).unpack(idx_element), + static_cast(test_vec[idx_vector][static_cast(idx_element)])); }); }); } @@ -367,19 +377,77 @@ TEST(FP6, TestAsType32x1) vector_type right_vec; // check default CTOR ck::static_for<0, packed_size, 1>{}([&](auto i) { - ASSERT_EQ( - right_vec.template AsType()(Number<0>{}).template unpack<>(Number{}), 0); + ASSERT_EQ(right_vec.template AsType()(Number<0>{}).unpack(i), 0); }); // assign test values to the vector ck::static_for<0, vector_size, 1>{}([&](auto i) { - right_vec.template AsType()(Number{}) = f6x32_pk_t{}.pack(test_vec); + right_vec.template AsType()(Number{}) = f6x32_pk_t{test_vec}; }); // copy the vector vector_type left_vec{right_vec}; // check if values were copied correctly ck::static_for<0, packed_size, 1>{}([&](auto i) { - ASSERT_EQ( - left_vec.template AsType()(Number<0>{}).template unpack<>(Number{}), - static_cast(test_vec[static_cast(i)])); + ASSERT_EQ(left_vec.template AsType()(Number<0>{}).unpack(i), + static_cast(test_vec[static_cast(i)])); + }); +} + +TEST(FP6, TestAllValues) +{ + constexpr std::array e2m3ValuesOCP = { + // clang-format off + 0.0000000000, 0.1250000000, 0.2500000000, 0.3750000000, 0.5000000000, 0.6250000000, 0.7500000000, 0.8750000000, + 1.0000000000, 1.1250000000, 1.2500000000, 1.3750000000, 1.5000000000, 1.6250000000, 1.7500000000, 1.8750000000, + 2.0000000000, 2.2500000000, 2.5000000000, 2.7500000000, 3.0000000000, 3.2500000000, 3.5000000000, 3.7500000000, + 4.0000000000, 4.5000000000, 5.0000000000, 5.5000000000, 6.0000000000, 6.5000000000, 7.0000000000, 7.5000000000, + -0.0000000000, -0.1250000000, -0.2500000000, -0.3750000000, -0.5000000000, -0.6250000000, -0.7500000000, -0.8750000000, + -1.0000000000, -1.1250000000, -1.2500000000, -1.3750000000, -1.5000000000, -1.6250000000, -1.7500000000, -1.8750000000, + -2.0000000000, -2.2500000000, -2.5000000000, -2.7500000000, -3.0000000000, -3.2500000000, -3.5000000000, -3.7500000000, + -4.0000000000, -4.5000000000, -5.0000000000, -5.5000000000, -6.0000000000, -6.5000000000, -7.0000000000, -7.5000000000 + // clang-format on + }; + + constexpr uint8_t e2m3BitsOCP[] = { + // clang-format off + 0b000000, 0b000001, 0b000010, 0b000011, + 0b000100, 0b000101, 0b000110, 0b000111, + 0b001000, 0b001001, 0b001010, 0b001011, + 0b001100, 0b001101, 0b001110, 0b001111, + 0b010000, 0b010001, 0b010010, 0b010011, + 0b010100, 0b010101, 0b010110, 0b010111, + 0b011000, 0b011001, 0b011010, 0b011011, + 0b011100, 0b011101, 0b011110, 0b011111, + 0b100000, 0b100001, 0b100010, 0b100011, + 0b100100, 0b100101, 0b100110, 0b100111, + 0b101000, 0b101001, 0b101010, 0b101011, + 0b101100, 0b101101, 0b101110, 0b101111, + 0b110000, 0b110001, 0b110010, 0b110011, + 0b110100, 0b110101, 0b110110, 0b110111, + 0b111000, 0b111001, 0b111010, 0b111011, + 0b111100, 0b111101, 0b111110, 0b111111 + // clang-format on + }; + + const bool ck_logging = ck::EnvIsEnabled(CK_ENV(CK_LOGGING)); + + if(ck_logging) + printf("FP6 Table\n"); + ck::static_for<0, 64, 1>{}([&](auto i) { + float fp = type_convert(f6_t(e2m3BitsOCP[i])); + ASSERT_EQ(fp, e2m3ValuesOCP[i]); + + f6_t fp6 = type_convert(e2m3ValuesOCP[i]); + ASSERT_EQ(fp6 & 0x3F, e2m3BitsOCP[i] & 0x3F); + + if(ck_logging) + { + // Print the binary representation + printf("Bits: 0b"); + for(int j = 5; j >= 0; --j) + { + printf("%c", (e2m3BitsOCP[i] & (1 << j)) ? '1' : '0'); + } + printf(", 0x%02X, Value: %f\n", e2m3BitsOCP[i], e2m3ValuesOCP[i]); + } }); } diff --git a/test/mx_mfma_op/mx_mfma_op.cpp b/test/mx_mfma_op/mx_mfma_op.cpp index fddb8288a6..5e2aedd35e 100644 --- a/test/mx_mfma_op/mx_mfma_op.cpp +++ b/test/mx_mfma_op/mx_mfma_op.cpp @@ -5,9 +5,12 @@ #include "mx_mfma_op.hpp" +using ck::bf6_t; +using ck::bf8_t; using ck::e8m0_bexp_t; using ck::f4_t; using ck::f4x2_pk_t; +using ck::f6_t; using ck::f8_t; using ck::half_t; using ck::type_convert; @@ -17,13 +20,15 @@ using ck::type_convert; * * @param init - selects initialization algorithm for A and B tensors */ -template -bool run_mfma_km_kn_nm_test(ck::index_t init) +template +bool run_mfma_test(ck::index_t init) { - using ALayout = ck::tensor_layout::gemm::ColumnMajor; - using BLayout = ck::tensor_layout::gemm::ColumnMajor; - using CLayout = ck::tensor_layout::gemm::ColumnMajor; - using AccType = float; // only MFMA_F32 instructions supported using CPUAccType = AccType; @@ -53,74 +58,153 @@ bool run_mfma_km_kn_nm_test(ck::index_t init) return pass; } +const ck::index_t common_init = -4; // set to "< 0" for test-specific initializations + TEST(MFMA, FP8MFMA16x16x128) { - auto AB_init = 5; - auto pass = run_mfma_km_kn_nm_test(AB_init); + using ALayout = ck::tensor_layout::gemm::ColumnMajor; + using BLayout = ck::tensor_layout::gemm::ColumnMajor; + using CLayout = ck::tensor_layout::gemm::ColumnMajor; + auto AB_init = (common_init < 0) ? 5 : common_init; + auto pass = run_mfma_test(AB_init); EXPECT_TRUE(pass); } -TEST(MFMA, FP8MFMA32x32x64) +TEST(MFMA, BF8MFMA16x16x128) { - auto AB_init = 5; - auto pass = run_mfma_km_kn_nm_test(AB_init); + using ALayout = ck::tensor_layout::gemm::ColumnMajor; + using BLayout = ck::tensor_layout::gemm::ColumnMajor; + using CLayout = ck::tensor_layout::gemm::ColumnMajor; + auto AB_init = (common_init < 0) ? 5 : common_init; + auto pass = run_mfma_test(AB_init); EXPECT_TRUE(pass); } -/** - * @brief Run the test for the given MFMA instruction - * - * @param init - selects initialization algorithm for A and B tensors - */ -template -bool run_mfma_mk_kn_mn_test(ck::index_t init) +TEST(MFMA, FP4MFMA16x16x128) { using ALayout = ck::tensor_layout::gemm::RowMajor; using BLayout = ck::tensor_layout::gemm::ColumnMajor; using CLayout = ck::tensor_layout::gemm::RowMajor; - using AccType = float; // only MFMA_F32 instructions supported - using CPUAccType = AccType; - - ck::mfma_type(mfma)> mfma_instr; - constexpr auto BLOCK_M = mfma_instr.m_per_blk; - constexpr auto BLOCK_N = mfma_instr.n_per_blk; - constexpr auto BLOCK_K = mfma_instr.num_input_blks * mfma_instr.k_per_blk; - - const auto mfma_kernel = ck:: - matmul; - - bool pass = true; - - pass = ck::mfma_test::TestMFMA{}(mfma_kernel, init); - - return pass; + auto AB_init = (common_init < 0) ? 5 : common_init; + auto pass = + run_mfma_test( + AB_init); + EXPECT_TRUE(pass); } -TEST(MFMA, FP4MFMA16x16x128) +TEST(MFMA, FP6MFMA16x16x128) { - auto AB_init = 4; - auto pass = run_mfma_mk_kn_mn_test( - AB_init); + using ALayout = ck::tensor_layout::gemm::RowMajor; + using BLayout = ck::tensor_layout::gemm::ColumnMajor; + using CLayout = ck::tensor_layout::gemm::RowMajor; + + auto AB_init = (common_init < 0) ? 5 : common_init; + auto pass = + run_mfma_test( + AB_init); + EXPECT_TRUE(pass); +} + +TEST(MFMA, BF6MFMA16x16x128) +{ + using ALayout = ck::tensor_layout::gemm::RowMajor; + using BLayout = ck::tensor_layout::gemm::ColumnMajor; + using CLayout = ck::tensor_layout::gemm::RowMajor; + + auto AB_init = (common_init < 0) ? 5 : common_init; + auto pass = run_mfma_test(AB_init); + EXPECT_TRUE(pass); +} + +TEST(MFMA, FP8MFMA32x32x64) +{ + using ALayout = ck::tensor_layout::gemm::ColumnMajor; + using BLayout = ck::tensor_layout::gemm::ColumnMajor; + using CLayout = ck::tensor_layout::gemm::ColumnMajor; + + auto AB_init = (common_init < 0) ? 5 : common_init; + auto pass = + run_mfma_test( + AB_init); + EXPECT_TRUE(pass); +} + +TEST(MFMA, BF8MFMA32x32x64) +{ + using ALayout = ck::tensor_layout::gemm::ColumnMajor; + using BLayout = ck::tensor_layout::gemm::ColumnMajor; + using CLayout = ck::tensor_layout::gemm::ColumnMajor; + + auto AB_init = (common_init < 0) ? 5 : common_init; + auto pass = run_mfma_test(AB_init); EXPECT_TRUE(pass); } TEST(MFMA, FP4MFMA32x32x64) { - auto AB_init = 4; - auto pass = run_mfma_mk_kn_mn_test( - AB_init); + using ALayout = ck::tensor_layout::gemm::RowMajor; + using BLayout = ck::tensor_layout::gemm::ColumnMajor; + using CLayout = ck::tensor_layout::gemm::RowMajor; + + auto AB_init = (common_init < 0) ? 5 : common_init; + auto pass = + run_mfma_test( + AB_init); + EXPECT_TRUE(pass); +} + +TEST(MFMA, FP6MFMA32x32x64) +{ + using ALayout = ck::tensor_layout::gemm::RowMajor; + using BLayout = ck::tensor_layout::gemm::ColumnMajor; + using CLayout = ck::tensor_layout::gemm::RowMajor; + + auto AB_init = (common_init < 0) ? 5 : common_init; + auto pass = + run_mfma_test( + AB_init); + EXPECT_TRUE(pass); +} + +TEST(MFMA, BF6MFMA32x32x64) +{ + using ALayout = ck::tensor_layout::gemm::RowMajor; + using BLayout = ck::tensor_layout::gemm::ColumnMajor; + using CLayout = ck::tensor_layout::gemm::RowMajor; + + auto AB_init = (common_init < 0) ? 5 : common_init; + auto pass = run_mfma_test(AB_init); EXPECT_TRUE(pass); } @@ -129,15 +213,18 @@ TEST(MFMA, FP4MFMA32x32x64) * * @param init - selects initialization algorithm for A and B tensors */ -template -bool run_mxmfma_mk_kn_mn_test(ck::index_t init) +template +bool run_mxmfma_test(ck::index_t init) { static_assert(mfma == ck::MFMA_F8F6F4::SCALE_F32_16x16x128 || mfma == ck::MFMA_F8F6F4::SCALE_F32_32x32x64, "Only SCALE_F32_16x16x128 and SCALE_F32_32x32x64 are supported"); - using ALayout = ck::tensor_layout::gemm::RowMajor; - using BLayout = ck::tensor_layout::gemm::ColumnMajor; - using CLayout = ck::tensor_layout::gemm::RowMajor; using AccType = float; // only MFMA_F32 instructions supported using ScaleType = ck::e8m0_bexp_t; // biased exponent type @@ -181,34 +268,170 @@ bool run_mxmfma_mk_kn_mn_test(ck::index_t init) TEST(MXMFMA, MXFP8MFMA16x16x128) { - auto AB_init = 5; - auto pass = - run_mxmfma_mk_kn_mn_test(AB_init); + using ALayout = ck::tensor_layout::gemm::RowMajor; + using BLayout = ck::tensor_layout::gemm::ColumnMajor; + using CLayout = ck::tensor_layout::gemm::RowMajor; + + auto AB_init = (common_init < 0) ? 5 : common_init; + auto pass = run_mxmfma_test(AB_init); EXPECT_TRUE(pass); } TEST(MXMFMA, MXFP8MFMA32x32x64) { - auto AB_init = 5; - auto pass = - run_mxmfma_mk_kn_mn_test(AB_init); + using ALayout = ck::tensor_layout::gemm::RowMajor; + using BLayout = ck::tensor_layout::gemm::ColumnMajor; + using CLayout = ck::tensor_layout::gemm::RowMajor; + + auto AB_init = (common_init < 0) ? 5 : common_init; + auto pass = run_mxmfma_test(AB_init); + EXPECT_TRUE(pass); +} + +TEST(MXMFMA, MXBF8MFMA16x16x128) +{ + using ALayout = ck::tensor_layout::gemm::RowMajor; + using BLayout = ck::tensor_layout::gemm::ColumnMajor; + using CLayout = ck::tensor_layout::gemm::RowMajor; + + auto AB_init = (common_init < 0) ? 5 : common_init; + auto pass = run_mxmfma_test(AB_init); + EXPECT_TRUE(pass); +} + +TEST(MXMFMA, MXBF8MFMA32x32x64) +{ + using ALayout = ck::tensor_layout::gemm::RowMajor; + using BLayout = ck::tensor_layout::gemm::ColumnMajor; + using CLayout = ck::tensor_layout::gemm::RowMajor; + + auto AB_init = (common_init < 0) ? 5 : common_init; + auto pass = run_mxmfma_test(AB_init); + EXPECT_TRUE(pass); +} + +TEST(MXMFMA, MXFP6MFMA16x16x128) +{ + using ALayout = ck::tensor_layout::gemm::RowMajor; + using BLayout = ck::tensor_layout::gemm::ColumnMajor; + using CLayout = ck::tensor_layout::gemm::RowMajor; + + auto AB_init = (common_init < 0) ? 5 : common_init; + auto pass = run_mxmfma_test(AB_init); + EXPECT_TRUE(pass); +} + +TEST(MXMFMA, MXFP6MFMA32x32x64) +{ + using ALayout = ck::tensor_layout::gemm::RowMajor; + using BLayout = ck::tensor_layout::gemm::ColumnMajor; + using CLayout = ck::tensor_layout::gemm::RowMajor; + + auto AB_init = (common_init < 0) ? 5 : common_init; + auto pass = run_mxmfma_test(AB_init); + EXPECT_TRUE(pass); +} + +TEST(MXMFMA, MXBF6MFMA16x16x128) +{ + using ALayout = ck::tensor_layout::gemm::RowMajor; + using BLayout = ck::tensor_layout::gemm::ColumnMajor; + using CLayout = ck::tensor_layout::gemm::RowMajor; + + auto AB_init = (common_init < 0) ? 5 : common_init; + auto pass = run_mxmfma_test(AB_init); + EXPECT_TRUE(pass); +} + +TEST(MXMFMA, MXBF6MFMA32x32x64) +{ + using ALayout = ck::tensor_layout::gemm::RowMajor; + using BLayout = ck::tensor_layout::gemm::ColumnMajor; + using CLayout = ck::tensor_layout::gemm::RowMajor; + + auto AB_init = (common_init < 0) ? 5 : common_init; + auto pass = run_mxmfma_test(AB_init); EXPECT_TRUE(pass); } TEST(MXMFMA, MXFP4MFMA16x16x128) { - auto AB_init = 4; - auto pass = - run_mxmfma_mk_kn_mn_test( - AB_init); + using ALayout = ck::tensor_layout::gemm::RowMajor; + using BLayout = ck::tensor_layout::gemm::ColumnMajor; + using CLayout = ck::tensor_layout::gemm::RowMajor; + + auto AB_init = (common_init < 0) ? 5 : common_init; + auto pass = run_mxmfma_test(AB_init); EXPECT_TRUE(pass); } TEST(MXMFMA, MXFP4MFMA32x32x64) { - auto AB_init = 4; - auto pass = - run_mxmfma_mk_kn_mn_test( - AB_init); + using ALayout = ck::tensor_layout::gemm::RowMajor; + using BLayout = ck::tensor_layout::gemm::ColumnMajor; + using CLayout = ck::tensor_layout::gemm::RowMajor; + + auto AB_init = (common_init < 0) ? 5 : common_init; + auto pass = run_mxmfma_test(AB_init); EXPECT_TRUE(pass); } diff --git a/test/mx_mfma_op/mx_mfma_op.hpp b/test/mx_mfma_op/mx_mfma_op.hpp index 9ce871cfb1..4cab411cb4 100644 --- a/test/mx_mfma_op/mx_mfma_op.hpp +++ b/test/mx_mfma_op/mx_mfma_op.hpp @@ -151,6 +151,8 @@ __device__ AFragT load_A_col_major(AType const* input_ptr) // Reg 7 [24:31] | K79 | K95 | K111 | K127 | v[31] || Reg 7 [24:31] | K47 | K63 | v[31] | // clang-format on + static_assert(!is_packed_type_v, "Packed type is not supported"); + static constexpr int32_t WAVE_SIZE = 64; // Here we want to load from rows of A in chunks of 16 elements each. @@ -270,12 +272,28 @@ __device__ AFragT load_A_row_major(AType const* input_ptr) // Reg 3 [8:15] | K26K27 | K58K59 | K90K91 | K122K123 | v[13] || Reg 3 [8:15] | K26K27 | K58K59 | v[13] | // Reg 3 [16:23] | K28K29 | K60K61 | K92K93 | K124K125 | v[14] || Reg 3 [16:23] | K28K29 | K60K61 | v[14] | // Reg 3 [24:31] | K30K31 | K62K63 | K94K95 | K126K127 | v[15] || Reg 3 [24:31] | K30K31 | K62K63 | v[15] | + + + // Register Mapping for 16x128 for FP6: || Register Mapping for 32x64 for FP6: + // Size | BLOCK_M | BLOCK_M | BLOCK_M | BLOCK_M | || Size | BLOCK_M | BLOCK_M | | + // M | 0 ... 15 | 0 ... 15 | 0 ... 15 | 0 ... 15 | Vector || M | 0 ... 31 | 0 ... 31 | Vector | + // Thread Id | 0 ... 15 | 16 ... 31 | 32 ... 47 | 48 ... 63 | Element || Thread Id | 0 ... 31 | 32 ... 63 | Element| + // Register Element |------------|-------------|------------|-------------|-----------|| Register Element |------------|-------------|--------| + // Reg 0-2 [0:95] | K = 0-15 | K = 32-47 | K = 64-79 | K = 96-111 | v[0] || Reg 0-2 [0:95] | K = 0-15 | K = 32-47 | v[0] | + // Reg 3-5 [0:95] | K = 16-31 | K = 48-63 | K = 80-95 | K = 112-127 | v[0] || Reg 3-5 [0:95] | K = 16-31 | K = 48-63 | v[0] | + // clang-format on static constexpr int32_t WAVE_SIZE = 64; + // FP8 chunk_size = 16, num_chunks = 2, packed_size = 1 + // FP4 chunk_size = 32, num_chunks = 1, packed_size = 2 + // FP6 chunk_size = 32, num_chunks = 1, packed_size = 32 + + constexpr index_t num_chunks = is_packed_type_v ? 1 : 2; + // Here we want to load from rows of A in chunks of 16 elements each. - static constexpr uint32_t chunk_size = 16; + constexpr uint32_t chunk_size = is_packed_type_v ? 32 : 16; // each chunk is separated by offset static constexpr uint32_t chunk_offset = chunk_size * WAVE_SIZE / BLOCK_M; @@ -283,43 +301,35 @@ __device__ AFragT load_A_row_major(AType const* input_ptr) // To start the loading process, let's visualize in 2D coords. // Each thread will load 32 elements. // We need to know where they start, and where the next elements are. - auto startCoord2D = - std::make_pair(threadIdx.x % BLOCK_M, // Row {0-31} | {0-15} - (threadIdx.x / BLOCK_M) * chunk_size); // Col {0, 16} | {0, 16, 32, 48} + // FP8/6/4 Row {0-31} | {0-15} + // FP8 Col {0, 16} | {0, 16, 32, 48} + // FP6/4 Col {0, 32} | {0, 32, 64, 96} + auto startCoord2D = std::make_pair(threadIdx.x % BLOCK_M, (threadIdx.x / BLOCK_M) * chunk_size); - // auto minorStepCoord2D = std::make_pair(0u, 1u); // read rows auto majorStepCoord2D = std::make_pair(0, chunk_offset); // read a chunk from a row // Flatten to 1D row_major offsets. auto row_major = [](auto const& coord, auto ld) { return coord.first * ld + coord.second; }; - // BLOCK_K is a stride in A matrix - auto startOffset = row_major( - startCoord2D, BLOCK_K / (ck::is_same_v, ck::f4x2_pk_t> ? 2 : 1)); - // auto kMinorOffset = row_major(minorStepCoord2D, BLOCK_K / - // (ck::is_same_v, ck::f4x2_pk_t> ? 2 : 1)); - auto kMajorOffset = - row_major(majorStepCoord2D, - BLOCK_K / (ck::is_same_v, ck::f4x2_pk_t> ? 2 : 1)); - - using ARawT = typename scalar_type::type; - using AScalarFragT = vector_type::type; - - constexpr index_t num_chunks = - (ck::is_same_v, ck::f4x2_pk_t> ? 1 : 2); + using ARawT = typename scalar_type::type; + using AScalarChunkT = vector_type::vector_size / num_chunks>::type; union { AFragT frag; - AScalarFragT chunks[num_chunks]; + AScalarChunkT chunks[num_chunks]; } fragA{}; - const AScalarFragT* fragPtr; + const AScalarChunkT* fragPtr; + + // BLOCK_K is a stride in A matrix + auto startOffset = row_major(startCoord2D, BLOCK_K) / packed_size_v; + auto kMajorOffset = row_major(majorStepCoord2D, BLOCK_K) / packed_size_v; for(index_t chunk_idx = 0; chunk_idx < num_chunks; chunk_idx++) { - fragPtr = reinterpret_cast(input_ptr + startOffset + - chunk_idx * kMajorOffset); + fragPtr = reinterpret_cast(input_ptr + startOffset + + chunk_idx * kMajorOffset); fragA.chunks[chunk_idx] = *fragPtr; } @@ -488,12 +498,27 @@ __device__ BFragT load_B_col_major(BType const* input_ptr) // Reg 3 [8:15] | K26K27 | K58K59 | K90K91 | K122K123 | v[13] || Reg 3 [8:15] | K26K27 | K58K59 | v[13] | // Reg 3 [16:23] | K28K29 | K60K61 | K92K93 | K124K125 | v[14] || Reg 3 [16:23] | K28K29 | K60K61 | v[14] | // Reg 3 [24:31] | K30K31 | K62K63 | K94K95 | K126K127 | v[15] || Reg 3 [24:31] | K30K31 | K62K63 | v[15] | + + // Register Mapping for 16x128 for FP6: || Register Mapping for 32x64 for FP6: + // Size | BLOCK_N | BLOCK_N | BLOCK_N | BLOCK_N | || Size | BLOCK_N | BLOCK_N | | + // N | 0 ... 15 | 0 ... 15 | 0 ... 15 | 0 ... 15 | Vector || N | 0 ... 31 | 0 ... 31 | Vector | + // Thread Id | 0 ... 15 | 16 ... 31 | 32 ... 47 | 48 ... 63 | Element || Thread Id | 0 ... 31 | 32 ... 63 | Element| + // Register Element |------------|-------------|------------|-------------|-----------|| Register Element |------------|-------------|--------| + // Reg 0-2 [0:95] | K = 0-15 | K = 32-47 | K = 64-79 | K = 96-111 | v[0] || Reg 0-2 [0:95] | K = 0-15 | K = 32-47 | v[0] | + // Reg 3-5 [0:95] | K = 16-31 | K = 48-63 | K = 80-95 | K = 112-127 | v[0] || Reg 3-5 [0:95] | K = 16-31 | K = 48-63 | v[0] | + // clang-format on static constexpr int32_t WAVE_SIZE = 64; + // FP8 chunk_size = 16, num_chunks = 2, packed_size = 1 + // FP4 chunk_size = 32, num_chunks = 1, packed_size = 2 + // FP6 chunk_size = 32, num_chunks = 1, packed_size = 32 + + constexpr index_t num_chunks = is_packed_type_v ? 1 : 2; + // Here we want to load from cols of B in chunks of 16 elements each. - static constexpr uint32_t chunk_size = 16; + constexpr uint32_t chunk_size = is_packed_type_v ? 32 : 16; // each chunk is separated by an offset static constexpr uint32_t chunk_offset = chunk_size * WAVE_SIZE / BLOCK_N; // 32 or 64 @@ -501,44 +526,36 @@ __device__ BFragT load_B_col_major(BType const* input_ptr) // To start the loading process, let's visualize in 2D coords. // Each thread will load 32 elements. // We need to know where they start, and where the next elements are. - auto startCoord2D = - std::make_pair((threadIdx.x / BLOCK_N) * chunk_size, // Row {0, 16} | {0, 16, 32, 48} - threadIdx.x % BLOCK_N); // Col {0-31} | {0-15} + // FP8/6/4 Col {0-31} | {0-15} + // FP8 Row {0, 16} | {0, 16, 32, 48} + // FP6/4 Row {0, 32} | {0, 32, 64, 96} + auto startCoord2D = std::make_pair((threadIdx.x / BLOCK_N) * chunk_size, threadIdx.x % BLOCK_N); // Flatten to 1D col_major offsets. auto col_major = [](auto const& coord, auto ld) { return coord.first + coord.second * ld; }; - // auto minorStepCoord2D = std::make_pair(1u, 0u); // read cols auto majorStepCoord2D = std::make_pair(chunk_offset, 0); // read a chunk from a col - // BLOCK_K is a stride in B matrix - auto startOffset = col_major( - startCoord2D, BLOCK_K / (ck::is_same_v, ck::f4x2_pk_t> ? 2 : 1)); - // auto kMinorOffset = col_major(minorStepCoord2D, BLOCK_K / - // (ck::is_same_v, ck::f4x2_pk_t> ? 2 : 1)); - auto kMajorOffset = - col_major(majorStepCoord2D, - BLOCK_K / (ck::is_same_v, ck::f4x2_pk_t> ? 2 : 1)); - - using BRawT = typename scalar_type::type; - using BScalarFragT = vector_type::type; - - constexpr index_t num_chunks = - (ck::is_same_v, ck::f4x2_pk_t> ? 1 : 2); + using BRawT = typename scalar_type::type; + using BScalarChunkT = vector_type::vector_size / num_chunks>::type; union { BFragT frag; - BScalarFragT chunks[num_chunks]; + BScalarChunkT chunks[num_chunks]; } fragB{}; - const BScalarFragT* fragPtr; + const BScalarChunkT* fragPtr; - for(index_t chunk = 0; chunk < num_chunks; chunk++) + // BLOCK_K is a stride in B matrix + auto startOffset = col_major(startCoord2D, BLOCK_K) / packed_size_v; + auto kMajorOffset = col_major(majorStepCoord2D, BLOCK_K) / packed_size_v; + + for(index_t chunk_idx = 0; chunk_idx < num_chunks; chunk_idx++) { - fragPtr = - reinterpret_cast(input_ptr + startOffset + chunk * kMajorOffset); - fragB.chunks[chunk] = *fragPtr; + fragPtr = reinterpret_cast(input_ptr + startOffset + + chunk_idx * kMajorOffset); + fragB.chunks[chunk_idx] = *fragPtr; } return fragB.frag; @@ -904,20 +921,22 @@ template -__global__ void matmul(const AType* a, const BType* b, CType* c) +__global__ void matmul(const typename packed_type::type* a, + const typename packed_type::type* b, + CType* c) { + using PackedAType = typename packed_type::type; + constexpr auto packed_size_a = packed_type::packed_size; + using PackedBType = typename packed_type::type; + constexpr auto packed_size_b = packed_type::packed_size; + constexpr int WAVE_SIZE = 64; assert(threadIdx.x < WAVE_SIZE); assert(blockDim.x == 1 && blockDim.y == 1 && blockDim.z == 1); - using AFragT = - vector_type, ck::f4x2_pk_t> ? 2 : 1)>::type; - using BFragT = - vector_type, ck::f4x2_pk_t> ? 2 : 1)>::type; + using AFragT = vector_type::type; + using BFragT = vector_type::type; + using CFragT = vector_type::type; using AccumFragT = vector_type; using RawAccumFragT = vector_type::type; @@ -931,11 +950,11 @@ __global__ void matmul(const AType* a, const BType* b, CType* c) // Load the inputs. if constexpr(is_same_v) { - fragA = load_A_row_major(a); + fragA = load_A_row_major(a); } else { - fragA = load_A_col_major(a); + fragA = load_A_col_major(a); } if constexpr(is_same_v) @@ -944,7 +963,7 @@ __global__ void matmul(const AType* a, const BType* b, CType* c) } else { - fragB = load_B_col_major(b); + fragB = load_B_col_major(b); } // Matrix multiply-accumulate using MFMA units @@ -979,21 +998,24 @@ template -__global__ void -matmul(const AType* a, const ScaleType* xa, const BType* b, const ScaleType* xb, CType* c) +__global__ void matmul(const packed_type_t* a, + const ScaleType* xa, + const packed_type_t* b, + const ScaleType* xb, + CType* c) { + using PackedAType = packed_type_t; + constexpr auto packed_size_a = packed_size_v; + using PackedBType = packed_type_t; + constexpr auto packed_size_b = packed_size_v; + constexpr int WAVE_SIZE = 64; assert(threadIdx.x < WAVE_SIZE); assert(blockDim.x == 1 && blockDim.y == 1 && blockDim.z == 1); - using AFragT = - vector_type, ck::f4x2_pk_t> ? 2 : 1)>::type; - using BFragT = - vector_type, ck::f4x2_pk_t> ? 2 : 1)>::type; + using AFragT = vector_type::type; + using BFragT = vector_type::type; + using CFragT = vector_type::type; using AccumFragT = vector_type; using RawAccumFragT = vector_type::type; @@ -1011,9 +1033,13 @@ matmul(const AType* a, const ScaleType* xa, const BType* b, const ScaleType* xb, // Load the inputs. if constexpr(is_same_v) { - fragA = - load_mx_A_row_major( - a, xa, fragXa); + fragA = load_mx_A_row_major(a, xa, fragXa); } else { @@ -1026,9 +1052,13 @@ matmul(const AType* a, const ScaleType* xa, const BType* b, const ScaleType* xb, } else { - fragB = - load_mx_B_col_major( - b, xb, fragXb); + fragB = load_mx_B_col_major(b, xb, fragXb); } // Scaled Matrix multiply-accumulate using MFMA units @@ -1151,6 +1181,11 @@ template struct TestMXMFMA { + using PackedAType = typename packed_type::type; + static constexpr auto packed_size_a = packed_type::packed_size; + using PackedBType = typename packed_type::type; + static constexpr auto packed_size_b = packed_type::packed_size; + auto PrepareGemmTensors(const GemmParams& params, index_t init) { auto f_host_tensor_descriptor = @@ -1167,11 +1202,11 @@ struct TestMXMFMA } }; - Tensor a_m_k( + Tensor a_m_k( f_host_tensor_descriptor(params.M, params.K, params.StrideA, ALayout{})); Tensor a_scales( f_host_tensor_descriptor(params.M, params.K / BLOCK_X, params.K / BLOCK_X, ALayout{})); - Tensor b_n_k( + Tensor b_n_k( f_host_tensor_descriptor(params.K, params.N, params.StrideB, BLayout{})); Tensor b_scales( f_host_tensor_descriptor(params.K / BLOCK_X, params.N, params.K / BLOCK_X, BLayout{})); @@ -1183,51 +1218,44 @@ struct TestMXMFMA switch(init) { case 0: - a_m_k.GenerateTensorValue(GeneratorTensor_1{1.0f}); - a_scales.GenerateTensorValue(GeneratorTensor_1{ScaleType{0.015625f}}); // 1/6 + a_m_k.GenerateTensorValue(GeneratorTensor_1{1.0f}); + a_scales.GenerateTensorValue(GeneratorTensor_1{ScaleType{0.5f}}); // NOTE: not all numbers are representable in FP8, BF8, etc. // 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 16 18 20 20 20 22 24 24 24 26 28 28 28 30 32 - b_n_k.GenerateTensorValue(GeneratorTensor_Sequential{}); + b_n_k.GenerateTensorValue(GeneratorTensor_Sequential{}); b_scales.GenerateTensorValue(GeneratorTensor_1{ScaleType{1.0f}}); break; case 1: // results in C = {K} - a_m_k.GenerateTensorValue(GeneratorTensor_1{1.0f}); + a_m_k.GenerateTensorValue(GeneratorTensor_1{1.0f}); a_scales.GenerateTensorValue(GeneratorTensor_1{ScaleType{512.0f}}); - b_n_k.GenerateTensorValue(GeneratorTensor_1{1.0f}); + b_n_k.GenerateTensorValue(GeneratorTensor_1{1.0f}); b_scales.GenerateTensorValue(GeneratorTensor_1{ScaleType{1.0f / 512}}); break; case 2: // expect small round off errors - a_m_k.GenerateTensorValue(GeneratorTensor_3{-2.0, 2.0}); + a_m_k.GenerateTensorValue(GeneratorTensor_3{-2.0, 2.0}); a_scales.GenerateTensorValue( GeneratorTensor_2{126, 129}); // scales: {0.5, 1, 2} - b_n_k.GenerateTensorValue(GeneratorTensor_3{-2.0, 2.0}); + b_n_k.GenerateTensorValue(GeneratorTensor_3{-2.0, 2.0}); b_scales.GenerateTensorValue(GeneratorTensor_2{126, 129}); break; case 3: // expect small round off errors - a_m_k.GenerateTensorValue(GeneratorTensor_4(0, 1)); + a_m_k.GenerateTensorValue(GeneratorTensor_4(0, 1, time(nullptr))); a_scales.GenerateTensorValue( GeneratorTensor_2{126, 129}); // scales: {0.5, 1, 2} - b_n_k.GenerateTensorValue(GeneratorTensor_4(0, 1)); - b_scales.GenerateTensorValue( - GeneratorTensor_2{126, 129}); // scales: {0.5, 1, 2} - break; - case 4: - a_m_k.GenerateTensorValue(GeneratorTensor_3{-1., 1.}); - a_scales.GenerateTensorValue( - GeneratorTensor_2{126, 129}); // scales: {0.5, 1, 2} - b_n_k.GenerateTensorValue(GeneratorTensor_3{-1., 1.}); + b_n_k.GenerateTensorValue(GeneratorTensor_4(0, 1, time(nullptr) / 2)); b_scales.GenerateTensorValue( GeneratorTensor_2{126, 129}); // scales: {0.5, 1, 2} break; + default: // all initial values are representable in FP8, BF8 - a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 6}); // Z[-5,5] + a_m_k.GenerateTensorValue(GeneratorTensor_2{-6, 7}); // Z[-6,6] a_scales.GenerateTensorValue( - GeneratorTensor_2{122, 129}); // scales: [1/32,..., 2] - b_n_k.GenerateTensorValue(GeneratorTensor_2{-5, 6}); // Z[-5,5] + GeneratorTensor_2{122, 129}); // scales: [1/32,..., 2] + b_n_k.GenerateTensorValue(GeneratorTensor_2{-6, 7}); // Z[-6,6] b_scales.GenerateTensorValue( GeneratorTensor_2{122, 129}); // scales: [1/32,..., 2] @@ -1272,9 +1300,9 @@ struct TestMXMFMA auto host_tensors = PrepareGemmTensors(params, init); - const Tensor& a = std::get<0>(host_tensors); + const Tensor& a = std::get<0>(host_tensors); const Tensor& a_scales = std::get<1>(host_tensors); - const Tensor& b = std::get<2>(host_tensors); + const Tensor& b = std::get<2>(host_tensors); const Tensor& b_scales = std::get<3>(host_tensors); Tensor& c_host = std::get<4>(host_tensors); Tensor& c_device = std::get<5>(host_tensors); @@ -1356,6 +1384,12 @@ template struct TestMFMA { + + using PackedAType = typename packed_type::type; + static constexpr auto packed_size_a = packed_type::packed_size; + using PackedBType = typename packed_type::type; + static constexpr auto packed_size_b = packed_type::packed_size; + auto PrepareGemmTensors(const GemmParams& params, index_t init) { auto f_host_tensor_descriptor = @@ -1372,9 +1406,9 @@ struct TestMFMA } }; - Tensor a_m_k( + Tensor a_m_k( f_host_tensor_descriptor(params.M, params.K, params.StrideA, ALayout{})); - Tensor b_n_k( + Tensor b_n_k( f_host_tensor_descriptor(params.K, params.N, params.StrideB, BLayout{})); Tensor c_m_n_host_result( f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{})); @@ -1384,34 +1418,30 @@ struct TestMFMA switch(init) { case 0: - a_m_k.GenerateTensorValue(GeneratorTensor_1{0.015625f}); + a_m_k.GenerateTensorValue(GeneratorTensor_1{0.625f}); // NOTE: not all numbers are representable in FP8, BF8, etc. - b_n_k.GenerateTensorValue(GeneratorTensor_Sequential{}); + b_n_k.GenerateTensorValue(GeneratorTensor_Sequential{}); break; case 1: // results in C = {K} - a_m_k.GenerateTensorValue(GeneratorTensor_1{1.0f}); - b_n_k.GenerateTensorValue(GeneratorTensor_1{1.0f}); + a_m_k.GenerateTensorValue(GeneratorTensor_1{1.0f}); + b_n_k.GenerateTensorValue(GeneratorTensor_1{1.0f}); break; case 2: - // expect small round off errors - a_m_k.GenerateTensorValue(GeneratorTensor_3{-5, 5}); - b_n_k.GenerateTensorValue(GeneratorTensor_3{-5, 5}); + // expect small round off errors that lead to FP8MFMA32x32x64 failures + a_m_k.GenerateTensorValue(GeneratorTensor_3{-5, 5}); + b_n_k.GenerateTensorValue(GeneratorTensor_3{-5, 5}); break; case 3: - // expect small round off errors - a_m_k.GenerateTensorValue(GeneratorTensor_4(-1, 3)); - b_n_k.GenerateTensorValue(GeneratorTensor_4(1, 3)); - break; - case 4: - // FP4 values case - a_m_k.GenerateTensorValue(GeneratorTensor_2{-4, 5}); - b_n_k.GenerateTensorValue(GeneratorTensor_2{-4, 5}); + // expect small round off errors that lead to FP8MFMA32x32x64 failures + a_m_k.GenerateTensorValue(GeneratorTensor_4(-1, 3)); + b_n_k.GenerateTensorValue(GeneratorTensor_4(1, 3)); break; + default: - // all initial values are representable in FP8, BF8 - a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 6}); - b_n_k.GenerateTensorValue(GeneratorTensor_2{-5, 6}); + // all initial values are representable in FP8/6, BF8/6 FP4 is missing 5 + a_m_k.GenerateTensorValue(GeneratorTensor_2{-6, 7}); // Z[-6,6] + b_n_k.GenerateTensorValue(GeneratorTensor_2{-6, 7}); break; } @@ -1453,10 +1483,10 @@ struct TestMFMA auto host_tensors = PrepareGemmTensors(params, init); - const Tensor& a = std::get<0>(host_tensors); - const Tensor& b = std::get<1>(host_tensors); - Tensor& c_host = std::get<2>(host_tensors); - Tensor& c_device = std::get<3>(host_tensors); + const Tensor& a = std::get<0>(host_tensors); + const Tensor& b = std::get<1>(host_tensors); + Tensor& c_host = std::get<2>(host_tensors); + Tensor& c_device = std::get<3>(host_tensors); using PassThrough = ck::tensor_operation::element_wise::PassThrough; @@ -1464,8 +1494,8 @@ struct TestMFMA auto b_element_op = PassThrough{}; auto c_element_op = PassThrough{}; - using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm Date: Mon, 19 May 2025 17:29:51 -0700 Subject: [PATCH 011/147] Use new mfma instructions for FP8 on gfx950 (#2202) * Add logic to use new mfma instructions for fp8 bf8 * Fix example_gemm_xdl_fp8_pk_i4_bpreshuffle_v3 on gfx950 and run clang format * Update include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp Co-authored-by: Andriy Roshchenko <107577548+andriy-ca@users.noreply.github.com> * Fix intrin_mfma f8 calls due to merge mistake --------- Co-authored-by: Andriy Roshchenko <107577548+andriy-ca@users.noreply.github.com> --- example/01_gemm/gemm_xdl_fp8.cpp | 2 + ...ipeline_xdlops_b_preshuffle_dequant_v3.hpp | 4 +- ...iple_d_welford_first_half_xdl_cshuffle.hpp | 16 ++- ...wise_batched_gemm_gemm_xdl_cshuffle_v1.hpp | 11 +- ...iple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp | 33 +++-- ...ultiple_d_softmax_gemm_xdl_cshuffle_v1.hpp | 19 ++- ...ched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp | 11 +- ...e_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp | 11 +- ...ridwise_gemm_multiple_abd_xdl_cshuffle.hpp | 20 +-- ...emm_multiple_d_multiple_r_xdl_cshuffle.hpp | 11 +- .../gridwise_gemm_multiple_d_xdl_cshuffle.hpp | 20 +-- ...ultiple_d_xdl_cshuffle_lds_direct_load.hpp | 19 +-- ...se_gemm_multiple_d_xdl_splitk_cshuffle.hpp | 16 ++- .../gridwise_gemm_reduce_xdl_cshuffle_v1.hpp | 11 +- ...e_gemm_split_k_multiple_d_xdl_cshuffle.hpp | 32 +++-- ...emm_split_k_multiple_d_xdl_cshuffle_v2.hpp | 16 ++- .../gridwise_gemm_xdl_cshuffle_conv_v3.hpp | 13 +- .../gridwise_gemm_xdl_cshuffle_streamk_v3.hpp | 13 +- .../grid/gridwise_gemm_xdl_cshuffle_v1.hpp | 16 ++- .../grid/gridwise_gemm_xdl_cshuffle_v2.hpp | 16 ++- .../grid/gridwise_gemm_xdl_cshuffle_v3.hpp | 13 +- ...wise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp | 20 ++- .../gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp | 13 +- ...ridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp | 13 +- .../gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp | 14 +- ..._gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp | 14 +- ...m_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp | 25 +++- ...ridwise_gemm_xdl_layernorm_cshuffle_v1.hpp | 11 +- ...ridwise_gemm_xdl_waveletmodel_cshuffle.hpp | 17 ++- .../grid/gridwise_gemm_xdlops_bwd_weight.hpp | 17 ++- .../gpu/grid/gridwise_gemm_xdlops_v3r1.hpp | 11 +- .../gpu/grid/gridwise_moe_gemm.hpp | 29 ++-- .../tensor_operation/gpu/warp/xdlops_gemm.hpp | 131 +++++++++++++++--- include/ck/utility/amd_xdlops.hpp | 90 ++++++++++++ 34 files changed, 548 insertions(+), 180 deletions(-) diff --git a/example/01_gemm/gemm_xdl_fp8.cpp b/example/01_gemm/gemm_xdl_fp8.cpp index 3c75a44d21..0c51a58037 100644 --- a/example/01_gemm/gemm_xdl_fp8.cpp +++ b/example/01_gemm/gemm_xdl_fp8.cpp @@ -32,6 +32,8 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle // ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | | | // ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, LoopSched, PipelineVer, ComputeTypeA, ComputeTypeB>; + // this instance has been tested working on gfx950 + // < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 128, 32, 32, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, LoopSched, PipelineVer, ComputeTypeA, ComputeTypeB>; // clang-format on using ReferenceGemmInstance = ck::tensor_operation::host:: diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_dequant_v3.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_dequant_v3.hpp index 4be4e321d3..e5fe92a50d 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_dequant_v3.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_dequant_v3.hpp @@ -124,7 +124,6 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_bdequant_v3{}; + static constexpr index_t PrefetchStages = 2; static constexpr index_t PrefillStages = 1; static constexpr index_t GlobalBufferNum = 1; diff --git a/include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp index d728360c55..02dba97430 100644 --- a/include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp @@ -519,13 +519,19 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle constexpr bool is_single_rate_mfma = (((is_same::value || is_same::value) && lcm_AK1_BK1 <= 4) || - (is_same::value && lcm_AK1_BK1 <= 8)) + (is_same::value && lcm_AK1_BK1 <= 8) || + ((is_same::value || is_same::value) && + lcm_AK1_BK1 < 32)) ? true : false; - constexpr index_t KPack = - math::max(lcm_AK1_BK1, - MfmaSelector:: - selected_mfma.k_per_blk); + constexpr auto is_scale_mfma = false; + constexpr index_t KPack = math::max(lcm_AK1_BK1, + MfmaSelector::selected_mfma.k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp index 50b4a734fa..258d0ad0ca 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp @@ -452,13 +452,16 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle constexpr bool is_single_rate_mfma = (((is_same::value || is_same::value) && lcm_AK1_BK1 <= 4) || - (is_same::value && lcm_AK1_BK1 <= 8)) + (is_same::value && lcm_AK1_BK1 <= 8) || + ((is_same::value || is_same::value) && + lcm_AK1_BK1 < 32)) ? true : false; - constexpr index_t KPack = math::max( + constexpr auto is_scale_mfma = false; + constexpr index_t KPack = math::max( lcm_AK1_BK1, - MfmaSelector::selected_mfma - .k_per_blk); + MfmaSelector:: + selected_mfma.k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_v2< BlockSize, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp index 79a9410898..53a45c7f16 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp @@ -365,16 +365,20 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle constexpr bool is_single_rate_mfma = (((is_same::value || is_same::value) && lcm_A0K1_B0K1 <= 4) || - (is_same::value && lcm_A0K1_B0K1 <= 8)) + (is_same::value && lcm_A0K1_B0K1 <= 8) || + ((is_same::value || is_same::value) && + lcm_A0K1_B0K1 < 32)) ? true : false; - constexpr auto mfma = MfmaSelector::selected_mfma; - constexpr auto N3 = mfma.num_groups_per_blk; - constexpr auto N5 = mfma.group_size; + is_single_rate_mfma, + is_scale_mfma>::selected_mfma; + constexpr auto N3 = mfma.num_groups_per_blk; + constexpr auto N5 = mfma.group_size; return transform_tensor_descriptor( d0_grid_desc_m_n, make_tuple(make_unmerge_transform(make_tuple( @@ -657,16 +661,19 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle constexpr bool is_single_rate_mfma = (((is_same::value || is_same::value) && lcm_A0K1_B0K1 <= 4) || - (is_same::value && lcm_A0K1_B0K1 <= 8)) + (is_same::value && lcm_A0K1_B0K1 <= 8) || + ((is_same::value || is_same::value) && + lcm_A0K1_B0K1 < 32)) ? true : false; - constexpr index_t KPack = - math::max(lcm_A0K1_B0K1, - MfmaSelector::selected_mfma.k_per_blk); + constexpr auto is_scale_mfma = false; + constexpr index_t KPack = math::max(lcm_A0K1_B0K1, + MfmaSelector::selected_mfma.k_per_blk); auto blockwise_gemm0 = BlockwiseGemmXdlops_v2< BlockSize, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp index d15767f658..0f2085525f 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp @@ -347,11 +347,15 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle constexpr bool is_single_rate_mfma = (((is_same::value || is_same::value) && lcm_AK1_BK1 <= 4) || - (is_same::value && lcm_AK1_BK1 <= 8)) + (is_same::value && lcm_AK1_BK1 <= 8) || + ((is_same::value || is_same::value) && + lcm_AK1_BK1 < 32)) ? true : false; + constexpr auto is_scale_mfma = false; constexpr auto mfma = - MfmaSelector::selected_mfma; + MfmaSelector:: + selected_mfma; constexpr auto N3 = mfma.num_groups_per_blk; constexpr auto N4 = mfma.num_input_blks; constexpr auto N5 = mfma.group_size; @@ -564,13 +568,16 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle constexpr bool is_single_rate_mfma = (((is_same::value || is_same::value) && lcm_AK1_BK1 <= 4) || - (is_same::value && lcm_AK1_BK1 <= 8)) + (is_same::value && lcm_AK1_BK1 <= 8) || + ((is_same::value || is_same::value) && + lcm_AK1_BK1 < 32)) ? true : false; - constexpr index_t KPack = math::max( + constexpr auto is_scale_mfma = false; + constexpr index_t KPack = math::max( lcm_AK1_BK1, - MfmaSelector::selected_mfma - .k_per_blk); + MfmaSelector:: + selected_mfma.k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_v2< BlockSize, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp index a11d696019..33b9199ea5 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp @@ -473,13 +473,16 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle constexpr bool is_single_rate_mfma = (((is_same::value || is_same::value) && lcm_AK1_BK1 <= 4) || - (is_same::value && lcm_AK1_BK1 <= 8)) + (is_same::value && lcm_AK1_BK1 <= 8) || + ((is_same::value || is_same::value) && + lcm_AK1_BK1 < 32)) ? true : false; - constexpr index_t KPack = math::max( + constexpr auto is_scale_mfma = false; + constexpr index_t KPack = math::max( lcm_AK1_BK1, - MfmaSelector::selected_mfma - .k_per_blk); + MfmaSelector:: + selected_mfma.k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_v2< BlockSize, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp index ab97a940a8..f406bfb95a 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp @@ -502,13 +502,16 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 constexpr bool is_single_rate_mfma = (((is_same::value || is_same::value) && lcm_AK1_BK1 <= 4) || - (is_same::value && lcm_AK1_BK1 <= 8)) + (is_same::value && lcm_AK1_BK1 <= 8) || + ((is_same::value || is_same::value) && + lcm_AK1_BK1 < 32)) ? true : false; - constexpr index_t KPack = math::max( + constexpr auto is_scale_mfma = false; + constexpr index_t KPack = math::max( lcm_AK1_BK1, - MfmaSelector::selected_mfma - .k_per_blk); + MfmaSelector:: + selected_mfma.k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp index 79ab3acd92..054aca2936 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp @@ -679,17 +679,19 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle (((is_same::value || is_same::value) && lcm_AK1_BK1 <= 4) || - (is_same::value && lcm_AK1_BK1 <= 8)) + (is_same::value && lcm_AK1_BK1 <= 8) || + ((is_same::value || is_same::value) && + lcm_AK1_BK1 < 32)) ? true : false; - - constexpr index_t KPack = - math::max(lcm_AK1_BK1, - MfmaSelector::selected_mfma.k_per_blk); + static constexpr auto is_scale_mfma = false; + constexpr index_t KPack = math::max(lcm_AK1_BK1, + MfmaSelector::selected_mfma.k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp index 0e51c6904c..127d889572 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp @@ -468,13 +468,16 @@ struct GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1 constexpr bool is_single_rate_mfma = (((is_same::value || is_same::value) && lcm_AK1_BK1 <= 4) || - (is_same::value && lcm_AK1_BK1 <= 8)) + (is_same::value && lcm_AK1_BK1 <= 8) || + ((is_same::value || is_same::value) && + lcm_AK1_BK1 < 32)) ? true : false; - constexpr index_t KPack = math::max( + constexpr auto is_scale_mfma = false; + constexpr index_t KPack = math::max( lcm_AK1_BK1, - MfmaSelector::selected_mfma - .k_per_blk); + MfmaSelector:: + selected_mfma.k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp index a3301dd932..be0fff087e 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp @@ -647,17 +647,19 @@ struct GridwiseGemmMultipleD_xdl_cshuffle (((is_same::value || is_same::value) && lcm_AK1_BK1 <= 4) || - (is_same::value && lcm_AK1_BK1 <= 8)) + (is_same::value && lcm_AK1_BK1 <= 8) || + ((is_same::value || is_same::value) && + lcm_AK1_BK1 < 32)) ? true : false; - - constexpr index_t KPack = - math::max(lcm_AK1_BK1, - MfmaSelector::selected_mfma.k_per_blk); + constexpr auto is_scale_mfma = false; + constexpr index_t KPack = math::max(lcm_AK1_BK1, + MfmaSelector::selected_mfma.k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp index 57b9b02548..7781d1def3 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp @@ -605,17 +605,20 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad (((is_same::value || is_same::value) && lcm_AK1_BK1 <= 4) || - (is_same::value && lcm_AK1_BK1 <= 8)) + (is_same::value && lcm_AK1_BK1 <= 8) || + ((is_same::value || is_same::value) && + lcm_AK1_BK1 < 32)) ? true : false; + constexpr auto is_scale_mfma = false; - constexpr index_t KPack = - math::max(lcm_AK1_BK1, - MfmaSelector::selected_mfma.k_per_blk); + constexpr index_t KPack = math::max(lcm_AK1_BK1, + MfmaSelector::selected_mfma.k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp index 88d6be234c..5815eb5b0b 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp @@ -603,13 +603,19 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle constexpr bool is_single_rate_mfma = (((is_same::value || is_same::value) && lcm_AK1_BK1 <= 4) || - (is_same::value && lcm_AK1_BK1 <= 8)) + (is_same::value && lcm_AK1_BK1 <= 8) || + ((is_same::value || is_same::value) && + lcm_AK1_BK1 < 32)) ? true : false; - constexpr index_t KPack = math::max( - lcm_AK1_BK1, - MfmaSelector:: - selected_mfma.k_per_blk); + constexpr auto is_scale_mfma = false; + constexpr index_t KPack = math::max(lcm_AK1_BK1, + MfmaSelector::selected_mfma.k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp index 56581256dc..db227bb7ef 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp @@ -455,13 +455,16 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 constexpr bool is_single_rate_mfma = (((is_same::value || is_same::value) && lcm_AK1_BK1 <= 4) || - (is_same::value && lcm_AK1_BK1 <= 8)) + (is_same::value && lcm_AK1_BK1 <= 8) || + ((is_same::value || is_same::value) && + lcm_AK1_BK1 < 32)) ? true : false; - constexpr index_t KPack = math::max( + constexpr auto is_scale_mfma = false; + constexpr index_t KPack = math::max( lcm_AK1_BK1, - MfmaSelector::selected_mfma - .k_per_blk); + MfmaSelector:: + selected_mfma.k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle.hpp index 23b4aec3b0..70301c326a 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle.hpp @@ -585,13 +585,19 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle constexpr bool is_single_rate_mfma = (((is_same::value || is_same::value) && lcm_AK1_BK1 <= 4) || - (is_same::value && lcm_AK1_BK1 <= 8)) + (is_same::value && lcm_AK1_BK1 <= 8) || + ((is_same::value || is_same::value) && + lcm_AK1_BK1 < 32)) ? true : false; - constexpr index_t KPack = - math::max(lcm_AK1_BK1, - MfmaSelector:: - selected_mfma.k_per_blk); + constexpr auto is_scale_mfma = false; + constexpr index_t KPack = math::max(lcm_AK1_BK1, + MfmaSelector::selected_mfma.k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, @@ -1018,13 +1024,19 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle constexpr bool is_single_rate_mfma = (((is_same::value || is_same::value) && lcm_AK1_BK1 <= 4) || - (is_same::value && lcm_AK1_BK1 <= 8)) + (is_same::value && lcm_AK1_BK1 <= 8) || + ((is_same::value || is_same::value) && + lcm_AK1_BK1 < 32)) ? true : false; - constexpr index_t KPack = - math::max(lcm_AK1_BK1, - MfmaSelector:: - selected_mfma.k_per_blk); + constexpr auto is_scale_mfma = false; + constexpr index_t KPack = math::max(lcm_AK1_BK1, + MfmaSelector::selected_mfma.k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle_v2.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle_v2.hpp index 44c1e936bd..f64838ea4e 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle_v2.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle_v2.hpp @@ -599,13 +599,19 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle constexpr bool is_single_rate_mfma = (((is_same::value || is_same::value) && lcm_AK1_BK1 <= 4) || - (is_same::value && lcm_AK1_BK1 <= 8)) + (is_same::value && lcm_AK1_BK1 <= 8) || + ((is_same::value || is_same::value) && + lcm_AK1_BK1 < 32)) ? true : false; - constexpr index_t KPack = math::max( - lcm_AK1_BK1, - MfmaSelector:: - selected_mfma.k_per_blk); + constexpr auto is_scale_mfma = false; + constexpr index_t KPack = math::max(lcm_AK1_BK1, + MfmaSelector::selected_mfma.k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp index d37b3cd38e..4d3ae93659 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp @@ -83,13 +83,20 @@ struct GridwiseGemm_xdl_cshuffle_v3 static constexpr bool is_single_rate_mfma = (((is_same::value || is_same::value) && lcm_AK1_BK1 <= 4) || - (is_same::value && lcm_AK1_BK1 <= 8)) + (is_same::value && lcm_AK1_BK1 <= 8) || + ((is_same::value || is_same::value) && + lcm_AK1_BK1 < 32)) ? true : false; + static constexpr auto is_scale_mfma = false; static constexpr index_t KPack = math::max(lcm_AK1_BK1, - MfmaSelector:: - selected_mfma.k_per_blk); + MfmaSelector::selected_mfma.k_per_blk); using ThisThreadBlock = ThisThreadBlock; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp index e5e32a8535..4e72255d31 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp @@ -144,13 +144,20 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 static constexpr bool is_single_rate_mfma = (((is_same::value || is_same::value) && lcm_AK1_BK1 <= 4) || - (is_same::value && lcm_AK1_BK1 <= 8)) + (is_same::value && lcm_AK1_BK1 <= 8) || + ((is_same::value || is_same::value) && + lcm_AK1_BK1 < 32)) ? true : false; + static constexpr auto is_scale_mfma = false; static constexpr index_t KPack = math::max(lcm_AK1_BK1, - MfmaSelector:: - selected_mfma.k_per_blk); + MfmaSelector::selected_mfma.k_per_blk); using ThisThreadBlock = ThisThreadBlock; __host__ static auto CalculateMPadded(index_t M) diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp index 240bc464e1..7edcd7270f 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp @@ -814,13 +814,19 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 constexpr bool is_single_rate_mfma = (((is_same::value || is_same::value) && lcm_AK1_BK1 <= 4) || - (is_same::value && lcm_AK1_BK1 <= 8)) + (is_same::value && lcm_AK1_BK1 <= 8) || + ((is_same::value || is_same::value) && + lcm_AK1_BK1 < 32)) ? true : false; - constexpr index_t KPack = math::max( - lcm_AK1_BK1, - MfmaSelector:: - selected_mfma.k_per_blk); + constexpr auto is_scale_mfma = false; + constexpr index_t KPack = math::max(lcm_AK1_BK1, + MfmaSelector::selected_mfma.k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v2.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v2.hpp index c7d44e842d..f92268265f 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v2.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v2.hpp @@ -873,13 +873,19 @@ struct GridwiseGemm_xdl_cshuffle_v2 constexpr bool is_single_rate_mfma = (((is_same::value || is_same::value) && lcm_AK1_BK1 <= 4) || - (is_same::value && lcm_AK1_BK1 <= 8)) + (is_same::value && lcm_AK1_BK1 <= 8) || + ((is_same::value || is_same::value) && + lcm_AK1_BK1 < 32)) ? true : false; - constexpr index_t KPack = math::max( - lcm_AK1_BK1, - MfmaSelector:: - selected_mfma.k_per_blk); + constexpr auto is_scale_mfma = false; + constexpr index_t KPack = math::max(lcm_AK1_BK1, + MfmaSelector::selected_mfma.k_per_blk); // auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< // BlockSize, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp index 29150c0688..0dbbc2a5e9 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp @@ -255,13 +255,20 @@ struct GridwiseGemm_xdl_cshuffle_v3 static constexpr bool is_single_rate_mfma = (((is_same::value || is_same::value) && lcm_AK1_BK1 <= 4) || - (is_same::value && lcm_AK1_BK1 <= 8)) + (is_same::value && lcm_AK1_BK1 <= 8) || + ((is_same::value || is_same::value) && + lcm_AK1_BK1 < 32)) ? true : false; + static constexpr auto is_scale_mfma = false; static constexpr index_t KPack = math::max(lcm_AK1_BK1, - MfmaSelector:: - selected_mfma.k_per_blk); + MfmaSelector::selected_mfma.k_per_blk); using ThisThreadBlock = ThisThreadBlock; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp index a22fc06a50..cfa8bfeb2a 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp @@ -148,13 +148,21 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle static constexpr auto AK1Number = Number{}; static constexpr auto BK1Number = Number{}; - using mfma_selector = MfmaSelector; + // Use singal rate mfma instruction for this special case A (f8_t) * B (pk_i4_t) + // See example gemm_xdl_fp8_pk_i4_bpreshuffle_v3 + // TODO: explore optimization opportunity by using new mfma instructions on gfx950 + static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number); + static constexpr bool is_single_rate_mfma = true; + static constexpr auto is_scale_mfma = false; + static constexpr auto mfma = MfmaSelector{}; + static constexpr index_t KPack = math::max(lcm_AK1_BK1, mfma.selected_mfma.k_per_blk); + static constexpr index_t KLane = mfma.GetKPerXdlops() / mfma.GetK1PerXdlops(); - static constexpr index_t KPack = - math::max(math::lcm(AK1Number, BK1Number), mfma_selector::selected_mfma.k_per_blk); - - static constexpr index_t KLane = - mfma_selector::GetKPerXdlops() / mfma_selector::GetK1PerXdlops(); static constexpr index_t KRepeat = KPerBlock / KLane / KPack; static constexpr index_t NLane = NPerXdl; static constexpr index_t NWave = NPerBlock / NPerXdl / NXdlPerWave; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp index 7124687d5d..93c1779a80 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp @@ -160,13 +160,20 @@ struct GridwiseGemm_xdl_cshuffle_v3 static constexpr bool is_single_rate_mfma = (((is_same::value || is_same::value) && lcm_AK1_BK1 <= 4) || - (is_same::value && lcm_AK1_BK1 <= 8)) + (is_same::value && lcm_AK1_BK1 <= 8) || + ((is_same::value || is_same::value) && + lcm_AK1_BK1 < 32)) ? true : false; + static constexpr auto is_scale_mfma = false; static constexpr index_t KPack = math::max(lcm_AK1_BK1, - MfmaSelector:: - selected_mfma.k_per_blk); + MfmaSelector::selected_mfma.k_per_blk); using ThisThreadBlock = ThisThreadBlock; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp index ac3e821340..97d0e2a4eb 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp @@ -198,13 +198,20 @@ struct GridwiseGemm_xdl_cshuffle_v3 static constexpr bool is_single_rate_mfma = (((is_same::value || is_same::value) && lcm_AK1_BK1 <= 4) || - (is_same::value && lcm_AK1_BK1 <= 8)) + (is_same::value && lcm_AK1_BK1 <= 8) || + ((is_same::value || is_same::value) && + lcm_AK1_BK1 < 32)) ? true : false; + static constexpr auto is_scale_mfma = false; static constexpr index_t KPack = math::max(lcm_AK1_BK1, - MfmaSelector:: - selected_mfma.k_per_blk); + MfmaSelector::selected_mfma.k_per_blk); using ThisThreadBlock = ThisThreadBlock; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp index 4163d1d01a..38ce9536ab 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp @@ -183,14 +183,20 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 static constexpr bool is_single_rate_mfma = (((is_same::value || is_same::value) && lcm_AK1_BK1 <= 4) || - (is_same::value && lcm_AK1_BK1 <= 8)) + (is_same::value && lcm_AK1_BK1 <= 8) || + ((is_same::value || is_same::value) && + lcm_AK1_BK1 < 32)) ? true : false; - + static constexpr auto is_scale_mfma = false; static constexpr index_t KPack = math::max(lcm_AK1_BK1, - MfmaSelector:: - selected_mfma.k_per_blk); + MfmaSelector::selected_mfma.k_per_blk); using ThisThreadBlock = ThisThreadBlock; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp index 21812380c2..ef84dd182a 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp @@ -153,14 +153,20 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3 static constexpr bool is_single_rate_mfma = (((is_same::value || is_same::value) && lcm_AK1_BK1 <= 4) || - (is_same::value && lcm_AK1_BK1 <= 8)) + (is_same::value && lcm_AK1_BK1 <= 8) || + ((is_same::value || is_same::value) && + lcm_AK1_BK1 < 32)) ? true : false; - + static constexpr auto is_scale_mfma = false; static constexpr index_t KPack = math::max(lcm_AK1_BK1, - MfmaSelector:: - selected_mfma.k_per_blk); + MfmaSelector::selected_mfma.k_per_blk); using ThisThreadBlock = ThisThreadBlock; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp index c0d9464136..8fb955c561 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp @@ -164,12 +164,25 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle static constexpr index_t NumDTensor = DsDataType::Size(); - using mfma_selector = MfmaSelector; - static constexpr index_t KPack = - math::max(math::lcm(AK1Number, BK1Number), mfma_selector::selected_mfma.k_per_blk); - static constexpr index_t KGroup = mfma_selector::selected_mfma.k_per_blk == 32 ? 2 : 1; - static constexpr index_t KLane = - mfma_selector::GetKPerXdlops() / mfma_selector::GetK1PerXdlops(); + static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number); + static constexpr bool is_single_rate_mfma = + (((is_same::value || is_same::value) && + lcm_AK1_BK1 <= 4) || + (is_same::value && lcm_AK1_BK1 <= 8) || + ((is_same::value || is_same::value) && + lcm_AK1_BK1 < 32)) + ? true + : false; + static constexpr auto is_scale_mfma = false; + static constexpr auto mfma = MfmaSelector{}; + static constexpr index_t KPack = math::max(lcm_AK1_BK1, mfma.selected_mfma.k_per_blk); + static constexpr index_t KGroup = mfma.selected_mfma.k_per_blk == 32 ? 2 : 1; + static constexpr index_t KLane = mfma.GetKPerXdlops() / mfma.GetK1PerXdlops(); static constexpr index_t KPackPerGroup = KPack / KGroup; static constexpr index_t KRepeat = KPerBlock / KLane / KPackPerGroup; static constexpr index_t NLane = NPerXdl; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp index b435fd5d5a..67fb4d651e 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp @@ -493,13 +493,16 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 constexpr bool is_single_rate_mfma = (((is_same::value || is_same::value) && lcm_AK1_BK1 <= 4) || - (is_same::value && lcm_AK1_BK1 <= 8)) + (is_same::value && lcm_AK1_BK1 <= 8) || + ((is_same::value || is_same::value) && + lcm_AK1_BK1 < 32)) ? true : false; - constexpr index_t KPack = math::max( + constexpr auto is_scale_mfma = false; + constexpr index_t KPack = math::max( lcm_AK1_BK1, - MfmaSelector::selected_mfma - .k_per_blk); + MfmaSelector:: + selected_mfma.k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_waveletmodel_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_waveletmodel_cshuffle.hpp index ad65e75ef9..50363d832e 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_waveletmodel_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_waveletmodel_cshuffle.hpp @@ -491,13 +491,20 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle constexpr bool is_single_rate_mfma = (((is_same::value || is_same::value) && lcm_AK1_BK1 <= 4) || - (is_same::value && lcm_AK1_BK1 <= 8)) + (is_same::value && lcm_AK1_BK1 <= 8) || + ((is_same::value || is_same::value) && + lcm_AK1_BK1 < 32)) ? true : false; - constexpr index_t KPack = math::max( - lcm_AK1_BK1, - MfmaSelector:: - selected_mfma.k_per_blk); + constexpr auto is_scale_mfma = false; + constexpr index_t KPack = + math::max(lcm_AK1_BK1, + MfmaSelector::selected_mfma.k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1< TileMathThreadGroupSize, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp index 168c553180..b7947309e4 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp @@ -744,14 +744,19 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight constexpr bool is_single_rate_mfma = (((is_same::value || is_same::value) && K1 <= 4) || - (is_same::value && K1 <= 8)) + (is_same::value && K1 <= 8) || + ((is_same::value || is_same::value) && + K1 < 32)) ? true : false; - - constexpr index_t KPack = math::max( - K1, - MfmaSelector:: - selected_mfma.k_per_blk); + constexpr auto is_scale_mfma = false; + constexpr index_t KPack = math::max(K1, + MfmaSelector::selected_mfma.k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1::value || is_same::value) && lcm_AK1_BK1 <= 4) || - (is_same::value && lcm_AK1_BK1 <= 8)) + (is_same::value && lcm_AK1_BK1 <= 8) || + ((is_same::value || is_same::value) && + lcm_AK1_BK1 < 32)) ? true : false; - constexpr index_t k_pack = math::max( + constexpr auto is_scale_mfma = false; + constexpr index_t k_pack = math::max( lcm_AK1_BK1, - MfmaSelector::selected_mfma - .k_per_blk); + MfmaSelector:: + selected_mfma.k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1; - static constexpr index_t KPack = - math::max(math::lcm(AK1Number, BK1Number), mfma_selector::selected_mfma.k_per_blk); - static constexpr index_t KLane = - mfma_selector::GetKPerXdlops() / mfma_selector::GetK1PerXdlops(); - static constexpr index_t KRepeat = KPerBlock / KLane / KPack; - static constexpr index_t NLane = NPerXdl; - static constexpr index_t NWave = NPerBlock / NPerXdl / NXdlPerWave; + static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number); + static constexpr bool is_single_rate_mfma = + (((is_same::value || is_same::value) && + lcm_AK1_BK1 <= 4) || + (is_same::value && lcm_AK1_BK1 <= 8) || + ((is_same::value || is_same::value) && + lcm_AK1_BK1 < 32)) + ? true + : false; + static constexpr auto is_scale_mfma = false; + static constexpr auto mfma = MfmaSelector{}; + static constexpr index_t KPack = math::max(lcm_AK1_BK1, mfma.selected_mfma.k_per_blk); + static constexpr index_t KLane = mfma.GetKPerXdlops() / mfma.GetK1PerXdlops(); + static constexpr index_t KRepeat = KPerBlock / KLane / KPack; + static constexpr index_t NLane = NPerXdl; + static constexpr index_t NWave = NPerBlock / NPerXdl / NXdlPerWave; // static constexpr index_t NumTokens = 1; static constexpr index_t SortedTileSize = MPerBlock; diff --git a/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp b/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp index 06268f3cfb..b825d7ab69 100644 --- a/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp +++ b/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp @@ -1117,12 +1117,31 @@ struct MfmaSelector #endif } + // Use singal rate mfma instruction for this special case A (f8_t) * B (pk_i4_t) + // See example gemm_xdl_fp8_pk_i4_bpreshuffle_v3 + // TODO: explore optimization opportunity by using new mfma instructions on gfx950 template <> - constexpr auto GetMfma() + constexpr auto GetMfma() { return MfmaInstr::mfma_f32_32x32x16f8f8; } + template <> + constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_32x32x16f8f8; + } + + template <> + constexpr auto GetMfma() + { +#if defined(__gfx950__) + return MfmaInstr::mfma_f32_32x32x64f8f6f4; +#else + return MfmaInstr::mfma_f32_32x32x16f8f8; +#endif + } + template <> constexpr auto GetMfma() { @@ -1136,11 +1155,21 @@ struct MfmaSelector } template <> - constexpr auto GetMfma() + constexpr auto GetMfma() { return MfmaInstr::mfma_f32_16x16x32f8f8; } + template <> + constexpr auto GetMfma() + { +#if defined(__gfx950__) + return MfmaInstr::mfma_f32_16x16x128f8f6f4; +#else + return MfmaInstr::mfma_f32_16x16x32f8f8; +#endif + } + template <> constexpr auto GetMfma() { @@ -1166,41 +1195,101 @@ struct MfmaSelector } template <> - constexpr auto GetMfma() + constexpr auto GetMfma() { return MfmaInstr::mfma_f32_32x32x16bf8bf8; } template <> - constexpr auto GetMfma() + constexpr auto GetMfma() + { +#if defined(__gfx950__) + return MfmaInstr::mfma_f32_32x32x64f8f6f4; +#else + return MfmaInstr::mfma_f32_32x32x16bf8bf8; +#endif + } + + template <> + constexpr auto GetMfma() { return MfmaInstr::mfma_f32_16x16x32bf8bf8; } template <> - constexpr auto GetMfma() + constexpr auto GetMfma() + { +#if defined(__gfx950__) + return MfmaInstr::mfma_f32_16x16x128f8f6f4; +#else + return MfmaInstr::mfma_f32_16x16x32bf8bf8; +#endif + } + + template <> + constexpr auto GetMfma() { return MfmaInstr::mfma_f32_32x32x16f8bf8; } template <> - constexpr auto GetMfma() + constexpr auto GetMfma() + { +#if defined(__gfx950__) + return MfmaInstr::mfma_f32_32x32x64f8f6f4; +#else + return MfmaInstr::mfma_f32_32x32x16f8bf8; +#endif + } + + template <> + constexpr auto GetMfma() { return MfmaInstr::mfma_f32_16x16x32f8bf8; } template <> - constexpr auto GetMfma() + constexpr auto GetMfma() + { +#if defined(__gfx950__) + return MfmaInstr::mfma_f32_16x16x128f8f6f4; +#else + return MfmaInstr::mfma_f32_16x16x32f8bf8; +#endif + } + + template <> + constexpr auto GetMfma() { return MfmaInstr::mfma_f32_32x32x16bf8f8; } template <> - constexpr auto GetMfma() + constexpr auto GetMfma() + { +#if defined(__gfx950__) + return MfmaInstr::mfma_f32_32x32x64f8f6f4; +#else + return MfmaInstr::mfma_f32_32x32x16bf8f8; +#endif + } + + template <> + constexpr auto GetMfma() { return MfmaInstr::mfma_f32_16x16x32bf8f8; } + template <> + constexpr auto GetMfma() + { +#if defined(__gfx950__) + return MfmaInstr::mfma_f32_16x16x128f8f6f4; +#else + return MfmaInstr::mfma_f32_16x16x32bf8f8; +#endif + } + static constexpr auto selected_mfma = mfma_type::value || - is_same::value) && - KPack <= 4) || - (is_same::value && KPack <= 8)) - ? true - : false, - is_scale_mfma > {}; + // Falls back to single rate instruction on gfx950 if KPack is single rate; no change on gfx942- + // when base_type is either f8_t or bf8_t, additional_type will always be either f8_t or bf8_t, + // except Use single rate mfma instruction for this special case A (f8_t) * B (pk_i4_t) + static constexpr bool is_single_rate_mfma = + (((is_same::value || is_same::value) && + KPack <= 4) || + (is_same::value && KPack <= 8) || + ((is_same::value || is_same::value) && KPack < 32) || + is_same::value) + ? true + : false; + static constexpr auto mfma = MfmaSelector{}; static constexpr auto mfma_instr = mfma.selected_mfma; diff --git a/include/ck/utility/amd_xdlops.hpp b/include/ck/utility/amd_xdlops.hpp index ad48389625..ed3354dfb5 100644 --- a/include/ck/utility/amd_xdlops.hpp +++ b/include/ck/utility/amd_xdlops.hpp @@ -533,6 +533,50 @@ struct intrin_mfma_f32_32x32x64f8f6f4<32, 32> #endif } + template + __device__ static void Run(const bf8x32_t& reg_a, const f8x32_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx950__) + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4( + reg_a, + reg_b, + reg_c.template AsType()[Number<0>{}], + 1, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} + 0, // blgp + 0, + 0, + 0, + 0); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif + } + + template + __device__ static void Run(const f8x32_t& reg_a, const bf8x32_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx950__) + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4( + reg_a, + reg_b, + reg_c.template AsType()[Number<0>{}], + 0, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} + 1, // blgp + 0, + 0, + 0, + 0); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif + } + template __device__ static void Run(const f4x32_t& reg_a, const f4x32_t& reg_b, FloatC& reg_c) { @@ -1118,6 +1162,52 @@ struct intrin_mfma_f32_16x16x128f8f6f4<16, 16> #endif } + template + __device__ static void Run(const bf8x32_t& reg_a, const f8x32_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx950__) + // https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10 + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( + reg_a, + reg_b, + reg_c.template AsType()[Number<0>{}], + 1, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} + 0, // blgp + 0, + 0, + 0, + 0); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif + } + + template + __device__ static void Run(const f8x32_t& reg_a, const bf8x32_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx950__) + // https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10 + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( + reg_a, + reg_b, + reg_c.template AsType()[Number<0>{}], + 0, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} + 1, // blgp + 0, + 0, + 0, + 0); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif + } + template __device__ static void Run(const f4x32_t& reg_a, const f4x32_t& reg_b, FloatC& reg_c) { From 0970f22221917d559d0bd72c1065303f08a70450 Mon Sep 17 00:00:00 2001 From: Jan Patrick Lehr Date: Tue, 20 May 2025 02:30:15 +0200 Subject: [PATCH 012/147] [CMake] Disable newly added compiler warning -Wnrvo (#2210) Recently a new warning was added to Clang to warn when no copy-elision on return happens. That prevents our CK build. This disables the warning. --- CMakeLists.txt | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CMakeLists.txt b/CMakeLists.txt index 4e12462a41..13606975c0 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -94,6 +94,9 @@ add_compile_options(-Wno-pass-failed) add_compile_options(-Wno-switch-default) add_compile_options(-Wno-unique-object-duplication) +# Recent change in compiler makes this warning ON by default, which led to compile errors. +add_compile_options(-Wno-nrvo) + if(NOT DISABLE_DL_KERNELS) add_definitions(-DDL_KERNELS) set(DL_KERNELS "ON") From c4929225f60f56e3a9320547dcfdff30a77f0aa3 Mon Sep 17 00:00:00 2001 From: Aviral Goel Date: Mon, 19 May 2025 19:31:04 -0500 Subject: [PATCH 013/147] remove debug statements from CMakeLists (#2204) --- example/CMakeLists.txt | 2 -- test/CMakeLists.txt | 4 ---- 2 files changed, 6 deletions(-) diff --git a/example/CMakeLists.txt b/example/CMakeLists.txt index 996a543ecc..9c30a2e255 100644 --- a/example/CMakeLists.txt +++ b/example/CMakeLists.txt @@ -135,11 +135,9 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME) endif() #message("add_example returns ${result}") if(result EQUAL 0 AND NOT "${EXAMPLE_NAME}" IN_LIST REGRESSION_EXAMPLES) - #message("adding to SMOKE EXAMPLE FILTER ${EXAMPLE_NAME}") set_tests_properties(${EXAMPLE_NAME} PROPERTIES LABELS "SMOKE_TEST") add_dependencies(smoke ${EXAMPLE_NAME}) elseif(result EQUAL 0 AND "${EXAMPLE_NAME}" IN_LIST REGRESSION_EXAMPLES) - #message("Adding to REGRESSION EXAMPLE FILTER ${EXAMPLE_NAME}") set_tests_properties(${EXAMPLE_NAME} PROPERTIES LABELS "REGRESSION_TEST") add_dependencies(regression ${EXAMPLE_NAME}) endif() diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 69ffb94488..5ea61d2dfc 100755 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -121,11 +121,9 @@ function(add_test_executable TEST_NAME) #message("add_test returns ${result}") set(result ${result} PARENT_SCOPE) if(result EQUAL 0 AND NOT "${TEST_NAME}" IN_LIST REGRESSION_TESTS) - message("adding to SMOKE TEST FILTER ${TEST_NAME}") set_tests_properties(${TEST_NAME} PROPERTIES LABELS "SMOKE_TEST") add_dependencies(smoke ${TEST_NAME}) elseif(result EQUAL 0 AND "${TEST_NAME}" IN_LIST REGRESSION_TESTS) - message("Adding to REGRESSION TEST FILTER ${TEST_NAME}") set_tests_properties(${TEST_NAME} PROPERTIES LABELS "REGRESSION_TEST") add_dependencies(regression ${TEST_NAME}) endif() @@ -222,11 +220,9 @@ function(add_gtest_executable TEST_NAME) #message("add_gtest returns ${result}") set(result ${result} PARENT_SCOPE) if(result EQUAL 0 AND NOT "${TEST_NAME}" IN_LIST REGRESSION_TESTS) - #message("adding to smoke test FILTER ${TEST_NAME}") set_tests_properties(${TEST_NAME} PROPERTIES LABELS "SMOKE_TEST") add_dependencies(smoke ${TEST_NAME}) elseif(result EQUAL 0 AND "${TEST_NAME}" IN_LIST REGRESSION_TESTS) - #message("Adding to REGRESSION TEST FILTER ${TEST_NAME}") set_tests_properties(${TEST_NAME} PROPERTIES LABELS "REGRESSION_TEST") add_dependencies(regression ${TEST_NAME}) endif() From d1e6f0982d04d6b356f001da731dd5e315f78812 Mon Sep 17 00:00:00 2001 From: Sami Remes Date: Tue, 20 May 2025 17:18:57 +0300 Subject: [PATCH 014/147] [CK_TILE] Grouped GEMM tile loop (#2146) * Add trait to use a persistent kernel and split the entrypoints in grouped gemm * Some helper functions for persistent kernel case * Get max occupancy grid using device properties * Implement tile loop in main entry point to grouped gemm * Enable GridSize() on device * Handle offset tile index using real current block index * Add persistent kernel choice to grouped gemm example * Use a for-loop for iterating over the group * Reduce VGPR spills by early-exit * Enable persistent kernel choice in grouped_gemm example * Add persistent kernel option to grouped_gemm test * Fix formatting with remod.py * Remove GridUpdateBlocks as blocks are now iteratively computed * Add comment about VGPR spilling * Fix formatting * Use CK_TILE_HOST instead of __host__ * Enable all Row/Col combinations in grouped gemm unit test * Add some KBatch=2 cases to grouped gemm tests * Fix SplitK for grouped gemm * Enable pipeline hotloop/tailnumber selection in-kernel for grouped gemm * Add type traits * Split examples to regular and tileloop * Formatting * Use hipExtStreamGetCUMask to get current active CUs for the given stream * Align test and example kernel config, and disable validation for splitk repeats * Remove debug options from CMakeLists.txt * Separate the code paths for persistent/non-persistent in test * Fix formatting * Address review comments --------- Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com> --- .../ck_tile/17_grouped_gemm/CMakeLists.txt | 2 +- .../ck_tile/17_grouped_gemm/grouped_gemm.cpp | 132 ++++----- .../ck_tile/17_grouped_gemm/grouped_gemm.hpp | 17 +- .../17_grouped_gemm/grouped_gemm_tileloop.cpp | 174 ++++++++++++ .../run_grouped_gemm_example.inc | 98 +++++-- include/ck_tile/core/utility/type_traits.hpp | 11 + include/ck_tile/host/stream_utils.hpp | 45 ++++ .../ops/gemm/kernel/gemm_tile_partitioner.hpp | 18 +- .../ops/gemm/kernel/grouped_gemm_kernel.hpp | 252 ++++++++++++++++-- .../gemm_pipeline_ag_bg_cr_comp_v3.hpp | 12 +- ...ine_agmem_bgmem_creg_v1_default_policy.hpp | 0 .../ops/gemm/pipeline/tile_gemm_traits.hpp | 24 +- .../grouped_gemm/test_grouped_gemm.cpp | 30 ++- .../test_grouped_gemm_ut_cases.inc | 30 ++- .../grouped_gemm/test_grouped_gemm_util.hpp | 209 +++++++++++++-- 15 files changed, 908 insertions(+), 146 deletions(-) create mode 100644 example/ck_tile/17_grouped_gemm/grouped_gemm_tileloop.cpp create mode 100644 include/ck_tile/host/stream_utils.hpp mode change 100755 => 100644 include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp diff --git a/example/ck_tile/17_grouped_gemm/CMakeLists.txt b/example/ck_tile/17_grouped_gemm/CMakeLists.txt index d34013dd6c..79df4e624d 100644 --- a/example/ck_tile/17_grouped_gemm/CMakeLists.txt +++ b/example/ck_tile/17_grouped_gemm/CMakeLists.txt @@ -1,2 +1,2 @@ add_executable(tile_example_grouped_gemm EXCLUDE_FROM_ALL grouped_gemm.cpp) - +add_executable(tile_example_grouped_gemm_tileloop EXCLUDE_FROM_ALL grouped_gemm_tileloop.cpp) diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp b/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp index 9b134ff779..61193e2e29 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp @@ -16,15 +16,10 @@ #include "ck_tile/host.hpp" #include "grouped_gemm.hpp" -std::size_t get_workspace_size(const std::vector& gemm_descs) -{ - return gemm_descs.size() * sizeof(ck_tile::GemmTransKernelArg); -} - template float grouped_gemm(const std::vector& gemm_descs, const ck_tile::stream_config& s, - void* p_workspace_) + void* kargs_ptr) { #if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY) // Memory friendly for Interwave scheduler @@ -114,70 +109,76 @@ float grouped_gemm(const std::vector& gemm_descs, float ave_time{0}; - const auto Run = - [&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - constexpr auto scheduler = GEMM_PIPELINE_SCHEDULER; - constexpr auto memory_operation = memory_operation_.value; + const auto Run = [&](const auto has_hot_loop_, + const auto tail_number_, + const auto memory_operation_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = GEMM_PIPELINE_SCHEDULER; + constexpr auto memory_operation = memory_operation_.value; - using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; - using GemmPipeline = GEMM_PIPELINE; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; - using Kernel = ck_tile::GroupedGemmKernel; - auto kargs = Kernel::MakeKargs(gemm_descs); + using GemmPipeline = GEMM_PIPELINE; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + using Kernel = ck_tile::GroupedGemmKernel; + auto kargs = Kernel::MakeKargs(gemm_descs); + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Kernel arguments not supported!"); + } - const dim3 grids = Kernel::GridSize(gemm_descs); - constexpr dim3 blocks = Kernel::BlockSize(); + constexpr dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(gemm_descs); - ck_tile::hip_check_error(hipMemcpyWithStream(p_workspace_, - kargs.data(), - get_workspace_size(gemm_descs), - hipMemcpyHostToDevice, - s.stream_id_)); + HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr, + kargs.data(), + get_workspace_size(gemm_descs), + hipMemcpyHostToDevice, + s.stream_id_)); - if(s.log_level_ > 0) - { - std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" - << " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z - << "}" << std::endl; - } + if(s.log_level_ > 0) + { + std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" + << " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << std::endl; + } - ave_time = ck_tile::launch_kernel( - s, - ck_tile::make_kernel( - Kernel{}, - grids, - blocks, - 0, - ck_tile::cast_pointer_to_constant_address_space(p_workspace_), - gemm_descs.size())); - return ave_time; - }; + ave_time = + ck_tile::launch_kernel(s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), + gemm_descs.size())); + + return ave_time; + }; const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { if(gemm_descs[0].k_batch == 1) @@ -317,4 +318,5 @@ float grouped_gemm(const std::vector& gemm_descs, #include "run_grouped_gemm_example.inc" -int main(int argc, char* argv[]) { return !run_grouped_gemm_example(argc, argv); } +constexpr bool Persistent = false; +int main(int argc, char* argv[]) { return !run_grouped_gemm_example(argc, argv); } diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp b/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp index 4fec329c2f..77db182c72 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp @@ -70,14 +70,25 @@ auto create_args(int argc, char* argv[]) .insert("validate", "1", "0. No validation, 1. Validation on CPU.") .insert("warmup", "10", "number of iterations before benchmark the kernel.") .insert("repeat", "100", "number of iterations to benchmark the kernel.") - .insert("group_count", "8", "group count."); + .insert("group_count", "8", "group count.") + .insert("kbatch", "1", "kbatch for SplitK"); bool result = arg_parser.parse(argc, argv); return std::make_tuple(result, arg_parser); } -std::size_t get_workspace_size(const std::vector& gemm_descs); +inline std::size_t get_workspace_size(const std::vector& gemm_descs) +{ + return gemm_descs.size() * sizeof(ck_tile::GemmTransKernelArg); +} +template float grouped_gemm(const std::vector& gemm_descs, const ck_tile::stream_config& s, - void* p_workspace_); + void* kargs_ptr); + +template +float grouped_gemm_tileloop(const ck_tile::stream_config& s, + const ck_tile::index_t num_groups, + void* kargs_ptr, + bool splitk = false); diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm_tileloop.cpp b/example/ck_tile/17_grouped_gemm/grouped_gemm_tileloop.cpp new file mode 100644 index 0000000000..5c0cb92683 --- /dev/null +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm_tileloop.cpp @@ -0,0 +1,174 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include +#include +#include +#include +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/epilogue.hpp" +#include "ck_tile/ops/gemm.hpp" +#include "ck_tile/host.hpp" +#include "grouped_gemm.hpp" + +template +float grouped_gemm_tileloop(const ck_tile::stream_config& s, + const ck_tile::index_t num_groups, + void* kargs_ptr, + bool splitk) +{ +#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY) + // Memory friendly for Interwave scheduler + constexpr ck_tile::index_t M_Tile = 128; + constexpr ck_tile::index_t N_Tile = 32; + constexpr ck_tile::index_t K_Tile = 64; + + constexpr ck_tile::index_t M_Warp = 4; + constexpr ck_tile::index_t N_Warp = 1; + constexpr ck_tile::index_t K_Warp = 1; + + constexpr ck_tile::index_t M_Warp_Tile = 32; + constexpr ck_tile::index_t N_Warp_Tile = 32; + constexpr ck_tile::index_t K_Warp_Tile = 8; + + constexpr bool DoubleSmemBuffer = false; +#endif +#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3) + // Compute friendly for Intrawave scheduler + constexpr ck_tile::index_t M_Tile = 256; + constexpr ck_tile::index_t N_Tile = 256; + constexpr ck_tile::index_t K_Tile = 64; + + constexpr ck_tile::index_t M_Warp = 2; + constexpr ck_tile::index_t N_Warp = 2; + constexpr ck_tile::index_t K_Warp = 1; + + constexpr ck_tile::index_t M_Warp_Tile = 32; + constexpr ck_tile::index_t N_Warp_Tile = 32; + constexpr ck_tile::index_t K_Warp_Tile = 16; + + constexpr bool DoubleSmemBuffer = false; +#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4) + // Compute friendly for Intrawave scheduler + // Using the ping pong reader in the lds level + constexpr ck_tile::index_t M_Tile = 256; + constexpr ck_tile::index_t N_Tile = 256; + constexpr ck_tile::index_t K_Tile = 32; + + constexpr ck_tile::index_t M_Warp = 2; + constexpr ck_tile::index_t N_Warp = 2; + constexpr ck_tile::index_t K_Warp = 1; + + constexpr ck_tile::index_t M_Warp_Tile = 32; + constexpr ck_tile::index_t N_Warp_Tile = 32; + constexpr ck_tile::index_t K_Warp_Tile = 16; + + constexpr bool DoubleSmemBuffer = true; +#endif + + constexpr bool kPadM = false; + constexpr bool kPadN = false; + constexpr bool kPadK = false; + + constexpr int kBlockPerCu = 1; + constexpr ck_tile::index_t TileParitionerGroupNum = 8; + constexpr ck_tile::index_t TileParitionerM01 = 4; + + using GemmShape = + ck_tile::TileGemmShape, + ck_tile::sequence, + ck_tile::sequence>; + using TilePartitioner = ck_tile:: + GemmSpatiallyLocalTilePartitioner; + + using Traits = ck_tile::TileGemmTraits; + using GemmUniversalTraits = ck_tile::PersistentTileGemmUniversalTraits; + using GemmPipelineProblem = + ck_tile::GemmPipelineProblem; + + float ave_time{0}; + + const auto Run = [&](const auto memory_operation_) { + constexpr auto scheduler = GEMM_PIPELINE_SCHEDULER; + constexpr auto memory_operation = memory_operation_.value; + + // We create the GEMM pipeline without specifying hotloop or tailnumber. + // These are automatically run inside the kernel based on the given input data. + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + + using GemmPipeline = GEMM_PIPELINE; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + using Kernel = ck_tile::GroupedGemmKernel; + constexpr dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::MaxOccupancyGridSize(s); + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" + << " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << std::endl; + } + + ave_time = + ck_tile::launch_kernel(s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), + num_groups)); + + return ave_time; + }; + + if(!splitk) + { + Run(ck_tile::integral_constant{}); + } + else + { + Run(ck_tile::integral_constant{}); + } + + return ave_time; +} + +#include "run_grouped_gemm_example.inc" + +constexpr bool Persistent = true; +int main(int argc, char* argv[]) { return !run_grouped_gemm_example(argc, argv); } diff --git a/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc b/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc index f068510d26..a01d8178cc 100644 --- a/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc +++ b/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc @@ -30,20 +30,60 @@ auto calculate_rtol_atol(const ck_tile::index_t K, return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); } -template +template float invoke_gemm(int n_warmup, int n_repeat, int group_count, const std::vector& args) { - + // Workspace memory allocated to hold the gemm descriptions. ck_tile::DeviceMem gemm_workspace; gemm_workspace.Realloc(get_workspace_size(args)); - float ave_time = grouped_gemm( - args, - ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}, - gemm_workspace.GetDeviceBuffer()); + float ave_time = 0; + if constexpr(!Persistent) + { + // Regular version of grouped gemm + ave_time = grouped_gemm( + args, + ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}, + gemm_workspace.GetDeviceBuffer()); + } + else + { + // NOTE: With the persistent TileLoop kernel, we do not necessarily need to have + // the gemm problems known on the host. Instead, we can just pass the pointer + // to the kernel and let the workgroups figure out which tiles to work on. + // This is useful when the gemm problems are generated dynamically. + // In this example however, we generate the `kargs` using the known gemm_descs, + // and copy the gemm descriptions to the device memory. + // The contents of the memory pointed to by `kargs_ptr` pointer could be + // written by e.g. another kernel from earlier stage. + std::vector kargs; + void* kargs_ptr = gemm_workspace.GetDeviceBuffer(); + const bool splitk = args[0].k_batch > 1; + for(const auto& arg : args) + { + kargs.emplace_back(ck_tile::GemmKernelArgs{arg.a_ptr, + arg.b_ptr, + arg.c_ptr, + arg.M, + arg.N, + arg.K, + arg.stride_A, + arg.stride_B, + arg.stride_C, + arg.k_batch}); + } + const auto stream = ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}; + HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr, + kargs.data(), + kargs.size() * sizeof(ck_tile::GemmTransKernelArg), + hipMemcpyHostToDevice, + stream.stream_id_)); + ave_time = grouped_gemm_tileloop( + stream, group_count, kargs_ptr, splitk); + } std::string op_name{"Grouped Gemm"}; @@ -66,7 +106,7 @@ float invoke_gemm(int n_warmup, return ave_time; } -template +template int run_grouped_gemm_example_with_layouts(int argc, char* argv[], const ALayout a_layout = ALayout{}, @@ -87,6 +127,15 @@ int run_grouped_gemm_example_with_layouts(int argc, const int group_count = arg_parser.get_int("group_count"); const int repeat = arg_parser.get_int("repeat"); const int warmup = arg_parser.get_int("warmup"); + const int kbatch = arg_parser.get_int("kbatch"); + bool validate = arg_parser.get_bool("validate"); + + if(kbatch > 1 && validate && warmup + repeat > 1) + { + std::cout << "WARNING: Data validation enabled with SplitK and more than" + << "1 warmup/repeat. Disabling validation." << std::endl; + validate = false; + } std::vector Ms = arg_parser.get_int_vec("Ms"); std::vector Ns = arg_parser.get_int_vec("Ns"); @@ -102,7 +151,7 @@ int run_grouped_gemm_example_with_layouts(int argc, { Ms.push_back(256 + 256 * i); Ns.push_back(256 + 512 * i); - Ks.push_back(256 + 64 * i); + Ks.push_back(512 + 128 * i); stride_As.push_back(Ks[i]); stride_Bs.push_back(Ks[i]); @@ -150,8 +199,8 @@ int run_grouped_gemm_example_with_layouts(int argc, << " a_m_k: " << a_m_k_tensors[i].mDesc << " b_k_n: " << b_k_n_tensors[i].mDesc << " c_m_n: " << c_m_n_tensors[i].mDesc << std::endl; - ck_tile::FillUniformDistribution{-5.f, 5.f}(a_m_k_tensors[i]); - ck_tile::FillUniformDistribution{-5.f, 5.f}(b_k_n_tensors[i]); + ck_tile::FillUniformDistribution{-1.f, 1.f}(a_m_k_tensors[i]); + ck_tile::FillUniformDistribution{-1.f, 1.f}(b_k_n_tensors[i]); a_m_k_dev_buf.push_back(std::make_unique( a_m_k_tensors[i].get_element_space_size_in_bytes())); @@ -169,13 +218,11 @@ int run_grouped_gemm_example_with_layouts(int argc, const void* p_b = b_k_n_dev_buf[i]->GetDeviceBuffer(); void* p_c = c_m_n_dev_buf[i]->GetDeviceBuffer(); - // TODO Add support for kbatch > 1 in grouped gemm - static constexpr ck_tile::index_t k_batch = 1; gemm_descs.push_back( - {p_a, p_b, p_c, k_batch, M, N, K, stride_As[i], stride_Bs[i], stride_Cs[i]}); + {p_a, p_b, p_c, kbatch, M, N, K, stride_As[i], stride_Bs[i], stride_Cs[i]}); } - invoke_gemm(warmup, repeat, group_count, gemm_descs); + invoke_gemm(warmup, repeat, group_count, gemm_descs); for(int i = 0; i < group_count; i++) { @@ -183,7 +230,7 @@ int run_grouped_gemm_example_with_layouts(int argc, } bool pass{true}; - if(arg_parser.get_int("validate")) + if(validate) { for(int i = 0; i < group_count; ++i) { @@ -194,7 +241,7 @@ int run_grouped_gemm_example_with_layouts(int argc, a_m_k_tensors[i], b_k_n_tensors[i], c_m_n_host_ref); const float max_accumulated_value = *std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end()); - const auto rtol_atol = calculate_rtol_atol(Ks[i], 1 /*kbatch*/, max_accumulated_value); + const auto rtol_atol = calculate_rtol_atol(Ks[i], kbatch, max_accumulated_value); pass &= ck_tile::check_err(c_m_n_tensors[i], c_m_n_host_ref, "Error: Incorrect results!", @@ -211,6 +258,7 @@ int run_grouped_gemm_example_with_layouts(int argc, return pass; } +template int run_grouped_gemm_example(int argc, char* argv[]) { auto [result, arg_parser] = create_args(argc, argv); @@ -227,12 +275,20 @@ int run_grouped_gemm_example(int argc, char* argv[]) if(a_layout == "R" && b_layout == "C") { - return run_grouped_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); + return run_grouped_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); + } + else if(a_layout == "R" && b_layout == "R") + { + return run_grouped_gemm_example_with_layouts(argc, argv, Row{}, Row{}, Row{}); + } + else if(a_layout == "C" && b_layout == "R") + { + return run_grouped_gemm_example_with_layouts(argc, argv, Col{}, Row{}, Row{}); + } + else if(a_layout == "C" && b_layout == "C") + { + return run_grouped_gemm_example_with_layouts(argc, argv, Col{}, Col{}, Row{}); } - // else if(a_layout == "R" && b_layout == "R") - // { - // return run_grouped_gemm_example_with_layouts(argc, argv, Row{}, Row{}, Row{}); - // } else { throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!"); diff --git a/include/ck_tile/core/utility/type_traits.hpp b/include/ck_tile/core/utility/type_traits.hpp index b432cfcef7..2e82e21ba1 100644 --- a/include/ck_tile/core/utility/type_traits.hpp +++ b/include/ck_tile/core/utility/type_traits.hpp @@ -127,4 +127,15 @@ struct is_any_of { }; +// Helper to check if a type is a specialization of a given template +template class RefTemplate> +struct is_specialization_of : std::false_type +{ +}; + +template