diff --git a/.azuredevops/rocm-ci.yml b/.azuredevops/rocm-ci.yml index 4161c2d5a4..b37b8cc27f 100644 --- a/.azuredevops/rocm-ci.yml +++ b/.azuredevops/rocm-ci.yml @@ -14,6 +14,7 @@ trigger: branches: include: - develop + - amd-develop paths: exclude: - .github diff --git a/CMakeLists.txt b/CMakeLists.txt index 86ad9d39d8..1fe1bc91d5 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -103,7 +103,7 @@ if(DPP_KERNELS) endif() option(CK_USE_CODEGEN "Enable codegen library" OFF) if(CK_USE_CODEGEN) - add_definitions(-DCK_USE_CODEGEN) + add_definitions(-DCK_USE_CODEGEN) endif() option(CK_TIME_KERNEL "Enable kernel time tracking" ON) @@ -196,17 +196,20 @@ if (SUPPORTED_GPU_TARGETS MATCHES "gfx9") add_definitions(-DCK_USE_XDL) set(CK_USE_XDL "ON") endif() -if (SUPPORTED_GPU_TARGETS MATCHES "gfx94") +if (SUPPORTED_GPU_TARGETS MATCHES "gfx94" OR SUPPORTED_GPU_TARGETS MATCHES "gfx95") message("Enabling FP8 gemms on native architectures") add_definitions(-DCK_USE_GFX94) set(CK_USE_GFX94 "ON") endif() +if (SUPPORTED_GPU_TARGETS MATCHES "gfx95") + add_definitions(-DCK_USE_AMD_MFMA_GFX950) +endif() if (SUPPORTED_GPU_TARGETS MATCHES "gfx11" OR SUPPORTED_GPU_TARGETS MATCHES "gfx12") message("Enabling WMMA instances") add_definitions(-DCK_USE_WMMA) set(CK_USE_WMMA "ON") endif() -if (SUPPORTED_GPU_TARGETS MATCHES "gfx12") +if (SUPPORTED_GPU_TARGETS MATCHES "gfx12" OR SUPPORTED_GPU_TARGETS MATCHES "gfx950") add_definitions(-DCK_USE_OCP_FP8) set(CK_USE_OCP_FP8 "ON") endif() @@ -214,6 +217,10 @@ if (SUPPORTED_GPU_TARGETS MATCHES "gfx90a" OR SUPPORTED_GPU_TARGETS MATCHES "gfx add_definitions(-DCK_USE_FNUZ_FP8) set(CK_USE_FNUZ_FP8 "ON") endif() +if (SUPPORTED_GPU_TARGETS MATCHES "gfx950") + add_definitions(-DCK_USE_NATIVE_MX_SUPPORT) + set(CK_USE_NATIVE_MX_SUPPORT "ON") +endif() option(CK_USE_FP8_ON_UNSUPPORTED_ARCH "Enable FP8 GEMM instances on older architectures" OFF) if(CK_USE_FP8_ON_UNSUPPORTED_ARCH AND (SUPPORTED_GPU_TARGETS MATCHES "gfx90a" OR SUPPORTED_GPU_TARGETS MATCHES "gfx908")) diff --git a/Jenkinsfile b/Jenkinsfile index 2d8f7561f2..835b7e724f 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -795,8 +795,8 @@ pipeline { description: "Run the ck_tile FMHA tests (default: OFF)") booleanParam( name: "RUN_CK_TILE_GEMM_TESTS", - defaultValue: false, - description: "Run the ck_tile GEMM tests (default: OFF)") + defaultValue: true, + description: "Run the ck_tile GEMM tests (default: ON)") booleanParam( name: "BUILD_INSTANCES_ONLY", defaultValue: false, diff --git a/client_example/01_gemm/README.md b/client_example/01_gemm/README.md new file mode 100644 index 0000000000..6dcd1e2959 --- /dev/null +++ b/client_example/01_gemm/README.md @@ -0,0 +1,126 @@ +[Back to supported operations](../../../include/ck/README.md) +# Composable Kernel GEMM + +## GEMM +General matrix multiplications operation. In CK GEMM operation is called as `DeviceGemm` and requires following types as template parameters: + +* **ALayout** - A matrix layout (RowMajor/ColumnMajor). +* **BLayout** - B matrix layout (RowMajor/ColumnMajor). +* **CLayout** - B matrix layout (RowMajor/ColumnMajor). +* **ADataType** - A matrix data type. +* **BDataType** - B matrix data type. +* **CDataType** - B matrix data type. +* **AElementwiseOperation** - Fused operation on tensor A before GEMM. +* **BElementwiseOperation** - Fused operation on tensor B before GEMM. +* **CElementwiseOperation** - Fused operation on tensor C after GEMM. + +For matrices with large K dimension `DeviceGemmSplitK` implementation is available. This implementation allows user to split K dimension between work groups. This implementation uses `AtomicAdd` operation on global memory, thus need to zero-out output buffer for correct results. + +For fused operations with additional tensor there are `DeviceGemmMultipleABD` or `DeviceGemmMultipleD` operation which require following parameters: +* **DsLayout** - layouts for additional tensors for fused operations. +* **DsDataType** - data types for additional tensors for fused operations. + +For `DeviceGemmMultipleABD` **ALayout**, **BLayout**, **ADataType** and **BDataType** user should pass a tuple. + +List of the device operations in CK: + +* **DeviceGemmDl** - Device operation with DL instructions. +* **DeviceGemmDpp** - Device operation with DL instructions with DPP instructions during data load. +* **DeviceGemmWmma_CShuffle** - Device operation with WMMA instructions with CShuffle optimization for more optimized data store. +* **DeviceGemm_Xdl_CShuffle_LdsDirectLoad** - Device operation with XDL instructions and CShuffle optimization for more optimized data store and direct load from global memory to shared memory. +* **DeviceGemm_Xdl_CShuffle** - Device operation with XDL instructions with CShuffle optimization for more optimized data store. +* **DeviceGemm_Xdl_CShuffleV2** - Device operation with XDL instructions with CShuffle optimization for more optimized data store. GEMM pipeline has been optimized compared to **DeviceGemm_Xdl_CShuffle**. +* **DeviceGemmXdlSkipBLds** - Device operation with XDL instructions. Load to shared memory has been skiped for B matrix. +* **DeviceGemm_Xdl_WaveletModel_CShuffle** - Device operation with XDL instructions with CShuffle optimization for more optimized data store. Producer and consumer scheme cooperation between waves in workgroup. +* **DeviceGemmXdl** - Device operation with XDL instructions. + +Table of supported cases by instance factory with XDL instruction for Row/Row/Row, Row/Column/Row, Column/Row/Row or Column/Column/Row: + +| |Is supported| +|-------|---| +|bf16|✓| +|fp16|✓| +|fp32|✓| +|int8|✓| +|fp8 |✓| + +Table of supported cases by instance factory with WMMA instruction for Row/Row/Row, Row/Column/Row, Column/Row/Row or Column/Column/Row: + +| |Is supported| +|-------|---| +|bf16|✓| +|fp16|✓| +|fp32|✗| +|int8|✓| +|fp8 |✗| + +Table of supported cases by instance factory with DL instruction for Row/Row/Row, Row/Column/Row, Column/Row/Row or Column/Column/Row: + +| |Is supported| +|-------|---| +|bf16|✗| +|fp16|✓| +|fp32|✓| +|int8|✓| +|fp8 |✗| + +Table of supported cases by instance factory with fused output elementwise operation: + +* **B Matrix Multiply + Add + Gelu** - bf16 (int8 for B matrix) +* **B Matrix Multiply + Add** - bf16 (int8 for B matrix) +* **B Matrix Multiply + Gelu** - bf16 (int8 for B matrix) +* **B Matrix Multiply** - bf16 (int8 for B matrix) + +* **Add + Add + Gelu** - fp16 +* **Add + Gelu** - fp16, bf16 (int8 for B matrix) for Row/Column/Row +* **Multiply** - fp16 +* **Add + Multiply** - fp16 +* **Add + Relu** - fp16 (int8 for B matrix) for Row/Column/Row, bf16 (int8 for B matrix) for Row/Column/Row +* **Add + Silu** - fp16 (int8 for B matrix) for Row/Column/Row, bf16 (int8 for B matrix) for Row/Column/Row +* **Add** - fp16 (int8 for B matrix) for Row/Column/Row, bf16 (int8 for B matrix) for Row/Column/Row +* **Bilinear** - fp16, int8 +* **Gelu** - fp16 +* **Multiply + Add** - fp16 for Row/Column/Row and Row/Row/Row, fp16 (int8 for B matrix, fp32 for Bias) for Row/Column/Row and Row/Row/Row, +* **Quantization** - int8 + +## GEMM V2 (Universal GEMM) +General matrix multiplications operation optimized for MI300 series. Operation is called as `DeviceGemmV2` and requires following types as template parameters: + +* **ALayout** - A matrix layout (RowMajor/ColumnMajor). +* **BLayout** - B matrix layout (RowMajor/ColumnMajor). +* **CLayout** - B matrix layout (RowMajor/ColumnMajor). +* **ADataType** - A matrix data type. +* **BDataType** - B matrix data type. +* **CDataType** - B matrix data type. +* **AElementwiseOperation** - Fused operation on tensor A before GEMM. +* **BElementwiseOperation** - Fused operation on tensor B before GEMM. +* **CElementwiseOperation** - Fused operation on tensor C after GEMM. + +This implementation allows user to split K dimension between work groups. This implementation requires AtomicAdd operation on global memory (output buffer must be set to zeroes if splitK parameter is larger than one). + +List of the device operations for in CK: + +* **DeviceGemm_Xdl_CShuffleV3** - Device operation with XDL instructions with CShuffle optimization for more optimized data store. +* **DeviceGemm_Xdl_CShuffleV3R1** - Device operation with XDL instructions with CShuffle optimization for more optimized data store. This implementation perform reduction on splitted K dimension after GEMM instead of AtomicAdd instruction. + +Table of supported cases by instance factory with XDL instruction for Row/Row/Row, Row/Column/Row, Column/Row/Row or Column/Column/Row: + +| |Is supported| +|-------|---| +|bf16|✓| +|fp16|✓| +|fp32|✗| +|int8|✗| +|fp8 (C bf16)|✓| +|fp16 (A fp8)|✓| +|fp16 (B fp8)|✓| + +## Others + +* **DeviceGemm_dequantB** - GEMM with dequantization (implemented with WMMA instructions). +* **DeviceGemmMultipleD_ABScale** - GEMM with scale for A and B matrix. +* **DeviceGemmMultipleDLayernorm** - GEMM fused with layernorm. +* **DeviceGemmMultipleDMultipleR** - GEMM fused with reductions and custom global reductions operators. +* **DeviceGemmReduce** - GEMM fused with reduction. +* **DeviceGemm_Streamk_V2** - GEMM stream K implementation. Implementation allows to use reduction instead of AtomicAdd. +* **DeviceGemmStreamK** - GEMM stream K implementation using AtomicAdd. diff --git a/client_example/07_grouped_convnd_fwd/README.md b/client_example/07_grouped_convnd_fwd/README.md new file mode 100644 index 0000000000..28a64ad733 --- /dev/null +++ b/client_example/07_grouped_convnd_fwd/README.md @@ -0,0 +1,68 @@ +[Back to supported operations](../../../include/ck/README.md) +# Composable Kernel Grouped Convolution + +## Grouped Convolution Forward +Grouped convolution operation for 1D, 2D or 3D spatial dimensions. Convolution utilizes GEMM kernel after tensor coordinate transform. In CK Grouped Convolution Forward operation is called as `DeviceGroupedConvFwdMultipleABD` and requires following types as template parameters: + +* **NumDimSpatial** - number of spatial dimensions (1D, 2D, 3D). +* **InLayout** - input layout (NHWGC, GNHWC, NGCHW). +* **WeiLayout** - weight layout (GKYXC). +* **DsLayout** - layouts for additional tensors for fused operations. +* **OutLayout** - output layout (NHWGK, GNHWK, NGKHW). +* **ADataType** - input data type. Pass tuple if there is fused operation with input. +* **BDataType** - weight data type. Pass tuple if there is fused operation with weight. +* **DsDataType** - data types for additional tensors for fused operations. +* **EDataType** - Output data type. +* **AElementwiseOperation** - fused operation on tensor A (input). +* **BElementwiseOperation** - fused operation on tensor B (weight). +* **CDEElementwiseOperation** - fused operation on tensor C (output). +* **AComputeType** - compute data type of tensor A for mfma instruction (ADataType by default). +* **BComputeType** - compute data type of tensor B for mfma instruction (AComputeType by default). + +Grouped convolution forward support tensors larger than 2GB. + +List of the device operations for grouped convolution forward in CK: + +* **DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3** - Device operation with XDL instructions. Optimized for AMD Instinct MI300 series. +* **DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle** - Device operation with XDL instructions and support of fused operations to input, weight and output. +* **DeviceGroupedConvFwdMultipleD_Wmma_CShuffle** - Device operation with WMMA instructions. +* **DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK** - Device operation with DL instructions. + +Table of supported cases by instance factory with XDL instruction: + +| |NHWGC/GKYXC/NHWGK|NGCHW/GKYXC/NGKHW|GNHWC/GKYXC/GNHWK| +|-------|---|---|---| +|bf16 |2D, 3D|2D|1D, 2D, 3D| +|fp16 |2D, 3D|2D|1D, 2D, 3D| +|fp32 |2D, 3D|2D|1D, 2D, 3D| +|int8 |2D, 3D|2D|1D, 3D| +|fp8 |3D|✗|✗| +|bf8 |3D|✗|✗| + +Table of supported cases by instance factory with WMMA instruction: + +| |NHWGC/GKYXC/NHWGK|NGCHW/GKYXC/NGKHW|GNHWC/GKYXC/GNHWK| +|-------|---|---|---| +|fp16 |2D, 3D|✗|2D, 3D| +|int8 |2D, 3D|✗|2D, 3D| + +Table of supported cases by instance factory with DL instruction: + +| |NHWGC/GKYXC/NHWGK|NGCHW/GKYXC/NGKHW|GNHWC/GKYXC/GNHWK| +|-------|---|---|---| +|bf16 |✗|✗|2D| +|fp16 |✗|✗|2D| +|fp32 |✗|✗|2D| +|int8 |✗|✗|2D| + +Table of supported cases by instance factory with fused elementwise operation: + +* **Dynamic elementwise operation** - 2D/3D, NHWGC, bf16/fp16/fp32/int8 +* **Bilinear** - 3D, NHWGC, bf16/fp16/fp32/int8 +* **ConvInvScale** - 3D, NHWGC, fp8 +* **ConvScale** - 3D, NHWGC, fp8/bf8 +* **ConvScale + Add** - 3D, NHWGC, fp8 +* **ConvScale + Relu** - 3D, NHWGC, fp8 +* **Scale** - 3D, NHWGC, bf16/fp16/fp32/int8 +* **Scale + Add (for A and B)** - 3D, NHWGC, bf16/fp16/fp32/int8 +* **Scale + Add + Scale + Add + Relu** - 3D, NHWGC, bf16/fp16/fp32/int8 diff --git a/client_example/10_grouped_convnd_bwd_data/README.md b/client_example/10_grouped_convnd_bwd_data/README.md new file mode 100644 index 0000000000..0ed133310e --- /dev/null +++ b/client_example/10_grouped_convnd_bwd_data/README.md @@ -0,0 +1,48 @@ +[Back to supported operations](../../../include/ck/README.md) +# Composable Kernel Grouped Convolution + +## Grouped Convolution Backward Data + +Grouped convolution operation for 1D, 2D or 3D spatial dimensions. Convolution utilizes GEMM kernel after tensor coordinate transform. In CK Grouped Convolution Backward Data operation is called as `DeviceGroupedConvBwdDataMultipleD` and requires following types as template parameters: + +* **NumDimSpatial** - number of spatial dimensions (1D, 2D, 3D). +* **ALayout** - output layout (NHWGK, GNHWK, NGKHW). +* **BLayout** - weight layout (GKYXC). +* **DsLayout** - layouts for additional tensors for fused operations. +* **ELayout** - input layout (NHWGC, GNHWC, NGCHW). +* **ADataType** - output data type. +* **BDataType** - weight data type. +* **DsDataType** - data types for additional tensors for fused operations. +* **EDataType** - input data type. +* **AElementwiseOperation** - fused operation on tensor A (output). +* **BElementwiseOperation** - fused operation on tensor B (weight). +* **CDEElementwiseOperation** - fused operation on tensor C (input). +* **AComputeType** - compute data type of tensor A for mfma instruction (ADataType by default). +* **BComputeType** - compute data type of tensor B for mfma instruction (AComputeType by default). + +Grouped convolution backward data supports tensors larger than 2GB (except when image is larger than 2GB). + +List of the device operations for grouped convolution backward data in CK: + +* **DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1** - Device operation with XDL instructions and support of fused operations to input. +* **DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle** - Device operation with WMMA instructions. + +Table of supported cases by instance factory with XDL instruction: + +| |NHWGC/GKYXC/NHWGK|NGCHW/GKYXC/NGKHW|GNHWC/GKYXC/GNHWK| +|-------|---|---|---| +|bf16|2D, 3D|✗|2D, 3D| +|fp16 |2D, 3D|✗|2D, 3D| +|fp32 |2D, 3D|✗|2D, 3D| + +Table of supported cases by instance factory with WMMA instruction: + +| |NHWGC/GKYXC/NHWGK|NGCHW/GKYXC/NGKHW|GNHWC/GKYXC/GNHWK| +|-------|---|---|---| +|fp16 |2D, 3D|✗|2D, 3D| +|int8 |2D, 3D|✗|2D, 3D| + +Table of supported cases by instance factory with fused elementwise operation: + +* **Bilinear** - 3D, NHWGC, bf16/fp16/fp32 +* **Scale** - 3D, NHWGC, bf16/fp16/fp32 diff --git a/client_example/11_grouped_conv_bwd_weight/README.md b/client_example/11_grouped_conv_bwd_weight/README.md new file mode 100644 index 0000000000..ed3dff0f1e --- /dev/null +++ b/client_example/11_grouped_conv_bwd_weight/README.md @@ -0,0 +1,62 @@ +[Back to supported operations](../../../include/ck/README.md) +# Composable Kernel Grouped Convolution + +## Grouped Convolution Backward Weight + +Grouped convolution operation for 1D, 2D or 3D spatial dimensions. Convolution utilizes GEMM kernel after tensor coordinate transform. Backward weight version uses splitK feature (due to large GEMM K dimension). In CK Grouped Convolution Backward Weight operation is called as `DeviceGroupedConvBwdWeight` and requires following types as template parameters: + +* **NumDimSpatial** - number of spatial dimensions (1D, 2D, 3D). +* **InLayout** - input layout (NHWGC, GNHWC, NGCHW). +* **WeiLayout** - weight layout (GKYXC). +* **OutLayout** - output layout (NHWGK, GNHWK, NGKHW). +* **InDataType** - input data type. +* **WeiDataType** - weight data type. +* **OutDataType** - output data type. +* **InElementwiseOperation** - fused operation on tensor input. +* **WeiElementwiseOperation** - fused operation on tensor weight. +* **OutElementwiseOperation** - fused operation on tensor output. +* **ComputeTypeA** - compute data type of tensor A for mfma instruction (ADataType by default). +* **ComputeTypeB** - compute data type of tensor B for mfma instruction (ComputeTypeA by default). + +For fused operations with additional tensor there is `DeviceGroupedConvBwdWeightMultipleD` operation which requires following parameters: +* **DsLayout** - layouts for additional tensors for fused operations. +* **DsDataType** - data types for additional tensors for fused operations. + +Grouped convolution backward weight doesn't supports tensors larger than 2GB. + +List of the device operations for grouped convolution backward weight in CK: + +* **DeviceGroupedConvBwdWeight_Xdl_CShuffle** - Device operation with XDL instructions. +* **DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle** - Device operation with XDL instructions. Optimized for small C or K. +* **DeviceGroupedConvBwdWeight_Wmma_CShuffle** - Device operation with WMMA instructions. +* **DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle** - Device operation with XDL instructions and support of fused operations to output. +* **DeviceGroupedConvBwdWeight_Dl** - Device operation with DL instructions. + +Table of supported cases by instance factory with XDL instruction: + +| |NHWGC/GKYXC/NHWGK|NGCHW/GKYXC/NGKHW|GNHWC/GKYXC/GNHWK| +|-------|---|---|---| +|bf16|2D, 3D|✗|✗| +|bf16(fp32 for weight)|2D, 3D|✗|1D, 2D, 3D| +|fp16 |2D, 3D|✗|1D, 2D, 3D| +|fp32 |2D, 3D|✗|1D, 2D, 3D| + +Table of supported cases by instance factory with WMMA instruction: + +| |NHWGC/GKYXC/NHWGK|NGCHW/GKYXC/NGKHW|GNHWC/GKYXC/GNHWK| +|-------|---|---|---| +|fp16 |3D|✗|3D| +|int8 |3D|✗|3D| + +Table of supported cases by instance factory with DL instruction: + +| |NHWGC/GKYXC/NHWGK|NGCHW/GKYXC/NGKHW|GNHWC/GKYXC/GNHWK| +|-------|---|---|---| +|bf16(fp32 for weight)|1D, 2D, 3D|✗|1D, 2D, 3D| +|fp16 |1D, 2D, 3D|✗|1D, 2D, 3D| +|fp32 |1D, 2D, 3D|✗|1D, 2D, 3D| + +Table of supported cases by instance factory with fused elementwise operation: + +* **Bilinear** - 3D, NHWGC, bf16(fp32 for weight)/fp16/fp32 +* **Scale** - 3D, NHWGC, bf16(fp32 for weight)/fp16/fp32 diff --git a/client_example/CMakeLists.txt b/client_example/CMakeLists.txt index ce5834d1e2..9e2012bf8a 100644 --- a/client_example/CMakeLists.txt +++ b/client_example/CMakeLists.txt @@ -56,7 +56,7 @@ if (GPU_TARGETS) add_definitions(-DCK_USE_WMMA) set(CK_USE_WMMA "ON") endif() - if (GPU_TARGETS MATCHES "gfx12") + if (GPU_TARGETS MATCHES "gfx12" OR GPU_TARGETS MATCHES "gfx950") add_definitions(-DCK_USE_OCP_FP8) set(CK_USE_OCP_FP8 "ON") endif() diff --git a/codegen/driver/main.cpp b/codegen/driver/main.cpp index c7d295de94..7b878d0d57 100644 --- a/codegen/driver/main.cpp +++ b/codegen/driver/main.cpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/codegen/src/headers.cpp b/codegen/src/headers.cpp index 5b0c929db3..452cd99846 100644 --- a/codegen/src/headers.cpp +++ b/codegen/src/headers.cpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + #include "ck/host/headers.hpp" #include "ck_headers.hpp" diff --git a/codegen/src/types.cpp b/codegen/src/types.cpp index a8a8b10c04..9aa5d39fae 100644 --- a/codegen/src/types.cpp +++ b/codegen/src/types.cpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + #include "ck/host/types.hpp" #include "ck/host/stringutils.hpp" #include diff --git a/codegen/test/gemm_multiple_d.cpp b/codegen/test/gemm_multiple_d.cpp index bd7ef463fb..9e2d990d9b 100644 --- a/codegen/test/gemm_multiple_d.cpp +++ b/codegen/test/gemm_multiple_d.cpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + #include "ck/host/device_gemm_multiple_d/problem.hpp" #include "ck/host/device_gemm_multiple_d/operation.hpp" #include "ck/host/headers.hpp" diff --git a/codegen/test/grouped_conv_fwd_multiple_d_v1.cpp b/codegen/test/grouped_conv_fwd_multiple_d_v1.cpp index 50290fa25a..9902caab04 100644 --- a/codegen/test/grouped_conv_fwd_multiple_d_v1.cpp +++ b/codegen/test/grouped_conv_fwd_multiple_d_v1.cpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + #include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_op.hpp" #include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_problem.hpp" #include "ck/host/headers.hpp" diff --git a/codegen/test/grouped_conv_fwd_multiple_d_v2.cpp b/codegen/test/grouped_conv_fwd_multiple_d_v2.cpp index b558d97c78..205283e7aa 100644 --- a/codegen/test/grouped_conv_fwd_multiple_d_v2.cpp +++ b/codegen/test/grouped_conv_fwd_multiple_d_v2.cpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + #include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_op.hpp" #include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_problem.hpp" #include "ck/host/headers.hpp" diff --git a/codegen/test/grouped_conv_fwd_multiple_d_v3.cpp b/codegen/test/grouped_conv_fwd_multiple_d_v3.cpp index e2972a93d2..2b83af2432 100644 --- a/codegen/test/grouped_conv_fwd_multiple_d_v3.cpp +++ b/codegen/test/grouped_conv_fwd_multiple_d_v3.cpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + #include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_op.hpp" #include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_problem.hpp" #include "ck/host/headers.hpp" diff --git a/codegen/test/grouped_conv_fwd_multiple_d_v4.cpp b/codegen/test/grouped_conv_fwd_multiple_d_v4.cpp index b728096c51..fbe27e9c8b 100644 --- a/codegen/test/grouped_conv_fwd_multiple_d_v4.cpp +++ b/codegen/test/grouped_conv_fwd_multiple_d_v4.cpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + #include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_op.hpp" #include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_problem.hpp" #include "ck/host/headers.hpp" diff --git a/codegen/test/include/common.hpp b/codegen/test/include/common.hpp index 99d4c64973..24fde2e523 100644 --- a/codegen/test/include/common.hpp +++ b/codegen/test/include/common.hpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + #pragma once #include #include diff --git a/codegen/test/rtc/include/rtc/compile_kernel.hpp b/codegen/test/rtc/include/rtc/compile_kernel.hpp index c4413b47be..a49714f7c6 100644 --- a/codegen/test/rtc/include/rtc/compile_kernel.hpp +++ b/codegen/test/rtc/include/rtc/compile_kernel.hpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + #ifndef GUARD_HOST_TEST_RTC_INCLUDE_RTC_COMPILE_KERNEL #define GUARD_HOST_TEST_RTC_INCLUDE_RTC_COMPILE_KERNEL diff --git a/codegen/test/rtc/include/rtc/hip.hpp b/codegen/test/rtc/include/rtc/hip.hpp index e962d4cd3e..af2f4a9122 100644 --- a/codegen/test/rtc/include/rtc/hip.hpp +++ b/codegen/test/rtc/include/rtc/hip.hpp @@ -1,10 +1,13 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + #ifndef GUARD_HOST_TEST_RTC_INCLUDE_RTC_HIP #define GUARD_HOST_TEST_RTC_INCLUDE_RTC_HIP #include #include -#include #include +#include namespace rtc { diff --git a/codegen/test/rtc/include/rtc/kernel.hpp b/codegen/test/rtc/include/rtc/kernel.hpp index 9f38e90416..b1ee729f77 100644 --- a/codegen/test/rtc/include/rtc/kernel.hpp +++ b/codegen/test/rtc/include/rtc/kernel.hpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + #ifndef GUARD_HOST_TEST_RTC_INCLUDE_RTC_KERNEL #define GUARD_HOST_TEST_RTC_INCLUDE_RTC_KERNEL diff --git a/codegen/test/rtc/include/rtc/manage_ptr.hpp b/codegen/test/rtc/include/rtc/manage_ptr.hpp index 92edf12628..52b94d4b70 100644 --- a/codegen/test/rtc/include/rtc/manage_ptr.hpp +++ b/codegen/test/rtc/include/rtc/manage_ptr.hpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + #ifndef GUARD_HOST_TEST_RTC_INCLUDE_RTC_MANAGE_POINTER #define GUARD_HOST_TEST_RTC_INCLUDE_RTC_MANAGE_POINTER diff --git a/codegen/test/rtc/include/rtc/tmp_dir.hpp b/codegen/test/rtc/include/rtc/tmp_dir.hpp index a0a2cb9b77..2f3b26cc43 100644 --- a/codegen/test/rtc/include/rtc/tmp_dir.hpp +++ b/codegen/test/rtc/include/rtc/tmp_dir.hpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + #ifndef GUARD_HOST_TEST_RTC_INCLUDE_RTC_TMP_DIR #define GUARD_HOST_TEST_RTC_INCLUDE_RTC_TMP_DIR diff --git a/codegen/test/rtc/src/compile_kernel.cpp b/codegen/test/rtc/src/compile_kernel.cpp index 8cb71b9043..5a70f898e8 100644 --- a/codegen/test/rtc/src/compile_kernel.cpp +++ b/codegen/test/rtc/src/compile_kernel.cpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + #include #include #include diff --git a/codegen/test/rtc/src/hip.cpp b/codegen/test/rtc/src/hip.cpp index 747f83e3ba..6f16e36720 100644 --- a/codegen/test/rtc/src/hip.cpp +++ b/codegen/test/rtc/src/hip.cpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + #include #include #include diff --git a/codegen/test/rtc/src/kernel.cpp b/codegen/test/rtc/src/kernel.cpp index 9fe38e84ad..982e95de17 100644 --- a/codegen/test/rtc/src/kernel.cpp +++ b/codegen/test/rtc/src/kernel.cpp @@ -1,6 +1,10 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + #include #include #include +#include #include // extern declare the function since hip/hip_ext.h header is broken diff --git a/codegen/test/rtc/src/tmp_dir.cpp b/codegen/test/rtc/src/tmp_dir.cpp index 4e89bc3539..b36b17cce1 100644 --- a/codegen/test/rtc/src/tmp_dir.cpp +++ b/codegen/test/rtc/src/tmp_dir.cpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + #include #include #include diff --git a/docs/sphinx/requirements.in b/docs/sphinx/requirements.in index 7f48a51ce8..e9df8c9f5f 100644 --- a/docs/sphinx/requirements.in +++ b/docs/sphinx/requirements.in @@ -1,2 +1,2 @@ -rocm-docs-core==1.14.1 +rocm-docs-core==1.15.0 sphinxcontrib-bibtex==2.6.3 diff --git a/docs/sphinx/requirements.txt b/docs/sphinx/requirements.txt index 0332e19bc7..a42fdf09bf 100644 --- a/docs/sphinx/requirements.txt +++ b/docs/sphinx/requirements.txt @@ -199,7 +199,7 @@ requests==2.32.3 # via # pygithub # sphinx -rocm-docs-core==1.14.1 +rocm-docs-core==1.15.0 # via -r requirements.in rpds-py==0.22.3 # via diff --git a/example/01_gemm/CMakeLists.txt b/example/01_gemm/CMakeLists.txt index 77f15a213c..97ac21eba5 100755 --- a/example/01_gemm/CMakeLists.txt +++ b/example/01_gemm/CMakeLists.txt @@ -61,7 +61,7 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp64) add_example_executable(example_gemm_xdl_streamk gemm_xdl_streamk.cpp) -list(APPEND gpu_list gfx90a gfx940 gfx941 gfx942) +list(APPEND gpu_list gfx90a gfx940 gfx941 gfx942 gfx950) set(target 0) foreach(gpu IN LISTS GPU_TARGETS) if(gpu IN_LIST gpu_list AND target EQUAL 0) diff --git a/example/01_gemm/gemm_xdl_fp16.cpp b/example/01_gemm/gemm_xdl_fp16.cpp index 07d51855d6..414683ffdf 100644 --- a/example/01_gemm/gemm_xdl_fp16.cpp +++ b/example/01_gemm/gemm_xdl_fp16.cpp @@ -31,9 +31,7 @@ using DeviceGemmInstance0 = ck::tensor_operation::device::DeviceGemmXdl // ######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| // ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>; -// // clang-format on -// clang-format off using DeviceGemmInstance1 = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle // ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| // ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| diff --git a/example/04_gemm_add_add_fastgelu/CMakeLists.txt b/example/04_gemm_add_add_fastgelu/CMakeLists.txt index be47665a26..aa9367cdcf 100644 --- a/example/04_gemm_add_add_fastgelu/CMakeLists.txt +++ b/example/04_gemm_add_add_fastgelu/CMakeLists.txt @@ -16,7 +16,7 @@ if(USE_BITINT_EXTENSION_INT4) add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_int4) endif(USE_BITINT_EXTENSION_INT4) -list(APPEND gpu_list gfx90a gfx940 gfx941 gfx942) +list(APPEND gpu_list gfx90a gfx940 gfx941 gfx942 gfx950) set(target 0) foreach(gpu IN LISTS GPU_TARGETS) if(gpu IN_LIST gpu_list AND target EQUAL 0) diff --git a/example/18_batched_gemm_reduce/CMakeLists.txt b/example/18_batched_gemm_reduce/CMakeLists.txt index 94ed129dc0..018b57f82c 100644 --- a/example/18_batched_gemm_reduce/CMakeLists.txt +++ b/example/18_batched_gemm_reduce/CMakeLists.txt @@ -1,4 +1,4 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) +list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942 gfx950) set(target 0) foreach(gpu IN LISTS GPU_TARGETS) if(gpu IN_LIST gpu_list AND target EQUAL 0) diff --git a/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_example.inc b/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_example.inc index e3370b880b..ce42a20be7 100644 --- a/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_example.inc +++ b/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_example.inc @@ -32,6 +32,56 @@ using BiasLayout = typename LayoutSettingSelector::BiasLayout; template using ResidualLayout = typename LayoutSettingSelector::ResidualLayout; +#if defined(CK_USE_AMD_MFMA_GFX950) +template +using DeviceConvFwdInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< + NDimSpatial, + InputLayout, + WeightLayout, + ck::Tuple, ResidualLayout>, + OutputLayout, + InKernelDataType, + WeiKernelDataType, + AccDataType, + CShuffleDataType, + ck::Tuple, + OutKernelDataType, + InElementOp, + WeiElementOp, + OutElementOp, + ConvSpec, // ConvForwardSpecialization + GemmSpec, // GemmSpecialization + 1, // + 256, // BlockSize + 128, // MPerBlock + 256, // NPerBlock + 64, // KPerBlock + 16, // AK1 + 16, // BK1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 4, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 4, // ABlockTransferSrcScalarPerVector + 4, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 4, // BBlockTransferSrcScalarPerVector + 4, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, + 1, + S<1, 16, 1, 16>, + 4>; +#else // defined(CK_USE_AMD_MFMA_GFX950) template using DeviceConvFwdInstance = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< @@ -80,6 +130,7 @@ using DeviceConvFwdInstance = 1, S<1, 16, 1, 16>, 4>; +#endif // defined(CK_USE_AMD_MFMA_GFX950) template using HostConvFwdInstance = ck::tensor_operation::host::ReferenceConvFwd + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_ab_scale.hpp" +#include "ck/utility/blkgemmpipe_scheduler.hpp" +#include "ck/utility/data_type.hpp" +#include "ck/utility/sequence.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_mx_gemm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/fill.hpp" +#include "ck/library/utility/host_tensor.hpp" + +using ScaleDataType = ck::e8m0_bexp_t; + +template +using S = ck::Sequence; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +struct ExecutionConfig final +{ + int do_verification = 1; // (0=no, 1=CPU) + int init_method = 2; // (0=no init, 1=integer value, 2=decimal value) + bool time_kernel = false; // (0=no, 1=yes) + int verbosity = 0; // (0=no info, 1=verbose info) +}; + +struct ProblemSize final +{ + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = -1; + ck::index_t StrideB = -1; + ck::index_t StrideC = -1; +}; + +bool parse_cmd_args(int argc, char* argv[], ProblemSize& problem_size, ExecutionConfig& config) +{ + if(argc == 1) + { + // use default case + } + else if(argc == 5) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + config.verbosity = std::stoi(argv[4]); + } + else if(argc == 11) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + config.verbosity = std::stoi(argv[4]); + + problem_size.M = std::stoi(argv[5]); + problem_size.N = std::stoi(argv[6]); + problem_size.K = std::stoi(argv[7]); + + problem_size.StrideA = std::stoi(argv[8]); + problem_size.StrideB = std::stoi(argv[9]); + problem_size.StrideC = std::stoi(argv[10]); + } + else + { + std::cerr << "arg1: verification (0=no, 1=CPU)" << std::endl + << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)" + << std::endl + << "arg3: time kernel (0=no, 1=yes)" << std::endl + << "arg4: verbosity (0=no info, 1=verbose info)" << std::endl + << "arg5 to 10: M (16x), N(16x), K(16x), StrideA, StrideB, StrideC" << std::endl; + return false; + } + + return true; +} + +template +bool run_mx_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) +{ + using ELayout = CLayout; + using DsLayout = ck::Tuple<>; + using DsDataType = ck::Tuple<>; + using AElementOp = PassThrough; + using BElementOp = PassThrough; + using CDEElementOp = CElementWiseOp; + + static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; + static constexpr auto BlkGemmPSched = ck::BlockGemmPipelineScheduler::Intrawave; + static constexpr auto BlkGemmPVer = ck::BlockGemmPipelineVersion::v3; + +#if 1 + // XXX: These parameters should not exist in MX-native GEMM kernel + static constexpr ck::index_t Scale_Block_M = 128; + static constexpr ck::index_t Scale_Block_N = 128; +#endif + static constexpr ck::index_t Scale_Block_K = MXVectorSize; + + // XXX: DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3 is not designed to utilize MX-specific MFMA + // instructions. + // + // XXX: DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3 is not designed to utilize device-optimized + // scaled type convert functions. + // + // XXX: In DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3, KPerBlock is expected to be equal to + // ScaleBlockK (aka MXVectorSize). + // Additionally, the following is also expected: + // static_assert(ScaleBlockM % MPerBlock == 0); + // static_assert(ScaleBlockN % NPerBlock == 0); + // In MX-native GEMM kernel these requirements should be relaxed. + // + // XXX: It appears, by default we are using mfma_f32_16x16x4xf32 + // MfmaSelector::selected_mfma.k_per_blk = + // MfmaSelector::selected_mfma.k_per_blk = mfma_f32_16x16x4xf32 + // XXX: GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3 assumes scale type is float + + // clang-format off + using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3 + // ######| ALayout| BLayout| DsLayout| CLayout| ADataType| AScale| BDataType| BScale| DsDataType| CDataType| GemmAcc| CShuffleDataType|AElementwise|BElementwise| CElementwise| GemmSpec|Block| ScaleBlockM| ScaleBlockN| ScaleBlockK| M| N| K| AK1| BK1| M| N|MXdl|NXdl|ABlockTransfer|ABlockTransfer|ABlockTransfer|ABlockTransfer|ABlockTransfer|ABlockTransfer| ABlock|BBlockTransfer|BBlockTransfer|BBlockTransfer|BBlockTransfer|BBlockTransfer|BBlockTransfer| BBlock| CShuffle| CShuffle|CShuffleBlockTransfer|CDEShuffleBlockTransfer| BlkGemm| BlkGemm|ComputeTypeA|ComputeTypeB|LDSTypeA|LDSTypeB| + // ######| | | | | | DataType| | DataType| | | DataType| | Operation| Operation| Operation| | Size| | | | Per| Per| Per| | | Per| Per| Per| Per| ThreadCluster| ThreadCluster|SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar|LdsExtraM| ThreadCluster| ThreadCluster|SrcAccessOrder| SrcVector| SrcScalar| DstScalar|LdsExtraN| MXdl| NXdl| ClusterLengths| Scalar| PipeSched| PipelineVer| | | | | + // ######| | | | | | | | | | | | | | | | | | | | |Block|Block| Block| | | XDL| XDL|Wave|Wave| Lengths| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths| ArrangeOrder| | Dim| PerVector| PerVector_BK1| | PerWave| PerWave| MBlock_MPerBlock| PerVectors| | | | | | | + // ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | AK0_M_AK1| | | | | | | BK0_N_BK1| | | | | |PerShuffle|PerShuffle| NBlock_NPerBlock| | | | | | | | + < ALayout, BLayout, DsLayout, ELayout, ADataType, XDataType, BDataType, XDataType, DsDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, Scale_Block_M, Scale_Block_N, Scale_Block_K, 128, 128, 128, 16, 16, 16, 16, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlkGemmPSched, BlkGemmPVer, float, float, float, float>; + // clang-format on + + auto M = problem_size.M; + auto N = problem_size.N; + auto K = problem_size.K; + auto StrideA = problem_size.StrideA; + auto StrideB = problem_size.StrideB; + auto StrideC = problem_size.StrideC; + + auto f_host_tensor_descriptor = + [](ck::index_t row, ck::index_t col, ck::index_t stride, auto layout) { + if constexpr(std::is_same_v) + { + return HostTensorDescriptor({row, col}, {stride, 1}); + } + else + { + return HostTensorDescriptor({row, col}, {1, stride}); + } + }; + + auto f_get_default_stride = + [](ck::index_t row, ck::index_t col, ck::index_t stride, auto layout) { + if(stride == -1) + { + // give a chance if stride is -1, return a default packed stride + if constexpr(std::is_same_v) + { + return static_cast(col); + } + else + { + return static_cast(row); + } + } + else + return static_cast(stride); + }; + + StrideA = f_get_default_stride(M, K, StrideA, ALayout{}); + StrideB = f_get_default_stride(K, N, StrideB, BLayout{}); + StrideC = f_get_default_stride(M, N, StrideC, CLayout{}); + + if(K % Scale_Block_K != 0) + { + throw std::runtime_error("wrong! K must be multiple of Scale_Block_K (16 or 32)"); + }; + + auto Scale_Stride_AM = f_get_default_stride(M, K / Scale_Block_K, StrideA, ALayout{}); + auto Scale_Stride_BN = f_get_default_stride(K / Scale_Block_K, N, StrideB, BLayout{}); + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + + Tensor a_m_k_scale( + f_host_tensor_descriptor(M, K / Scale_Block_K, Scale_Stride_AM, ALayout{})); // scales for A + Tensor b_k_n_scale( + f_host_tensor_descriptor(K / Scale_Block_K, N, Scale_Stride_BN, BLayout{})); // scales for B + + Tensor c_m_n_host_result( + f_host_tensor_descriptor(M, N, StrideC, CLayout{})); // host verification + Tensor c_m_n_device_result( + f_host_tensor_descriptor(M, N, StrideC, CLayout{})); // device result downloaded to host + + if(config.verbosity >= 0) + { + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "a_m_k_scale: " << a_m_k_scale.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "b_k_n_scale: " << b_k_n_scale.mDesc << std::endl; + std::cout << "c_m_n_device_result: " << c_m_n_device_result.mDesc << std::endl; + } + + switch(config.init_method) + { + case 0: + if(config.verbosity > 0) + { + std::cout << "NOTE: No input data initialization." << std::endl; + } + break; + case 1: + case 2: + ck::utils::FillConstant{ck::type_convert(1.0f)}(a_m_k); + ck::utils::FillConstant{ck::type_convert(0.5f)}(a_m_k_scale); + ck::utils::FillConstant{ck::type_convert(1.0f)}(b_k_n); + ck::utils::FillConstant{ck::type_convert(2.0f)}(b_k_n_scale); + if(config.verbosity > 0) + { + std::cout << "Init A = {1}" << std::endl; + std::cout << "Init A scale = {0.5}" << std::endl; + std::cout << "Init B = {1}" << std::endl; + std::cout << "Init B scale = {2.0}" << std::endl; + std::cout << "Expect C = {K}" << std::endl; + } + break; + + default: + if(config.verbosity > 0) + { + std::cout << "NOTE: No input data initialization." << std::endl; + } + } + + if(config.verbosity > 0) + std::cout << "Device memory allocation..." << std::endl; + DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem a_scale_device_buf(sizeof(XDataType) * a_m_k_scale.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem b_scale_device_buf(sizeof(XDataType) * b_k_n_scale.mDesc.GetElementSpaceSize()); + DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize()); + + if(config.verbosity > 0) + std::cout << "Upload data to device..." << std::endl; + a_device_buf.ToDevice(a_m_k.mData.data()); + a_scale_device_buf.ToDevice(a_m_k_scale.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + b_scale_device_buf.ToDevice(b_k_n_scale.mData.data()); + if(config.verbosity > 0) + std::cout << "Done." << std::endl; + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + constexpr ck::index_t NumDTensor = DsDataType::Size(); + + // do GEMM + auto device_op = DeviceOpInstance{}; + auto invoker = device_op.MakeInvoker(); + auto argument = device_op.MakeArgument(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + std::array{}, + c_device_buf.GetDeviceBuffer(), + M, + N, + K, + StrideA, + StrideB, + std::array{}, + StrideC, + a_scale_device_buf.GetDeviceBuffer(), + b_scale_device_buf.GetDeviceBuffer(), + a_element_op, + b_element_op, + cde_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error("wrong!\n" + "Provided combination of compilation and runtime parameters is " + "not consistent with the supported device_gemm arguments."); + } + + if(config.verbosity > 0) + std::cout << "Computing GEMM on device..." << std::endl; + float ave_time = + invoker.Run(argument, StreamConfig{nullptr, config.time_kernel, config.verbosity, 20, 50}); + + bool res_verified = true; + if(config.do_verification > 0) + { + c_device_buf.FromDevice(c_m_n_device_result.mData.data()); + if(config.verbosity > 0) + { + std::cout << "Done." << std::endl; + std::cout << "Computing GEMM on host..." << std::endl; + } + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceMXGemm; + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument(a_m_k, + a_m_k_scale, + b_k_n, + b_k_n_scale, + c_m_n_host_result, + PassThrough{}, + PassThrough{}, + PassThrough{}); + + ref_invoker.Run(ref_argument); + + if(config.verbosity > 0) + { + std::cout << "Done." << std::endl; + std::cout << "Comparing results..." << std::endl; + } + + if(config.init_method == 1) + { + res_verified = + res_verified && std::abs(static_cast(K) - c_m_n_device_result(0, 0)) <= 0.0f; + std::cout << "Expected vs Computed: " << 1.0f * K << " vs " << c_m_n_device_result(0, 0) + << ((res_verified) ? " (PASSED!)" : " (FAILED!)") << std::endl; + } + + res_verified = res_verified && ck::utils::check_err(c_m_n_device_result, + c_m_n_host_result, + "Error: Incorrect results!"); + + if(config.verbosity > 0 && res_verified) + std::cout << "Done." << std::endl; + } + else + { + if(config.verbosity > 0) + std::cout << "Done." << std::endl; + } + + if(config.time_kernel) + { + std::size_t flop = std::size_t(2) * M * N * K + M * K + K * N; // GEMM + A scale + B scale + std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + + sizeof(CDataType) * M * N + + sizeof(XDataType) * (M * K + K * N) / Scale_Block_K; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s" << std::endl; + } + + return res_verified; +} + +template +bool run_mx_gemm_example(int argc, char* argv[]) +{ + ProblemSize problem_size; + ExecutionConfig config; + + return parse_cmd_args(argc, argv, problem_size, config) && + run_mx_gemm(problem_size, config); +} diff --git a/example/67_gemm_microscaling/gemm_mx_fp8.cpp b/example/67_gemm_microscaling/gemm_mx_fp8.cpp new file mode 100644 index 0000000000..d2e21698ec --- /dev/null +++ b/example/67_gemm_microscaling/gemm_mx_fp8.cpp @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_mx_common.hpp" + +using ADataType = ck::f8_t; +using BDataType = ck::f8_t; +#if 1 +// XXX: MX-native GEMM kernel will work with e8m0_bexp_t scale type +using XDataType = float; +#else +using XDataType = ck::e8m0_bexp_t; +#endif +using AccDataType = float; +using CShuffleDataType = float; +using CDataType = float; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using CElementOp = PassThrough; // elementwise transformation for C matrix + +constexpr ck::index_t mx_vector_size = 128; // scaling block size + +int main(int argc, char* argv[]) +{ + return run_mx_gemm_example(argc, argv) + ? 0 + : -1; +} diff --git a/example/CMakeLists.txt b/example/CMakeLists.txt index f26d738625..bcb62df625 100644 --- a/example/CMakeLists.txt +++ b/example/CMakeLists.txt @@ -23,34 +23,34 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME) message("adding example ${EXAMPLE_NAME}") set(result 1) if(DEFINED DTYPES) - foreach(source IN LISTS FILE_NAME) - set(test 0) - if((source MATCHES "_fp16" OR source MATCHES "_f16") AND NOT "fp16" IN_LIST DTYPES) - set(test 1) - endif() - if((source MATCHES "_fp32" OR source MATCHES "_f32") AND NOT "fp32" IN_LIST DTYPES) - set(test 1) - endif() - if((source MATCHES "_fp64" OR source MATCHES "_f64") AND NOT "fp64" IN_LIST DTYPES) - set(test 1) - endif() - if((source MATCHES "_fp8" OR source MATCHES "_f8") AND NOT "fp8" IN_LIST DTYPES) - set(test 1) - endif() - if((source MATCHES "_bf8" OR source MATCHES "_bf8") AND NOT "bf8" IN_LIST DTYPES) - set(test 1) - endif() - if((source MATCHES "_bf16" OR source MATCHES "_b16") AND NOT "bf16" IN_LIST DTYPES) - set(test 1) - endif() - if((source MATCHES "_int8" OR source MATCHES "_i8") AND NOT "int8" IN_LIST DTYPES) - set(test 1) - endif() - if(test EQUAL 1) - message("removing example source file ${source} ") - list(REMOVE_ITEM FILE_NAME "${source}") - endif() - endforeach() + foreach(source IN LISTS FILE_NAME) + set(test 0) + if((source MATCHES "_fp16" OR source MATCHES "_f16") AND NOT "fp16" IN_LIST DTYPES) + set(test 1) + endif() + if((source MATCHES "_fp32" OR source MATCHES "_f32") AND NOT "fp32" IN_LIST DTYPES) + set(test 1) + endif() + if((source MATCHES "_fp64" OR source MATCHES "_f64") AND NOT "fp64" IN_LIST DTYPES) + set(test 1) + endif() + if((source MATCHES "_fp8" OR source MATCHES "_f8") AND NOT "fp8" IN_LIST DTYPES) + set(test 1) + endif() + if((source MATCHES "_bf8" OR source MATCHES "_bf8") AND NOT "bf8" IN_LIST DTYPES) + set(test 1) + endif() + if((source MATCHES "_bf16" OR source MATCHES "_b16") AND NOT "bf16" IN_LIST DTYPES) + set(test 1) + endif() + if((source MATCHES "_int8" OR source MATCHES "_i8") AND NOT "int8" IN_LIST DTYPES) + set(test 1) + endif() + if(test EQUAL 1) + message("removing example source file ${source} ") + list(REMOVE_ITEM FILE_NAME "${source}") + endif() + endforeach() endif() set(EX_TARGETS ${SUPPORTED_GPU_TARGETS}) @@ -83,6 +83,13 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME) list(REMOVE_ITEM FILE_NAME "${source}") endif() endforeach() + #Do not build any microscaling examples if gfx950 target is not on the list + foreach(source IN LISTS FILE_NAME) + if(NOT EX_TARGETS MATCHES "gfx950" AND source MATCHES "_mx") + message("removing microscaling example ${source} ") + list(REMOVE_ITEM FILE_NAME "${source}") + endif() + endforeach() #Do not build any FP8 examples if CK_ENABLE_FP8 not set foreach(source IN LISTS FILE_NAME) if(NOT DEFINED CK_ENABLE_FP8 AND source MATCHES "_fp8") @@ -102,7 +109,9 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME) if(FILE_NAME MATCHES "_xdl") list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic) elseif(FILE_NAME MATCHES "_wmma") - list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030) + list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030 gfx950) + elseif(FILE_NAME MATCHES "_mx") #only build mx example for gfx950 + list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic) endif() set_source_files_properties(${FILE_NAME} PROPERTIES LANGUAGE HIP) add_executable(${EXAMPLE_NAME} ${FILE_NAME}) @@ -195,7 +204,7 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME) if(FILE_NAME MATCHES "_xdl") list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic) elseif(FILE_NAME MATCHES "_wmma") - list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030) + list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030 gfx950) endif() set_source_files_properties(${FILE_NAME} PROPERTIES LANGUAGE HIP) add_executable(${EXAMPLE_NAME} ${FILE_NAME}) diff --git a/example/ck_tile/03_gemm/gemm_basic.cpp b/example/ck_tile/03_gemm/gemm_basic.cpp index c3a66ba3ea..2e04780eb0 100644 --- a/example/ck_tile/03_gemm/gemm_basic.cpp +++ b/example/ck_tile/03_gemm/gemm_basic.cpp @@ -12,7 +12,13 @@ #include "ck_tile/host.hpp" #include "gemm_basic.hpp" -template +template float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) { // The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part. @@ -20,16 +26,12 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& constexpr bool kPadN = false; constexpr bool kPadK = false; - constexpr bool kTilePermute = false; - // The rank and permutation will also be generate out by the CodeGen part. - constexpr ck_tile::index_t kOutputRank = 2; - constexpr int kBlockPerCu = 1; // This part comes from the Codegen constexpr ck_tile::index_t M_Tile = 128; constexpr ck_tile::index_t N_Tile = 128; - constexpr ck_tile::index_t K_Tile = 32; + constexpr ck_tile::index_t K_Tile = 64; constexpr ck_tile::index_t M_Warp = 2; constexpr ck_tile::index_t N_Warp = 2; @@ -37,40 +39,33 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& constexpr ck_tile::index_t M_Warp_Tile = 32; constexpr ck_tile::index_t N_Warp_Tile = 32; - constexpr ck_tile::index_t K_Warp_Tile = 8; - - // Whether doing the CShuffle (transpose before the global memory), depending on the output - // layout. - constexpr bool CShuffleEpilogue = - std::is_same_v; + constexpr ck_tile::index_t K_Warp_Tile = 16; using CodegenGemmShape = ck_tile::TileGemmShape, ck_tile::sequence, ck_tile::sequence>; - using TilePartitioner = ck_tile::GemmTile2DPartitioner; - - using GemmEpilogue = std::conditional_t< - CShuffleEpilogue, - ck_tile::CShuffleEpilogue>, - ck_tile::Default2DEpilogue< - ck_tile::Default2DEpilogueProblem>>; + using TilePartitioner = ck_tile::GemmTile1DPartitioner; using CodegenGemmTraits = ck_tile::TileGemmTraits; using CodegenPipelineProblem = ck_tile:: GemmPipelineProblem; using CodegenGemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; // ToDo: Will add the codegen part to test different pipeline policies in GEMM. // Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy. using Kernel = ck_tile::GemmKernel; @@ -110,12 +105,32 @@ int run_gemm_example(int argc, char* argv[]) using Row = ck_tile::tensor_layout::gemm::RowMajor; using Col = ck_tile::tensor_layout::gemm::ColumnMajor; - std::string a_layout = arg_parser.get_str("a_layout"); - std::string b_layout = arg_parser.get_str("b_layout"); + std::string data_type = arg_parser.get_str("prec"); + std::string a_layout = arg_parser.get_str("a_layout"); + std::string b_layout = arg_parser.get_str("b_layout"); if(a_layout == "R" && b_layout == "C") { - return run_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); + if(data_type == "fp16") + { + return run_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); + } + else if(data_type == "bf16") + { + return run_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); + } + else if(data_type == "fp8") + { + return run_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); + } + else if(data_type == "bf8") + { + return run_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); + } + else + { + throw std::runtime_error("Unsupported data_type!"); + } } else { diff --git a/example/ck_tile/03_gemm/gemm_basic.hpp b/example/ck_tile/03_gemm/gemm_basic.hpp index 4500e3b4fd..5fa94f5f72 100644 --- a/example/ck_tile/03_gemm/gemm_basic.hpp +++ b/example/ck_tile/03_gemm/gemm_basic.hpp @@ -18,7 +18,7 @@ #define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_COMPUTE #endif -#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY) +#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE) #define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrMem #define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrMem #define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Interwave @@ -43,6 +43,33 @@ struct GemmBasicTypeConfig // ToDo: Add more bias config to support different categories of GEMM. }; +template <> +struct GemmBasicTypeConfig +{ + using ADataType = ck_tile::bf16_t; + using BDataType = ck_tile::bf16_t; + using AccDataType = float; + using CDataType = ck_tile::bf16_t; +}; + +template <> +struct GemmBasicTypeConfig +{ + using ADataType = ck_tile::fp8_t; + using BDataType = ck_tile::fp8_t; + using AccDataType = float; + using CDataType = ck_tile::half_t; +}; + +template <> +struct GemmBasicTypeConfig +{ + using ADataType = ck_tile::bf8_t; + using BDataType = ck_tile::bf8_t; + using AccDataType = float; + using CDataType = ck_tile::half_t; +}; + template struct DataTypeTraits; @@ -64,13 +91,23 @@ struct DataTypeTraits static constexpr const char* name = "fp16"; }; -using Types = GemmBasicTypeConfig; +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "bf16"; +}; -// Specific type aliases for easy access -using ADataType = Types::ADataType; -using BDataType = Types::BDataType; -using AccDataType = Types::AccDataType; -using CDataType = Types::CDataType; +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp8"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "bf8"; +}; auto create_args(int argc, char* argv[]) { @@ -79,7 +116,7 @@ auto create_args(int argc, char* argv[]) .insert("n", "4096", "n dimension") .insert("k", "2048", "k dimension") .insert("a_layout", "R", "A tensor data layout - Row by default") - .insert("b_layout", "R", "B tensor data layout - Row by default") + .insert("b_layout", "C", "B tensor data layout - Column by default") .insert("c_layout", "R", "C tensor data layout - Row by default") .insert("stride_a", "0", "Tensor A stride") .insert("stride_b", "0", "Tensor B stride") diff --git a/example/ck_tile/03_gemm/run_gemm_example.inc b/example/ck_tile/03_gemm/run_gemm_example.inc index d32ec57be5..028f8a44c3 100644 --- a/example/ck_tile/03_gemm/run_gemm_example.inc +++ b/example/ck_tile/03_gemm/run_gemm_example.inc @@ -9,6 +9,7 @@ static constexpr inline auto is_row_major(Layout layout_) ck_tile::tensor_layout::gemm::RowMajor>>{}; } +template auto calculate_rtol_atol(const ck_tile::index_t K, const ck_tile::index_t kbatch, const float max_accumulated_value) @@ -29,7 +30,8 @@ auto calculate_rtol_atol(const ck_tile::index_t K, return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); } -template +template float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, ck_tile::DeviceMem& b_k_n_dev_buf, ck_tile::DeviceMem& c_m_n_dev_buf, @@ -55,7 +57,8 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, args.stride_B = stride_B; args.stride_C = stride_C; - float ave_time = gemm_calc( + float ave_time = gemm_calc( args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}); std::size_t flop = std::size_t(2) * M * N * K; @@ -66,13 +69,19 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, std::cout << "Run Gemm kernel with M =" << M << " N =" << N << " K =" << K << " StrideA =" << stride_A << " StrideB =" << stride_B << " StrideC =" << stride_C + << " A_Layout =" << ALayout::name + << " B_Layout =" << BLayout::name + << " C_Layout =" << CLayout::name + << " A Type = " << DataTypeTraits::name + << " B Type = " << DataTypeTraits::name + << " C Type = " << DataTypeTraits::name << " : " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl; return ave_time; } -template +template int run_gemm_example_with_layouts(int argc, char* argv[], const ALayout a_layout = ALayout{}, @@ -83,6 +92,11 @@ int run_gemm_example_with_layouts(int argc, if(!result) return -1; + using ADataType = typename GemmBasicTypeConfig::ADataType; + using BDataType = typename GemmBasicTypeConfig::BDataType; + using CDataType = typename GemmBasicTypeConfig::CDataType; + using AccDataType = typename GemmBasicTypeConfig::AccDataType; + ck_tile::index_t M = arg_parser.get_int("m"); ck_tile::index_t N = arg_parser.get_int("n"); ck_tile::index_t K = arg_parser.get_int("k"); @@ -119,7 +133,8 @@ int run_gemm_example_with_layouts(int argc, c_m_n_dev_buf.SetZero(); c_m_n_dev_result.SetZero(); - invoke_gemm(a_m_k_dev_buf, + invoke_gemm(a_m_k_dev_buf, b_k_n_dev_buf, c_m_n_dev_buf, M, @@ -145,7 +160,8 @@ int run_gemm_example_with_layouts(int argc, a_m_k, b_k_n, c_m_n_host_ref); const float max_accumulated_value = *std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end()); - const auto rtol_atol = calculate_rtol_atol(K, kbatch, max_accumulated_value); + const auto rtol_atol = calculate_rtol_atol + (K, kbatch, max_accumulated_value); pass = ck_tile::check_err(c_m_n_dev_result, c_m_n_host_ref, "Error: Incorrect results!", @@ -202,7 +218,8 @@ int run_gemm_example_with_layouts(int argc, c_m_n_gpu_buf_ref.FromDevice(c_m_n_gpu_ref.data()); const float max_accumulated_value = *std::max_element(c_m_n_gpu_ref.mData.begin(), c_m_n_gpu_ref.mData.end()); - const auto rtol_atol = calculate_rtol_atol(K, kbatch, max_accumulated_value); + const auto rtol_atol = calculate_rtol_atol + (K, kbatch, max_accumulated_value); pass = ck_tile::check_err(c_m_n_dev_result, c_m_n_gpu_ref, "Error: Incorrect results!", diff --git a/example/ck_tile/03_gemm/script/benchmark_basic.sh b/example/ck_tile/03_gemm/script/benchmark_basic.sh index f5473e46f4..a1646da5bd 100755 --- a/example/ck_tile/03_gemm/script/benchmark_basic.sh +++ b/example/ck_tile/03_gemm/script/benchmark_basic.sh @@ -1,12 +1,13 @@ #!/bin/sh EXE="$(find . -name tile_example_gemm_basic -type f | head -n 1)" -VALID=0 +VALID=1 -for b_matrix_layout in "R" "C"; do + +for b_matrix_layout in "C"; do for m in "64" "512" "1024" "2048"; do for n in "512" "1024" "2048"; do for k in "64" "512" "1024" "2048"; do - $EXE -prec=fp16 -b=1 -m=$m -n=$n -k=$k -a_layout="R" -b_layout="$b_matrix_layout" -c_layout="R" -v=$VALID + $EXE -prec=fp16 -m=$m -n=$n -k=$k -a_layout="R" -b_layout="$b_matrix_layout" -c_layout="R" -v=$VALID done done done diff --git a/example/ck_tile/03_gemm/script/benchmark_basic_bf16.sh b/example/ck_tile/03_gemm/script/benchmark_basic_bf16.sh new file mode 100644 index 0000000000..e69de29bb2 diff --git a/example/ck_tile/03_gemm/script/benchmark_basic_bf8.sh b/example/ck_tile/03_gemm/script/benchmark_basic_bf8.sh new file mode 100644 index 0000000000..e69de29bb2 diff --git a/example/ck_tile/03_gemm/script/benchmark_basic_fp8.sh b/example/ck_tile/03_gemm/script/benchmark_basic_fp8.sh new file mode 100644 index 0000000000..21462616be --- /dev/null +++ b/example/ck_tile/03_gemm/script/benchmark_basic_fp8.sh @@ -0,0 +1,14 @@ +#!/bin/sh +EXE="$(find . -name tile_example_gemm_basic -type f | head -n 1)" +VALID=1 + + +for b_matrix_layout in "C"; do + for m in "64" "512" "1024" "2048"; do + for n in "512" "1024" "2048"; do + for k in "64" "512" "1024" "2048"; do + $EXE -prec=fp8 -m=$m -n=$n -k=$k -a_layout="R" -b_layout="$b_matrix_layout" -c_layout="R" -v=$VALID + done + done + done +done \ No newline at end of file diff --git a/example/ck_tile/03_gemm/script/benchmark_mem_pipeline.sh b/example/ck_tile/03_gemm/script/benchmark_mem_pipeline.sh index a3029cbeb5..c4cf4ddcbf 100755 --- a/example/ck_tile/03_gemm/script/benchmark_mem_pipeline.sh +++ b/example/ck_tile/03_gemm/script/benchmark_mem_pipeline.sh @@ -1,12 +1,12 @@ #!/bin/sh EXE="$(find . -name tile_example_gemm_universal -type f | head -n 1)" -VALID=0 +VALID=1 -for b_matrix_layout in "R" "C"; do - for m in "64" "512" "1024" "2048"; do +for b_matrix_layout in "C"; do + for m in "512" "1024" "2048" "4096"; do for n in "512" "1024" "2048"; do - for k in "64" "512" "1024" "2048"; do - $EXE -prec=fp16 -b=1 -m=$m -n=$n -k=$k -a_layout="R" -b_layout="$b_matrix_layout" -c_layout="R" -v=$VALID + for k in "512" "1024" "2048"; do + $EXE -prec=fp16 -m=$m -n=$n -k=$k -a_layout="R" -b_layout="$b_matrix_layout" -c_layout="R" -v=$VALID done done done diff --git a/example/ck_tile/03_gemm/script/benchmark_mem_pipeline_bf16.sh b/example/ck_tile/03_gemm/script/benchmark_mem_pipeline_bf16.sh new file mode 100644 index 0000000000..903b4a3c0f --- /dev/null +++ b/example/ck_tile/03_gemm/script/benchmark_mem_pipeline_bf16.sh @@ -0,0 +1,13 @@ +#!/bin/sh +EXE="$(find . -name tile_example_gemm_universal -type f | head -n 1)" +VALID=1 + +for b_matrix_layout in "C"; do + for m in "512" "1024" "2048" "4096"; do + for n in "512" "1024" "2048"; do + for k in "512" "1024" "2048"; do + $EXE -prec=bf16 -m=$m -n=$n -k=$k -a_layout="R" -b_layout="$b_matrix_layout" -c_layout="R" -v=$VALID + done + done + done +done \ No newline at end of file diff --git a/example/ck_tile/03_gemm/script/benchmark_mem_pipeline_bf8.sh b/example/ck_tile/03_gemm/script/benchmark_mem_pipeline_bf8.sh new file mode 100644 index 0000000000..8c92c2e991 --- /dev/null +++ b/example/ck_tile/03_gemm/script/benchmark_mem_pipeline_bf8.sh @@ -0,0 +1,13 @@ +#!/bin/sh +EXE="$(find . -name tile_example_gemm_universal -type f | head -n 1)" +VALID=1 + +for b_matrix_layout in "C"; do + for m in "512" "1024" "2048" "4096"; do + for n in "512" "1024" "2048"; do + for k in "512" "1024" "2048"; do + $EXE -prec=bf8 -m=$m -n=$n -k=$k -a_layout="R" -b_layout="$b_matrix_layout" -c_layout="R" -v=$VALID + done + done + done +done \ No newline at end of file diff --git a/example/ck_tile/03_gemm/script/benchmark_mem_pipeline_fp8.sh b/example/ck_tile/03_gemm/script/benchmark_mem_pipeline_fp8.sh new file mode 100644 index 0000000000..e238006c7d --- /dev/null +++ b/example/ck_tile/03_gemm/script/benchmark_mem_pipeline_fp8.sh @@ -0,0 +1,13 @@ +#!/bin/sh +EXE="$(find . -name tile_example_gemm_universal -type f | head -n 1)" +VALID=1 + +for b_matrix_layout in "C"; do + for m in "512" "1024" "2048" "4096"; do + for n in "512" "1024" "2048"; do + for k in "512" "1024" "2048"; do + $EXE -prec=fp8 -m=$m -n=$n -k=$k -a_layout="R" -b_layout="$b_matrix_layout" -c_layout="R" -v=$VALID + done + done + done +done \ No newline at end of file diff --git a/example/ck_tile/03_gemm/script/smoke_test_basic.sh b/example/ck_tile/03_gemm/script/smoke_test_basic.sh index 8eb4e101a0..7ca6759f42 100755 --- a/example/ck_tile/03_gemm/script/smoke_test_basic.sh +++ b/example/ck_tile/03_gemm/script/smoke_test_basic.sh @@ -7,22 +7,20 @@ export CK_REPEAT=1 COMMON_ARGS='-v=2 -warmup=0 -repeat=1' -run_fp16_tests() { - for batch in 1 2; do - for m in 128 1024; do - for n in 128 2048; do - for k in 32 64; do +run_tests() { + for m in 128 1024; do + for n in 128 2048; do + for k in 64 128; do - $EXE -b=$batch -m=$m -n=$n -k=$k -stride_a=0 -stride_b=0 -stride_c=0 -e=1e-5 -prec=fp16 $COMMON_ARGS - if [ $? -eq 0 ]; then - echo "Success: Test with batch=$batch, m=$m, n=$n, k=$k executed successfully." - else - echo "Error: Test with batch=$batch, m=$m, n=$n, k=$k failed to execute properly." - # Optionally, exit or break if you need to halt further execution - # exit 1 - fi + $EXE -m=$m -n=$n -k=$k -stride_a=0 -stride_b=0 -stride_c=0 -prec=$1 $COMMON_ARGS + if [ $? -eq 0 ]; then + echo "Success: Test with m=$m, n=$n, k=$k executed successfully." + else + echo "Error: Test with m=$m, n=$n, k=$k failed to execute properly." + # Optionally, exit or break if you need to halt further execution + # exit 1 + fi - done done done done @@ -30,6 +28,9 @@ run_fp16_tests() { set -x -run_fp16_tests +run_tests "fp16" +run_tests "bf16" +run_tests "fp8" +run_tests "bf8" set +x diff --git a/example/ck_tile/03_gemm/script/smoke_test_mem_pipeline.sh b/example/ck_tile/03_gemm/script/smoke_test_mem_pipeline.sh index a9c7f48da0..951f8aa63a 100755 --- a/example/ck_tile/03_gemm/script/smoke_test_mem_pipeline.sh +++ b/example/ck_tile/03_gemm/script/smoke_test_mem_pipeline.sh @@ -7,22 +7,20 @@ export CK_REPEAT=1 COMMON_ARGS='-v=2 -warmup=0 -repeat=1' -run_fp16_tests() { - for batch in 1 2; do - for m in 128 1024; do - for n in 128 2048; do - for k in 32 64; do +run_tests() { + for m in 512 1024; do + for n in 512 2048; do + for k in 512 1024; do - $EXE -b=$batch -m=$m -n=$n -k=$k -stride_a=0 -stride_b=0 -stride_c=0 -e=1e-5 -prec=fp16 $COMMON_ARGS - if [ $? -eq 0 ]; then - echo "Success: Test with batch=$batch, m=$m, n=$n, k=$k executed successfully." - else - echo "Error: Test with batch=$batch, m=$m, n=$n, k=$k failed to execute properly." - # Optionally, exit or break if you need to halt further execution - # exit 1 - fi + $EXE -m=$m -n=$n -k=$k -stride_a=0 -stride_b=0 -stride_c=0 -prec=$1 $COMMON_ARGS + if [ $? -eq 0 ]; then + echo "Success: Test with batch=$batch, m=$m, n=$n, k=$k executed successfully." + else + echo "Error: Test with batch=$batch, m=$m, n=$n, k=$k failed to execute properly." + # Optionally, exit or break if you need to halt further execution + # exit 1 + fi - done done done done @@ -30,6 +28,9 @@ run_fp16_tests() { set -x -run_fp16_tests +run_tests "fp16" +run_tests "bf16" +run_tests "fp8" +run_tests "bf8" set +x diff --git a/example/ck_tile/03_gemm/universal_gemm.cpp b/example/ck_tile/03_gemm/universal_gemm.cpp index 5d2bd2df31..08a9cdb24b 100644 --- a/example/ck_tile/03_gemm/universal_gemm.cpp +++ b/example/ck_tile/03_gemm/universal_gemm.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include @@ -12,7 +12,13 @@ #include "ck_tile/host.hpp" #include "gemm_basic.hpp" -template +template float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) { #if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY) @@ -33,7 +39,7 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& // Compute friendly for Intrawave scheduler constexpr ck_tile::index_t M_Tile = 256; constexpr ck_tile::index_t N_Tile = 256; - constexpr ck_tile::index_t K_Tile = 32; + constexpr ck_tile::index_t K_Tile = 64; constexpr ck_tile::index_t M_Warp = 2; constexpr ck_tile::index_t N_Warp = 2; @@ -50,7 +56,9 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& constexpr bool TransposeC = false; - constexpr int kBlockPerCu = 1; + constexpr int kBlockPerCu = 1; + constexpr ck_tile::index_t TileParitionerGroupNum = 8; + constexpr ck_tile::index_t TileParitionerM01 = 4; // =============================================== @@ -58,10 +66,8 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& ck_tile::TileGemmShape, ck_tile::sequence, ck_tile::sequence>; - using TilePartitioner = ck_tile::GemmTile2DPartitioner; - - using GemmEpilogue = ck_tile::Default2DEpilogue< - ck_tile::Default2DEpilogueProblem>; + using TilePartitioner = ck_tile:: + GemmSpatiallyLocalTilePartitioner; using Traits = ck_tile::TileGemmTraits; using GemmUniversalTraits = ck_tile:: @@ -95,6 +101,19 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& using GemmPipeline = GEMM_PIPELINE; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; using Kernel = ck_tile::GemmKernel; auto kargs = Kernel::MakeKernelArgs(args); @@ -230,24 +249,101 @@ int run_gemm_example(int argc, char* argv[]) using Row = ck_tile::tensor_layout::gemm::RowMajor; using Col = ck_tile::tensor_layout::gemm::ColumnMajor; - std::string a_layout = arg_parser.get_str("a_layout"); - std::string b_layout = arg_parser.get_str("b_layout"); + std::string data_type = arg_parser.get_str("prec"); + std::string a_layout = arg_parser.get_str("a_layout"); + std::string b_layout = arg_parser.get_str("b_layout"); if(a_layout == "R" && b_layout == "R") { - return run_gemm_example_with_layouts(argc, argv, Row{}, Row{}, Row{}); + if(data_type == "fp16") + { + return run_gemm_example_with_layouts(argc, argv, Row{}, Row{}, Row{}); + } + else if(data_type == "bf16") + { + return run_gemm_example_with_layouts(argc, argv, Row{}, Row{}, Row{}); + } + else if(data_type == "fp8") + { + return run_gemm_example_with_layouts(argc, argv, Row{}, Row{}, Row{}); + } + else if(data_type == "bf8") + { + return run_gemm_example_with_layouts(argc, argv, Row{}, Row{}, Row{}); + } + else + { + throw std::runtime_error("Unsupported data_type!"); + } } else if(a_layout == "R" && b_layout == "C") { - return run_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); + if(data_type == "fp16") + { + return run_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); + } + else if(data_type == "bf16") + { + return run_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); + } + else if(data_type == "fp8") + { + return run_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); + } + else if(data_type == "bf8") + { + return run_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); + } + else + { + throw std::runtime_error("Unsupported data_type!"); + } } else if(a_layout == "C" && b_layout == "C") { - return run_gemm_example_with_layouts(argc, argv, Col{}, Col{}, Row{}); + if(data_type == "fp16") + { + return run_gemm_example_with_layouts(argc, argv, Col{}, Col{}, Row{}); + } + else if(data_type == "bf16") + { + return run_gemm_example_with_layouts(argc, argv, Col{}, Col{}, Row{}); + } + else if(data_type == "fp8") + { + return run_gemm_example_with_layouts(argc, argv, Col{}, Col{}, Row{}); + } + else if(data_type == "bf8") + { + return run_gemm_example_with_layouts(argc, argv, Col{}, Col{}, Row{}); + } + else + { + throw std::runtime_error("Unsupported data_type!"); + } } else if(a_layout == "C" && b_layout == "R") { - return run_gemm_example_with_layouts(argc, argv, Col{}, Row{}, Row{}); + if(data_type == "fp16") + { + return run_gemm_example_with_layouts(argc, argv, Col{}, Row{}, Row{}); + } + else if(data_type == "bf16") + { + return run_gemm_example_with_layouts(argc, argv, Col{}, Row{}, Row{}); + } + else if(data_type == "fp8") + { + return run_gemm_example_with_layouts(argc, argv, Col{}, Row{}, Row{}); + } + else if(data_type == "bf8") + { + return run_gemm_example_with_layouts(argc, argv, Col{}, Row{}, Row{}); + } + else + { + throw std::runtime_error("Unsupported data_type!"); + } } else { diff --git a/example/ck_tile/16_batched_gemm/batched_gemm.cpp b/example/ck_tile/16_batched_gemm/batched_gemm.cpp index 720802236c..949621e116 100644 --- a/example/ck_tile/16_batched_gemm/batched_gemm.cpp +++ b/example/ck_tile/16_batched_gemm/batched_gemm.cpp @@ -19,12 +19,9 @@ template float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stream_config& s) { // The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part. - constexpr bool kPadM = false; - constexpr bool kPadN = false; - constexpr bool kPadK = false; - constexpr bool kTilePermute = false; - // The rank and permutation will also be generate out by the CodeGen part. - constexpr ck_tile::index_t kOutputRank = 2; + constexpr bool kPadM = false; + constexpr bool kPadN = false; + constexpr bool kPadK = false; constexpr int kBlockPerCu = 1; @@ -41,38 +38,31 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre constexpr ck_tile::index_t N_Warp_Tile = 32; constexpr ck_tile::index_t K_Warp_Tile = 8; - // Whether doing the CShuffle (transpose before the global memory), depending on the output - // layout. - constexpr bool CShuffleEpilogue = - std::is_same_v; - using CodegenGemmShape = ck_tile::TileGemmShape, ck_tile::sequence, ck_tile::sequence>; - using TilePartitioner = ck_tile::GemmTile2DPartitioner; - - using GemmEpilogue = std::conditional_t< - CShuffleEpilogue, - ck_tile::CShuffleEpilogue>, - ck_tile::Default2DEpilogue< - ck_tile::Default2DEpilogueProblem>>; + using TilePartitioner = ck_tile::GemmTile1DPartitioner; using CodegenGemmTraits = ck_tile::TileGemmTraits; using CodegenPipelineProblem = ck_tile:: GemmPipelineProblem; using CodegenGemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; // ToDo: Will add the codegen part to test different pipeline policies in GEMM. // Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy. using Kernel = ck_tile::BatchedGemmKernel; diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp b/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp index bb4bdbf514..c32fac6c0d 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include @@ -20,12 +20,9 @@ namespace { struct GroupedGemmKernelParam { - static const bool kPadM = false; - static const bool kPadN = false; - static const bool kPadK = false; - static const bool kTilePermute = false; - - static const ck_tile::index_t kOutputRank = 2; + static const bool kPadM = false; + static const bool kPadN = false; + static const bool kPadK = false; static const int kBlockPerCu = 1; static const ck_tile::index_t M_Tile = 128; @@ -54,24 +51,6 @@ using CodegenGemmShape = using TilePartitioner = ck_tile::GemmTile1DPartitioner; -template -using GemmEpilogue = std::conditional_t< - std::is_same_v, - ck_tile::CShuffleEpilogue>, - ck_tile::Default2DEpilogue>>; - template using CodegenGemmTraits = ck_tile::TileGemmTraits using CodegenGemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1>; +template +using GemmEpilogue = ck_tile::CShuffleEpilogue::kBlockSize, + TilePartitioner::MPerBlock, + TilePartitioner::NPerBlock, + GroupedGemmKernelParam::M_Warp, + GroupedGemmKernelParam::N_Warp, + GroupedGemmKernelParam::M_Warp_Tile, + GroupedGemmKernelParam::N_Warp_Tile, + GroupedGemmKernelParam::K_Warp_Tile, + CodegenPipelineProblem::TransposeC>>; + template using Kernel = ck_tile::GroupedGemmKernel, - GemmEpilogue>; + GemmEpilogue>; }; // namespace std::size_t get_workspace_size(const std::vector& gemm_descs) diff --git a/include/ck/README.md b/include/ck/README.md index bff689f6b0..92d5a51087 100644 --- a/include/ck/README.md +++ b/include/ck/README.md @@ -1,19 +1,23 @@ [Back to the main page](../../README.md) # Composable Kernel supported operations ## Supported device operations -* [Average pooling]() -* [Batched contraction]() -* [Batched gemm]() -* [Batchnorm]() -* [CGEMM]() -* [Contraction]() -* [Convolution]() -* [Image to Column and Column to Image]() -* [Elementwise]() -* [GEMM]() -* [Max pooling]() -* [Reduce]() -* [Normalization]() -* [Permute]() -* [Put]() -* [Softmax]() + + + + + + + + +* [GEMM](../../client_example/01_gemm/README.md) +* [Grouped Convolution Forward](../../client_example/07_grouped_convnd_fwd/README.md) +* [Grouped Convolution Backward Data](../../client_example/10_grouped_convnd_bwd_data/README.md) +* [Grouped Convolution Backward Weight](../../client_example/11_grouped_conv_bwd_weight/README.md) + + + + + + + + diff --git a/include/ck/ck.hpp b/include/ck/ck.hpp index fc9d074716..1ec0c6bc23 100644 --- a/include/ck/ck.hpp +++ b/include/ck/ck.hpp @@ -5,7 +5,7 @@ #include "ck/config.h" #include "ck/utility/env.hpp" - +#ifndef CK_CODE_GEN_RTC #ifndef CK_DONT_USE_HIP_RUNTIME_HEADERS #include "hip/hip_runtime.h" #include "hip/hip_fp16.h" @@ -14,7 +14,7 @@ // environment variable to enable logging: // export CK_LOGGING=ON or CK_LOGGING=1 or CK_LOGGING=ENABLED CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING) - +#endif // to do: add various levels of logging with CK_LOG_LEVEL #ifndef CK_TIME_KERNEL @@ -55,10 +55,10 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING) // define general macros for various architectures #if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \ - defined(__gfx942__) + defined(__gfx942__) || defined(__gfx950__) #define __gfx9__ #endif -#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) || defined(__gfx950__) #define __gfx94__ #endif #if defined(__gfx1010__) || defined(__gfx1011__) || defined(__gfx1012__) @@ -163,6 +163,16 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING) // set rounding to nearest even as default for f8 conversions #define CK_USE_SR_F8_CONVERSION 0 +// set rounding to nearest even as default for f6 conversions +#define CK_USE_SR_F6_CONVERSION 0 + +// set rounding to nearest even as default for f4 conversions +#define CK_USE_SR_F4_CONVERSION 0 + +// shuffle pk_i4 values during conversion to optimize number of binary +// operations +#define CK_USE_PK4_LAYOUT_SHUFFLE 1 + // block synchronization only s_wait lgkmcnt(0), not vmcnt(0) #define CK_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM 1 diff --git a/include/ck/config.h.in b/include/ck/config.h.in index 3a590c676f..994e60025d 100644 --- a/include/ck/config.h.in +++ b/include/ck/config.h.in @@ -131,6 +131,10 @@ #cmakedefine CK_USE_FP8_ON_UNSUPPORTED_ARCH @CK_USE_FP8_ON_UNSUPPORTED_ARCH@ #endif +#ifndef CK_USE_NATIVE_MX_SUPPORT +#cmakedefine CK_USE_NATIVE_MX_SUPPORT @CK_USE_NATIVE_MX_SUPPORT@ +#endif + // clang-format on #endif // CK_CONFIG_H_IN diff --git a/include/ck/host_utility/device_prop.hpp b/include/ck/host_utility/device_prop.hpp index f5c4b43ad2..05dc491af7 100644 --- a/include/ck/host_utility/device_prop.hpp +++ b/include/ck/host_utility/device_prop.hpp @@ -55,20 +55,21 @@ inline bool is_xdl_supported() { return ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" || ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" || - ck::get_device_name() == "gfx942"; + ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950"; } inline bool is_lds_direct_load_supported() { // Check if direct loads from global memory to LDS are supported. return ck::get_device_name() == "gfx90a" || ck::get_device_name() == "gfx940" || - ck::get_device_name() == "gfx941" || ck::get_device_name() == "gfx942"; + ck::get_device_name() == "gfx941" || ck::get_device_name() == "gfx942" || + ck::get_device_name() == "gfx950"; } inline bool is_bf16_atomic_supported() { return ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" || - ck::get_device_name() == "gfx942"; + ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950"; } inline bool is_gfx101_supported() diff --git a/include/ck/library/utility/check_err.hpp b/include/ck/library/utility/check_err.hpp index 08bfefb87f..d33ecaeef8 100644 --- a/include/ck/library/utility/check_err.hpp +++ b/include/ck/library/utility/check_err.hpp @@ -26,6 +26,7 @@ namespace utils { template double get_relative_threshold(const int number_of_accumulations = 1) { + using F4 = ck::f4_t; using F8 = ck::f8_t; using F16 = ck::half_t; using BF16 = ck::bhalf_t; @@ -33,10 +34,10 @@ double get_relative_threshold(const int number_of_accumulations = 1) using I8 = int8_t; using I32 = int32_t; - static_assert(is_same_v || is_same_v || - is_same_v || is_same_v || - is_same_v || is_same_v || - is_same_v, + static_assert(is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v, "Warning: Unhandled ComputeDataType for setting up the relative threshold!"); double compute_error = 0; if constexpr(is_same_v || is_same_v || @@ -49,10 +50,10 @@ double get_relative_threshold(const int number_of_accumulations = 1) compute_error = std::pow(2, -NumericUtils::mant) * 0.5; } - static_assert(is_same_v || is_same_v || - is_same_v || is_same_v || - is_same_v || is_same_v || - is_same_v, + static_assert(is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v, "Warning: Unhandled OutDataType for setting up the relative threshold!"); double output_error = 0; if constexpr(is_same_v || is_same_v || @@ -66,10 +67,10 @@ double get_relative_threshold(const int number_of_accumulations = 1) } double midway_error = std::max(compute_error, output_error); - static_assert(is_same_v || is_same_v || - is_same_v || is_same_v || - is_same_v || is_same_v || - is_same_v, + static_assert(is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v, "Warning: Unhandled AccDataType for setting up the relative threshold!"); double acc_error = 0; if constexpr(is_same_v || is_same_v || @@ -87,6 +88,7 @@ double get_relative_threshold(const int number_of_accumulations = 1) template double get_absolute_threshold(const double max_possible_num, const int number_of_accumulations = 1) { + using F4 = ck::f4_t; using F8 = ck::f8_t; using F16 = ck::half_t; using BF16 = ck::bhalf_t; @@ -94,10 +96,10 @@ double get_absolute_threshold(const double max_possible_num, const int number_of using I8 = int8_t; using I32 = int32_t; - static_assert(is_same_v || is_same_v || - is_same_v || is_same_v || - is_same_v || is_same_v || - is_same_v, + static_assert(is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v, "Warning: Unhandled ComputeDataType for setting up the absolute threshold!"); auto expo = std::log2(std::abs(max_possible_num)); double compute_error = 0; @@ -111,10 +113,10 @@ double get_absolute_threshold(const double max_possible_num, const int number_of compute_error = std::pow(2, expo - NumericUtils::mant) * 0.5; } - static_assert(is_same_v || is_same_v || - is_same_v || is_same_v || - is_same_v || is_same_v || - is_same_v, + static_assert(is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v, "Warning: Unhandled OutDataType for setting up the absolute threshold!"); double output_error = 0; if constexpr(is_same_v || is_same_v || @@ -128,10 +130,10 @@ double get_absolute_threshold(const double max_possible_num, const int number_of } double midway_error = std::max(compute_error, output_error); - static_assert(is_same_v || is_same_v || - is_same_v || is_same_v || - is_same_v || is_same_v || - is_same_v, + static_assert(is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v, "Warning: Unhandled AccDataType for setting up the absolute threshold!"); double acc_error = 0; if constexpr(is_same_v || is_same_v || @@ -450,5 +452,54 @@ check_err(const Range& out, return res; } +template +std::enable_if_t<(std::is_same_v, ranges::range_value_t> && + std::is_same_v, f4_t>), + bool> +check_err(const Range& out, + const RefRange& ref, + const std::string& msg = "Error: Incorrect results!", + double rtol = 0.5, + double atol = 0.5) +{ + if(out.size() != ref.size()) + { + std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size() + << std::endl; + return false; + } + + bool res{true}; + int err_count = 0; + double err = 0; + double max_err = std::numeric_limits::min(); + + for(std::size_t i = 0; i < ref.size(); ++i) + { + const double o = type_convert(*std::next(std::begin(out), i)); + const double r = type_convert(*std::next(std::begin(ref), i)); + err = std::abs(o - r); + + if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r)) + { + max_err = err > max_err ? err : max_err; + err_count++; + if(err_count < 5) + { + std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i + << "] != ref[" << i << "]: " << o << " != " << r << std::endl; + } + res = false; + } + } + + if(!res) + { + std::cerr << std::setw(12) << std::setprecision(7) << "max err: " << max_err + << " number of errors: " << err_count << std::endl; + } + return res; +} + } // namespace utils } // namespace ck diff --git a/include/ck/library/utility/host_tensor_generator.hpp b/include/ck/library/utility/host_tensor_generator.hpp index 6a90523c33..274051da83 100644 --- a/include/ck/library/utility/host_tensor_generator.hpp +++ b/include/ck/library/utility/host_tensor_generator.hpp @@ -69,6 +69,18 @@ struct GeneratorTensor_1 }; #endif +template <> +struct GeneratorTensor_1 +{ + float value = 1.0; + + template + ck::f4_t operator()(Is...) + { + return ck::type_convert(value); + } +}; + template <> struct GeneratorTensor_1 { @@ -183,6 +195,20 @@ struct GeneratorTensor_2 }; #endif +template <> +struct GeneratorTensor_2 +{ + int min_value = 0; + int max_value = 1; + + template + ck::f4_t operator()(Is...) + { + float tmp = (std::rand() % (max_value - min_value)) + min_value; + return ck::type_convert(tmp); + } +}; + template struct GeneratorTensor_3 { @@ -253,6 +279,23 @@ struct GeneratorTensor_3 }; #endif +template <> +struct GeneratorTensor_3 +{ + float min_value = 0; + float max_value = 1; + + template + ck::f4_t operator()(Is...) + { + float tmp = float(std::rand()) / float(RAND_MAX); + + float fp32_tmp = min_value + tmp * (max_value - min_value); + + return ck::type_convert(fp32_tmp); + } +}; + template struct GeneratorTensor_4 { diff --git a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r2.hpp b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r2.hpp index 1c4de5ed31..0a0bcbac38 100644 --- a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r2.hpp +++ b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r2.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -131,7 +131,7 @@ struct ThreadGroupTensorSliceTransfer_v7r2 } template - using is_tuple = decltype(std::declval().IsTuple()); + using is_tuple = decltype(ck::declval().IsTuple()); template __device__ void RunWrite(const DstDescs& dst_descs, diff --git a/include/ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp b/include/ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp index 0eef827a5b..cf20025d46 100644 --- a/include/ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp +++ b/include/ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp @@ -1,9 +1,11 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once +#ifndef CK_CODE_GEN_RTC #include +#endif namespace ck { namespace tensor_operation { @@ -18,6 +20,7 @@ enum struct ConvolutionForwardSpecialization Filter3x3, }; +#ifndef CK_CODE_GEN_RTC inline std::string getConvForwardSpecializationString(const ConvolutionForwardSpecialization& s) { switch(s) @@ -30,6 +33,7 @@ inline std::string getConvForwardSpecializationString(const ConvolutionForwardSp default: return "Unrecognized specialization!"; } } +#endif } // namespace device } // namespace tensor_operation diff --git a/include/ck/tensor_operation/gpu/device/device_base.hpp b/include/ck/tensor_operation/gpu/device/device_base.hpp index 736e241fdf..774982d905 100644 --- a/include/ck/tensor_operation/gpu/device/device_base.hpp +++ b/include/ck/tensor_operation/gpu/device/device_base.hpp @@ -1,19 +1,21 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once +#ifndef CK_CODE_GEN_RTC #include #include #include #include - #include "ck/stream_config.hpp" +#endif namespace ck { namespace tensor_operation { namespace device { +#ifndef CK_CODE_GEN_RTC #define GET_OBJECT_NAME_IMLP \ std::optional GetObjectName() const override \ { \ @@ -41,7 +43,9 @@ namespace device { } #define REGISTER_EXTRA_PRINTING_METHODS GET_OBJECT_NAME_IMLP GET_TEMPLATE_INFO_IMPL +#endif +#ifndef CK_CODE_GEN_RTC struct BaseArgument { BaseArgument() = default; @@ -66,13 +70,14 @@ struct BaseInvoker virtual ~BaseInvoker() {} }; +#endif struct BaseOperator { BaseOperator() = default; BaseOperator(const BaseOperator&) = default; BaseOperator& operator=(const BaseOperator&) = default; - +#ifndef CK_CODE_GEN_RTC virtual bool IsSupportedArgument(const BaseArgument*) { return false; } virtual std::string GetTypeString() const { return ""; } @@ -100,7 +105,7 @@ struct BaseOperator assert(p_arg); p_arg->p_workspace_ = p_workspace; } - +#endif virtual ~BaseOperator() {} }; diff --git a/include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp b/include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp index 184efbbd68..8c9b768a8b 100644 --- a/include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp +++ b/include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp @@ -1,9 +1,11 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once +#ifndef CK_CODE_GEN_RTC #include +#endif #include "ck/tensor_operation/gpu/device/device_base.hpp" #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" @@ -13,8 +15,13 @@ namespace ck { namespace tensor_operation { namespace device { +#ifdef CK_CODE_GEN_RTC +template +using is_tuple = decltype(ck::declval().IsTuple()); +#else template using is_tuple = decltype(std::declval().IsTuple()); +#endif /** * \brief Grouped Convolution Forward @@ -72,12 +79,18 @@ struct DeviceGroupedConvFwdMultipleABD : public BaseOperator static constexpr index_t NumDTensor = DsDataType::Size(); static_assert(NumDTensor == DsLayout::Size(), "wrong! Inconsistent NumDTensor"); - +#ifdef CK_CODE_GEN_RTC + using APointers = ck::conditional_t&, const void*>; + using BPointers = ck::conditional_t&, const void*>; +#else // If DataType is tuple, user has to pass std::array with pointers. using APointers = - std::conditional_t&, const void*>; + ck::conditional_t&, const void*>; using BPointers = - std::conditional_t&, const void*>; + ck::conditional_t&, const void*>; +#endif + +#ifndef CK_CODE_GEN_RTC /** * \brief Make argument pointer for grouped conv fwd. @@ -150,6 +163,7 @@ struct DeviceGroupedConvFwdMultipleABD : public BaseOperator const CDEElementwiseOperation& cde_element_op) = 0; virtual std::unique_ptr MakeInvokerPointer() = 0; +#endif }; } // namespace device diff --git a/include/ck/tensor_operation/gpu/device/gemm_specialization.hpp b/include/ck/tensor_operation/gpu/device/gemm_specialization.hpp index 0bb45b18c3..997dcb75a6 100644 --- a/include/ck/tensor_operation/gpu/device/gemm_specialization.hpp +++ b/include/ck/tensor_operation/gpu/device/gemm_specialization.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -29,6 +29,7 @@ enum struct GemmSpecialization MNKOPadding, }; +#ifndef CK_CODE_GEN_RTC inline std::string getGemmSpecializationString(const GemmSpecialization& s) { switch(s) @@ -52,6 +53,7 @@ inline std::string getGemmSpecializationString(const GemmSpecialization& s) default: return "Unrecognized specialization!"; } } +#endif } // namespace device } // namespace tensor_operation diff --git a/include/ck/tensor_operation/gpu/device/impl/codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp index 180e32c8b6..00518b369f 100644 --- a/include/ck/tensor_operation/gpu/device/impl/codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp @@ -3,11 +3,17 @@ #pragma once +#ifndef CK_CODE_GEN_RTC #include #include #include #include #include +#include + +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#endif #include "ck/utility/common_header.hpp" #include "ck/tensor_description/tensor_descriptor.hpp" @@ -15,15 +21,12 @@ #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" #include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp" -#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" -#include "ck/host_utility/device_prop.hpp" -#include "ck/host_utility/kernel_launch.hpp" -#include "ck/host_utility/io.hpp" namespace ck { namespace tensor_operation { @@ -91,8 +94,7 @@ __device__ void device_grouped_conv_fwd_multiple_abd_xdl_cshuffle( const Block2ETileMap block_2_ctile_map, const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx94__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) // offset base pointer for each work-group const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); @@ -259,8 +261,13 @@ __global__ void } // namespace +#ifdef CK_CODE_GEN_RTC +template +using is_tuple = decltype(ck::declval().IsTuple()); +#else template using is_tuple = decltype(std::declval().IsTuple()); +#endif // // @brief Device Convolution operation. @@ -429,8 +436,8 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle // If we are using multiAB and one of the template datatype parameters is not a tuple, convert // it to it - using GemmADataType = std::conditional_t, ADataType>; - using GemmBDataType = std::conditional_t, BDataType>; + using GemmADataType = ck::conditional_t, ADataType>; + using GemmBDataType = ck::conditional_t, BDataType>; #define GridwiseGemmTemplateParameters \ GemmADataType, GemmBDataType, ComputeDataType, AccDataType, CShuffleDataType, DsDataType, \ @@ -449,15 +456,13 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle CDEBlockTransferScalarPerVector_NPerBlock, LoopSched // Use appropriate gridwise gemm using GridwiseGemm = - std::conditional_t, - GridwiseGemmMultipleD_xdl_cshuffle>; + ck::conditional_t, + GridwiseGemmMultipleD_xdl_cshuffle>; // If ADataTypes or BDataTypes is tuple, user has to pass ck::Array with pointers. - using APointers = - std::conditional_t&, const void*>; - using BPointers = - std::conditional_t&, const void*>; + using APointers = ck::conditional_t&, const void*>; + using BPointers = ck::conditional_t&, const void*>; // Use Tuple for the both cases for GridPointer to initialize it in Argument constructor (not // in initializer list what is required for single const pointer). using AGridPointer = remove_cvref_t< @@ -812,7 +817,6 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle static_for<0, NumDTensor, 1>{}([&](auto i) { using DLayout = remove_cvref_t>; - // FIXME: layout if constexpr(is_same_v || is_same_v || is_same_v || is_same_v || @@ -965,18 +969,18 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle const BElementwiseOperation& b_element_op, const CDEElementwiseOperation& cde_element_op) { - std::array a_g_n_c_wis_lengths_i32; - std::array a_g_n_c_wis_strides_i32; - std::array b_g_k_c_xs_lengths_i32; - std::array b_g_k_c_xs_strides_i32; - std::array, NumDTensor> ds_g_n_k_wos_lengths_i32; - std::array, NumDTensor> ds_g_n_k_wos_strides_i32; - std::array e_g_n_k_wos_lengths_i32; - std::array e_g_n_k_wos_strides_i32; - std::array conv_filter_strides_i32; - std::array conv_filter_dilations_i32; - std::array input_left_pads_i32; - std::array input_right_pads_i32; + ck::Array a_g_n_c_wis_lengths_i32; + ck::Array a_g_n_c_wis_strides_i32; + ck::Array b_g_k_c_xs_lengths_i32; + ck::Array b_g_k_c_xs_strides_i32; + ck::Array, NumDTensor> ds_g_n_k_wos_lengths_i32; + ck::Array, NumDTensor> ds_g_n_k_wos_strides_i32; + ck::Array e_g_n_k_wos_lengths_i32; + ck::Array e_g_n_k_wos_strides_i32; + ck::Array conv_filter_strides_i32; + ck::Array conv_filter_dilations_i32; + ck::Array input_left_pads_i32; + ck::Array input_right_pads_i32; array_convert(a_g_n_c_wis_lengths_i32, a_g_n_c_wis_lengths); array_convert(a_g_n_c_wis_strides_i32, a_g_n_c_wis_strides); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp index 64aa398d53..d53fbca4ea 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp @@ -56,8 +56,7 @@ __global__ void const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, const Block2ETileMap block_2_etile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx94__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; const index_t num_blocks_per_batch = diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_e_permute_xdl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_e_permute_xdl.hpp index d06eab1264..25a9d7f96d 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_e_permute_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_e_permute_xdl.hpp @@ -74,8 +74,7 @@ __global__ void const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, const Block2ETileMap block_2_etile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx94__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp index e950169ccf..985752796b 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp @@ -60,8 +60,7 @@ __global__ void const index_t batch_count, const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx94__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); @@ -108,7 +107,7 @@ __global__ void ignore = block_2_ctile_map; ignore = batch_count; ignore = compute_base_ptr_of_batch; -#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +#endif // end of if (defined(__gfx9__)) } // Computes C = A * B0 * B1 diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multi_d_xdl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multi_d_xdl.hpp index d6b92bc97a..630f143260 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multi_d_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multi_d_xdl.hpp @@ -83,8 +83,7 @@ __global__ void const Block2ETileMap block_2_etile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx94__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp index 6ab1669e30..f6c228fb7b 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp @@ -68,8 +68,7 @@ __global__ void const index_t batch_count, const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx94__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_xdl_cshuffle.hpp index 34b1d503af..30ae72a63e 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_xdl_cshuffle.hpp @@ -59,8 +59,7 @@ __global__ void const ComputeBasePrtOfBatch compute_base_ptr_of_batch_, const Block2CTileMap block_2_ctile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx94__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp index e178b8f525..2662e5c360 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp @@ -67,8 +67,7 @@ __global__ void const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch, const C0MatrixMask c0_matrix_mask) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx94__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); @@ -127,7 +126,7 @@ __global__ void ignore = batch_count; ignore = compute_base_ptr_of_batch; ignore = c0_matrix_mask; -#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +#endif // end of if (defined(__gfx9__)) } // Computes C = A * B0 * B1 diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp index 9af1a44781..bfbcebd7c8 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp @@ -62,8 +62,7 @@ __global__ void const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch, const C0MatrixMask c0_matrix_mask) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx94__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); @@ -112,7 +111,7 @@ __global__ void ignore = batch_count; ignore = compute_base_ptr_of_batch; ignore = c0_matrix_mask; -#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +#endif // end of if (defined(__gfx9__)) } // Computes C = A * B0 * B1 diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl.hpp index 6be2ffbdd7..494524b6f0 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl.hpp @@ -52,8 +52,7 @@ __global__ void #endif kernel_batched_gemm_xdlops_v2r3(const typename DeviceOp::Argument karg) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx94__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / karg.Batch); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_column_to_image_impl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_column_to_image_impl.hpp index e4203e0313..9482812f75 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_column_to_image_impl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_column_to_image_impl.hpp @@ -3,6 +3,7 @@ #pragma once +#include "ck/library/utility/numeric.hpp" #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_operation/gpu/device/device_conv_tensor_rearrange.hpp" diff --git a/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_abd_xdl_cshuffle.hpp index dae16612cc..df5922a04f 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_abd_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_abd_xdl_cshuffle.hpp @@ -55,8 +55,7 @@ __global__ void e_grid_desc_mblock_mperblock_nblock_nperblock, const Block2ETileMap block_2_etile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx94__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; GridwiseGemm::template Run(p_as_grid, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp index 6e69213513..8aa20f7ad4 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp @@ -55,8 +55,7 @@ __global__ void const CElementwiseOperation c_element_op, const Block2CTileMap block_2_ctile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx94__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / num_batches); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); @@ -97,7 +96,7 @@ __global__ void ignore = b_element_op; ignore = c_element_op; ignore = block_2_ctile_map; -#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +#endif // end of if (defined(__gfx9__)) } // specialization for #D conv: in[n, di, hi, wi, c] * wei[k, z, y, x, c] = out[n, do, ho, wo, k] diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_dl.hpp index 811f1ae939..b9467ac194 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_dl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_dl.hpp @@ -50,9 +50,8 @@ __global__ void const CGridDesc_M0_M10_M11_N0_N10_N11 e_grid_desc_m0_m10_m11_n0_n10_n11, const Block2CTileMap block_2_ctile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \ - defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__) || \ - defined(__gfx12__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx9__) || \ + defined(__gfx103__) || defined(__gfx11__) || defined(__gfx12__)) constexpr index_t shared_block_size = GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(ABDataType); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp index eaafd7d5c5..47fb630ea9 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp @@ -63,8 +63,7 @@ __global__ void const Block2ETileMap block_2_etile_map, index_t NRaw) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx94__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) __shared__ char p_shared[GridwiseGemmWelford::GetSharedMemoryNumberOfByte()]; GridwiseGemmWelford::template Run( diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp index bb2db930c8..c048e7249c 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp @@ -60,8 +60,7 @@ __global__ void const RsGridDescriptor_MBlock_MPerBlock rs_grid_desc_mblock_mperblock, const Block2ETileMap block_2_etile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx94__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; GridwiseGemm::template Run(p_a_grid, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp index 77ed9625c5..e6466a487b 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp @@ -52,8 +52,7 @@ __global__ void e_grid_desc_mblock_mperblock_nblock_nperblock, const Block2ETileMap block_2_etile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx94__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; GridwiseGemm::template Run(p_a_grid, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp index 600f12139d..1c14496650 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp @@ -138,6 +138,7 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2 0) { arg.Print(); + GridwiseGemm::BlockwiseGemmPipe::HotLoopInstList::Print(); } if(!GridwiseGemm::CheckValidity(arg)) @@ -733,7 +734,9 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2(p_a_grid, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_contraction_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_contraction_multiple_d_xdl_cshuffle.hpp index cc022b89c5..1cf58fec25 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_contraction_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_contraction_multiple_d_xdl_cshuffle.hpp @@ -37,8 +37,7 @@ __global__ void const BElementwiseOperation b_element_op, const CDEElementwiseOperation cde_element_op) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx94__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; const index_t block_id = get_block_1d_id(); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp index c8c58d5d85..99bd3be15d 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp @@ -87,8 +87,7 @@ __global__ void const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, const ComputePtrOffsetOfN compute_ptr_offset_of_n) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx94__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) // offset base pointer for each work-group const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z); const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp index a7df1c9d57..57c4b1a5cf 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp @@ -60,8 +60,7 @@ __global__ void const Block2CTileMap block_2_ctile_map, const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx94__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); @@ -103,7 +102,7 @@ __global__ void compute_ptr_offset_of_batch.GetAPtrOffset(0); compute_ptr_offset_of_batch.GetBPtrOffset(0); compute_ptr_offset_of_batch.GetCPtrOffset(0); -#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +#endif // end of if (defined(__gfx9__)) } template #include +#include "ck/library/utility/numeric.hpp" #include "ck/utility/common_header.hpp" #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" @@ -98,8 +99,7 @@ __global__ void const ComputePtrOffsetOfG compute_ptr_offset_of_groups, const ComputePtrOffsetOfN compute_ptr_offset_of_n) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx94__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) // offset base pointer for each work-group const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y); @@ -212,9 +212,13 @@ __global__ void } } // namespace - +#ifdef CK_CODE_GEN_RTC +template +using is_tuple = decltype(ck::declval().IsTuple()); +#else template using is_tuple = decltype(std::declval().IsTuple()); +#endif // // @brief Device Convolution operation. diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp index 589a0daa99..9363d7ecb9 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp @@ -9,6 +9,7 @@ #include #include +#include "ck/library/utility/numeric.hpp" #include "ck/utility/common_header.hpp" #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" @@ -117,7 +118,7 @@ __global__ void c_grid_desc_mblock_mperblock_nblock_nperblock); #else ignore = karg; -#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +#endif // end of if (defined(__gfx9__)) } template ::value, bool>::type = false> @@ -438,6 +439,7 @@ std::ostream& operator<<(std::ostream& os, const Layout&) os << Layout::name; return os; } +#endif } // namespace tensor_layout } // namespace ck diff --git a/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp index c87c90a91d..530876650e 100644 --- a/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -340,8 +340,8 @@ struct Bilinear }; template <> - __host__ __device__ constexpr void operator()( - std::int8_t& y, const std::int32_t& x0, const std::int8_t& x1) const + __host__ __device__ constexpr void + operator()(int8_t& y, const int32_t& x0, const int8_t& x1) const { y = type_convert(alpha_ * type_convert(x0) + beta_ * type_convert(x1)); diff --git a/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp index b914c0b96f..370d03258d 100644 --- a/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -533,7 +533,7 @@ struct NormalizeInInfer const T3& gamma, const T4& beta) const { - static_assert(std::is_same::value || std::is_same::value, + static_assert(is_same::value || is_same::value, "Data type is not supported by this operation!"); using ck::type_convert; diff --git a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp index 5e522fb2ea..f1055d1eff 100644 --- a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp @@ -16,7 +16,8 @@ namespace ck { // [Who Says Elephants Can't Run: Bringing Large Scale MoE Models into Cloud Scale Production] // (https://arxiv.org/abs/2211.10017) and implementation: // https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h -__host__ __device__ inline half4_t pki4_to_half4(int q) +// Convert lower part of packed int4 -> int4 to half +__device__ inline half4_t i4_to_half4(int q) { const int LO = 0x000f000f; const int HI = 0x00f000f0; @@ -44,7 +45,7 @@ __host__ __device__ inline half4_t pki4_to_half4(int q) return res.template AsType()[Number<0>{}]; } -__host__ __device__ inline half4_t pki4_to_half4_scale(int q, const ck::half2_t& scale) +__device__ inline half4_t i4_to_half4_scale(int q, const ck::half2_t& scale) { const int LO = 0x000f000f; const int HI = 0x00f000f0; @@ -78,34 +79,7 @@ __host__ __device__ inline half4_t pki4_to_half4_scale(int q, const ck::half2_t& return res.template AsType()[Number<0>{}]; } -__host__ __device__ inline half2_t pki4_to_half2(pk_i4_t q) -{ -#if 1 - uint8_t x_u8 = ck::bit_cast(q); - uint32_t i4s = ((x_u8 & 0x0f) << 16) | ((x_u8 & 0xf0) >> 4); - - const int EX = 0x64006400; - const int SUB = 0xE408E408; //-8 - - int lo = i4s | EX; - - return amd_assembly_pk_add_f16(bit_cast(lo), bit_cast(SUB)); -#else - uint8_t x_u8 = ck::bit_cast(q); - - vector_type res; - - half_t x_h = (x_u8 & 0x0f) - 8; - half_t x_l = ((x_u8 & 0xf0) >> 4) - 8; - - res.template AsType()(Number<0>{}) = x_l; - res.template AsType()(Number<1>{}) = x_h; - - return res.template AsType()[Number<0>{}]; -#endif -} - -__host__ __device__ inline bhalf4_t pki4_to_bhalf4(int q) +__device__ inline bhalf4_t i4_to_bhalf4(int q) { uint32_t i8s = (q & 0xf) | ((q & 0xf0) << 4) | ((q & 0xf00) << 8) | ((q & 0xf000) << 12); @@ -134,21 +108,6 @@ __host__ __device__ inline bhalf4_t pki4_to_bhalf4(int q) return res.template AsType()[Number<0>{}]; } -__host__ __device__ inline bhalf2_t pki4_to_bhalf2(pk_i4_t q) -{ - uint8_t x_u8 = ck::bit_cast(q); - - float x_h = ((x_u8 & 0x0f) >> 0) - 8.f; - float x_l = ((x_u8 & 0xf0) >> 4) - 8.f; - - vector_type res; - - res.template AsType()(Number<0>{}) = type_convert(x_l); - res.template AsType()(Number<1>{}) = type_convert(x_h); - - return res.template AsType()[Number<0>{}]; -} - namespace tensor_operation { namespace element_wise { @@ -159,11 +118,11 @@ struct PassThroughPack8 __host__ __device__ constexpr void operator()(ck::half8_t& y, const ck::pk_i4x4_t& x) const { -#if 1 +#if CK_USE_PK4_LAYOUT_SHUFFLE vector_type result; - result.template AsType()(Number<0>{}) = pki4_to_half4(bit_cast(x)); - result.template AsType()(Number<1>{}) = pki4_to_half4(bit_cast(x) >> 8); + result.template AsType()(Number<0>{}) = i4_to_half4(bit_cast(x)); + result.template AsType()(Number<1>{}) = i4_to_half4(bit_cast(x) >> 8); y = result.template AsType()[Number<0>{}]; #else @@ -171,13 +130,13 @@ struct PassThroughPack8 vector_type src{x}; dst.template AsType()(Number<0>{}) = - pki4_to_half2(src.template AsType()[Number<0>{}]); + type_convert(src.template AsType()[Number<0>{}]); dst.template AsType()(Number<1>{}) = - pki4_to_half2(src.template AsType()[Number<1>{}]); + type_convert(src.template AsType()[Number<1>{}]); dst.template AsType()(Number<2>{}) = - pki4_to_half2(src.template AsType()[Number<2>{}]); + type_convert(src.template AsType()[Number<2>{}]); dst.template AsType()(Number<3>{}) = - pki4_to_half2(src.template AsType()[Number<3>{}]); + type_convert(src.template AsType()[Number<3>{}]); y = dst.template AsType()[Number<0>{}]; #endif @@ -185,11 +144,11 @@ struct PassThroughPack8 __host__ __device__ constexpr void operator()(ck::bhalf8_t& y, const ck::pk_i4x4_t& x) const { -#if 1 +#if CK_USE_PK4_LAYOUT_SHUFFLE vector_type result; - result.template AsType()(Number<0>{}) = pki4_to_bhalf4(bit_cast(x)); - result.template AsType()(Number<1>{}) = pki4_to_bhalf4(bit_cast(x) >> 16); + result.template AsType()(Number<0>{}) = i4_to_bhalf4(bit_cast(x)); + result.template AsType()(Number<1>{}) = i4_to_bhalf4(bit_cast(x) >> 16); y = result.template AsType()[Number<0>{}]; #else @@ -197,13 +156,13 @@ struct PassThroughPack8 vector_type src{x}; dst.template AsType()(Number<0>{}) = - pki4_to_bhalf2(src.template AsType()[Number<0>{}]); + type_convert(src.template AsType()[Number<0>{}]); dst.template AsType()(Number<1>{}) = - pki4_to_bhalf2(src.template AsType()[Number<1>{}]); + type_convert(src.template AsType()[Number<1>{}]); dst.template AsType()(Number<2>{}) = - pki4_to_bhalf2(src.template AsType()[Number<2>{}]); + type_convert(src.template AsType()[Number<2>{}]); dst.template AsType()(Number<3>{}) = - pki4_to_bhalf2(src.template AsType()[Number<3>{}]); + type_convert(src.template AsType()[Number<3>{}]); y = dst.template AsType()[Number<0>{}]; #endif @@ -219,12 +178,12 @@ struct DequantPack8 __host__ __device__ constexpr void operator()(ck::half8_t& y, const ck::pk_i4x4_t& x, const ck::half2_t& z) const { -#if 1 +#if CK_USE_PK4_LAYOUT_SHUFFLE vector_type result; - result.template AsType()(Number<0>{}) = pki4_to_half4_scale(bit_cast(x), z); + result.template AsType()(Number<0>{}) = i4_to_half4_scale(bit_cast(x), z); result.template AsType()(Number<1>{}) = - pki4_to_half4_scale(bit_cast(x) >> 8, z); + i4_to_half4_scale(bit_cast(x) >> 8, z); y = result.template AsType()[Number<0>{}]; #else @@ -232,13 +191,13 @@ struct DequantPack8 vector_type src{x}; dst.template AsType()(Number<0>{}) = - pki4_to_half2(src.template AsType()[Number<0>{}]); + type_convert(src.template AsType()[Number<0>{}]); dst.template AsType()(Number<1>{}) = - pki4_to_half2(src.template AsType()[Number<1>{}]); + type_convert(src.template AsType()[Number<1>{}]); dst.template AsType()(Number<2>{}) = - pki4_to_half2(src.template AsType()[Number<2>{}]); + type_convert(src.template AsType()[Number<2>{}]); dst.template AsType()(Number<3>{}) = - pki4_to_half2(src.template AsType()[Number<3>{}]); + type_convert(src.template AsType()[Number<3>{}]); y = dst.template AsType()[Number<0>{}]; #endif @@ -252,7 +211,7 @@ struct PassThroughPack2 template __host__ __device__ void operator()(Y& y, const X& x) const; - __host__ __device__ constexpr void operator()(ck::half2_t& y, const ck::f8x2_t& x) const + __host__ __device__ constexpr void operator()(half2_t& y, const f8x2_t& x) const { auto t = type_convert(x); y = type_convert(t); @@ -260,7 +219,7 @@ struct PassThroughPack2 __host__ __device__ constexpr void operator()(ck::half2_t& y, const ck::pk_i4_t& x) const { -#if 1 +#if CK_USE_PK4_LAYOUT_SHUFFLE uint8_t x_u8 = ck::bit_cast(x); uint8_t x_l = (x_u8 & 0x0f) >> 0; uint8_t x_h = (x_u8 & 0xf0) >> 4; @@ -479,7 +438,7 @@ struct PassThrough template <> __host__ __device__ void operator()(bf8_t& y, const half_t& x) const { - y = ck::type_convert(x); + y = type_convert(x); } }; @@ -552,21 +511,21 @@ struct Scale template __host__ __device__ void operator()(Y& y, const X& x) const { - y = ck::type_convert(ck::type_convert(x) * scale_); + y = type_convert(type_convert(x) * scale_); } template <> __host__ __device__ void operator()(half_t& y, const half_t& x) const { - y = ck::type_convert(scale_) * x; + y = type_convert(scale_) * x; }; template <> __host__ __device__ void operator()(bhalf_t& y, const bhalf_t& x) const { - const float x_tmp = ck::type_convert(x); + const float x_tmp = type_convert(x); const float y_tmp = scale_ * x_tmp; - y = ck::type_convert(y_tmp); + y = type_convert(y_tmp); }; template <> @@ -584,7 +543,7 @@ struct Scale template <> __host__ __device__ void operator()(int8_t& y, const int8_t& x) const { - y = ck::type_convert(scale_ * ck::type_convert(x)); + y = type_convert(scale_ * type_convert(x)); }; float scale_; @@ -600,7 +559,7 @@ struct ScaleAndResetNaNToMinusInfinity template <> __host__ __device__ void operator()(float& y, const float& x) const { - y = ck::math::isnan(x) ? -ck::NumericLimits::Infinity() : scale_ * x; + y = math::isnan(x) ? -NumericLimits::Infinity() : scale_ * x; }; float scale_; @@ -671,12 +630,13 @@ struct UnaryAbs template __host__ __device__ void operator()(T& y, const T& x) const { + static_assert(is_same::value || is_same::value || is_same::value || is_same::value || is_same::value, "Data type is not supported by this operation!"); - y = ck::math::abs(x); + y = math::abs(x); }; template <> @@ -694,7 +654,7 @@ struct UnarySqrt static_assert(is_same::value || is_same::value, "Data type is not supported by this operation!"); - y = ck::math::sqrt(x); + y = math::sqrt(x); }; }; @@ -713,9 +673,9 @@ struct Relu template <> __host__ __device__ void operator()(bhalf_t& y, const bhalf_t& x) const { - float x_f32 = ck::type_convert(x); + float x_f32 = type_convert(x); float y_f32 = x_f32 > 0 ? x_f32 : 0; - y = ck::type_convert(y_f32); + y = type_convert(y_f32); } }; @@ -731,7 +691,7 @@ struct FastGelu template __device__ void operator()(Y& y, const X& x) const; - +#ifndef CK_CODE_GEN_RTC template <> __host__ void operator()(float& y, const float& x) const { @@ -742,6 +702,7 @@ struct FastGelu const float emu = exp(u); y = x / (1.f + emu); } +#endif // device code, use lower precision "__ocml_exp_f32" and "rcp" template <> @@ -753,7 +714,7 @@ struct FastGelu const float u = x * (c1 * x * x + c2); const float emu = __ocml_exp_f32(u); - y = x * ck::math::rcp(1.f + emu); + y = x * math::rcp(1.f + emu); } template <> @@ -851,10 +812,9 @@ struct Gelu } template <> - __host__ __device__ void operator()(ck::half_t& y, - const ck::half_t& x) const + __host__ __device__ void operator()(half_t& y, const half_t& x) const { - y = ck::half_t(0.5) * x * (ck::half_t(1) + ck::half_t(erf(float(0.70710678118f * x)))); + y = half_t(0.5) * x * (half_t(1) + half_t(erf(float(0.70710678118f * x)))); } }; @@ -868,7 +828,7 @@ struct Sigmoid is_same::value, "Data type is not supported by this operation!"); constexpr T one = type_convert(1); - y = one / (one + ck::math::exp(-x)); + y = one / (one + math::exp(-x)); }; }; @@ -877,11 +837,11 @@ struct Silu template __host__ __device__ void operator()(T& y, const T& x) const { - static_assert(is_same_v || is_same_v || is_same_v || + static_assert(is_same_v || is_same_v || is_same_v || is_same_v || is_same_v, "Data type is not supported by this operation!"); constexpr T one = type_convert(1); - y = x * (one / (one + ck::math::exp(-x))); + y = x * (one / (one + math::exp(-x))); }; }; @@ -895,7 +855,7 @@ struct TanH is_same::value, "Data type is not supported by this operation!"); - y = ck::math::tanh(x); + y = math::tanh(x); }; }; @@ -905,11 +865,11 @@ struct ACos __host__ __device__ void operator()(T& y, const T& x) const { static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || + is_same::value || is_same::value || is_same::value, "Data type is not supported by this operation!"); - y = ck::math::acos(x); + y = math::acos(x); }; }; @@ -919,11 +879,11 @@ struct Neg __host__ __device__ void operator()(T& y, const T& x) const { static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || + is_same::value || is_same::value || is_same::value, "Data type is not supported by this operation!"); - y = ck::math::neg(x); + y = math::neg(x); }; }; @@ -933,11 +893,11 @@ struct ATan __host__ __device__ void operator()(T& y, const T& x) const { static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || + is_same::value || is_same::value || is_same::value, "Data type is not supported by this operation!"); - y = ck::math::atan(x); + y = math::atan(x); }; }; @@ -947,11 +907,11 @@ struct Sin __host__ __device__ void operator()(T& y, const T& x) const { static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || + is_same::value || is_same::value || is_same::value, "Data type is not supported by this operation!"); - y = ck::math::sin(x); + y = math::sin(x); }; }; @@ -961,11 +921,11 @@ struct ASinH __host__ __device__ void operator()(T& y, const T& x) const { static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || + is_same::value || is_same::value || is_same::value, "Data type is not supported by this operation!"); - y = ck::math::asinh(x); + y = math::asinh(x); }; }; @@ -975,11 +935,11 @@ struct Cos __host__ __device__ void operator()(T& y, const T& x) const { static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || + is_same::value || is_same::value || is_same::value, "Data type is not supported by this operation!"); - y = ck::math::cos(x); + y = cos(x); }; }; @@ -989,11 +949,11 @@ struct ACosH __host__ __device__ void operator()(T& y, const T& x) const { static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || + is_same::value || is_same::value || is_same::value, "Data type is not supported by this operation!"); - y = ck::math::acosh(x); + y = math::acosh(x); }; }; @@ -1003,11 +963,11 @@ struct Tan __host__ __device__ void operator()(T& y, const T& x) const { static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || + is_same::value || is_same::value || is_same::value, "Data type is not supported by this operation!"); - y = ck::math::tan(x); + y = math::tan(x); }; }; @@ -1017,11 +977,11 @@ struct ATanH __host__ __device__ void operator()(T& y, const T& x) const { static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || + is_same::value || is_same::value || is_same::value, "Data type is not supported by this operation!"); - y = ck::math::atanh(x); + y = math::atanh(x); }; }; @@ -1031,11 +991,11 @@ struct SinH __host__ __device__ void operator()(T& y, const T& x) const { static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || + is_same::value || is_same::value || is_same::value, "Data type is not supported by this operation!"); - y = ck::math::sinh(x); + y = math::sinh(x); }; }; @@ -1045,11 +1005,11 @@ struct Ceil __host__ __device__ void operator()(T& y, const T& x) const { static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || + is_same::value || is_same::value || is_same::value, "Data type is not supported by this operation!"); - y = ck::math::ceil(x); + y = math::ceil(x); }; }; @@ -1059,11 +1019,11 @@ struct Exp __host__ __device__ void operator()(T& y, const T& x) const { static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || + is_same::value || is_same::value || is_same::value, "Data type is not supported by this operation!"); - y = ck::math::exp(x); + y = math::exp(x); }; }; @@ -1073,11 +1033,11 @@ struct CosH __host__ __device__ void operator()(T& y, const T& x) const { static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || + is_same::value || is_same::value || is_same::value, "Data type is not supported by this operation!"); - y = ck::math::cosh(x); + y = math::cosh(x); }; }; @@ -1087,11 +1047,11 @@ struct Floor __host__ __device__ void operator()(T& y, const T& x) const { static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || + is_same::value || is_same::value || is_same::value, "Data type is not supported by this operation!"); - y = ck::math::floor(x); + y = math::floor(x); }; }; @@ -1101,11 +1061,11 @@ struct Log __host__ __device__ void operator()(T& y, const T& x) const { static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || + is_same::value || is_same::value || is_same::value, "Data type is not supported by this operation!"); - y = ck::math::log(x); + y = math::log(x); }; }; @@ -1115,11 +1075,11 @@ struct ASin __host__ __device__ void operator()(T& y, const T& x) const { static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || + is_same::value || is_same::value || is_same::value, "Data type is not supported by this operation!"); - y = ck::math::asin(x); + y = math::asin(x); }; }; @@ -1129,11 +1089,11 @@ struct Rcp __host__ __device__ void operator()(T& y, const T& x) const { static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || + is_same::value || is_same::value || is_same::value, "Data type is not supported by this operation!"); - y = ck::math::rcp(x); + y = math::rcp(x); }; }; @@ -1153,7 +1113,7 @@ struct Swish "Data type is not supported by this operation!"); float bx = -beta_ * type_convert(x); - y = type_convert(x / (1.f + ck::math::exp(bx))); + y = type_convert(x / (1.f + math::exp(bx))); }; const float beta_; @@ -1172,7 +1132,7 @@ struct SoftRelu "Data type is not supported by this operation!"); T casted_alpha = type_convert(alpha_); constexpr T one = type_convert(1); - y = ck::math::log(one + ck::math::exp(x * casted_alpha)) / casted_alpha; + y = math::log(one + math::exp(x * casted_alpha)) / casted_alpha; } const float alpha_; }; @@ -1193,7 +1153,7 @@ struct Power T casted_beta = type_convert(beta_); T casted_gamma = type_convert(gamma_); T shifted_scaled_x = casted_alpha + casted_beta * x; - y = ck::math::pow(shifted_scaled_x, casted_gamma); + y = math::pow(shifted_scaled_x, casted_gamma); } const float alpha_; const float beta_; @@ -1213,7 +1173,7 @@ struct ClippedRelu "Data type is not supported by this operation!"); T casted_alpha = type_convert(alpha_); T casted_beta = type_convert(beta_); - y = ck::math::min(casted_beta, ck::math::max(casted_alpha, x)); + y = math::min(casted_beta, math::max(casted_alpha, x)); } const float alpha_; const float beta_; @@ -1248,7 +1208,7 @@ struct Elu is_same::value, "Data type is not supported by this operation!"); T casted_alpha = type_convert(alpha_); - y = x > 0 ? x : casted_alpha * ck::math::expm1(x); + y = x > 0 ? x : casted_alpha * math::expm1(x); } const float alpha_; }; @@ -1350,10 +1310,10 @@ struct FastNumericArrayConverter }; template <> -struct FastNumericArrayConverter +struct FastNumericArrayConverter { using InputArray = vector_type; - using OutputArray = vector_type; + using OutputArray = vector_type; __device__ static OutputArray convert(InputArray const& Input) { @@ -1383,13 +1343,13 @@ struct FastNumericArrayConverter }; template -struct FastNumericArrayConverter +struct FastNumericArrayConverter { static constexpr int VEC_WIDTH = 4; static_assert(!(N % VEC_WIDTH), "N must be multiple of 4."); using InputArray = vector_type; - using OutputArray = vector_type; + using OutputArray = vector_type; __device__ static OutputArray convert(InputArray const& Input) { @@ -1398,7 +1358,7 @@ struct FastNumericArrayConverter OutputArray Output; using Vec_InputArray = vector_type; - using Vec_OutputArray = vector_type; + using Vec_OutputArray = vector_type; Vec_OutputArray* half_4_ptr = reinterpret_cast(&Output); Vec_InputArray const* uint8_4_ptr = reinterpret_cast(&Input); diff --git a/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp b/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp index 56c37b1b72..2bc9ef87ac 100644 --- a/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp +++ b/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp @@ -1,14 +1,17 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "ck/utility/math.hpp" #include "ck/utility/number.hpp" +#include "ck/utility/tuple.hpp" #include "ck/tensor_description/tensor_adaptor.hpp" #include "ck/tensor_description/multi_index_transform_helper.hpp" +#ifndef CK_CODE_GEN_RTC #include #include +#endif namespace ck { @@ -978,8 +981,7 @@ struct BlockToCTileMap_3DGrid_KSplit // Create 3D grid const auto M0 = math::integer_divide_ceil(M, MPerBlock); const auto N0 = math::integer_divide_ceil(N, NPerBlock); - - return std::make_tuple(N0, M0, k_split); + return make_tuple(N0, M0, k_split); } template @@ -1103,7 +1105,7 @@ struct BlockToCTileMap_GemmStreamK uint32_t dp_for_sk_iters = k_iters_per_tile.get(); uint32_t best_sk_score = - std::numeric_limits::max(); // we need to find the smallest sk iters + NumericLimits::Max(); // we need to find the smallest sk iters for(uint32_t tentative_sk_blocks = min_sk_tiles; tentative_sk_blocks < max_sk_tiles; tentative_sk_blocks++) { diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp index 9469fa7bc7..73bac20e43 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp @@ -607,6 +607,7 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle // with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will // cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7]. // therefore we may just as well assign Gemm1KPack = group_size + constexpr index_t Gemm1KPack = MfmaSelector::selected_mfma.group_size; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp index 42f7c2a33f..355e0130f1 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp @@ -856,11 +856,18 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle static_cast(p_shared) + SharedMemTrait::b1_block_space_offset, b1_block_desc_bk0_n_bk1.GetElementSpaceSize()); - constexpr index_t Gemm1KPack = math::max( - math::lcm( - MfmaSelector::selected_mfma.group_size, - B1K1), - MfmaSelector::selected_mfma.k_per_blk); + // selected_mfma.group_size or B1K1 <= Gemm1KPack <= selected_mfma.group_size + // selected_mfma.k_per_blk <= Gemm1KPack + // + // Following similar rationale behind Gemm0KPack, let Gemm1KPack be the lowest common + // multiples of A1K1 (predetermined by selected_mfma.group_size) and B1K1. But in this case + // Gemm1KPack can't be higher than A1K1 itself because A1 matrix is distributed in VGPRs + // with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will + // cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7]. + // therefore we may just as well assign Gemm1KPack = group_size + + constexpr index_t Gemm1KPack = + MfmaSelector::selected_mfma.group_size; auto blockwise_gemm1 = BlockwiseGemmXdlops_v2< BlockSize, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp index bc76d4cc4f..44a488c5dd 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp @@ -773,6 +773,7 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle // with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will // cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7]. // therefore we may just as well assign Gemm1KPack = group_size + constexpr index_t Gemm1KPack = MfmaSelector::selected_mfma.group_size; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp index afb2ad2e76..7d2dfab15f 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp @@ -628,6 +628,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle // with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will // cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7]. // therefore we may just as well assign Gemm1KPack = group_size + constexpr index_t Gemm1KPack = MfmaSelector::selected_mfma.group_size; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp index 150dd98064..344656b13f 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -423,10 +423,17 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle } template - __host__ __device__ static auto - MakeAsGridDescriptor_M_K(const std::array& MRaws, - const std::array& KRaws, - const std::array& AsStride) + __host__ __device__ static auto MakeAsGridDescriptor_M_K( +#ifdef CK_CODE_GEN_RTC + const ck::Array& MRaws, + const ck::Array& KRaws, + const ck::Array& AsStride +#else + const std::array& MRaws, + const std::array& KRaws, + const std::array& AsStride +#endif + ) { return generate_tuple( [&](auto i) { @@ -462,10 +469,17 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle } template - __host__ __device__ static auto - MakeBsGridDescriptor_N_K(const std::array& NRaws, - const std::array& KRaws, - const std::array& BsStride) + __host__ __device__ static auto MakeBsGridDescriptor_N_K( +#ifdef CK_CODE_GEN_RTC + const ck::Array& NRaws, + const ck::Array& KRaws, + const ck::Array& BsStride +#else + const std::array& NRaws, + const std::array& KRaws, + const std::array& BsStride +#endif + ) { return generate_tuple( [&](auto i) { @@ -500,10 +514,17 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle } template - __host__ __device__ static auto - MakeDsGridDescriptor_M_N(const std::array& MRaws, - const std::array& NRaws, - const std::array& DsStride) + __host__ __device__ static auto MakeDsGridDescriptor_M_N( +#ifdef CK_CODE_GEN_RTC + const ck::Array& MRaws, + const ck::Array& NRaws, + const ck::Array& DsStride +#else + const std::array& MRaws, + const std::array& NRaws, + const std::array& DsStride +#endif + ) { return generate_tuple( [&](auto i) { @@ -969,9 +990,15 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle const index_t M, const index_t N, const index_t K, +#ifdef CK_CODE_GEN_RTC + const ck::Array StrideAs, + const ck::Array StrideBs, + const ck::Array StrideDs, +#else const std::array StrideAs, const std::array StrideBs, const std::array StrideDs, +#endif const index_t StrideE, const Block2ETileMap& block_2_etile_map) { diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp index 4b344c02f8..eb1eb533d7 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -473,11 +473,19 @@ struct GridwiseGemmMultipleD_xdl_cshuffle return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw); } +#ifdef CK_CODE_GEN_RTC + template + __host__ __device__ static auto + MakeDsGridDescriptor_M_N(const ck::Array& MRaws, + const ck::Array& NRaws, + const ck::Array& DsStride) +#else template __host__ __device__ static auto MakeDsGridDescriptor_M_N(const std::array& MRaws, const std::array& NRaws, const std::array& DsStride) +#endif { return generate_tuple( [&](auto i) { @@ -941,7 +949,11 @@ struct GridwiseGemmMultipleD_xdl_cshuffle const index_t K, const index_t StrideA, const index_t StrideB, +#ifdef CK_CODE_GEN_RTC + const ck::Array StrideDs, +#else const std::array StrideDs, +#endif const index_t StrideE, const Block2ETileMap& block_2_etile_map) { diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp index 44cbbcd049..9dad66913a 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp @@ -1,10 +1,11 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once - +#ifndef CK_CODE_GEN_RTC #include #include +#endif #include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp" @@ -53,12 +54,15 @@ constexpr auto GridwiseGemmPipeline_Selector() } else { +#ifndef CK_CODE_GEN_RTC std::cerr << "GridwiseGemmPipeline configuration is not available" << std::endl; +#endif } } } // namespace ck +#ifndef CK_CODE_GEN_RTC inline std::ostream& operator<<(std::ostream& os, const ck::PipelineVersion& p) { switch(p) @@ -71,3 +75,4 @@ inline std::ostream& operator<<(std::ostream& os, const ck::PipelineVersion& p) } return os; } +#endif diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp index bb1871ae62..21315c2567 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -212,7 +212,7 @@ template ::type = false> struct ThreadwiseTensorSliceTransfer_v2 { - static_assert((InvalidElementAsNaN && !std::is_integral::value) || + static_assert((InvalidElementAsNaN && !ck::is_integral::value) || (!InvalidElementAsNaN), "Filling invalid element as NaN is only for floating point types"); diff --git a/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp b/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp index 24fac91e22..4f20487b9b 100644 --- a/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp +++ b/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp @@ -37,7 +37,17 @@ enum struct MfmaInstr mfma_f32_32x32x16f8bf8, mfma_f32_16x16x32f8bf8, mfma_f32_32x32x16bf8f8, - mfma_f32_16x16x32bf8f8 + mfma_f32_16x16x32bf8f8, + mfma_f32_32x32x16f16, + mfma_f32_16x16x32f16, + mfma_f32_32x32x16bf16, + mfma_f32_16x16x32bf16, + mfma_i32_32x32x32i8, + mfma_i32_16x16x64i8, + mfma_f32_32x32x64f8f6f4, + mfma_f32_16x16x128f8f6f4, + mfma_scale_f32_32x32x64f8f6f4, + mfma_scale_f32_16x16x128f8f6f4 }; template @@ -198,6 +208,50 @@ struct mfma_type } }; +template <> +struct mfma_type +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 4; + static constexpr index_t num_regs_per_blk = 16; + static constexpr index_t num_threads_per_blk = 32; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 2; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 32; + static constexpr index_t n_per_blk = 32; + static constexpr index_t k_per_blk = 8; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_32x32x16f16::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_type +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 1; + static constexpr index_t num_regs_per_blk = 4; + static constexpr index_t num_threads_per_blk = 16; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 4; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 16; + static constexpr index_t n_per_blk = 16; + static constexpr index_t k_per_blk = 8; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_16x16x32f16::Run(a, b, reg_c); + } +}; + template <> struct mfma_type { @@ -264,6 +318,28 @@ struct mfma_type } }; +template <> +struct mfma_type +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 4; + static constexpr index_t num_regs_per_blk = 16; + static constexpr index_t num_threads_per_blk = 32; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 2; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 32; + static constexpr index_t n_per_blk = 32; + static constexpr index_t k_per_blk = 8; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_32x32x16bf16::Run(a, b, reg_c); + } +}; + template <> struct mfma_type { @@ -286,6 +362,28 @@ struct mfma_type } }; +template <> +struct mfma_type +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 1; + static constexpr index_t num_regs_per_blk = 4; + static constexpr index_t num_threads_per_blk = 16; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 4; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 16; + static constexpr index_t n_per_blk = 16; + static constexpr index_t k_per_blk = 8; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_16x16x32bf16::Run(a, b, reg_c); + } +}; + template <> struct mfma_type { @@ -440,6 +538,50 @@ struct mfma_type } }; +template <> +struct mfma_type +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 4; + static constexpr index_t num_regs_per_blk = 16; + static constexpr index_t num_threads_per_blk = 32; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 2; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 32; + static constexpr index_t n_per_blk = 32; + static constexpr index_t k_per_blk = 16; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_i32_32x32x32i8::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_type +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 1; + static constexpr index_t num_regs_per_blk = 4; + static constexpr index_t num_threads_per_blk = 16; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 4; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 16; + static constexpr index_t n_per_blk = 16; + static constexpr index_t k_per_blk = 16; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_i32_16x16x64i8::Run(a, b, reg_c); + } +}; + template <> struct mfma_type { @@ -638,16 +780,115 @@ struct mfma_type } }; +// TODO: fix mfma...f8f6f4 instructions +template <> +struct mfma_type +{ + // clang-format off + static constexpr index_t group_size = 4; // ??? group_size * num_groups_per_blk == num_regs_per_blk + static constexpr index_t num_groups_per_blk = 4; // ??? group_size * num_groups_per_blk == num_regs_per_blk + static constexpr index_t num_regs_per_blk = 16; // m_per_blk * n_per_blk / wave_size + static constexpr index_t num_threads_per_blk = 32; // n_per_blk + static constexpr index_t wave_size = 64; // fixed + static constexpr index_t num_input_blks = 2; // m_per_blk / num_regs_per_blk + static constexpr index_t num_output_blks = 1; // (is_k_reduction == true) ??? + static constexpr index_t m_per_blk = 32; // from the instruction + static constexpr index_t n_per_blk = 32; // from the instruction + static constexpr index_t k_per_blk = 32; // (is_k_reduction == true) ? 64 / num_input_blks + static constexpr bool is_k_reduction = true; // ??? + // clang-format on + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_32x32x64f8f6f4::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_type +{ + // clang-format off + static constexpr index_t group_size = 4; // ??? group_size * num_groups_per_blk == num_regs_per_blk + static constexpr index_t num_groups_per_blk = 1; // ??? group_size * num_groups_per_blk == num_regs_per_blk + static constexpr index_t num_regs_per_blk = 4; // m_per_blk * n_per_blk / wave_size + static constexpr index_t num_threads_per_blk = 16; // == n_per_blk + static constexpr index_t wave_size = 64; // fixed + static constexpr index_t num_input_blks = 4; // m_per_blk / num_regs_per_blk + static constexpr index_t num_output_blks = 1; // (is_k_reduction == true) ??? + static constexpr index_t m_per_blk = 16; // from the instruction + static constexpr index_t n_per_blk = 16; // from the instruction + static constexpr index_t k_per_blk = 32; // (is_k_reduction == true) ? 128 / num_input_blks + static constexpr bool is_k_reduction = true; // ??? + // clang-format on + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_16x16x128f8f6f4::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_type +{ + // clang-format off + static constexpr index_t group_size = 4; // ??? group_size * num_groups_per_blk == num_regs_per_blk + static constexpr index_t num_groups_per_blk = 4; // ??? group_size * num_groups_per_blk == num_regs_per_blk + static constexpr index_t num_regs_per_blk = 16; // m_per_blk * n_per_blk / wave_size + static constexpr index_t num_threads_per_blk = 32; // n_per_blk + static constexpr index_t wave_size = 64; // fixed + static constexpr index_t num_input_blks = 2; // m_per_blk / num_regs_per_blk + static constexpr index_t num_output_blks = 1; // (is_k_reduction == true) ??? + static constexpr index_t m_per_blk = 32; // from the instruction + static constexpr index_t n_per_blk = 32; // from the instruction + static constexpr index_t k_per_blk = 32; // (is_k_reduction == true) ? 64 / num_input_blks + static constexpr bool is_k_reduction = true; // ??? + // clang-format on + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_scale_f32_32x32x64f8f6f4::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_type +{ + // clang-format off + static constexpr index_t group_size = 4; // ??? group_size * num_groups_per_blk == num_regs_per_blk + static constexpr index_t num_groups_per_blk = 1; // ??? group_size * num_groups_per_blk == num_regs_per_blk + static constexpr index_t num_regs_per_blk = 4; // m_per_blk * n_per_blk / wave_size + static constexpr index_t num_threads_per_blk = 16; // == n_per_blk + static constexpr index_t wave_size = 64; // fixed + static constexpr index_t num_input_blks = 4; // m_per_blk / num_regs_per_blk + static constexpr index_t num_output_blks = 1; // (is_k_reduction == true) ??? + static constexpr index_t m_per_blk = 16; // from the instruction + static constexpr index_t n_per_blk = 16; // from the instruction + static constexpr index_t k_per_blk = 32; // (is_k_reduction == true) ? 128 / num_input_blks + static constexpr bool is_k_reduction = true; // ??? + // clang-format on + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_scale_f32_16x16x128f8f6f4::Run(a, b, reg_c); + } +}; + template + typename additional_type = base_type, + bool is_single_rate_mfma = false> struct MfmaSelector { template + typename additional_type_ = base_type_, + bool is_single_rate_mfma_ = false> static constexpr auto GetMfma(); template <> @@ -711,13 +952,32 @@ struct MfmaSelector } template <> - constexpr auto GetMfma() + constexpr auto GetMfma() + { +#if defined(__gfx950__) + return MfmaInstr::mfma_f32_32x32x16f16; +#else + return MfmaInstr::mfma_f32_32x32x8f16; +#endif + } + template <> + constexpr auto GetMfma() { return MfmaInstr::mfma_f32_32x32x8f16; } template <> - constexpr auto GetMfma() + constexpr auto GetMfma() + { +#if defined(__gfx950__) + return MfmaInstr::mfma_f32_16x16x32f16; +#else + return MfmaInstr::mfma_f32_16x16x16f16; +#endif + } + + template <> + constexpr auto GetMfma() { return MfmaInstr::mfma_f32_16x16x16f16; } @@ -741,7 +1001,19 @@ struct MfmaSelector } template <> - constexpr auto GetMfma() + constexpr auto GetMfma() + { +#if defined(__gfx950__) + return MfmaInstr::mfma_f32_32x32x16bf16; +#elif defined(CK_USE_AMD_MFMA_BF16_1K_OP) + return MfmaInstr::mfma_f32_32x32x8bf16_1k; +#else + return MfmaInstr::mfma_f32_32x32x4bf16; +#endif + } + + template <> + constexpr auto GetMfma() { #if defined(CK_USE_AMD_MFMA_BF16_1K_OP) return MfmaInstr::mfma_f32_32x32x8bf16_1k; @@ -751,7 +1023,19 @@ struct MfmaSelector } template <> - constexpr auto GetMfma() + constexpr auto GetMfma() + { +#if defined(__gfx950__) + return MfmaInstr::mfma_f32_16x16x32bf16; +#elif defined(CK_USE_AMD_MFMA_BF16_1K_OP) + return MfmaInstr::mfma_f32_16x16x16bf16_1k; +#else + return MfmaInstr::mfma_f32_16x16x8bf16; +#endif + } + + template <> + constexpr auto GetMfma() { #if defined(CK_USE_AMD_MFMA_BF16_1K_OP) return MfmaInstr::mfma_f32_16x16x16bf16_1k; @@ -760,7 +1044,18 @@ struct MfmaSelector #endif } -#if defined(CK_USE_AMD_MFMA_GFX940) +#if defined(__gfx950__) + template <> + constexpr auto GetMfma() + { + return MfmaInstr::mfma_i32_32x32x32i8; + } + template <> + constexpr auto GetMfma() + { + return MfmaInstr::mfma_i32_16x16x64i8; + } +#elif defined(__gfx942__) template <> constexpr auto GetMfma() { @@ -832,8 +1127,8 @@ struct MfmaSelector return MfmaInstr::mfma_f32_16x16x32bf8f8; } - static constexpr auto selected_mfma = - mfma_type()>{}; + static constexpr auto selected_mfma = mfma_type< + GetMfma()>{}; __host__ __device__ constexpr MfmaSelector() { @@ -1135,7 +1430,13 @@ struct XdlopsGemm return TransposeC ? CIndex4D{blk_td, I0, blk_id, I0} : CIndex4D{I0, blk_id, I0, blk_td}; } - static constexpr auto mfma = MfmaSelector{}; + // Falls back to single rate instruction on gfx950 if KPack <= 4; no change on gfx942- + static constexpr auto + mfma = MfmaSelector < base_type, + MPerXdlops, NPerXdlops, additional_type, + ((is_same::value || is_same::value) && KPack <= 4) + ? true + : false > {}; static constexpr auto mfma_instr = mfma.selected_mfma; diff --git a/include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp b/include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp index b91b12ad52..3db94deccb 100644 --- a/include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp +++ b/include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp @@ -1,10 +1,9 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once -#include "ck/library/utility/numeric.hpp" #include "ck/utility/common_header.hpp" #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" @@ -148,8 +147,8 @@ struct TransformConvFwdToGemm template ::type = false> + index_t NDim = NDimSpatial, + typename ck::enable_if::type = false> __host__ __device__ TransformConvFwdToGemm(const ConvDimsType& a_g_n_c_wis_lengths, const ConvDimsType& a_g_n_c_wis_strides, const ConvDimsType& b_g_k_c_xs_lengths, @@ -201,11 +200,15 @@ struct TransformConvFwdToGemm InRightPadW_{input_right_pads[I0]}, ZYX_{X_} { +#ifdef CK_CODE_GEN_RTC + static_assert(is_same_v>); + static_assert(is_same_v>); +#else static_assert(is_same_v> || is_same_v>); static_assert(is_same_v> || is_same_v>); - +#endif if constexpr(SplitN) { N_ = GetSplitedNSize( @@ -219,8 +222,8 @@ struct TransformConvFwdToGemm template ::type = false> + index_t NDim = NDimSpatial, + typename ck::enable_if::type = false> __host__ __device__ TransformConvFwdToGemm(const ConvDimsType& a_g_n_c_wis_lengths, const ConvDimsType& a_g_n_c_wis_strides, const ConvDimsType& b_g_k_c_xs_lengths, @@ -272,11 +275,15 @@ struct TransformConvFwdToGemm InRightPadW_{input_right_pads[I1]}, ZYX_{Y_ * X_} { +#ifdef CK_CODE_GEN_RTC + static_assert(is_same_v>); + static_assert(is_same_v>); +#else static_assert(is_same_v> || is_same_v>); static_assert(is_same_v> || is_same_v>); - +#endif if constexpr(SplitN) { N_ = GetSplitedNSize( @@ -290,8 +297,8 @@ struct TransformConvFwdToGemm template ::type = false> + index_t NDim = NDimSpatial, + typename ck::enable_if::type = false> __host__ __device__ TransformConvFwdToGemm(const ConvDimsType& a_g_n_c_wis_lengths, const ConvDimsType& a_g_n_c_wis_strides, const ConvDimsType& b_g_k_c_xs_lengths, @@ -343,11 +350,15 @@ struct TransformConvFwdToGemm InRightPadW_{input_right_pads[I2]}, ZYX_{Z_ * Y_ * X_} { +#ifdef CK_CODE_GEN_RTC + static_assert(is_same_v>); + static_assert(is_same_v>); +#else static_assert(is_same_v> || is_same_v>); static_assert(is_same_v> || is_same_v>); - +#endif if constexpr(SplitN) { N_ = GetSplitedNSize( @@ -478,11 +489,11 @@ struct TransformConvFwdToGemm // TODO: implement ck::tensor_layout::convolution that describe packed/strided dimemsion as // properties template || - is_same_v || - is_same_v), - bool>::type = false> + typename ck::enable_if || + is_same_v || + is_same_v), + bool>::type = false> __host__ __device__ auto MakeADescriptor_M_K() const { if constexpr(ConvForwardSpecialization == @@ -691,11 +702,11 @@ struct TransformConvFwdToGemm } template || - is_same_v || - is_same_v), - bool>::type = false> + typename ck::enable_if || + is_same_v || + is_same_v), + bool>::type = false> __host__ __device__ auto MakeADescriptor_M_K() const { @@ -932,7 +943,7 @@ struct TransformConvFwdToGemm } template || is_same_v || is_same_v), @@ -1242,19 +1253,19 @@ struct TransformConvFwdToGemm } template || - is_same_v || - is_same_v, - bool>::type = false> + typename ck::enable_if || + is_same_v || + is_same_v, + bool>::type = false> __host__ __device__ auto MakeBDescriptor_N_K() const { if constexpr(ConvForwardSpecialization == device::ConvolutionForwardSpecialization::Filter3x3) { using FilterSizeNumType = - std::conditional_t, - std::conditional_t, Number<27>>>; + ck::conditional_t, + ck::conditional_t, Number<27>>>; if constexpr(NumGroupsToMerge == 1) { @@ -1297,13 +1308,13 @@ struct TransformConvFwdToGemm template < typename BLayout, - typename std::enable_if || - is_same_v || - is_same_v || - is_same_v || - is_same_v || - is_same_v, - bool>::type = false> + typename ck::enable_if || + is_same_v || + is_same_v || + is_same_v || + is_same_v || + is_same_v, + bool>::type = false> __host__ __device__ auto MakeBDescriptor_N_K() const { const auto wei_k_yx_c_desc = make_naive_tensor_descriptor( @@ -1318,36 +1329,36 @@ struct TransformConvFwdToGemm return wei_gemmn_gemmk_desc; } - template ), - bool>::type = false> + typename ck::enable_if), + bool>::type = false> __host__ __device__ auto MakeCDescriptor_M_N() const { return make_naive_tensor_descriptor(make_tuple(N_ * Wo_, K_), make_tuple(I0, KStrideTensorC_)); } - template ), - bool>::type = false> + typename ck::enable_if), + bool>::type = false> __host__ __device__ auto MakeCDescriptor_M_N() const { return make_naive_tensor_descriptor(make_tuple(N_ * Ho_ * Wo_, K_), make_tuple(I0, KStrideTensorC_)); } - template ), - bool>::type = false> + typename ck::enable_if), + bool>::type = false> __host__ __device__ auto MakeCDescriptor_M_N() const { return make_naive_tensor_descriptor(make_tuple(N_ * Do_ * Ho_ * Wo_, K_), @@ -1355,12 +1366,12 @@ struct TransformConvFwdToGemm } template || - is_same_v || - is_same_v), - bool>::type = false> + index_t NDimSp = NDimSpatial, + typename ck::enable_if || + is_same_v || + is_same_v), + bool>::type = false> __host__ __device__ auto MakeCDescriptor_M_N() const { const IndexType NDoHoWo = N_ * Wo_; @@ -1410,11 +1421,11 @@ struct TransformConvFwdToGemm template || - is_same_v || - is_same_v), - bool>::type = false> + typename ck::enable_if || + is_same_v || + is_same_v), + bool>::type = false> __host__ __device__ auto MakeCDescriptor_M_N() const { const IndexType NDoHoWo = N_ * Ho_ * Wo_; @@ -1467,7 +1478,7 @@ struct TransformConvFwdToGemm template || is_same_v || is_same_v), diff --git a/include/ck/utility/amd_buffer_addressing.hpp b/include/ck/utility/amd_buffer_addressing.hpp index ad13c44311..328e37d009 100644 --- a/include/ck/utility/amd_buffer_addressing.hpp +++ b/include/ck/utility/amd_buffer_addressing.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "data_type.hpp" @@ -581,7 +581,7 @@ __device__ void amd_global_atomic_add_impl(const typename vector_type::typ tmp.template AsType()[i]); }); } -#if defined(__gfx942__) +#if defined(__gfx942__) || defined(__gfx950__) else if constexpr(is_same::value) { vector_type tmp{src_thread_data}; @@ -1021,15 +1021,24 @@ __device__ void amd_direct_load_global_to_lds(const T* global_base_ptr, constexpr auto bytes_per_thread = sizeof(T) * NumElemsPerThread; static_assert(bytes_per_thread == dword_bytes); +#ifndef CK_CODE_GEN_RTC const uint32_t* global_ptr = reinterpret_cast(reinterpret_cast(global_base_ptr)); +#else + const uint32_t* global_ptr = + reinterpret_cast(reinterpret_cast(global_base_ptr)); +#endif const int32x4_t src_resource = make_wave_buffer_resource(global_ptr, src_element_space_size); const index_t global_offset_bytes = is_valid ? global_offset * sizeof(T) : 0x80000000; #if CK_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM T* lds_ptr = lds_base_ptr + lds_offset; +#ifndef CK_CODE_GEN_RTC auto const lds_ptr_sgpr = __builtin_amdgcn_readfirstlane((reinterpret_cast(lds_ptr))); +#else + auto const lds_ptr_sgpr = __builtin_amdgcn_readfirstlane((reinterpret_cast(lds_ptr))); +#endif asm volatile("s_mov_b32 m0, %0; \n\t" "buffer_load_dword %1, %2, 0 offen lds;\n\t" ::"s"(lds_ptr_sgpr), "v"(global_offset_bytes), @@ -1038,8 +1047,13 @@ __device__ void amd_direct_load_global_to_lds(const T* global_base_ptr, #else // LDS pointer must be attributed with the LDS address space. __attribute__((address_space(3))) uint32_t* lds_ptr = +#ifndef CK_CODE_GEN_RTC reinterpret_cast<__attribute__((address_space(3))) uint32_t*>( reinterpret_cast(lds_base_ptr + lds_offset)); +#else + reinterpret_cast<__attribute__((address_space(3))) uint32_t*>( + reinterpret_cast(lds_base_ptr + lds_offset)); +#endif llvm_amdgcn_raw_buffer_load_lds( src_resource, lds_ptr, sizeof(uint32_t), global_offset_bytes, 0, 0, 0); diff --git a/include/ck/utility/amd_ck_fp8.hpp b/include/ck/utility/amd_ck_fp8.hpp index e9174904c9..42b784d303 100644 --- a/include/ck/utility/amd_ck_fp8.hpp +++ b/include/ck/utility/amd_ck_fp8.hpp @@ -1,8 +1,10 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once +#include "ck/ck.hpp" +#include "ck/utility/enable_if.hpp" #include "ck/utility/random_gen.hpp" #include "ck/utility/type.hpp" @@ -18,39 +20,25 @@ #define CK_USE_OCP_FP8 0 #endif -namespace { -// https://en.cppreference.com/w/cpp/types/conditional -template -struct conditional -{ - using type = T; -}; -template -struct conditional -{ - using type = F; -}; -} // namespace - -namespace ck { - -using f8_fnuz_t = _BitInt(8); -using bf8_fnuz_t = unsigned _BitInt(8); - #if(defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) || defined(__gfx1200__) || \ - defined(__gfx1201__)) && \ + defined(__gfx1201__) || defined(__gfx950__)) && \ __HIP_DEVICE_COMPILE__ #define CK_FP8_CVT_FAST_PATH 1 #else #define CK_FP8_CVT_FAST_PATH 0 #endif -#if(defined(__gfx1200__) || defined(__gfx1201__)) && __HIP_DEVICE_COMPILE__ +#if(defined(__gfx1200__) || defined(__gfx1201__) || defined(__gfx950__)) && __HIP_DEVICE_COMPILE__ #define CK_OCP_FP8_CVT_FAST_PATH 1 #else #define CK_OCP_FP8_CVT_FAST_PATH 0 #endif +namespace ck { + +using f8_fnuz_t = _BitInt(8); +using bf8_fnuz_t = unsigned _BitInt(8); + typedef unsigned char fp8_storage_t; /** @@ -205,10 +193,11 @@ __host__ __device__ static inline T cast_from_f8(fp8_storage_t x) } } - typename conditional< + typename std::conditional< sizeof(T) == 2, unsigned short int, - typename conditional::type>::type retval; + typename std::conditional::type>::type + retval; if constexpr(we == 5 && is_half && !is_fnuz) { @@ -301,7 +290,6 @@ static __device__ float2_t cast_to_f32x2_from_f8x2(fp8x2_storage_t v) return __builtin_amdgcn_cvt_pk_f32_bf8(i16val, false); } } - #endif } // namespace fp8_impl @@ -376,7 +364,7 @@ struct bf8_ocp_t __host__ explicit operator float() const #endif { -#if defined(__gfx1200__) || defined(__gfx1201__) +#if defined(__gfx950__) || defined(__gfx1200__) || defined(__gfx1201__) return fp8_impl::cast_to_f32_from_f8(this->data); #else return fp8_impl::cast_from_f8( @@ -390,7 +378,7 @@ struct bf8_ocp_t __host__ explicit operator _Float16() const #endif { -#if defined(__gfx1200__) || defined(__gfx1201__) +#if defined(__gfx950__) || defined(__gfx1200__) || defined(__gfx1201__) return static_cast<_Float16>(fp8_impl::cast_to_f32_from_f8(this->data)); #else return fp8_impl::cast_from_f8<_Float16, wm, we, false>( @@ -424,9 +412,9 @@ __host__ __device__ inline constexpr bool fp8_is_nan(bf8_fnuz_t a) } template || std::is_same_v || - std::is_same_v || std::is_same_v, - bool> = true> + ck::enable_if_t || is_same_v || + is_same_v || is_same_v, + bool> = true> __host__ __device__ static inline constexpr bool fp8_is_inf(T) { return false; @@ -551,10 +539,10 @@ __host__ __device__ static inline fp8_storage_t cast_to_f8(T _x, unsigned int rn constexpr int mfmt = (sizeof(T) == 8) ? 52 : ((sizeof(T) == 4) ? 23 : 10); - using T_bitwise = typename conditional< + using T_bitwise = typename std::conditional< sizeof(T) == 2, unsigned short int, - typename conditional::type>::type; + typename std::conditional::type>::type; T_bitwise x_bitwise = bit_cast(_x); unsigned long long x{x_bitwise}; @@ -823,7 +811,11 @@ __host__ __device__ static inline fp8_storage_t cvt_float_to_fp8(const float f) if constexpr(stochastic_rounding) { constexpr int seed = 1254739; - rng = prand_generator(reinterpret_cast(&f), f); +#ifndef CK_CODE_GEN_RTC + rng = prand_generator(reinterpret_cast(&f), f); +#else + rng = prand_generator(reinterpret_cast(&f), f); +#endif } return cast_to_f8_from_f32( f, rng); @@ -839,7 +831,11 @@ __host__ static inline fp8_storage_t cvt_float_to_fp8(const float f) if constexpr(stochastic_rounding) { constexpr int seed = 1254739; +#ifndef CK_CODE_GEN_RTC rng = prand_generator(reinterpret_cast(&f), f); +#else + rng = prand_generator(reinterpret_cast(&f), f); +#endif } if constexpr(interp == ck_fp8_interpretation_t::CK_E4M3_FNUZ) diff --git a/include/ck/utility/amd_wave_read_first_lane.hpp b/include/ck/utility/amd_wave_read_first_lane.hpp index d6e1eab314..128c8e9a2c 100644 --- a/include/ck/utility/amd_wave_read_first_lane.hpp +++ b/include/ck/utility/amd_wave_read_first_lane.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -7,10 +7,12 @@ #include "ck/utility/functional2.hpp" #include "ck/utility/math.hpp" +#ifndef CK_CODE_GEN_RTC #include #include #include #include +#endif namespace ck { namespace detail { @@ -37,7 +39,7 @@ struct get_carrier<3> { using value_type = uint32_t; - std::array bytes; + Array bytes; static_assert(sizeof(bytes) <= sizeof(value_type)); // replacement of host std::copy_n() @@ -61,22 +63,22 @@ struct get_carrier<3> // method to trigger template substitution failure __device__ carrier(const carrier& other) noexcept { - copy_n(other.bytes.begin(), bytes.size(), bytes.begin()); + copy_n(other.bytes.begin(), bytes.Size(), bytes.begin()); } public: __device__ carrier& operator=(value_type value) noexcept { - copy_n(reinterpret_cast(&value), bytes.size(), bytes.begin()); + copy_n(reinterpret_cast(&value), bytes.Size(), bytes.begin()); return *this; } __device__ operator value_type() const noexcept { - std::byte result[sizeof(value_type)]; + ck::byte result[sizeof(value_type)]; - copy_n(bytes.begin(), bytes.size(), result); + copy_n(bytes.begin(), bytes.Size(), result); return *reinterpret_cast(result); } @@ -109,8 +111,8 @@ __device__ inline int64_t amd_wave_read_first_lane(int64_t value) { constexpr unsigned object_size = sizeof(int64_t); constexpr unsigned second_part_offset = object_size / 2; - auto* const from_obj = reinterpret_cast(&value); - alignas(int64_t) std::byte to_obj[object_size]; + auto* const from_obj = reinterpret_cast(&value); + alignas(int64_t) ck::byte to_obj[object_size]; using Sgpr = uint32_t; @@ -122,17 +124,16 @@ __device__ inline int64_t amd_wave_read_first_lane(int64_t value) return *reinterpret_cast(to_obj); } -template < - typename Object, - typename = std::enable_if_t && std::is_trivially_copyable_v>> +template && ck::is_trivially_copyable_v>> __device__ auto amd_wave_read_first_lane(const Object& obj) { using Size = unsigned; constexpr Size SgprSize = 4; constexpr Size ObjectSize = sizeof(Object); - auto* const from_obj = reinterpret_cast(&obj); - alignas(Object) std::byte to_obj[ObjectSize]; + auto* const from_obj = reinterpret_cast(&obj); + alignas(Object) ck::byte to_obj[ObjectSize]; constexpr Size RemainedSize = ObjectSize % SgprSize; constexpr Size CompleteSgprCopyBoundary = ObjectSize - RemainedSize; diff --git a/include/ck/utility/amd_xdlops.hpp b/include/ck/utility/amd_xdlops.hpp index 5a7030cca7..b125e3adf6 100644 --- a/include/ck/utility/amd_xdlops.hpp +++ b/include/ck/utility/amd_xdlops.hpp @@ -5,7 +5,7 @@ namespace ck { // Define the common macro for MI300 models -#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) || defined(__gfx950__) #define __gfx94__ #endif @@ -134,6 +134,46 @@ struct intrin_mfma_f32_32x32x4f16<32, 64> } }; +template +struct intrin_mfma_f32_32x32x16f16; + +template <> +struct intrin_mfma_f32_32x32x16f16<32, 32> +{ + template + __device__ static void Run(const half8_t& reg_a, const half8_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx950__) + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x16_f16( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif // defined(__gfx950__) + } +}; + +template +struct intrin_mfma_f32_16x16x32f16; + +template <> +struct intrin_mfma_f32_16x16x32f16<16, 16> +{ + template + __device__ static void Run(const half8_t& reg_a, const half8_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx950__) + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_f16( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif // defined(__gfx950__) + } +}; + template struct intrin_mfma_f32_32x32x8f16; @@ -204,6 +244,46 @@ struct intrin_mfma_f32_4x4x4f16<8, 64> }; // bfp16 +template +struct intrin_mfma_f32_32x32x16bf16; + +template <> +struct intrin_mfma_f32_32x32x16bf16<32, 32> +{ + template + __device__ static void Run(const bhalf8_t& reg_a, const bhalf8_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx950__) + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x16_bf16( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif // defined(__gfx950__) + } +}; + +template +struct intrin_mfma_f32_16x16x32bf16; + +template <> +struct intrin_mfma_f32_16x16x32bf16<16, 16> +{ + template + __device__ static void Run(const bhalf8_t& reg_a, const bhalf8_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx950__) + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_bf16( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif // defined(__gfx950__) + } +}; + template struct intrin_mfma_f32_32x32x8bf16_1k; @@ -298,6 +378,46 @@ struct intrin_mfma_i32_16x16x16i8<16, 16> } }; +template +struct intrin_mfma_i32_32x32x32i8; + +template <> +struct intrin_mfma_i32_32x32x32i8<32, 32> +{ + template + __device__ static void Run(const int8x16_t& reg_a, const int8x16_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx950__) + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_i32_32x32x32_i8( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif // defined(__gfx950__) + } +}; + +template +struct intrin_mfma_i32_16x16x64i8; + +template <> +struct intrin_mfma_i32_16x16x64i8<16, 16> +{ + template + __device__ static void Run(const int8x16_t& reg_a, const int8x16_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx950__) + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_i32_16x16x64_i8( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif // defined(__gfx950__) + } +}; + template struct intrin_mfma_i32_32x32x16i8; @@ -356,6 +476,149 @@ struct intrin_mfma_f64_16x16x4f64<16, 16> } }; +template +struct intrin_mfma_f32_32x32x64f8f6f4; + +/// @brief Performs a matrix fused multiply-accumulate operation on 32x32x64 submatrices for f8, f6, +/// and f4 data types. +/// +/// @note Calls scaled version of the instruction as the original instruction is not supported in +/// the backend. That is the intended use. There is a backend optimization to select the unscaled +/// operation if the scale is 0. +template <> +struct intrin_mfma_f32_32x32x64f8f6f4<32, 32> +{ + template + __device__ static void Run(const f8x32_t& reg_a, const f8x32_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx950__) + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4( + reg_a, + reg_b, + reg_c.template AsType()[Number<0>{}], + 0, // cbsz + 0, // blgp + 0, + 0, + 0, + 0); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif + } +}; + +template +struct intrin_mfma_scale_f32_32x32x64f8f6f4; + +template <> +struct intrin_mfma_scale_f32_32x32x64f8f6f4<32, 32> +{ + template + __device__ static void Run(const f8x32_t& reg_a, + const int32_t scale_a, + const f8x32_t& reg_b, + const int32_t scale_b, + FloatC& reg_c) + { +#if defined(__gfx950__) + // https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10 + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4( + reg_a, + reg_b, + reg_c.template AsType()[Number<0>{}], + 0, // cbsz + 0, // blgp + 0, // { OPSEL_HI[0], OPSEL[0] }? + scale_a, + 0, // { OPSEL_HI[1], OPSEL[1] }? + scale_b); +#else + ignore = reg_a; + ignore = scale_a; + ignore = reg_b; + ignore = scale_b; + ignore = reg_c; +#endif + } +}; + +template +struct intrin_mfma_scale_f32_16x16x128f8f6f4; + +template <> +struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16> +{ + template + __device__ static void Run(const f8x32_t& reg_a, + const int32_t scale_a, + const f8x32_t& reg_b, + const int32_t scale_b, + FloatC& reg_c) + { +#if defined(__gfx950__) + // https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10 + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( + reg_a, + reg_b, + reg_c.template AsType()[Number<0>{}], + 0, // cbsz + 0, // blgp + 0, // { OPSEL_HI[0], OPSEL[0] }? + scale_a, + 0, // { OPSEL_HI[1], OPSEL[1] }? + scale_b); +#else + ignore = reg_a; + ignore = scale_a; + ignore = reg_b; + ignore = scale_b; + ignore = reg_c; +#endif + } +}; + +template +struct intrin_mfma_f32_16x16x128f8f6f4; + +/// @brief Performs a matrix fused multiply-accumulate operation on 16x16x128 submatrices for f8f6f4 +/// data types. +/// +/// @note Calls scaled version of the instruction as the original instruction is not supported in +/// the backend. That is the intended use. There is a backend optimization to select the unscaled +/// operation if the scale is 0. +template <> +struct intrin_mfma_f32_16x16x128f8f6f4<16, 16> +{ + template + __device__ static void Run(const f8x32_t& reg_a, const f8x32_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx950__) + // https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10 + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( + reg_a, + reg_b, + reg_c.template AsType()[Number<0>{}], + 0, // cbsz + 0, // blgp + 0, + 0, + 0, + 0); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif + } +}; + template struct intrin_mfma_f32_32x32x16f8f8; diff --git a/include/ck/utility/array.hpp b/include/ck/utility/array.hpp index 5366c56a9d..2afad00d49 100644 --- a/include/ck/utility/array.hpp +++ b/include/ck/utility/array.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #ifndef CK_ARRAY_HPP #define CK_ARRAY_HPP @@ -38,6 +38,8 @@ struct Array } __host__ __device__ constexpr const TData* begin() const { return &mData[0]; } __host__ __device__ constexpr const TData* end() const { return &mData[NSize]; } + __host__ __device__ constexpr TData* begin() { return &mData[0]; } + __host__ __device__ constexpr TData* end() { return &mData[NSize]; } }; // empty Array @@ -54,7 +56,7 @@ template __host__ __device__ constexpr auto make_array(X&& x, Xs&&... xs) { using data_type = remove_cvref_t; - return Array{std::forward(x), std::forward(xs)...}; + return Array{ck::forward(x), ck::forward(xs)...}; } // make empty array diff --git a/include/ck/utility/blkgemmpipe_scheduler.hpp b/include/ck/utility/blkgemmpipe_scheduler.hpp index 902195e2f8..86dcb6c157 100644 --- a/include/ck/utility/blkgemmpipe_scheduler.hpp +++ b/include/ck/utility/blkgemmpipe_scheduler.hpp @@ -90,14 +90,22 @@ struct BlockwiseGemmXdlops_pipeline_hotloop_inst KPerXDL); printf(" A/B buffer load inst: %d, %d\n A/B LDS write inst: %d, %d\n A/B LDS read inst: " - "%d, %d\n C MFMA inst: %d\n", + "%d, %d\n C MFMA inst: %d\n" + "A/B LDS read width: %d, %d, A/B LDS write width: %d, %d, A/B buffer load width: " + "%d/ %d\n", A_Buffer_Load_Inst_Num, B_Buffer_Load_Inst_Num, A_LDS_Write_Inst_Num, B_LDS_Write_Inst_Num, A_LDS_Read_Inst_Num, B_LDS_Read_Inst_Num, - C_MFMA_Inst_Num); + C_MFMA_Inst_Num, + A_LDS_Read_Width, + B_LDS_Read_Width, + ALDSWriteWidth, + BLDSWriteWidth, + ABufferLoadWidth, + BBufferLoadWidth); } }; diff --git a/include/ck/utility/container_helper.hpp b/include/ck/utility/container_helper.hpp index 9c7b954565..bd0ca42ecd 100644 --- a/include/ck/utility/container_helper.hpp +++ b/include/ck/utility/container_helper.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #ifndef CK_CONTAINER_HELPER_HPP #define CK_CONTAINER_HELPER_HPP @@ -326,14 +326,14 @@ template __host__ __device__ constexpr auto container_concat(const Array& ax, const Array& ay) { return unpack2( - [&](auto&&... zs) { return make_array(std::forward(zs)...); }, ax, ay); + [&](auto&&... zs) { return make_array(ck::forward(zs)...); }, ax, ay); } template __host__ __device__ constexpr auto container_concat(const Tuple& tx, const Tuple& ty) { return unpack2( - [&](auto&&... zs) { return make_tuple(std::forward(zs)...); }, tx, ty); + [&](auto&&... zs) { return make_tuple(ck::forward(zs)...); }, tx, ty); } template diff --git a/include/ck/utility/data_type.hpp b/include/ck/utility/data_type.hpp index d9c954c50f..f90fcf6791 100644 --- a/include/ck/utility/data_type.hpp +++ b/include/ck/utility/data_type.hpp @@ -4,13 +4,316 @@ #pragma once #include "ck/utility/amd_ck_fp8.hpp" +#include "ck/utility/e8m0.hpp" #include "ck/utility/statically_indexed_array.hpp" - +#ifdef CK_CODE_GEN_RTC +using int8_t = signed char; +using uint8_t = unsigned char; +using int16_t = signed short; +using uint16_t = unsigned short; +using float_t = float; +#endif namespace ck { +#ifdef CK_CODE_GEN_RTC +using byte = unsigned char; +#else +using std::byte; +#endif + using bhalf_t = ushort; using half_t = _Float16; using int4_t = _BitInt(4); +using f4_t = unsigned _BitInt(4); +using f6_t = _BitInt(6); // e2m3 format +using bf6_t = unsigned _BitInt(6); // e3m2 format + +struct f4x2_pk_t +{ + using type = uint8_t; + type data; + f4x2_pk_t() : data{type{}} {} + f4x2_pk_t(type init) : data{init} {} + + template + __host__ __device__ inline type unpack(Number) const + { + static_assert(I < 2, "Index is out of range."); + if constexpr(I == 0) + return data & 0b00001111; + else + return (data >> 4); + } + + __host__ __device__ inline type pack(const type x0, const type x1) + { + return (x1 << 4) | (x0 & 0b00001111); + } +}; + +struct f6x16_pk_t +{ + // store 16 elements of f6_t in an array of 3 uint32_t + using element_type = uint32_t; + using type = StaticallyIndexedArray_v2; + type data; + typedef int8_t test_vec_t __attribute__((ext_vector_type(16))); + f6x16_pk_t() : data{type{}} {} + f6x16_pk_t(type init) : data{init} {} + + template + __host__ __device__ inline f6_t unpack(Number) + { + static_assert(I < 16, "Index out of range for 16 f6_t elements."); + + constexpr int num_bits_elem = 6; + constexpr int num_bits_vec_elem = 32; + constexpr int vector_size = 3; + constexpr int bit_pos = I * num_bits_elem; + constexpr int arr_idx = bit_pos / num_bits_vec_elem; + constexpr int bit_offset = bit_pos % num_bits_vec_elem; + uint32_t bits = data.At(Number{}) >> bit_offset; + constexpr int overhang = bit_offset + num_bits_elem - num_bits_vec_elem; + + if constexpr(overhang > 0 && (arr_idx + 1) < vector_size) + { + bits |= (data.At(Number{}) & ((1u << overhang) - 1)) + << (num_bits_elem - overhang); + } + + return static_cast(bits & 0x3F); + } + + __host__ __device__ inline type pack(const test_vec_t& x) + { + type packed{}; + + // for each of the 16 f6_t values, place its 6 bits in the correct position + ck::static_for<0, 16, 1>{}([&](auto i) { + uint32_t bits = static_cast(x[static_cast(i)]) & 0x3F; + constexpr int num_bits_elem = 6; + constexpr int num_bits_vec_elem = 32; + constexpr int vector_size = 3; + constexpr int bit_pos = i * num_bits_elem; + constexpr int arr_index = bit_pos / num_bits_vec_elem; + constexpr int bit_offset = bit_pos % num_bits_vec_elem; + constexpr int overhang = bit_offset + num_bits_elem - num_bits_vec_elem; + uint32_t old_value = packed.At(Number{}); + + // insert bits into the current 32-bit block + old_value |= (bits << bit_offset); + packed.At(Number{}) = old_value; + + // if it crosses into the next block, shift the remainder + if constexpr(overhang > 0 && (arr_index + 1) < vector_size) + { + uint32_t next_value = packed.At(Number{}); + next_value |= (bits >> (num_bits_elem - overhang)); + packed.At(Number{}) = next_value; + } + }); + + return packed; + } +}; + +struct f6x32_pk_t +{ + // store 32 elements of f6_t in an array of 6 uint32_t + using element_type = uint32_t; + using type = StaticallyIndexedArray_v2; + type data; + typedef int8_t test_vec_t __attribute__((ext_vector_type(32))); + f6x32_pk_t() : data{type{}} {} + f6x32_pk_t(type init) : data{init} {} + + template + __host__ __device__ inline f6_t unpack(Number) + { + static_assert(I < 32, "Index out of range for 32 f6_t elements."); + + constexpr int num_bits_elem = 6; + constexpr int num_bits_vec_elem = 32; + constexpr int vector_size = 6; + constexpr int bit_pos = I * num_bits_elem; + constexpr int arr_idx = bit_pos / num_bits_vec_elem; + constexpr int bit_offset = bit_pos % num_bits_vec_elem; + uint32_t bits = data.At(Number{}) >> bit_offset; + constexpr int overhang = bit_offset + num_bits_elem - num_bits_vec_elem; + + if constexpr(overhang > 0 && (arr_idx + 1) < vector_size) + { + bits |= (data.At(Number{}) & ((1u << overhang) - 1)) + << (num_bits_elem - overhang); + } + + return static_cast(bits & 0x3F); + } + + __host__ __device__ inline type pack(const test_vec_t& x) + { + type packed{}; + + // for each of the 32 f6_t values, place its 6 bits in the correct position + ck::static_for<0, 32, 1>{}([&](auto i) { + uint32_t bits = static_cast(x[static_cast(i)]) & 0x3F; + constexpr int num_bits_elem = 6; + constexpr int num_bits_vec_elem = 32; + constexpr int vector_size = 6; + constexpr int bit_pos = i * num_bits_elem; + constexpr int arr_index = bit_pos / num_bits_vec_elem; + constexpr int bit_offset = bit_pos % num_bits_vec_elem; + constexpr int overhang = bit_offset + num_bits_elem - num_bits_vec_elem; + uint32_t old_value = packed.At(Number{}); + + // insert bits into the current 32-bit block + old_value |= (bits << bit_offset); + packed.At(Number{}) = old_value; + + // if it crosses into the next block, shift the remainder + if constexpr(overhang > 0 && (arr_index + 1) < vector_size) + { + uint32_t next_value = packed.At(Number{}); + next_value |= (bits >> (num_bits_elem - overhang)); + packed.At(Number{}) = next_value; + } + }); + + return packed; + } +}; + +struct bf6x16_pk_t +{ + // store 16 elements of bf6_t in an array of 3 uint32_t + using element_type = uint32_t; + using type = StaticallyIndexedArray_v2; + type data; + typedef int8_t test_vec_t __attribute__((ext_vector_type(16))); + bf6x16_pk_t() : data{type{}} {} + bf6x16_pk_t(type init) : data{init} {} + + template + __host__ __device__ inline bf6_t unpack(Number) + { + static_assert(I < 16, "Index out of range for 16 f6_t elements."); + + constexpr int num_bits_elem = 6; + constexpr int num_bits_vec_elem = 32; + constexpr int vector_size = 3; + constexpr int bit_pos = I * num_bits_elem; + constexpr int arr_idx = bit_pos / num_bits_vec_elem; + constexpr int bit_offset = bit_pos % num_bits_vec_elem; + uint32_t bits = data.At(Number{}) >> bit_offset; + constexpr int overhang = bit_offset + num_bits_elem - num_bits_vec_elem; + + if constexpr(overhang > 0 && (arr_idx + 1) < vector_size) + { + bits |= (data.At(Number{}) & ((1u << overhang) - 1)) + << (num_bits_elem - overhang); + } + + return static_cast(bits & 0x3F); + } + + __host__ __device__ inline type pack(const test_vec_t& x) + { + type packed{}; + + // for each of the 16 bf6_t values, place its 6 bits in the correct position + ck::static_for<0, 16, 1>{}([&](auto i) { + uint32_t bits = static_cast(x[static_cast(i)]) & 0x3F; + constexpr int num_bits_elem = 6; + constexpr int num_bits_vec_elem = 32; + constexpr int vector_size = 3; + constexpr int bit_pos = i * num_bits_elem; + constexpr int arr_index = bit_pos / num_bits_vec_elem; + constexpr int bit_offset = bit_pos % num_bits_vec_elem; + constexpr int overhang = bit_offset + num_bits_elem - num_bits_vec_elem; + uint32_t old_value = packed.At(Number{}); + + // insert bits into the current 32-bit block + old_value |= (bits << bit_offset); + packed.At(Number{}) = old_value; + + // if it crosses into the next block, shift the remainder + if constexpr(overhang > 0 && (arr_index + 1) < vector_size) + { + uint32_t next_value = packed.At(Number{}); + next_value |= (bits >> (num_bits_elem - overhang)); + packed.At(Number{}) = next_value; + } + }); + + return packed; + } +}; + +struct bf6x32_pk_t +{ + // store 32 elements of bf6_t in an array of 6 uint32_t + using element_type = uint32_t; + using type = StaticallyIndexedArray_v2; + type data; + typedef int8_t test_vec_t __attribute__((ext_vector_type(32))); + bf6x32_pk_t() : data{type{}} {} + bf6x32_pk_t(type init) : data{init} {} + + template + __host__ __device__ inline bf6_t unpack(Number) + { + static_assert(I < 32, "Index out of range for 32 f6_t elements."); + + constexpr int num_bits_elem = 6; + constexpr int num_bits_vec_elem = 32; + constexpr int vector_size = 6; + constexpr int bit_pos = I * num_bits_elem; + constexpr int arr_idx = bit_pos / num_bits_vec_elem; + constexpr int bit_offset = bit_pos % num_bits_vec_elem; + uint32_t bits = data.At(Number{}) >> bit_offset; + constexpr int overhang = bit_offset + num_bits_elem - num_bits_vec_elem; + + if constexpr(overhang > 0 && (arr_idx + 1) < vector_size) + { + bits |= (data.At(Number{}) & ((1u << overhang) - 1)) + << (num_bits_elem - overhang); + } + + return static_cast(bits & 0x3F); + } + + __host__ __device__ inline type pack(const test_vec_t& x) + { + type packed{}; + + // for each of the 32 bf6_t values, place its 6 bits in the correct position + ck::static_for<0, 32, 1>{}([&](auto i) { + uint32_t bits = static_cast(x[static_cast(i)]) & 0x3F; + constexpr int num_bits_elem = 6; + constexpr int num_bits_vec_elem = 32; + constexpr int vector_size = 6; + constexpr int bit_pos = i * num_bits_elem; + constexpr int arr_index = bit_pos / num_bits_vec_elem; + constexpr int bit_offset = bit_pos % num_bits_vec_elem; + constexpr int overhang = bit_offset + num_bits_elem - num_bits_vec_elem; + uint32_t old_value = packed.At(Number{}); + + // insert bits into the current 32-bit block + old_value |= (bits << bit_offset); + packed.At(Number{}) = old_value; + + // if it crosses into the next block, shift the remainder + if constexpr(overhang > 0 && (arr_index + 1) < vector_size) + { + uint32_t next_value = packed.At(Number{}); + next_value |= (bits >> (num_bits_elem - overhang)); + packed.At(Number{}) = next_value; + } + }); + + return packed; + } +}; // custom data type - pack int4 data struct pk_i4_t @@ -28,14 +331,15 @@ inline constexpr auto next_pow2(uint32_t x) } // native types: double, float, _Float16, ushort, int32_t, int8_t, uint8_t, f8_fnuz_t, bf8_fnuz_t, -// native types: bool +// native types: bool, f4_t, f6_t, bf6_t template inline constexpr bool is_native_type() { return is_same::value || is_same::value || is_same::value || is_same::value || is_same::value || is_same::value || is_same::value || is_same::value || - is_same::value || is_same::value; + is_same::value || is_same::value || is_same::value || + is_same::value || is_same::value; } // vector_type @@ -217,7 +521,7 @@ struct scalar_type }; template -struct vector_type()>> +struct vector_type()>> { using d1_t = T; using type = d1_t; @@ -253,7 +557,7 @@ struct vector_type()>> __device__ int static err = 0; template -struct vector_type()>> +struct vector_type()>> { using d1_t = T; typedef T d2_t __attribute__((ext_vector_type(2))); @@ -313,7 +617,7 @@ struct vector_type()>> }; template -struct vector_type()>> +struct vector_type()>> { using d1_t = T; typedef T d2_t __attribute__((ext_vector_type(2))); @@ -383,7 +687,7 @@ struct vector_type()>> }; template -struct vector_type()>> +struct vector_type()>> { using d1_t = T; typedef T d2_t __attribute__((ext_vector_type(2))); @@ -453,7 +757,7 @@ struct vector_type()>> }; template -struct vector_type()>> +struct vector_type()>> { using d1_t = T; typedef T d4_t __attribute__((ext_vector_type(4))); @@ -523,7 +827,7 @@ struct vector_type()>> }; template -struct vector_type()>> +struct vector_type()>> { using d1_t = T; typedef T d2_t __attribute__((ext_vector_type(2))); @@ -605,7 +909,7 @@ struct vector_type()>> }; template -struct vector_type()>> +struct vector_type()>> { using d1_t = T; typedef T d2_t __attribute__((ext_vector_type(2))); @@ -687,7 +991,7 @@ struct vector_type()>> }; template -struct vector_type()>> +struct vector_type()>> { using d1_t = T; typedef T d4_t __attribute__((ext_vector_type(4))); @@ -769,7 +1073,7 @@ struct vector_type()>> }; template -struct vector_type()>> +struct vector_type()>> { using d1_t = T; typedef T d2_t __attribute__((ext_vector_type(2))); @@ -863,7 +1167,7 @@ struct vector_type()>> }; template -struct vector_type()>> +struct vector_type()>> { using d1_t = T; typedef T d2_t __attribute__((ext_vector_type(2))); @@ -967,7 +1271,7 @@ struct vector_type()>> }; template -struct vector_type()>> +struct vector_type()>> { using d1_t = T; typedef T d2_t __attribute__((ext_vector_type(2))); @@ -1083,7 +1387,7 @@ struct vector_type()>> }; template -struct vector_type()>> +struct vector_type()>> { using d1_t = T; typedef T d2_t __attribute__((ext_vector_type(2))); @@ -1209,7 +1513,7 @@ struct vector_type()>> }; template -struct vector_type()>> +struct vector_type()>> { using d1_t = T; typedef T d2_t __attribute__((ext_vector_type(2))); @@ -1358,12 +1662,37 @@ struct nnvb_data_t_selector { using type = f8_ocp_t::data_type; }; + template <> struct nnvb_data_t_selector { using type = bf8_ocp_t::data_type; }; +template <> +struct nnvb_data_t_selector +{ + using type = f6x16_pk_t::type; +}; + +template <> +struct nnvb_data_t_selector +{ + using type = f6x32_pk_t::type; +}; + +template <> +struct nnvb_data_t_selector +{ + using type = bf6x16_pk_t::type; +}; + +template <> +struct nnvb_data_t_selector +{ + using type = bf6x32_pk_t::type; +}; + template <> struct nnvb_data_t_selector { @@ -1374,7 +1703,7 @@ template struct non_native_vector_base< T, N, - std::enable_if_t> + ck::enable_if_t> { using data_t = typename nnvb_data_t_selector::type; // select data_t based on the size of T static_assert(sizeof(T) == sizeof(data_t), "non_native_vector_base storage size mismatch"); @@ -1470,6 +1799,63 @@ struct non_native_vector_base< } }; +// implementation for f6x16 and f6x32 +template +struct non_native_vector_base> +{ + using data_t = + typename nnvb_data_t_selector::type; // select data_t based on declared base type + using element_t = typename T::element_type; // select element_t based on declared element type + static_assert(sizeof(T) == sizeof(data_t), "non_native_vector_base storage size mismatch"); + static constexpr size_t size_factor = + sizeof(data_t) / sizeof(element_t); // f6x16: 12/4 = 3, f6x32: 24/4 = 6 + using data_v = element_t __attribute__((ext_vector_type(N * size_factor))); + using type = non_native_vector_base; + + union alignas(next_pow2(N * sizeof(T))) + { + data_v dN; // storage vector; + StaticallyIndexedArray dxN; + StaticallyIndexedArray dTxN; + StaticallyIndexedArray dNx1; + } data_; + + __host__ __device__ constexpr non_native_vector_base(data_t a) + : data_{data_v(a.At(Number<0>{}))} + { + } + __host__ __device__ constexpr non_native_vector_base(T f) + : non_native_vector_base(bit_cast(f)) + { + } + __host__ __device__ constexpr non_native_vector_base() : non_native_vector_base(T{}){}; + __host__ __device__ constexpr non_native_vector_base(data_v v) : data_{v} {} + + __host__ __device__ constexpr operator data_v() const { return data_.dN; } + __host__ __device__ constexpr operator data_t() const + { + if constexpr(N == 1) + { + return data_.dxN[Number<0>{}]; + } + else + { + return data_.dxN; // XXX this should cause an error + } + } + __host__ __device__ constexpr operator T() const + { + if constexpr(N == 1) + { + return data_.dTxN[Number<0>{}]; + } + else + { + return data_.dTxN; // XXX this should cause an error + } + } +}; + template struct scalar_type>; @@ -1499,7 +1885,7 @@ struct scalar_type> // non-native vector_type implementation template -struct vector_type()>> +struct vector_type()>> { using d1_t = T; using d1_nnv_t = non_native_vector_base; @@ -1550,7 +1936,7 @@ struct vector_type()>> }; template -struct vector_type()>> +struct vector_type()>> { using d1_t = T; using d1_nnv_t = non_native_vector_base; @@ -1613,7 +1999,7 @@ struct vector_type()>> }; template -struct vector_type()>> +struct vector_type()>> { using d1_t = T; using d1_nnv_t = non_native_vector_base; @@ -1686,7 +2072,7 @@ struct vector_type()>> }; template -struct vector_type()>> +struct vector_type()>> { using d1_t = T; using d1_nnv_t = non_native_vector_base; @@ -1771,7 +2157,7 @@ struct vector_type()>> }; template -struct vector_type()>> +struct vector_type()>> { using d1_t = T; using d1_nnv_t = non_native_vector_base; @@ -1866,7 +2252,7 @@ struct vector_type()>> }; template -struct vector_type()>> +struct vector_type()>> { using d1_t = T; using d2_t = non_native_vector_base; @@ -1970,7 +2356,7 @@ struct vector_type()>> }; template -struct vector_type()>> +struct vector_type()>> { using d1_t = T; using d2_t = non_native_vector_base; @@ -2205,25 +2591,251 @@ using uint8x16_t = typename vector_type::type; using uint8x32_t = typename vector_type::type; using uint8x64_t = typename vector_type::type; +// f4 +using f4x2_t = typename vector_type::type; +using f4x4_t = typename vector_type::type; +using f4x8_t = typename vector_type::type; +using f4x16_t = typename vector_type::type; +using f4x32_t = typename vector_type::type; +using f4x64_t = typename vector_type::type; + +// f6 +using f6x16_t = typename vector_type::type; +using f6x32_t = typename vector_type::type; + +// bf6 +using bf6x16_t = typename vector_type::type; +using bf6x32_t = typename vector_type::type; + // pack int4 using pk_i4x2_t = typename vector_type::type; using pk_i4x4_t = typename vector_type::type; using pk_i4x8_t = typename vector_type::type; +#ifdef CK_CODE_GEN_RTC +template +struct NumericLimits; + +template <> +struct NumericLimits +{ + __host__ __device__ static constexpr int32_t Lowest() noexcept { return -2147483647 - 1; } + + __host__ __device__ static constexpr int32_t Min() noexcept { return -2147483647 - 1; } + + __host__ __device__ static constexpr int32_t Max() noexcept { return 2147483647; } + + __host__ __device__ static constexpr int32_t Infinity() noexcept { return 0; } + + __host__ __device__ static constexpr int32_t QuietNaN() { return 0; } +}; +template <> +struct NumericLimits +{ + __host__ __device__ static constexpr int16_t Lowest() noexcept { return -32768; } + + __host__ __device__ static constexpr int16_t Min() noexcept { return -32768; } + + __host__ __device__ static constexpr int16_t Max() noexcept { return 32767; } + + __host__ __device__ static constexpr int16_t Infinity() noexcept { return 0; } + + __host__ __device__ static constexpr int16_t QuietNaN() { return 0; } +}; + +template <> +struct NumericLimits +{ + __host__ __device__ static constexpr int8_t Lowest() noexcept { return -128; } + + __host__ __device__ static constexpr int8_t Min() noexcept { return -128; } + + __host__ __device__ static constexpr int8_t Max() noexcept { return 127; } + + __host__ __device__ static constexpr int8_t Infinity() noexcept { return 0; } + + __host__ __device__ static constexpr int8_t QuietNaN() { return 0; } +}; + +template <> +struct NumericLimits +{ + __host__ __device__ static constexpr uint32_t Lowest() noexcept { return 0; } + + __host__ __device__ static constexpr uint32_t Min() noexcept { return 0; } + + __host__ __device__ static constexpr uint32_t Max() noexcept { return 4294967295U; } + + __host__ __device__ static constexpr uint32_t Infinity() noexcept { return 0; } + + __host__ __device__ static constexpr uint32_t QuietNaN() { return 0; } +}; + +template <> +struct NumericLimits +{ + __host__ __device__ static constexpr uint16_t Lowest() noexcept { return 0; } + + __host__ __device__ static constexpr uint16_t Min() noexcept { return 0; } + + __host__ __device__ static constexpr uint16_t Max() noexcept { return 65535U; } + + __host__ __device__ static constexpr uint16_t Infinity() noexcept { return 0; } + + __host__ __device__ static constexpr uint16_t QuietNaN() { return 0; } +}; + +template <> +struct NumericLimits +{ + static constexpr unsigned int binary_min = 0x00800000; + static constexpr unsigned int binary_max = 0x7F7FFFFF; + static constexpr unsigned int binary_lowest = 0xFF7FFFFF; + static constexpr unsigned int binary_qnan = 0xFFC00001; + static constexpr unsigned int binary_inf = 0x7F8000000; + + __host__ __device__ static constexpr float Min() { return bit_cast(binary_min); } + + __host__ __device__ static constexpr float Max() { return bit_cast(binary_max); } + + __host__ __device__ static constexpr float Lowest() { return bit_cast(binary_lowest); } + + __host__ __device__ static constexpr float QuietNaN() { return bit_cast(binary_qnan); } + + __host__ __device__ static constexpr float Infinity() { return bit_cast(binary_inf); } +}; + +template <> +struct NumericLimits +{ + static constexpr unsigned short binary_min = 0x0400; + static constexpr unsigned short binary_max = 0x7BFF; + static constexpr unsigned short binary_lowest = 0xFBFF; + static constexpr unsigned short binary_qnan = 0x7FFF; + + __host__ __device__ static constexpr half_t Min() { return bit_cast(binary_min); } + + __host__ __device__ static constexpr half_t Max() { return bit_cast(binary_max); } + + __host__ __device__ static constexpr half_t Lowest() { return bit_cast(binary_lowest); } + + __host__ __device__ static constexpr half_t QuietNaN() { return bit_cast(binary_qnan); } +}; + +#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 +template <> +struct NumericLimits +{ + __host__ __device__ static constexpr int4_t Min() { return int4_t(-8); } + + __host__ __device__ static constexpr int4_t Max() { return int4_t(7); } + + __host__ __device__ static constexpr int4_t Lowest() { return int4_t(-8); } +}; +#endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 + +template <> +struct NumericLimits +{ + // negative zero nan mode with exp bias = 8 + static constexpr uint8_t binary_min = 0x08; // 0b00001000 + static constexpr uint8_t binary_max = 0x7F; // 0b01111111 + static constexpr uint8_t binary_lowest = 0xFF; // 0b11111111 + static constexpr uint8_t binary_qnan = 0x80; // 0b10000000 + // ieee mode with exp bias = 7 + // static constexpr uint8_t binary_min = 0x08; // 0b00001000 + // static constexpr uint8_t binary_max = 0x77; // 0b01110111 + // static constexpr uint8_t binary_lowest = 0xF7; // 0b11110111 + // static constexpr uint8_t binary_qnan = 0x79; // any sign, exp=1111, mant!=0 + + __host__ __device__ static constexpr f8_fnuz_t Min() { return f8_fnuz_t(binary_min); } + + __host__ __device__ static constexpr f8_fnuz_t Max() { return f8_fnuz_t(binary_max); } + + __host__ __device__ static constexpr f8_fnuz_t Lowest() { return f8_fnuz_t(binary_lowest); } + + __host__ __device__ static constexpr f8_fnuz_t QuietNaN() { return f8_fnuz_t(binary_qnan); } +}; + +template <> +struct NumericLimits +{ + // negative zero nan mode with exp bias = 16 + static constexpr uint8_t binary_min = 0x04; // 0b00000100 + static constexpr uint8_t binary_max = 0x7F; // 0b01111111 + static constexpr uint8_t binary_lowest = 0xFF; // 0b11111111 + static constexpr uint8_t binary_qnan = 0x80; // 0b10000000 + // ieee mode with exp bias = 15 + // static constexpr uint8_t binary_min = 0x04; // 0b00000100 + // static constexpr uint8_t binary_max = 0x7B; // 0b01111011 + // static constexpr uint8_t binary_lowest = 0xFB; // 0b11111011 + // static constexpr uint8_t binary_qnan = 0x79; // any sign, exp=1111, mant!= + + __host__ __device__ static constexpr bf8_fnuz_t Min() { return bf8_fnuz_t(binary_min); } + + __host__ __device__ static constexpr bf8_fnuz_t Max() { return bf8_fnuz_t(binary_max); } + + __host__ __device__ static constexpr bf8_fnuz_t Lowest() { return bf8_fnuz_t(binary_lowest); } + + __host__ __device__ static constexpr bf8_fnuz_t QuietNaN() { return bf8_fnuz_t(binary_qnan); } +}; + +template <> +struct NumericLimits +{ + static constexpr uint8_t binary_min = 0x08; // 0b00001000 = 2^-6 + static constexpr uint8_t binary_max = 0x7E; // 0b01111110 = 448 + static constexpr uint8_t binary_lowest = 0xFE; // 0b11111110 = -448 + static constexpr uint8_t binary_qnan = 0x7F; // 0b01111111 + + __host__ __device__ static constexpr f8_ocp_t Min() { return bit_cast(binary_min); } + + __host__ __device__ static constexpr f8_ocp_t Max() { return bit_cast(binary_max); } + + __host__ __device__ static constexpr f8_ocp_t Lowest() + { + return bit_cast(binary_lowest); + } + + __host__ __device__ static constexpr f8_ocp_t QuietNaN() + { + return bit_cast(binary_qnan); + } +}; + +template <> +struct NumericLimits +{ + static constexpr uint8_t binary_min = 0x04; // 0b00000100 = 2^-14 + static constexpr uint8_t binary_max = 0x7B; // 0b01111011 = 57344 + static constexpr uint8_t binary_lowest = 0xFB; // 0b11111011 = -57344 + static constexpr uint8_t binary_qnan = 0x7D; // 0b01111101 + + __host__ __device__ static constexpr bf8_ocp_t Min() { return bit_cast(binary_min); } + + __host__ __device__ static constexpr bf8_ocp_t Max() { return bit_cast(binary_max); } + + __host__ __device__ static constexpr bf8_ocp_t Lowest() + { + return bit_cast(binary_lowest); + } + + __host__ __device__ static constexpr bf8_ocp_t QuietNaN() + { + return bit_cast(binary_qnan); + } +}; +#else template struct NumericLimits { __host__ __device__ static constexpr T Min() { return std::numeric_limits::min(); } - __host__ __device__ static constexpr T Max() { return std::numeric_limits::max(); } - __host__ __device__ static constexpr T Lowest() { return std::numeric_limits::lowest(); } - __host__ __device__ static constexpr T QuietNaN() { return std::numeric_limits::quiet_NaN(); } - __host__ __device__ static constexpr T Infinity() { return std::numeric_limits::infinity(); } }; @@ -2347,6 +2959,119 @@ struct NumericLimits return bit_cast(binary_qnan); } }; +#endif + +template <> +struct NumericLimits +{ + static constexpr uint8_t binary_min_normal = 0x2; // 0b0010 + static constexpr uint8_t binary_max_normal = 0x7; // 0b0111 + static constexpr uint8_t binary_lowest_normal = 0xF; // 0b1111 + static constexpr uint8_t binary_min_subnorm = 0x1; // 0b0001 + static constexpr uint8_t binary_max_subnorm = 0x1; // 0b0001 + + static constexpr float data_max_normal_number = 6; + static constexpr float data_min_subnormal_number = 0.5; + + __host__ __device__ static constexpr f4_t Min() { return f4_t(binary_min_normal); } + __host__ __device__ static constexpr f4_t Max() { return f4_t(binary_max_normal); } + __host__ __device__ static constexpr f4_t Lowest() { return f4_t(binary_lowest_normal); } + __host__ __device__ static constexpr f4_t MinSubnorm() { return f4_t(binary_min_subnorm); } + __host__ __device__ static constexpr f4_t MaxSubnorm() { return f4_t(binary_max_subnorm); } + + __host__ __device__ static constexpr float DataMaxNorm() { return data_max_normal_number; } + __host__ __device__ static constexpr float DataMinSubnorm() + { + return data_min_subnormal_number; + } +}; + +template <> +struct NumericLimits +{ + static constexpr uint8_t binary_min_normal = 0x08; // 0b001000 + static constexpr uint8_t binary_max_normal = 0x1F; // 0b011111 + static constexpr uint8_t binary_lowest_normal = 0x3F; // 0b111111 + static constexpr uint8_t binary_min_subnorm = 0x01; // 0b000001 + static constexpr uint8_t binary_max_subnorm = 0x07; // 0b000111 + + static constexpr float data_max_normal_number = 7.5; + static constexpr float data_min_subnormal_number = 0.125; + + __host__ __device__ static constexpr f6_t Min() { return f6_t(binary_min_normal & 0b111111); } + __host__ __device__ static constexpr f6_t Max() { return f6_t(binary_max_normal & 0b111111); } + __host__ __device__ static constexpr f6_t Lowest() + { + return f6_t(binary_lowest_normal & 0b111111); + } + __host__ __device__ static constexpr f6_t MinSubnorm() + { + return f6_t(binary_min_subnorm & 0b111111); + } + __host__ __device__ static constexpr f6_t MaxSubnorm() + { + return f6_t(binary_max_subnorm & 0b111111); + } + + __host__ __device__ static constexpr float DataMaxNorm() { return data_max_normal_number; } + __host__ __device__ static constexpr float DataMinSubnorm() + { + return data_min_subnormal_number; + } +}; + +template <> +struct NumericLimits +{ + static constexpr uint8_t binary_min_normal = 0x08; // 0b001000 + static constexpr uint8_t binary_max_normal = 0x1F; // 0b011111 + static constexpr uint8_t binary_lowest_normal = 0x3F; // 0b111111 + static constexpr uint8_t binary_min_subnorm = 0x01; // 0b000001 + static constexpr uint8_t binary_max_subnorm = 0x03; // 0b000011 + + static constexpr float data_max_normal_number = 28; + static constexpr float data_min_subnormal_number = 0.0625; + + __host__ __device__ static constexpr bf6_t Min() { return bf6_t(binary_min_normal); } + __host__ __device__ static constexpr bf6_t Max() { return bf6_t(binary_max_normal); } + __host__ __device__ static constexpr bf6_t Lowest() { return bf6_t(binary_lowest_normal); } + __host__ __device__ static constexpr bf6_t MinSubnorm() { return bf6_t(binary_min_subnorm); } + __host__ __device__ static constexpr bf6_t MaxSubnorm() { return bf6_t(binary_max_subnorm); } + + __host__ __device__ static constexpr float DataMaxNorm() { return data_max_normal_number; } + __host__ __device__ static constexpr float DataMinSubnorm() + { + return data_min_subnormal_number; + } +}; + +template <> +struct NumericLimits +{ + static constexpr e8m0_bexp_t binary_min = 0x00; // 0b00000000 + static constexpr e8m0_bexp_t binary_max = 0xFE; // 0b11111110 + static constexpr e8m0_bexp_t binary_qnan = 0xFF; // 0b11111111 + static constexpr e8m0_bexp_t binary_1 = 0x7F; // 0b01111111 + static constexpr e8m0_bexp_t binary_2 = 0x80; // 0b10000000 + static constexpr e8m0_bexp_t binary_3 = 0x82; // 0b10000010 + static constexpr e8m0_bexp_t binary_135 = 0x87; // 0b10000111 + static constexpr e8m0_bexp_t binary_142 = 0x8E; // 0b10001110 + + __host__ __device__ static constexpr e8m0_bexp_t Min() { return e8m0_bexp_t(binary_min); } + __host__ __device__ static constexpr e8m0_bexp_t Max() { return e8m0_bexp_t(binary_max); } + __host__ __device__ static constexpr e8m0_bexp_t QuietNaN() { return e8m0_bexp_t(binary_qnan); } + __host__ __device__ static constexpr e8m0_bexp_t Binary_1() { return e8m0_bexp_t(binary_1); } + __host__ __device__ static constexpr e8m0_bexp_t Binary_2() { return e8m0_bexp_t(binary_2); } + __host__ __device__ static constexpr e8m0_bexp_t Binary_3() { return e8m0_bexp_t(binary_3); } + __host__ __device__ static constexpr e8m0_bexp_t Binary_135() + { + return e8m0_bexp_t(binary_135); + } + __host__ __device__ static constexpr e8m0_bexp_t Binary_142() + { + return e8m0_bexp_t(binary_142); + } +}; template struct NumericUtils @@ -2367,6 +3092,7 @@ struct NumericUtils static constexpr uint32_t NegInf = 0xFF800000; static constexpr uint32_t NaN = 0x7F800001; static constexpr uint32_t Neg0 = 0x80000000; + static constexpr bool has_inf = true; using bitwise_type = uint32_t; }; @@ -2384,9 +3110,19 @@ struct NumericUtils static constexpr uint32_t NegInf = 0xFC00; static constexpr uint32_t NaN = 0x7C01; static constexpr uint32_t Neg0 = 0x8000; + static constexpr bool has_inf = true; using bitwise_type = uint16_t; }; +template <> +struct NumericUtils +{ + static constexpr int exp = 8; + static constexpr int mant = 7; + static constexpr int bias = 128; // negative zero nan mode + // static constexpr int bias = 127; // ieee mode +}; + template <> struct NumericUtils { @@ -2394,6 +3130,7 @@ struct NumericUtils static constexpr int mant = 3; static constexpr int bias = 8; // negative zero nan mode // static constexpr int bias = 7; // ieee mode + static constexpr bool has_inf = false; }; template <> @@ -2403,6 +3140,7 @@ struct NumericUtils static constexpr int mant = 2; static constexpr int bias = 16; // negative zero nan mode // static constexpr int bias = 15; // ieee mode + static constexpr bool has_inf = false; }; template <> struct NumericUtils @@ -2421,11 +3159,109 @@ struct NumericUtils }; template <> -struct NumericUtils +struct NumericUtils +{ + static constexpr int exp = 2; + static constexpr int mant = 1; + static constexpr int bias = 1; + static constexpr uint32_t sr_shift = 10; + + static constexpr int unbiased_exp_min = 0; + static constexpr int unbiased_exp_max = 2; + static constexpr int biased_exp_min = 1; + static constexpr int biased_exp_max = 3; + + static constexpr uint8_t positive_zero_mask = 0b0000; + static constexpr uint8_t negative_zero_mask = 0b1000; + + static constexpr uint8_t one_mask = 0b0010; + static constexpr uint8_t set_sign_mask = 0b0111; + + static constexpr uint8_t data_max_positive_normal_mask = 0b0111; + static constexpr uint8_t data_max_negative_normal_mask = 0b1111; + + static constexpr uint8_t data_max_positive_subnormal_mask = 0b0001; + static constexpr uint8_t data_max_negative_subnormal_mask = 0b1001; + + static constexpr bool has_inf = false; + + using bitwise_type = uint8_t; +}; + +template <> +struct NumericUtils +{ + static constexpr int exp = 2; + static constexpr int mant = 3; + static constexpr int bias = 1; + static constexpr uint32_t sr_shift = 12; + + static constexpr int unbiased_exp_min = 0; + static constexpr int unbiased_exp_max = 2; + static constexpr int biased_exp_min = 1; + static constexpr int biased_exp_max = 3; + + static constexpr uint8_t positive_zero_mask = 0b000000; + static constexpr uint8_t negative_zero_mask = 0b100000; + + static constexpr uint8_t set_sign_mask = 0b011111; + + static constexpr uint8_t data_max_positive_normal_mask = 0b011111; + static constexpr uint8_t data_max_negative_normal_mask = 0b111111; + + static constexpr uint8_t data_max_positive_subnormal_mask = 0b000111; + static constexpr uint8_t data_max_negative_subnormal_mask = 0b100111; + + static constexpr bool has_inf = false; + static constexpr bool has_nan = false; + static constexpr bool has_zero = true; + + using bitwise_type = uint8_t; +}; + +template <> +struct NumericUtils +{ + static constexpr int exp = 3; + static constexpr int mant = 2; + static constexpr int bias = 3; + static constexpr uint32_t sr_shift = 11; + + static constexpr int unbiased_exp_min = -2; + static constexpr int unbiased_exp_max = 4; + static constexpr int biased_exp_min = 1; + static constexpr int biased_exp_max = 7; + + static constexpr uint8_t positive_zero_mask = 0b000000; + static constexpr uint8_t negative_zero_mask = 0b100000; + + static constexpr uint8_t set_sign_mask = 0b011111; + + static constexpr uint8_t data_max_positive_normal_mask = 0b011111; + static constexpr uint8_t data_max_negative_normal_mask = 0b111111; + + static constexpr uint8_t data_max_positive_subnormal_mask = 0b000011; + static constexpr uint8_t data_max_negative_subnormal_mask = 0b100011; + + static constexpr bool has_inf = false; + static constexpr bool has_nan = false; + static constexpr bool has_zero = true; + + using bitwise_type = uint8_t; +}; + +template <> +struct NumericUtils { static constexpr int exp = 8; - static constexpr int mant = 7; - static constexpr int bias = 128; // negative zero nan mode - // static constexpr int bias = 127; // ieee mode + static constexpr int mant = 0; + static constexpr int bias = 127; + + static constexpr int unbiased_exp_min = -127; + static constexpr int unbiased_exp_max = 127; + static constexpr int biased_exp_min = 0; + static constexpr int biased_exp_max = 254; + + using bitwise_type = uint8_t; }; } // namespace ck diff --git a/include/ck/utility/debug.hpp b/include/ck/utility/debug.hpp index 03c4e16dd6..2b247cc02a 100644 --- a/include/ck/utility/debug.hpp +++ b/include/ck/utility/debug.hpp @@ -1,8 +1,9 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #ifndef UTILITY_DEBUG_HPP #define UTILITY_DEBUG_HPP +#include "type.hpp" namespace ck { namespace debug { diff --git a/include/ck/utility/e8m0.hpp b/include/ck/utility/e8m0.hpp new file mode 100644 index 0000000000..a692f533f8 --- /dev/null +++ b/include/ck/utility/e8m0.hpp @@ -0,0 +1,80 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/type.hpp" + +namespace ck { + +/** + * @brief Unsigned representation of a conventional biased Float32 exponent. + * + * bias = 127; + * + * E8M0_1 = 0b01111111; => 2^(127-127) = 1 + * E8M0_2 = 0b10000000; => 2^(128-127) = 2^1 = 2 + * E8M0_3 = 0b10000010; => 2^(130-127) = 2^3 = 8 + * E8M0_135 = 0b10000111; => 2^(135-127) = 2^8 = 256 + * E8M0_142 = 0b10001110; => 2^(142-127) = 2^15 = 32768 + * E8M0_MIN = 0b00000000; => 2^-127 + * E8M0_MAX = 0b11111110; => 2^127 + * E8M0_NAN = 0b11111111; => NaN + */ +struct e8m0_bexp_t +{ + using type = uint8_t; + type data; + + constexpr static type bias = 127; + constexpr static type nan_mask = 0xFF; + + __host__ __device__ constexpr e8m0_bexp_t() : data{type{}} {} + __host__ __device__ constexpr e8m0_bexp_t(type init) : data{init} {} + __host__ __device__ constexpr e8m0_bexp_t(int init) : data{static_cast(init & nan_mask)} + { + } + __host__ __device__ explicit constexpr e8m0_bexp_t(float scale) + : data{static_cast((bit_cast(scale) & (nan_mask << 23)) >> 23)} + { + } + + __host__ __device__ explicit constexpr operator float() const + { + if(data == nan_mask || data == 0) + { + uint32_t bits = data << 1; + bits |= 1; + bits <<= 22; + return bit_cast(bits); + } + else + { + uint32_t bits = data << 23; + return bit_cast(bits); + } + } + + __host__ __device__ constexpr bool operator==(const e8m0_bexp_t& other) const + { + // strict IEEE compliance for NaN + return data == other.data && data != nan_mask; + } + + __host__ __device__ constexpr bool is_nan() const { return data == nan_mask; } +}; + +namespace utils { + +template +__host__ __device__ inline int get_exponent_value(T x); + +template <> +__host__ __device__ inline int get_exponent_value(e8m0_bexp_t x) +{ + return x.data; +} + +} // namespace utils + +} // namespace ck diff --git a/include/ck/utility/enable_if.hpp b/include/ck/utility/enable_if.hpp index c0a3c99f1f..6ba63fc761 100644 --- a/include/ck/utility/enable_if.hpp +++ b/include/ck/utility/enable_if.hpp @@ -1,14 +1,31 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once namespace ck { +#ifndef CK_CODE_GEN_RTC template using enable_if = std::enable_if; template using enable_if_t = typename std::enable_if::type; +#else +template +struct enable_if +{ +}; + +template +struct enable_if +{ + using type = T; +}; + +template +using enable_if_t = typename enable_if::type; +#endif + } // namespace ck diff --git a/include/ck/utility/env.hpp b/include/ck/utility/env.hpp index 6455402dcb..809f302f74 100644 --- a/include/ck/utility/env.hpp +++ b/include/ck/utility/env.hpp @@ -1,6 +1,7 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +#ifndef CK_CODE_GEN_RTC #pragma once #include @@ -183,3 +184,4 @@ void UpdateEnvVar(EnvVar, const std::string_view& val) } } // namespace ck +#endif diff --git a/include/ck/utility/functional.hpp b/include/ck/utility/functional.hpp index 91797d2409..cd48ed1747 100644 --- a/include/ck/utility/functional.hpp +++ b/include/ck/utility/functional.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -120,11 +120,11 @@ constexpr auto conditional_expr(X&& x, Y&& y) { if constexpr(predicate) { - return std::forward(x); + return ck::forward(x); } else { - return std::forward(y); + return ck::forward(y); } } diff --git a/include/ck/utility/functional4.hpp b/include/ck/utility/functional4.hpp index b5f3df8d7c..8e86a296dc 100644 --- a/include/ck/utility/functional4.hpp +++ b/include/ck/utility/functional4.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #ifndef CK_FUNCTIONAL4_HPP #define CK_FUNCTIONAL4_HPP @@ -21,7 +21,7 @@ struct unpack_impl> template __host__ __device__ constexpr auto operator()(F&& f, X&& x) const { - return std::forward(f)(std::forward(x).At(Number{})...); + return ck::forward(f)(ck::forward(x).At(Number{})...); } }; @@ -35,8 +35,8 @@ struct unpack2_impl, Sequence> template __host__ __device__ constexpr auto operator()(F&& f, X&& x, Y&& y) const { - return std::forward(f)(std::forward(x).At(Number{})..., - std::forward(y).At(Number{})...); + return ck::forward(f)(ck::forward(x).At(Number{})..., + ck::forward(y).At(Number{})...); } }; @@ -47,7 +47,7 @@ __host__ __device__ constexpr auto unpack(F&& f, X&& x) { using X_ = remove_reference_t; return detail::unpack_impl::type>{}( - std::forward(f), std::forward(x)); + ck::forward(f), ck::forward(x)); } // TODO: properly implement unpack that takes any number of containers @@ -58,7 +58,7 @@ __host__ __device__ constexpr auto unpack2(F&& f, X&& x, Y&& y) using Y_ = remove_reference_t; return detail::unpack2_impl::type, typename arithmetic_sequence_gen<0, Y_::Size(), 1>::type>{}( - std::forward(f), std::forward(x), std::forward(y)); + ck::forward(f), ck::forward(x), ck::forward(y)); } } // namespace ck diff --git a/include/ck/utility/integral_constant.hpp b/include/ck/utility/integral_constant.hpp index 376070eb3d..75f35d762c 100644 --- a/include/ck/utility/integral_constant.hpp +++ b/include/ck/utility/integral_constant.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -48,4 +48,9 @@ __host__ __device__ constexpr auto operator%(integral_constant, integral_ return integral_constant{}; } +template +using bool_constant = integral_constant; + +using true_type = bool_constant; +using false_type = bool_constant; } // namespace ck diff --git a/include/ck/utility/is_detected.hpp b/include/ck/utility/is_detected.hpp index 7a324a6c45..a700fcfff1 100644 --- a/include/ck/utility/is_detected.hpp +++ b/include/ck/utility/is_detected.hpp @@ -1,22 +1,24 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once +#include "ck/utility/integral_constant.hpp" + namespace ck { namespace detail { template class Op, class... Args> struct detector { - using value_t = std::false_type; + using value_t = integral_constant; using type = Default; }; template class Op, class... Args> -struct detector>, Op, Args...> +struct detector>, Op, Args...> { - using value_t = std::true_type; + using value_t = integral_constant; using type = Op; }; } // namespace detail @@ -32,12 +34,12 @@ template