mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 19:28:33 +00:00
Merge remote-tracking branch 'origin/develop' into ck_tile/refactor_moe_sorting
This commit is contained in:
@@ -14,6 +14,7 @@ trigger:
|
||||
branches:
|
||||
include:
|
||||
- develop
|
||||
- amd-develop
|
||||
paths:
|
||||
exclude:
|
||||
- .github
|
||||
|
||||
@@ -103,7 +103,7 @@ if(DPP_KERNELS)
|
||||
endif()
|
||||
option(CK_USE_CODEGEN "Enable codegen library" OFF)
|
||||
if(CK_USE_CODEGEN)
|
||||
add_definitions(-DCK_USE_CODEGEN)
|
||||
add_definitions(-DCK_USE_CODEGEN)
|
||||
endif()
|
||||
|
||||
option(CK_TIME_KERNEL "Enable kernel time tracking" ON)
|
||||
@@ -196,17 +196,20 @@ if (SUPPORTED_GPU_TARGETS MATCHES "gfx9")
|
||||
add_definitions(-DCK_USE_XDL)
|
||||
set(CK_USE_XDL "ON")
|
||||
endif()
|
||||
if (SUPPORTED_GPU_TARGETS MATCHES "gfx94")
|
||||
if (SUPPORTED_GPU_TARGETS MATCHES "gfx94" OR SUPPORTED_GPU_TARGETS MATCHES "gfx95")
|
||||
message("Enabling FP8 gemms on native architectures")
|
||||
add_definitions(-DCK_USE_GFX94)
|
||||
set(CK_USE_GFX94 "ON")
|
||||
endif()
|
||||
if (SUPPORTED_GPU_TARGETS MATCHES "gfx95")
|
||||
add_definitions(-DCK_USE_AMD_MFMA_GFX950)
|
||||
endif()
|
||||
if (SUPPORTED_GPU_TARGETS MATCHES "gfx11" OR SUPPORTED_GPU_TARGETS MATCHES "gfx12")
|
||||
message("Enabling WMMA instances")
|
||||
add_definitions(-DCK_USE_WMMA)
|
||||
set(CK_USE_WMMA "ON")
|
||||
endif()
|
||||
if (SUPPORTED_GPU_TARGETS MATCHES "gfx12")
|
||||
if (SUPPORTED_GPU_TARGETS MATCHES "gfx12" OR SUPPORTED_GPU_TARGETS MATCHES "gfx950")
|
||||
add_definitions(-DCK_USE_OCP_FP8)
|
||||
set(CK_USE_OCP_FP8 "ON")
|
||||
endif()
|
||||
@@ -214,6 +217,10 @@ if (SUPPORTED_GPU_TARGETS MATCHES "gfx90a" OR SUPPORTED_GPU_TARGETS MATCHES "gfx
|
||||
add_definitions(-DCK_USE_FNUZ_FP8)
|
||||
set(CK_USE_FNUZ_FP8 "ON")
|
||||
endif()
|
||||
if (SUPPORTED_GPU_TARGETS MATCHES "gfx950")
|
||||
add_definitions(-DCK_USE_NATIVE_MX_SUPPORT)
|
||||
set(CK_USE_NATIVE_MX_SUPPORT "ON")
|
||||
endif()
|
||||
|
||||
option(CK_USE_FP8_ON_UNSUPPORTED_ARCH "Enable FP8 GEMM instances on older architectures" OFF)
|
||||
if(CK_USE_FP8_ON_UNSUPPORTED_ARCH AND (SUPPORTED_GPU_TARGETS MATCHES "gfx90a" OR SUPPORTED_GPU_TARGETS MATCHES "gfx908"))
|
||||
|
||||
4
Jenkinsfile
vendored
4
Jenkinsfile
vendored
@@ -795,8 +795,8 @@ pipeline {
|
||||
description: "Run the ck_tile FMHA tests (default: OFF)")
|
||||
booleanParam(
|
||||
name: "RUN_CK_TILE_GEMM_TESTS",
|
||||
defaultValue: false,
|
||||
description: "Run the ck_tile GEMM tests (default: OFF)")
|
||||
defaultValue: true,
|
||||
description: "Run the ck_tile GEMM tests (default: ON)")
|
||||
booleanParam(
|
||||
name: "BUILD_INSTANCES_ONLY",
|
||||
defaultValue: false,
|
||||
|
||||
126
client_example/01_gemm/README.md
Normal file
126
client_example/01_gemm/README.md
Normal file
@@ -0,0 +1,126 @@
|
||||
[Back to supported operations](../../../include/ck/README.md)
|
||||
# Composable Kernel GEMM
|
||||
|
||||
## GEMM
|
||||
General matrix multiplications operation. In CK GEMM operation is called as `DeviceGemm` and requires following types as template parameters:
|
||||
|
||||
* **ALayout** - A matrix layout (RowMajor/ColumnMajor).
|
||||
* **BLayout** - B matrix layout (RowMajor/ColumnMajor).
|
||||
* **CLayout** - B matrix layout (RowMajor/ColumnMajor).
|
||||
* **ADataType** - A matrix data type.
|
||||
* **BDataType** - B matrix data type.
|
||||
* **CDataType** - B matrix data type.
|
||||
* **AElementwiseOperation** - Fused operation on tensor A before GEMM.
|
||||
* **BElementwiseOperation** - Fused operation on tensor B before GEMM.
|
||||
* **CElementwiseOperation** - Fused operation on tensor C after GEMM.
|
||||
|
||||
For matrices with large K dimension `DeviceGemmSplitK` implementation is available. This implementation allows user to split K dimension between work groups. This implementation uses `AtomicAdd` operation on global memory, thus need to zero-out output buffer for correct results.
|
||||
|
||||
For fused operations with additional tensor there are `DeviceGemmMultipleABD` or `DeviceGemmMultipleD` operation which require following parameters:
|
||||
* **DsLayout** - layouts for additional tensors for fused operations.
|
||||
* **DsDataType** - data types for additional tensors for fused operations.
|
||||
|
||||
For `DeviceGemmMultipleABD` **ALayout**, **BLayout**, **ADataType** and **BDataType** user should pass a tuple.
|
||||
|
||||
List of the device operations in CK:
|
||||
|
||||
* **DeviceGemmDl** - Device operation with DL instructions.
|
||||
* **DeviceGemmDpp** - Device operation with DL instructions with DPP instructions during data load.
|
||||
* **DeviceGemmWmma_CShuffle** - Device operation with WMMA instructions with CShuffle optimization for more optimized data store.
|
||||
* **DeviceGemm_Xdl_CShuffle_LdsDirectLoad** - Device operation with XDL instructions and CShuffle optimization for more optimized data store and direct load from global memory to shared memory.
|
||||
* **DeviceGemm_Xdl_CShuffle** - Device operation with XDL instructions with CShuffle optimization for more optimized data store.
|
||||
* **DeviceGemm_Xdl_CShuffleV2** - Device operation with XDL instructions with CShuffle optimization for more optimized data store. GEMM pipeline has been optimized compared to **DeviceGemm_Xdl_CShuffle**.
|
||||
* **DeviceGemmXdlSkipBLds** - Device operation with XDL instructions. Load to shared memory has been skiped for B matrix.
|
||||
* **DeviceGemm_Xdl_WaveletModel_CShuffle** - Device operation with XDL instructions with CShuffle optimization for more optimized data store. Producer and consumer scheme cooperation between waves in workgroup.
|
||||
* **DeviceGemmXdl** - Device operation with XDL instructions.
|
||||
|
||||
Table of supported cases by instance factory with XDL instruction for Row/Row/Row, Row/Column/Row, Column/Row/Row or Column/Column/Row:
|
||||
|
||||
| |Is supported|
|
||||
|-------|---|
|
||||
|bf16|✓|
|
||||
|fp16|✓|
|
||||
|fp32|✓|
|
||||
|int8|✓|
|
||||
|fp8 |✓|
|
||||
|
||||
Table of supported cases by instance factory with WMMA instruction for Row/Row/Row, Row/Column/Row, Column/Row/Row or Column/Column/Row:
|
||||
|
||||
| |Is supported|
|
||||
|-------|---|
|
||||
|bf16|✓|
|
||||
|fp16|✓|
|
||||
|fp32|✗|
|
||||
|int8|✓|
|
||||
|fp8 |✗|
|
||||
|
||||
Table of supported cases by instance factory with DL instruction for Row/Row/Row, Row/Column/Row, Column/Row/Row or Column/Column/Row:
|
||||
|
||||
| |Is supported|
|
||||
|-------|---|
|
||||
|bf16|✗|
|
||||
|fp16|✓|
|
||||
|fp32|✓|
|
||||
|int8|✓|
|
||||
|fp8 |✗|
|
||||
|
||||
Table of supported cases by instance factory with fused output elementwise operation:
|
||||
|
||||
* **B Matrix Multiply + Add + Gelu** - bf16 (int8 for B matrix)
|
||||
* **B Matrix Multiply + Add** - bf16 (int8 for B matrix)
|
||||
* **B Matrix Multiply + Gelu** - bf16 (int8 for B matrix)
|
||||
* **B Matrix Multiply** - bf16 (int8 for B matrix)
|
||||
|
||||
* **Add + Add + Gelu** - fp16
|
||||
* **Add + Gelu** - fp16, bf16 (int8 for B matrix) for Row/Column/Row
|
||||
* **Multiply** - fp16
|
||||
* **Add + Multiply** - fp16
|
||||
* **Add + Relu** - fp16 (int8 for B matrix) for Row/Column/Row, bf16 (int8 for B matrix) for Row/Column/Row
|
||||
* **Add + Silu** - fp16 (int8 for B matrix) for Row/Column/Row, bf16 (int8 for B matrix) for Row/Column/Row
|
||||
* **Add** - fp16 (int8 for B matrix) for Row/Column/Row, bf16 (int8 for B matrix) for Row/Column/Row
|
||||
* **Bilinear** - fp16, int8
|
||||
* **Gelu** - fp16
|
||||
* **Multiply + Add** - fp16 for Row/Column/Row and Row/Row/Row, fp16 (int8 for B matrix, fp32 for Bias) for Row/Column/Row and Row/Row/Row,
|
||||
* **Quantization** - int8
|
||||
|
||||
## GEMM V2 (Universal GEMM)
|
||||
General matrix multiplications operation optimized for MI300 series. Operation is called as `DeviceGemmV2` and requires following types as template parameters:
|
||||
|
||||
* **ALayout** - A matrix layout (RowMajor/ColumnMajor).
|
||||
* **BLayout** - B matrix layout (RowMajor/ColumnMajor).
|
||||
* **CLayout** - B matrix layout (RowMajor/ColumnMajor).
|
||||
* **ADataType** - A matrix data type.
|
||||
* **BDataType** - B matrix data type.
|
||||
* **CDataType** - B matrix data type.
|
||||
* **AElementwiseOperation** - Fused operation on tensor A before GEMM.
|
||||
* **BElementwiseOperation** - Fused operation on tensor B before GEMM.
|
||||
* **CElementwiseOperation** - Fused operation on tensor C after GEMM.
|
||||
|
||||
This implementation allows user to split K dimension between work groups. This implementation requires AtomicAdd operation on global memory (output buffer must be set to zeroes if splitK parameter is larger than one).
|
||||
|
||||
List of the device operations for in CK:
|
||||
|
||||
* **DeviceGemm_Xdl_CShuffleV3** - Device operation with XDL instructions with CShuffle optimization for more optimized data store.
|
||||
* **DeviceGemm_Xdl_CShuffleV3R1** - Device operation with XDL instructions with CShuffle optimization for more optimized data store. This implementation perform reduction on splitted K dimension after GEMM instead of AtomicAdd instruction.
|
||||
|
||||
Table of supported cases by instance factory with XDL instruction for Row/Row/Row, Row/Column/Row, Column/Row/Row or Column/Column/Row:
|
||||
|
||||
| |Is supported|
|
||||
|-------|---|
|
||||
|bf16|✓|
|
||||
|fp16|✓|
|
||||
|fp32|✗|
|
||||
|int8|✗|
|
||||
|fp8 (C bf16)|✓|
|
||||
|fp16 (A fp8)|✓|
|
||||
|fp16 (B fp8)|✓|
|
||||
|
||||
## Others
|
||||
|
||||
* **DeviceGemm_dequantB** - GEMM with dequantization (implemented with WMMA instructions).
|
||||
* **DeviceGemmMultipleD_ABScale** - GEMM with scale for A and B matrix.
|
||||
* **DeviceGemmMultipleDLayernorm** - GEMM fused with layernorm.
|
||||
* **DeviceGemmMultipleDMultipleR** - GEMM fused with reductions and custom global reductions operators.
|
||||
* **DeviceGemmReduce** - GEMM fused with reduction.
|
||||
* **DeviceGemm_Streamk_V2** - GEMM stream K implementation. Implementation allows to use reduction instead of AtomicAdd.
|
||||
* **DeviceGemmStreamK** - GEMM stream K implementation using AtomicAdd.
|
||||
68
client_example/07_grouped_convnd_fwd/README.md
Normal file
68
client_example/07_grouped_convnd_fwd/README.md
Normal file
@@ -0,0 +1,68 @@
|
||||
[Back to supported operations](../../../include/ck/README.md)
|
||||
# Composable Kernel Grouped Convolution
|
||||
|
||||
## Grouped Convolution Forward
|
||||
Grouped convolution operation for 1D, 2D or 3D spatial dimensions. Convolution utilizes GEMM kernel after tensor coordinate transform. In CK Grouped Convolution Forward operation is called as `DeviceGroupedConvFwdMultipleABD` and requires following types as template parameters:
|
||||
|
||||
* **NumDimSpatial** - number of spatial dimensions (1D, 2D, 3D).
|
||||
* **InLayout** - input layout (NHWGC, GNHWC, NGCHW).
|
||||
* **WeiLayout** - weight layout (GKYXC).
|
||||
* **DsLayout** - layouts for additional tensors for fused operations.
|
||||
* **OutLayout** - output layout (NHWGK, GNHWK, NGKHW).
|
||||
* **ADataType** - input data type. Pass tuple if there is fused operation with input.
|
||||
* **BDataType** - weight data type. Pass tuple if there is fused operation with weight.
|
||||
* **DsDataType** - data types for additional tensors for fused operations.
|
||||
* **EDataType** - Output data type.
|
||||
* **AElementwiseOperation** - fused operation on tensor A (input).
|
||||
* **BElementwiseOperation** - fused operation on tensor B (weight).
|
||||
* **CDEElementwiseOperation** - fused operation on tensor C (output).
|
||||
* **AComputeType** - compute data type of tensor A for mfma instruction (ADataType by default).
|
||||
* **BComputeType** - compute data type of tensor B for mfma instruction (AComputeType by default).
|
||||
|
||||
Grouped convolution forward support tensors larger than 2GB.
|
||||
|
||||
List of the device operations for grouped convolution forward in CK:
|
||||
|
||||
* **DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3** - Device operation with XDL instructions. Optimized for AMD Instinct MI300 series.
|
||||
* **DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle** - Device operation with XDL instructions and support of fused operations to input, weight and output.
|
||||
* **DeviceGroupedConvFwdMultipleD_Wmma_CShuffle** - Device operation with WMMA instructions.
|
||||
* **DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK** - Device operation with DL instructions.
|
||||
|
||||
Table of supported cases by instance factory with XDL instruction:
|
||||
|
||||
| |NHWGC/GKYXC/NHWGK|NGCHW/GKYXC/NGKHW|GNHWC/GKYXC/GNHWK|
|
||||
|-------|---|---|---|
|
||||
|bf16 |2D, 3D|2D|1D, 2D, 3D|
|
||||
|fp16 |2D, 3D|2D|1D, 2D, 3D|
|
||||
|fp32 |2D, 3D|2D|1D, 2D, 3D|
|
||||
|int8 |2D, 3D|2D|1D, 3D|
|
||||
|fp8 |3D|✗|✗|
|
||||
|bf8 |3D|✗|✗|
|
||||
|
||||
Table of supported cases by instance factory with WMMA instruction:
|
||||
|
||||
| |NHWGC/GKYXC/NHWGK|NGCHW/GKYXC/NGKHW|GNHWC/GKYXC/GNHWK|
|
||||
|-------|---|---|---|
|
||||
|fp16 |2D, 3D|✗|2D, 3D|
|
||||
|int8 |2D, 3D|✗|2D, 3D|
|
||||
|
||||
Table of supported cases by instance factory with DL instruction:
|
||||
|
||||
| |NHWGC/GKYXC/NHWGK|NGCHW/GKYXC/NGKHW|GNHWC/GKYXC/GNHWK|
|
||||
|-------|---|---|---|
|
||||
|bf16 |✗|✗|2D|
|
||||
|fp16 |✗|✗|2D|
|
||||
|fp32 |✗|✗|2D|
|
||||
|int8 |✗|✗|2D|
|
||||
|
||||
Table of supported cases by instance factory with fused elementwise operation:
|
||||
|
||||
* **Dynamic elementwise operation** - 2D/3D, NHWGC, bf16/fp16/fp32/int8
|
||||
* **Bilinear** - 3D, NHWGC, bf16/fp16/fp32/int8
|
||||
* **ConvInvScale** - 3D, NHWGC, fp8
|
||||
* **ConvScale** - 3D, NHWGC, fp8/bf8
|
||||
* **ConvScale + Add** - 3D, NHWGC, fp8
|
||||
* **ConvScale + Relu** - 3D, NHWGC, fp8
|
||||
* **Scale** - 3D, NHWGC, bf16/fp16/fp32/int8
|
||||
* **Scale + Add (for A and B)** - 3D, NHWGC, bf16/fp16/fp32/int8
|
||||
* **Scale + Add + Scale + Add + Relu** - 3D, NHWGC, bf16/fp16/fp32/int8
|
||||
48
client_example/10_grouped_convnd_bwd_data/README.md
Normal file
48
client_example/10_grouped_convnd_bwd_data/README.md
Normal file
@@ -0,0 +1,48 @@
|
||||
[Back to supported operations](../../../include/ck/README.md)
|
||||
# Composable Kernel Grouped Convolution
|
||||
|
||||
## Grouped Convolution Backward Data
|
||||
|
||||
Grouped convolution operation for 1D, 2D or 3D spatial dimensions. Convolution utilizes GEMM kernel after tensor coordinate transform. In CK Grouped Convolution Backward Data operation is called as `DeviceGroupedConvBwdDataMultipleD` and requires following types as template parameters:
|
||||
|
||||
* **NumDimSpatial** - number of spatial dimensions (1D, 2D, 3D).
|
||||
* **ALayout** - output layout (NHWGK, GNHWK, NGKHW).
|
||||
* **BLayout** - weight layout (GKYXC).
|
||||
* **DsLayout** - layouts for additional tensors for fused operations.
|
||||
* **ELayout** - input layout (NHWGC, GNHWC, NGCHW).
|
||||
* **ADataType** - output data type.
|
||||
* **BDataType** - weight data type.
|
||||
* **DsDataType** - data types for additional tensors for fused operations.
|
||||
* **EDataType** - input data type.
|
||||
* **AElementwiseOperation** - fused operation on tensor A (output).
|
||||
* **BElementwiseOperation** - fused operation on tensor B (weight).
|
||||
* **CDEElementwiseOperation** - fused operation on tensor C (input).
|
||||
* **AComputeType** - compute data type of tensor A for mfma instruction (ADataType by default).
|
||||
* **BComputeType** - compute data type of tensor B for mfma instruction (AComputeType by default).
|
||||
|
||||
Grouped convolution backward data supports tensors larger than 2GB (except when image is larger than 2GB).
|
||||
|
||||
List of the device operations for grouped convolution backward data in CK:
|
||||
|
||||
* **DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1** - Device operation with XDL instructions and support of fused operations to input.
|
||||
* **DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle** - Device operation with WMMA instructions.
|
||||
|
||||
Table of supported cases by instance factory with XDL instruction:
|
||||
|
||||
| |NHWGC/GKYXC/NHWGK|NGCHW/GKYXC/NGKHW|GNHWC/GKYXC/GNHWK|
|
||||
|-------|---|---|---|
|
||||
|bf16|2D, 3D|✗|2D, 3D|
|
||||
|fp16 |2D, 3D|✗|2D, 3D|
|
||||
|fp32 |2D, 3D|✗|2D, 3D|
|
||||
|
||||
Table of supported cases by instance factory with WMMA instruction:
|
||||
|
||||
| |NHWGC/GKYXC/NHWGK|NGCHW/GKYXC/NGKHW|GNHWC/GKYXC/GNHWK|
|
||||
|-------|---|---|---|
|
||||
|fp16 |2D, 3D|✗|2D, 3D|
|
||||
|int8 |2D, 3D|✗|2D, 3D|
|
||||
|
||||
Table of supported cases by instance factory with fused elementwise operation:
|
||||
|
||||
* **Bilinear** - 3D, NHWGC, bf16/fp16/fp32
|
||||
* **Scale** - 3D, NHWGC, bf16/fp16/fp32
|
||||
62
client_example/11_grouped_conv_bwd_weight/README.md
Normal file
62
client_example/11_grouped_conv_bwd_weight/README.md
Normal file
@@ -0,0 +1,62 @@
|
||||
[Back to supported operations](../../../include/ck/README.md)
|
||||
# Composable Kernel Grouped Convolution
|
||||
|
||||
## Grouped Convolution Backward Weight
|
||||
|
||||
Grouped convolution operation for 1D, 2D or 3D spatial dimensions. Convolution utilizes GEMM kernel after tensor coordinate transform. Backward weight version uses splitK feature (due to large GEMM K dimension). In CK Grouped Convolution Backward Weight operation is called as `DeviceGroupedConvBwdWeight` and requires following types as template parameters:
|
||||
|
||||
* **NumDimSpatial** - number of spatial dimensions (1D, 2D, 3D).
|
||||
* **InLayout** - input layout (NHWGC, GNHWC, NGCHW).
|
||||
* **WeiLayout** - weight layout (GKYXC).
|
||||
* **OutLayout** - output layout (NHWGK, GNHWK, NGKHW).
|
||||
* **InDataType** - input data type.
|
||||
* **WeiDataType** - weight data type.
|
||||
* **OutDataType** - output data type.
|
||||
* **InElementwiseOperation** - fused operation on tensor input.
|
||||
* **WeiElementwiseOperation** - fused operation on tensor weight.
|
||||
* **OutElementwiseOperation** - fused operation on tensor output.
|
||||
* **ComputeTypeA** - compute data type of tensor A for mfma instruction (ADataType by default).
|
||||
* **ComputeTypeB** - compute data type of tensor B for mfma instruction (ComputeTypeA by default).
|
||||
|
||||
For fused operations with additional tensor there is `DeviceGroupedConvBwdWeightMultipleD` operation which requires following parameters:
|
||||
* **DsLayout** - layouts for additional tensors for fused operations.
|
||||
* **DsDataType** - data types for additional tensors for fused operations.
|
||||
|
||||
Grouped convolution backward weight doesn't supports tensors larger than 2GB.
|
||||
|
||||
List of the device operations for grouped convolution backward weight in CK:
|
||||
|
||||
* **DeviceGroupedConvBwdWeight_Xdl_CShuffle** - Device operation with XDL instructions.
|
||||
* **DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle** - Device operation with XDL instructions. Optimized for small C or K.
|
||||
* **DeviceGroupedConvBwdWeight_Wmma_CShuffle** - Device operation with WMMA instructions.
|
||||
* **DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle** - Device operation with XDL instructions and support of fused operations to output.
|
||||
* **DeviceGroupedConvBwdWeight_Dl** - Device operation with DL instructions.
|
||||
|
||||
Table of supported cases by instance factory with XDL instruction:
|
||||
|
||||
| |NHWGC/GKYXC/NHWGK|NGCHW/GKYXC/NGKHW|GNHWC/GKYXC/GNHWK|
|
||||
|-------|---|---|---|
|
||||
|bf16|2D, 3D|✗|✗|
|
||||
|bf16(fp32 for weight)|2D, 3D|✗|1D, 2D, 3D|
|
||||
|fp16 |2D, 3D|✗|1D, 2D, 3D|
|
||||
|fp32 |2D, 3D|✗|1D, 2D, 3D|
|
||||
|
||||
Table of supported cases by instance factory with WMMA instruction:
|
||||
|
||||
| |NHWGC/GKYXC/NHWGK|NGCHW/GKYXC/NGKHW|GNHWC/GKYXC/GNHWK|
|
||||
|-------|---|---|---|
|
||||
|fp16 |3D|✗|3D|
|
||||
|int8 |3D|✗|3D|
|
||||
|
||||
Table of supported cases by instance factory with DL instruction:
|
||||
|
||||
| |NHWGC/GKYXC/NHWGK|NGCHW/GKYXC/NGKHW|GNHWC/GKYXC/GNHWK|
|
||||
|-------|---|---|---|
|
||||
|bf16(fp32 for weight)|1D, 2D, 3D|✗|1D, 2D, 3D|
|
||||
|fp16 |1D, 2D, 3D|✗|1D, 2D, 3D|
|
||||
|fp32 |1D, 2D, 3D|✗|1D, 2D, 3D|
|
||||
|
||||
Table of supported cases by instance factory with fused elementwise operation:
|
||||
|
||||
* **Bilinear** - 3D, NHWGC, bf16(fp32 for weight)/fp16/fp32
|
||||
* **Scale** - 3D, NHWGC, bf16(fp32 for weight)/fp16/fp32
|
||||
@@ -56,7 +56,7 @@ if (GPU_TARGETS)
|
||||
add_definitions(-DCK_USE_WMMA)
|
||||
set(CK_USE_WMMA "ON")
|
||||
endif()
|
||||
if (GPU_TARGETS MATCHES "gfx12")
|
||||
if (GPU_TARGETS MATCHES "gfx12" OR GPU_TARGETS MATCHES "gfx950")
|
||||
add_definitions(-DCK_USE_OCP_FP8)
|
||||
set(CK_USE_OCP_FP8 "ON")
|
||||
endif()
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <functional>
|
||||
#include <iostream>
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/host/headers.hpp"
|
||||
#include "ck_headers.hpp"
|
||||
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/host/types.hpp"
|
||||
#include "ck/host/stringutils.hpp"
|
||||
#include <algorithm>
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/host/device_gemm_multiple_d/problem.hpp"
|
||||
#include "ck/host/device_gemm_multiple_d/operation.hpp"
|
||||
#include "ck/host/headers.hpp"
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_op.hpp"
|
||||
#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_problem.hpp"
|
||||
#include "ck/host/headers.hpp"
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_op.hpp"
|
||||
#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_problem.hpp"
|
||||
#include "ck/host/headers.hpp"
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_op.hpp"
|
||||
#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_problem.hpp"
|
||||
#include "ck/host/headers.hpp"
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_op.hpp"
|
||||
#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_problem.hpp"
|
||||
#include "ck/host/headers.hpp"
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#ifndef GUARD_HOST_TEST_RTC_INCLUDE_RTC_COMPILE_KERNEL
|
||||
#define GUARD_HOST_TEST_RTC_INCLUDE_RTC_COMPILE_KERNEL
|
||||
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#ifndef GUARD_HOST_TEST_RTC_INCLUDE_RTC_HIP
|
||||
#define GUARD_HOST_TEST_RTC_INCLUDE_RTC_HIP
|
||||
|
||||
#include <hip/hip_runtime_api.h>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
|
||||
namespace rtc {
|
||||
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#ifndef GUARD_HOST_TEST_RTC_INCLUDE_RTC_KERNEL
|
||||
#define GUARD_HOST_TEST_RTC_INCLUDE_RTC_KERNEL
|
||||
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#ifndef GUARD_HOST_TEST_RTC_INCLUDE_RTC_MANAGE_POINTER
|
||||
#define GUARD_HOST_TEST_RTC_INCLUDE_RTC_MANAGE_POINTER
|
||||
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#ifndef GUARD_HOST_TEST_RTC_INCLUDE_RTC_TMP_DIR
|
||||
#define GUARD_HOST_TEST_RTC_INCLUDE_RTC_TMP_DIR
|
||||
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <rtc/hip.hpp>
|
||||
#include <rtc/compile_kernel.hpp>
|
||||
#include <rtc/tmp_dir.hpp>
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <rtc/hip.hpp>
|
||||
#include <rtc/manage_ptr.hpp>
|
||||
#include <stdexcept>
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <rtc/kernel.hpp>
|
||||
#include <rtc/manage_ptr.hpp>
|
||||
#include <rtc/hip.hpp>
|
||||
#include <stdexcept>
|
||||
#include <cassert>
|
||||
|
||||
// extern declare the function since hip/hip_ext.h header is broken
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <rtc/tmp_dir.hpp>
|
||||
#include <algorithm>
|
||||
#include <random>
|
||||
|
||||
@@ -1,2 +1,2 @@
|
||||
rocm-docs-core==1.14.1
|
||||
rocm-docs-core==1.15.0
|
||||
sphinxcontrib-bibtex==2.6.3
|
||||
|
||||
@@ -199,7 +199,7 @@ requests==2.32.3
|
||||
# via
|
||||
# pygithub
|
||||
# sphinx
|
||||
rocm-docs-core==1.14.1
|
||||
rocm-docs-core==1.15.0
|
||||
# via -r requirements.in
|
||||
rpds-py==0.22.3
|
||||
# via
|
||||
|
||||
@@ -61,7 +61,7 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp64)
|
||||
|
||||
add_example_executable(example_gemm_xdl_streamk gemm_xdl_streamk.cpp)
|
||||
|
||||
list(APPEND gpu_list gfx90a gfx940 gfx941 gfx942)
|
||||
list(APPEND gpu_list gfx90a gfx940 gfx941 gfx942 gfx950)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
|
||||
@@ -31,9 +31,7 @@ using DeviceGemmInstance0 = ck::tensor_operation::device::DeviceGemmXdl
|
||||
// ######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
|
||||
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>;
|
||||
// // clang-format on
|
||||
|
||||
// clang-format off
|
||||
using DeviceGemmInstance1 = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle
|
||||
// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
|
||||
@@ -16,7 +16,7 @@ if(USE_BITINT_EXTENSION_INT4)
|
||||
add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_int4)
|
||||
endif(USE_BITINT_EXTENSION_INT4)
|
||||
|
||||
list(APPEND gpu_list gfx90a gfx940 gfx941 gfx942)
|
||||
list(APPEND gpu_list gfx90a gfx940 gfx941 gfx942 gfx950)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942 gfx950)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
|
||||
@@ -32,6 +32,56 @@ using BiasLayout = typename LayoutSettingSelector<NDimSpatial>::BiasLayout;
|
||||
template <ck::index_t NDimSpatial>
|
||||
using ResidualLayout = typename LayoutSettingSelector<NDimSpatial>::ResidualLayout;
|
||||
|
||||
#if defined(CK_USE_AMD_MFMA_GFX950)
|
||||
template <ck::index_t NDimSpatial>
|
||||
using DeviceConvFwdInstance =
|
||||
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<
|
||||
NDimSpatial,
|
||||
InputLayout<NDimSpatial>,
|
||||
WeightLayout<NDimSpatial>,
|
||||
ck::Tuple<BiasLayout<NDimSpatial>, ResidualLayout<NDimSpatial>>,
|
||||
OutputLayout<NDimSpatial>,
|
||||
InKernelDataType,
|
||||
WeiKernelDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
ck::Tuple<BiasKernelDataType, ResidualKernelDataType>,
|
||||
OutKernelDataType,
|
||||
InElementOp,
|
||||
WeiElementOp,
|
||||
OutElementOp,
|
||||
ConvSpec, // ConvForwardSpecialization
|
||||
GemmSpec, // GemmSpecialization
|
||||
1, //
|
||||
256, // BlockSize
|
||||
128, // MPerBlock
|
||||
256, // NPerBlock
|
||||
64, // KPerBlock
|
||||
16, // AK1
|
||||
16, // BK1
|
||||
32, // MPerXdl
|
||||
32, // NPerXdl
|
||||
2, // MXdlPerWave
|
||||
4, // NXdlPerWave
|
||||
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
|
||||
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
|
||||
2, // ABlockTransferSrcVectorDim
|
||||
4, // ABlockTransferSrcScalarPerVector
|
||||
4, // ABlockTransferDstScalarPerVector_AK1
|
||||
1, // ABlockLdsExtraM
|
||||
S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
|
||||
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
|
||||
2, // BBlockTransferSrcVectorDim
|
||||
4, // BBlockTransferSrcScalarPerVector
|
||||
4, // BBlockTransferDstScalarPerVector_BK1
|
||||
1, // BBlockLdsExtraN
|
||||
1,
|
||||
1,
|
||||
S<1, 16, 1, 16>,
|
||||
4>;
|
||||
#else // defined(CK_USE_AMD_MFMA_GFX950)
|
||||
template <ck::index_t NDimSpatial>
|
||||
using DeviceConvFwdInstance =
|
||||
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<
|
||||
@@ -80,6 +130,7 @@ using DeviceConvFwdInstance =
|
||||
1,
|
||||
S<1, 16, 1, 16>,
|
||||
4>;
|
||||
#endif // defined(CK_USE_AMD_MFMA_GFX950)
|
||||
|
||||
template <ck::index_t NDimSpatial>
|
||||
using HostConvFwdInstance = ck::tensor_operation::host::ReferenceConvFwd<NDimSpatial,
|
||||
|
||||
@@ -5,6 +5,6 @@ if(USE_BITINT_EXTENSION_INT4)
|
||||
add_example_executable(example_batched_gemm_gemm_xdl_int4 batched_gemm_gemm_xdl_int4.cpp)
|
||||
endif(USE_BITINT_EXTENSION_INT4)
|
||||
|
||||
if(NOT GPU_TARGETS MATCHES "gfx94" AND NOT GPU_TARGETS MATCHES "gfx1")
|
||||
if(NOT GPU_TARGETS MATCHES "gfx94" AND NOT GPU_TARGETS MATCHES "gfx95" AND NOT GPU_TARGETS MATCHES "gfx1")
|
||||
add_example_executable(example_batched_gemm_gemm_xdl_int8 batched_gemm_gemm_xdl_int8.cpp)
|
||||
endif()
|
||||
|
||||
@@ -5,6 +5,6 @@ if(USE_BITINT_EXTENSION_INT4)
|
||||
add_example_executable(example_grouped_conv_conv_fwd_xdl_int4 grouped_conv_conv_fwd_xdl_int4.cpp)
|
||||
endif(USE_BITINT_EXTENSION_INT4)
|
||||
|
||||
if(NOT GPU_TARGETS MATCHES "gfx94" AND NOT GPU_TARGETS MATCHES "gfx1")
|
||||
if(NOT GPU_TARGETS MATCHES "gfx94" AND NOT GPU_TARGETS MATCHES "gfx95" AND NOT GPU_TARGETS MATCHES "gfx1")
|
||||
add_example_executable(example_grouped_conv_conv_fwd_xdl_int8 grouped_conv_conv_fwd_xdl_int8.cpp)
|
||||
endif()
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942 gfx950)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942 gfx950)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942 gfx950)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942 gfx950)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942 gfx950)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942 gfx950)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942 gfx950)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942 gfx950)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942 gfx950)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
|
||||
5
example/67_gemm_microscaling/CMakeLists.txt
Normal file
5
example/67_gemm_microscaling/CMakeLists.txt
Normal file
@@ -0,0 +1,5 @@
|
||||
add_custom_target(example_gemm_mx)
|
||||
|
||||
add_example_executable(example_gemm_mx_fp8 gemm_mx_fp8.cpp)
|
||||
add_example_dependencies(example_gemm_mx example_gemm_mx_fp8)
|
||||
|
||||
17
example/67_gemm_microscaling/README.md
Normal file
17
example/67_gemm_microscaling/README.md
Normal file
@@ -0,0 +1,17 @@
|
||||
# GEMM Examples for Microscaling Formats
|
||||
|
||||
## example_gemm_mx_fp8
|
||||
|
||||
```bash
|
||||
# arg1: verification (0=no, 1=CPU)
|
||||
# arg2: initialization (0=no init, 1=integer value, 2=decimal value)
|
||||
# arg3: time kernel (0=no, 1=yes)
|
||||
# arg4: verbosity (0=no info, 1=verbose info)
|
||||
# arg5 to 10: M (16x), N(16x), K(16x), StrideA, StrideB, StrideC
|
||||
./bin/example_gemm_mx_fp8 1 1 0 1
|
||||
```
|
||||
|
||||
```bash
|
||||
# Implies: ./bin/example_gemm_mx_fp8 1 2 0 0
|
||||
./bin/example_gemm_mx_fp8
|
||||
```
|
||||
415
example/67_gemm_microscaling/gemm_mx_common.hpp
Normal file
415
example/67_gemm_microscaling/gemm_mx_common.hpp
Normal file
@@ -0,0 +1,415 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_ab_scale.hpp"
|
||||
#include "ck/utility/blkgemmpipe_scheduler.hpp"
|
||||
#include "ck/utility/data_type.hpp"
|
||||
#include "ck/utility/sequence.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_mx_gemm.hpp"
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/fill.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
|
||||
using ScaleDataType = ck::e8m0_bexp_t;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
struct ExecutionConfig final
|
||||
{
|
||||
int do_verification = 1; // (0=no, 1=CPU)
|
||||
int init_method = 2; // (0=no init, 1=integer value, 2=decimal value)
|
||||
bool time_kernel = false; // (0=no, 1=yes)
|
||||
int verbosity = 0; // (0=no info, 1=verbose info)
|
||||
};
|
||||
|
||||
struct ProblemSize final
|
||||
{
|
||||
ck::index_t M = 3840;
|
||||
ck::index_t N = 4096;
|
||||
ck::index_t K = 4096;
|
||||
|
||||
ck::index_t StrideA = -1;
|
||||
ck::index_t StrideB = -1;
|
||||
ck::index_t StrideC = -1;
|
||||
};
|
||||
|
||||
bool parse_cmd_args(int argc, char* argv[], ProblemSize& problem_size, ExecutionConfig& config)
|
||||
{
|
||||
if(argc == 1)
|
||||
{
|
||||
// use default case
|
||||
}
|
||||
else if(argc == 5)
|
||||
{
|
||||
config.do_verification = std::stoi(argv[1]);
|
||||
config.init_method = std::stoi(argv[2]);
|
||||
config.time_kernel = std::stoi(argv[3]);
|
||||
config.verbosity = std::stoi(argv[4]);
|
||||
}
|
||||
else if(argc == 11)
|
||||
{
|
||||
config.do_verification = std::stoi(argv[1]);
|
||||
config.init_method = std::stoi(argv[2]);
|
||||
config.time_kernel = std::stoi(argv[3]);
|
||||
config.verbosity = std::stoi(argv[4]);
|
||||
|
||||
problem_size.M = std::stoi(argv[5]);
|
||||
problem_size.N = std::stoi(argv[6]);
|
||||
problem_size.K = std::stoi(argv[7]);
|
||||
|
||||
problem_size.StrideA = std::stoi(argv[8]);
|
||||
problem_size.StrideB = std::stoi(argv[9]);
|
||||
problem_size.StrideC = std::stoi(argv[10]);
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cerr << "arg1: verification (0=no, 1=CPU)" << std::endl
|
||||
<< "arg2: initialization (0=no init, 1=integer value, 2=decimal value)"
|
||||
<< std::endl
|
||||
<< "arg3: time kernel (0=no, 1=yes)" << std::endl
|
||||
<< "arg4: verbosity (0=no info, 1=verbose info)" << std::endl
|
||||
<< "arg5 to 10: M (16x), N(16x), K(16x), StrideA, StrideB, StrideC" << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename XDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
typename CElementWiseOp,
|
||||
typename AccDataType,
|
||||
typename CShuffleDataType,
|
||||
ck::index_t MXVectorSize>
|
||||
bool run_mx_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
|
||||
{
|
||||
using ELayout = CLayout;
|
||||
using DsLayout = ck::Tuple<>;
|
||||
using DsDataType = ck::Tuple<>;
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
using CDEElementOp = CElementWiseOp;
|
||||
|
||||
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
static constexpr auto BlkGemmPSched = ck::BlockGemmPipelineScheduler::Intrawave;
|
||||
static constexpr auto BlkGemmPVer = ck::BlockGemmPipelineVersion::v3;
|
||||
|
||||
#if 1
|
||||
// XXX: These parameters should not exist in MX-native GEMM kernel
|
||||
static constexpr ck::index_t Scale_Block_M = 128;
|
||||
static constexpr ck::index_t Scale_Block_N = 128;
|
||||
#endif
|
||||
static constexpr ck::index_t Scale_Block_K = MXVectorSize;
|
||||
|
||||
// XXX: DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3 is not designed to utilize MX-specific MFMA
|
||||
// instructions.
|
||||
//
|
||||
// XXX: DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3 is not designed to utilize device-optimized
|
||||
// scaled type convert functions.
|
||||
//
|
||||
// XXX: In DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3, KPerBlock is expected to be equal to
|
||||
// ScaleBlockK (aka MXVectorSize).
|
||||
// Additionally, the following is also expected:
|
||||
// static_assert(ScaleBlockM % MPerBlock == 0);
|
||||
// static_assert(ScaleBlockN % NPerBlock == 0);
|
||||
// In MX-native GEMM kernel these requirements should be relaxed.
|
||||
//
|
||||
// XXX: It appears, by default we are using mfma_f32_16x16x4xf32
|
||||
// MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl, ComputeTypeB>::selected_mfma.k_per_blk =
|
||||
// MfmaSelector<float, 16, 16, float>::selected_mfma.k_per_blk = mfma_f32_16x16x4xf32
|
||||
// XXX: GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3 assumes scale type is float
|
||||
|
||||
// clang-format off
|
||||
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3
|
||||
// ######| ALayout| BLayout| DsLayout| CLayout| ADataType| AScale| BDataType| BScale| DsDataType| CDataType| GemmAcc| CShuffleDataType|AElementwise|BElementwise| CElementwise| GemmSpec|Block| ScaleBlockM| ScaleBlockN| ScaleBlockK| M| N| K| AK1| BK1| M| N|MXdl|NXdl|ABlockTransfer|ABlockTransfer|ABlockTransfer|ABlockTransfer|ABlockTransfer|ABlockTransfer| ABlock|BBlockTransfer|BBlockTransfer|BBlockTransfer|BBlockTransfer|BBlockTransfer|BBlockTransfer| BBlock| CShuffle| CShuffle|CShuffleBlockTransfer|CDEShuffleBlockTransfer| BlkGemm| BlkGemm|ComputeTypeA|ComputeTypeB|LDSTypeA|LDSTypeB|
|
||||
// ######| | | | | | DataType| | DataType| | | DataType| | Operation| Operation| Operation| | Size| | | | Per| Per| Per| | | Per| Per| Per| Per| ThreadCluster| ThreadCluster|SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar|LdsExtraM| ThreadCluster| ThreadCluster|SrcAccessOrder| SrcVector| SrcScalar| DstScalar|LdsExtraN| MXdl| NXdl| ClusterLengths| Scalar| PipeSched| PipelineVer| | | | |
|
||||
// ######| | | | | | | | | | | | | | | | | | | | |Block|Block| Block| | | XDL| XDL|Wave|Wave| Lengths| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths| ArrangeOrder| | Dim| PerVector| PerVector_BK1| | PerWave| PerWave| MBlock_MPerBlock| PerVectors| | | | | | |
|
||||
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | AK0_M_AK1| | | | | | | BK0_N_BK1| | | | | |PerShuffle|PerShuffle| NBlock_NPerBlock| | | | | | | |
|
||||
< ALayout, BLayout, DsLayout, ELayout, ADataType, XDataType, BDataType, XDataType, DsDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, Scale_Block_M, Scale_Block_N, Scale_Block_K, 128, 128, 128, 16, 16, 16, 16, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlkGemmPSched, BlkGemmPVer, float, float, float, float>;
|
||||
// clang-format on
|
||||
|
||||
auto M = problem_size.M;
|
||||
auto N = problem_size.N;
|
||||
auto K = problem_size.K;
|
||||
auto StrideA = problem_size.StrideA;
|
||||
auto StrideB = problem_size.StrideB;
|
||||
auto StrideC = problem_size.StrideC;
|
||||
|
||||
auto f_host_tensor_descriptor =
|
||||
[](ck::index_t row, ck::index_t col, ck::index_t stride, auto layout) {
|
||||
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return HostTensorDescriptor({row, col}, {stride, 1});
|
||||
}
|
||||
else
|
||||
{
|
||||
return HostTensorDescriptor({row, col}, {1, stride});
|
||||
}
|
||||
};
|
||||
|
||||
auto f_get_default_stride =
|
||||
[](ck::index_t row, ck::index_t col, ck::index_t stride, auto layout) {
|
||||
if(stride == -1)
|
||||
{
|
||||
// give a chance if stride is -1, return a default packed stride
|
||||
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return static_cast<ck::index_t>(col);
|
||||
}
|
||||
else
|
||||
{
|
||||
return static_cast<ck::index_t>(row);
|
||||
}
|
||||
}
|
||||
else
|
||||
return static_cast<ck::index_t>(stride);
|
||||
};
|
||||
|
||||
StrideA = f_get_default_stride(M, K, StrideA, ALayout{});
|
||||
StrideB = f_get_default_stride(K, N, StrideB, BLayout{});
|
||||
StrideC = f_get_default_stride(M, N, StrideC, CLayout{});
|
||||
|
||||
if(K % Scale_Block_K != 0)
|
||||
{
|
||||
throw std::runtime_error("wrong! K must be multiple of Scale_Block_K (16 or 32)");
|
||||
};
|
||||
|
||||
auto Scale_Stride_AM = f_get_default_stride(M, K / Scale_Block_K, StrideA, ALayout{});
|
||||
auto Scale_Stride_BN = f_get_default_stride(K / Scale_Block_K, N, StrideB, BLayout{});
|
||||
|
||||
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
|
||||
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
|
||||
|
||||
Tensor<XDataType> a_m_k_scale(
|
||||
f_host_tensor_descriptor(M, K / Scale_Block_K, Scale_Stride_AM, ALayout{})); // scales for A
|
||||
Tensor<XDataType> b_k_n_scale(
|
||||
f_host_tensor_descriptor(K / Scale_Block_K, N, Scale_Stride_BN, BLayout{})); // scales for B
|
||||
|
||||
Tensor<CDataType> c_m_n_host_result(
|
||||
f_host_tensor_descriptor(M, N, StrideC, CLayout{})); // host verification
|
||||
Tensor<CDataType> c_m_n_device_result(
|
||||
f_host_tensor_descriptor(M, N, StrideC, CLayout{})); // device result downloaded to host
|
||||
|
||||
if(config.verbosity >= 0)
|
||||
{
|
||||
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
|
||||
std::cout << "a_m_k_scale: " << a_m_k_scale.mDesc << std::endl;
|
||||
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
|
||||
std::cout << "b_k_n_scale: " << b_k_n_scale.mDesc << std::endl;
|
||||
std::cout << "c_m_n_device_result: " << c_m_n_device_result.mDesc << std::endl;
|
||||
}
|
||||
|
||||
switch(config.init_method)
|
||||
{
|
||||
case 0:
|
||||
if(config.verbosity > 0)
|
||||
{
|
||||
std::cout << "NOTE: No input data initialization." << std::endl;
|
||||
}
|
||||
break;
|
||||
case 1:
|
||||
case 2:
|
||||
ck::utils::FillConstant<ADataType>{ck::type_convert<ADataType>(1.0f)}(a_m_k);
|
||||
ck::utils::FillConstant<XDataType>{ck::type_convert<XDataType>(0.5f)}(a_m_k_scale);
|
||||
ck::utils::FillConstant<BDataType>{ck::type_convert<BDataType>(1.0f)}(b_k_n);
|
||||
ck::utils::FillConstant<XDataType>{ck::type_convert<XDataType>(2.0f)}(b_k_n_scale);
|
||||
if(config.verbosity > 0)
|
||||
{
|
||||
std::cout << "Init A = {1}" << std::endl;
|
||||
std::cout << "Init A scale = {0.5}" << std::endl;
|
||||
std::cout << "Init B = {1}" << std::endl;
|
||||
std::cout << "Init B scale = {2.0}" << std::endl;
|
||||
std::cout << "Expect C = {K}" << std::endl;
|
||||
}
|
||||
break;
|
||||
|
||||
default:
|
||||
if(config.verbosity > 0)
|
||||
{
|
||||
std::cout << "NOTE: No input data initialization." << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
if(config.verbosity > 0)
|
||||
std::cout << "Device memory allocation..." << std::endl;
|
||||
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
|
||||
DeviceMem a_scale_device_buf(sizeof(XDataType) * a_m_k_scale.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b_scale_device_buf(sizeof(XDataType) * b_k_n_scale.mDesc.GetElementSpaceSize());
|
||||
DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize());
|
||||
|
||||
if(config.verbosity > 0)
|
||||
std::cout << "Upload data to device..." << std::endl;
|
||||
a_device_buf.ToDevice(a_m_k.mData.data());
|
||||
a_scale_device_buf.ToDevice(a_m_k_scale.mData.data());
|
||||
b_device_buf.ToDevice(b_k_n.mData.data());
|
||||
b_scale_device_buf.ToDevice(b_k_n_scale.mData.data());
|
||||
if(config.verbosity > 0)
|
||||
std::cout << "Done." << std::endl;
|
||||
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b_element_op = BElementOp{};
|
||||
auto cde_element_op = CDEElementOp{};
|
||||
|
||||
constexpr ck::index_t NumDTensor = DsDataType::Size();
|
||||
|
||||
// do GEMM
|
||||
auto device_op = DeviceOpInstance{};
|
||||
auto invoker = device_op.MakeInvoker();
|
||||
auto argument = device_op.MakeArgument(a_device_buf.GetDeviceBuffer(),
|
||||
b_device_buf.GetDeviceBuffer(),
|
||||
std::array<const void*, NumDTensor>{},
|
||||
c_device_buf.GetDeviceBuffer(),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
std::array<ck::index_t, NumDTensor>{},
|
||||
StrideC,
|
||||
a_scale_device_buf.GetDeviceBuffer(),
|
||||
b_scale_device_buf.GetDeviceBuffer(),
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op);
|
||||
|
||||
if(!device_op.IsSupportedArgument(argument))
|
||||
{
|
||||
throw std::runtime_error("wrong!\n"
|
||||
"Provided combination of compilation and runtime parameters is "
|
||||
"not consistent with the supported device_gemm arguments.");
|
||||
}
|
||||
|
||||
if(config.verbosity > 0)
|
||||
std::cout << "Computing GEMM on device..." << std::endl;
|
||||
float ave_time =
|
||||
invoker.Run(argument, StreamConfig{nullptr, config.time_kernel, config.verbosity, 20, 50});
|
||||
|
||||
bool res_verified = true;
|
||||
if(config.do_verification > 0)
|
||||
{
|
||||
c_device_buf.FromDevice(c_m_n_device_result.mData.data());
|
||||
if(config.verbosity > 0)
|
||||
{
|
||||
std::cout << "Done." << std::endl;
|
||||
std::cout << "Computing GEMM on host..." << std::endl;
|
||||
}
|
||||
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceMXGemm<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
float,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
float,
|
||||
float>;
|
||||
auto ref_gemm = ReferenceGemmInstance{};
|
||||
auto ref_invoker = ref_gemm.MakeInvoker();
|
||||
|
||||
auto ref_argument = ref_gemm.MakeArgument(a_m_k,
|
||||
a_m_k_scale,
|
||||
b_k_n,
|
||||
b_k_n_scale,
|
||||
c_m_n_host_result,
|
||||
PassThrough{},
|
||||
PassThrough{},
|
||||
PassThrough{});
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
if(config.verbosity > 0)
|
||||
{
|
||||
std::cout << "Done." << std::endl;
|
||||
std::cout << "Comparing results..." << std::endl;
|
||||
}
|
||||
|
||||
if(config.init_method == 1)
|
||||
{
|
||||
res_verified =
|
||||
res_verified && std::abs(static_cast<float>(K) - c_m_n_device_result(0, 0)) <= 0.0f;
|
||||
std::cout << "Expected vs Computed: " << 1.0f * K << " vs " << c_m_n_device_result(0, 0)
|
||||
<< ((res_verified) ? " (PASSED!)" : " (FAILED!)") << std::endl;
|
||||
}
|
||||
|
||||
res_verified = res_verified && ck::utils::check_err(c_m_n_device_result,
|
||||
c_m_n_host_result,
|
||||
"Error: Incorrect results!");
|
||||
|
||||
if(config.verbosity > 0 && res_verified)
|
||||
std::cout << "Done." << std::endl;
|
||||
}
|
||||
else
|
||||
{
|
||||
if(config.verbosity > 0)
|
||||
std::cout << "Done." << std::endl;
|
||||
}
|
||||
|
||||
if(config.time_kernel)
|
||||
{
|
||||
std::size_t flop = std::size_t(2) * M * N * K + M * K + K * N; // GEMM + A scale + B scale
|
||||
std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
|
||||
sizeof(CDataType) * M * N +
|
||||
sizeof(XDataType) * (M * K + K * N) / Scale_Block_K;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
|
||||
<< " GB/s" << std::endl;
|
||||
}
|
||||
|
||||
return res_verified;
|
||||
}
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename XDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
typename CElementWiseOp,
|
||||
typename AccDataType,
|
||||
typename CShuffleDataType,
|
||||
ck::index_t MXVectorSize>
|
||||
bool run_mx_gemm_example(int argc, char* argv[])
|
||||
{
|
||||
ProblemSize problem_size;
|
||||
ExecutionConfig config;
|
||||
|
||||
return parse_cmd_args(argc, argv, problem_size, config) &&
|
||||
run_mx_gemm<ADataType,
|
||||
BDataType,
|
||||
XDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
CElementWiseOp,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
MXVectorSize>(problem_size, config);
|
||||
}
|
||||
41
example/67_gemm_microscaling/gemm_mx_fp8.cpp
Normal file
41
example/67_gemm_microscaling/gemm_mx_fp8.cpp
Normal file
@@ -0,0 +1,41 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "gemm_mx_common.hpp"
|
||||
|
||||
using ADataType = ck::f8_t;
|
||||
using BDataType = ck::f8_t;
|
||||
#if 1
|
||||
// XXX: MX-native GEMM kernel will work with e8m0_bexp_t scale type
|
||||
using XDataType = float;
|
||||
#else
|
||||
using XDataType = ck::e8m0_bexp_t;
|
||||
#endif
|
||||
using AccDataType = float;
|
||||
using CShuffleDataType = float;
|
||||
using CDataType = float;
|
||||
|
||||
using ALayout = Row;
|
||||
using BLayout = Col;
|
||||
using CLayout = Row;
|
||||
|
||||
using CElementOp = PassThrough; // elementwise transformation for C matrix
|
||||
|
||||
constexpr ck::index_t mx_vector_size = 128; // scaling block size
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
return run_mx_gemm_example<ADataType,
|
||||
BDataType,
|
||||
XDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
CElementOp,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
mx_vector_size>(argc, argv)
|
||||
? 0
|
||||
: -1;
|
||||
}
|
||||
@@ -23,34 +23,34 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME)
|
||||
message("adding example ${EXAMPLE_NAME}")
|
||||
set(result 1)
|
||||
if(DEFINED DTYPES)
|
||||
foreach(source IN LISTS FILE_NAME)
|
||||
set(test 0)
|
||||
if((source MATCHES "_fp16" OR source MATCHES "_f16") AND NOT "fp16" IN_LIST DTYPES)
|
||||
set(test 1)
|
||||
endif()
|
||||
if((source MATCHES "_fp32" OR source MATCHES "_f32") AND NOT "fp32" IN_LIST DTYPES)
|
||||
set(test 1)
|
||||
endif()
|
||||
if((source MATCHES "_fp64" OR source MATCHES "_f64") AND NOT "fp64" IN_LIST DTYPES)
|
||||
set(test 1)
|
||||
endif()
|
||||
if((source MATCHES "_fp8" OR source MATCHES "_f8") AND NOT "fp8" IN_LIST DTYPES)
|
||||
set(test 1)
|
||||
endif()
|
||||
if((source MATCHES "_bf8" OR source MATCHES "_bf8") AND NOT "bf8" IN_LIST DTYPES)
|
||||
set(test 1)
|
||||
endif()
|
||||
if((source MATCHES "_bf16" OR source MATCHES "_b16") AND NOT "bf16" IN_LIST DTYPES)
|
||||
set(test 1)
|
||||
endif()
|
||||
if((source MATCHES "_int8" OR source MATCHES "_i8") AND NOT "int8" IN_LIST DTYPES)
|
||||
set(test 1)
|
||||
endif()
|
||||
if(test EQUAL 1)
|
||||
message("removing example source file ${source} ")
|
||||
list(REMOVE_ITEM FILE_NAME "${source}")
|
||||
endif()
|
||||
endforeach()
|
||||
foreach(source IN LISTS FILE_NAME)
|
||||
set(test 0)
|
||||
if((source MATCHES "_fp16" OR source MATCHES "_f16") AND NOT "fp16" IN_LIST DTYPES)
|
||||
set(test 1)
|
||||
endif()
|
||||
if((source MATCHES "_fp32" OR source MATCHES "_f32") AND NOT "fp32" IN_LIST DTYPES)
|
||||
set(test 1)
|
||||
endif()
|
||||
if((source MATCHES "_fp64" OR source MATCHES "_f64") AND NOT "fp64" IN_LIST DTYPES)
|
||||
set(test 1)
|
||||
endif()
|
||||
if((source MATCHES "_fp8" OR source MATCHES "_f8") AND NOT "fp8" IN_LIST DTYPES)
|
||||
set(test 1)
|
||||
endif()
|
||||
if((source MATCHES "_bf8" OR source MATCHES "_bf8") AND NOT "bf8" IN_LIST DTYPES)
|
||||
set(test 1)
|
||||
endif()
|
||||
if((source MATCHES "_bf16" OR source MATCHES "_b16") AND NOT "bf16" IN_LIST DTYPES)
|
||||
set(test 1)
|
||||
endif()
|
||||
if((source MATCHES "_int8" OR source MATCHES "_i8") AND NOT "int8" IN_LIST DTYPES)
|
||||
set(test 1)
|
||||
endif()
|
||||
if(test EQUAL 1)
|
||||
message("removing example source file ${source} ")
|
||||
list(REMOVE_ITEM FILE_NAME "${source}")
|
||||
endif()
|
||||
endforeach()
|
||||
endif()
|
||||
|
||||
set(EX_TARGETS ${SUPPORTED_GPU_TARGETS})
|
||||
@@ -83,6 +83,13 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME)
|
||||
list(REMOVE_ITEM FILE_NAME "${source}")
|
||||
endif()
|
||||
endforeach()
|
||||
#Do not build any microscaling examples if gfx950 target is not on the list
|
||||
foreach(source IN LISTS FILE_NAME)
|
||||
if(NOT EX_TARGETS MATCHES "gfx950" AND source MATCHES "_mx")
|
||||
message("removing microscaling example ${source} ")
|
||||
list(REMOVE_ITEM FILE_NAME "${source}")
|
||||
endif()
|
||||
endforeach()
|
||||
#Do not build any FP8 examples if CK_ENABLE_FP8 not set
|
||||
foreach(source IN LISTS FILE_NAME)
|
||||
if(NOT DEFINED CK_ENABLE_FP8 AND source MATCHES "_fp8")
|
||||
@@ -102,7 +109,9 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME)
|
||||
if(FILE_NAME MATCHES "_xdl")
|
||||
list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic)
|
||||
elseif(FILE_NAME MATCHES "_wmma")
|
||||
list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030)
|
||||
list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030 gfx950)
|
||||
elseif(FILE_NAME MATCHES "_mx") #only build mx example for gfx950
|
||||
list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic)
|
||||
endif()
|
||||
set_source_files_properties(${FILE_NAME} PROPERTIES LANGUAGE HIP)
|
||||
add_executable(${EXAMPLE_NAME} ${FILE_NAME})
|
||||
@@ -195,7 +204,7 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME)
|
||||
if(FILE_NAME MATCHES "_xdl")
|
||||
list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic)
|
||||
elseif(FILE_NAME MATCHES "_wmma")
|
||||
list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030)
|
||||
list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030 gfx950)
|
||||
endif()
|
||||
set_source_files_properties(${FILE_NAME} PROPERTIES LANGUAGE HIP)
|
||||
add_executable(${EXAMPLE_NAME} ${FILE_NAME})
|
||||
|
||||
@@ -12,7 +12,13 @@
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "gemm_basic.hpp"
|
||||
|
||||
template <typename ALayout, typename BLayout, typename CLayout>
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout>
|
||||
float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
|
||||
{
|
||||
// The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part.
|
||||
@@ -20,16 +26,12 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
|
||||
constexpr bool kPadN = false;
|
||||
constexpr bool kPadK = false;
|
||||
|
||||
constexpr bool kTilePermute = false;
|
||||
// The rank and permutation will also be generate out by the CodeGen part.
|
||||
constexpr ck_tile::index_t kOutputRank = 2;
|
||||
|
||||
constexpr int kBlockPerCu = 1;
|
||||
|
||||
// This part comes from the Codegen
|
||||
constexpr ck_tile::index_t M_Tile = 128;
|
||||
constexpr ck_tile::index_t N_Tile = 128;
|
||||
constexpr ck_tile::index_t K_Tile = 32;
|
||||
constexpr ck_tile::index_t K_Tile = 64;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp = 2;
|
||||
constexpr ck_tile::index_t N_Warp = 2;
|
||||
@@ -37,40 +39,33 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
|
||||
|
||||
constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
constexpr ck_tile::index_t K_Warp_Tile = 8;
|
||||
|
||||
// Whether doing the CShuffle (transpose before the global memory), depending on the output
|
||||
// layout.
|
||||
constexpr bool CShuffleEpilogue =
|
||||
std::is_same_v<CLayout, ck_tile::tensor_layout::gemm::ColumnMajor>;
|
||||
constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
|
||||
using CodegenGemmShape =
|
||||
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
|
||||
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
|
||||
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
|
||||
|
||||
using TilePartitioner = ck_tile::GemmTile2DPartitioner<CodegenGemmShape>;
|
||||
|
||||
using GemmEpilogue = std::conditional_t<
|
||||
CShuffleEpilogue,
|
||||
ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<AccDataType,
|
||||
CDataType,
|
||||
kPadM,
|
||||
kPadN,
|
||||
kTilePermute,
|
||||
kOutputRank,
|
||||
1,
|
||||
0,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock>>,
|
||||
ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadM, kPadN>>>;
|
||||
using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenGemmShape>;
|
||||
|
||||
using CodegenGemmTraits =
|
||||
ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
|
||||
using CodegenPipelineProblem = ck_tile::
|
||||
GemmPipelineProblem<ADataType, BDataType, AccDataType, CodegenGemmShape, CodegenGemmTraits>;
|
||||
using CodegenGemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<AccDataType,
|
||||
CDataType,
|
||||
CLayout,
|
||||
CodegenPipelineProblem::kBlockSize,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
M_Warp,
|
||||
N_Warp,
|
||||
M_Warp_Tile,
|
||||
N_Warp_Tile,
|
||||
K_Warp_Tile,
|
||||
CodegenPipelineProblem::TransposeC>>;
|
||||
// ToDo: Will add the codegen part to test different pipeline policies in GEMM.
|
||||
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
|
||||
using Kernel = ck_tile::GemmKernel<TilePartitioner, CodegenGemmPipeline, GemmEpilogue>;
|
||||
@@ -110,12 +105,32 @@ int run_gemm_example(int argc, char* argv[])
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
std::string a_layout = arg_parser.get_str("a_layout");
|
||||
std::string b_layout = arg_parser.get_str("b_layout");
|
||||
std::string data_type = arg_parser.get_str("prec");
|
||||
std::string a_layout = arg_parser.get_str("a_layout");
|
||||
std::string b_layout = arg_parser.get_str("b_layout");
|
||||
|
||||
if(a_layout == "R" && b_layout == "C")
|
||||
{
|
||||
return run_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{});
|
||||
if(data_type == "fp16")
|
||||
{
|
||||
return run_gemm_example_with_layouts<ck_tile::half_t>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(data_type == "bf16")
|
||||
{
|
||||
return run_gemm_example_with_layouts<ck_tile::bf16_t>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(data_type == "fp8")
|
||||
{
|
||||
return run_gemm_example_with_layouts<ck_tile::fp8_t>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(data_type == "bf8")
|
||||
{
|
||||
return run_gemm_example_with_layouts<ck_tile::bf8_t>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data_type!");
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
@@ -18,7 +18,7 @@
|
||||
#define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_COMPUTE
|
||||
#endif
|
||||
|
||||
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
|
||||
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE)
|
||||
#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrMem
|
||||
#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrMem
|
||||
#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Interwave
|
||||
@@ -43,6 +43,33 @@ struct GemmBasicTypeConfig<ck_tile::half_t>
|
||||
// ToDo: Add more bias config to support different categories of GEMM.
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GemmBasicTypeConfig<ck_tile::bf16_t>
|
||||
{
|
||||
using ADataType = ck_tile::bf16_t;
|
||||
using BDataType = ck_tile::bf16_t;
|
||||
using AccDataType = float;
|
||||
using CDataType = ck_tile::bf16_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GemmBasicTypeConfig<ck_tile::fp8_t>
|
||||
{
|
||||
using ADataType = ck_tile::fp8_t;
|
||||
using BDataType = ck_tile::fp8_t;
|
||||
using AccDataType = float;
|
||||
using CDataType = ck_tile::half_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GemmBasicTypeConfig<ck_tile::bf8_t>
|
||||
{
|
||||
using ADataType = ck_tile::bf8_t;
|
||||
using BDataType = ck_tile::bf8_t;
|
||||
using AccDataType = float;
|
||||
using CDataType = ck_tile::half_t;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct DataTypeTraits;
|
||||
|
||||
@@ -64,13 +91,23 @@ struct DataTypeTraits<ck_tile::half_t>
|
||||
static constexpr const char* name = "fp16";
|
||||
};
|
||||
|
||||
using Types = GemmBasicTypeConfig<ck_tile::half_t>;
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::bf16_t>
|
||||
{
|
||||
static constexpr const char* name = "bf16";
|
||||
};
|
||||
|
||||
// Specific type aliases for easy access
|
||||
using ADataType = Types::ADataType;
|
||||
using BDataType = Types::BDataType;
|
||||
using AccDataType = Types::AccDataType;
|
||||
using CDataType = Types::CDataType;
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::fp8_t>
|
||||
{
|
||||
static constexpr const char* name = "fp8";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::bf8_t>
|
||||
{
|
||||
static constexpr const char* name = "bf8";
|
||||
};
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
@@ -79,7 +116,7 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("n", "4096", "n dimension")
|
||||
.insert("k", "2048", "k dimension")
|
||||
.insert("a_layout", "R", "A tensor data layout - Row by default")
|
||||
.insert("b_layout", "R", "B tensor data layout - Row by default")
|
||||
.insert("b_layout", "C", "B tensor data layout - Column by default")
|
||||
.insert("c_layout", "R", "C tensor data layout - Row by default")
|
||||
.insert("stride_a", "0", "Tensor A stride")
|
||||
.insert("stride_b", "0", "Tensor B stride")
|
||||
|
||||
@@ -9,6 +9,7 @@ static constexpr inline auto is_row_major(Layout layout_)
|
||||
ck_tile::tensor_layout::gemm::RowMajor>>{};
|
||||
}
|
||||
|
||||
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
|
||||
auto calculate_rtol_atol(const ck_tile::index_t K,
|
||||
const ck_tile::index_t kbatch,
|
||||
const float max_accumulated_value)
|
||||
@@ -29,7 +30,8 @@ auto calculate_rtol_atol(const ck_tile::index_t K,
|
||||
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
|
||||
}
|
||||
|
||||
template <typename ALayout, typename BLayout, typename CLayout>
|
||||
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType,
|
||||
typename ALayout, typename BLayout, typename CLayout>
|
||||
float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
|
||||
ck_tile::DeviceMem& b_k_n_dev_buf,
|
||||
ck_tile::DeviceMem& c_m_n_dev_buf,
|
||||
@@ -55,7 +57,8 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
|
||||
args.stride_B = stride_B;
|
||||
args.stride_C = stride_C;
|
||||
|
||||
float ave_time = gemm_calc<ALayout, BLayout, CLayout>(
|
||||
float ave_time = gemm_calc<ADataType, BDataType, AccDataType, CDataType,
|
||||
ALayout, BLayout, CLayout>(
|
||||
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
|
||||
|
||||
std::size_t flop = std::size_t(2) * M * N * K;
|
||||
@@ -66,13 +69,19 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
|
||||
|
||||
std::cout << "Run Gemm kernel with M =" << M << " N =" << N << " K =" << K
|
||||
<< " StrideA =" << stride_A << " StrideB =" << stride_B << " StrideC =" << stride_C
|
||||
<< " A_Layout =" << ALayout::name
|
||||
<< " B_Layout =" << BLayout::name
|
||||
<< " C_Layout =" << CLayout::name
|
||||
<< " A Type = " << DataTypeTraits<ADataType>::name
|
||||
<< " B Type = " << DataTypeTraits<BDataType>::name
|
||||
<< " C Type = " << DataTypeTraits<CDataType>::name
|
||||
<< " : " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
|
||||
<< std::endl;
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
template <typename ALayout, typename BLayout, typename CLayout>
|
||||
template <typename PrecType, typename ALayout, typename BLayout, typename CLayout>
|
||||
int run_gemm_example_with_layouts(int argc,
|
||||
char* argv[],
|
||||
const ALayout a_layout = ALayout{},
|
||||
@@ -83,6 +92,11 @@ int run_gemm_example_with_layouts(int argc,
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
using ADataType = typename GemmBasicTypeConfig<PrecType>::ADataType;
|
||||
using BDataType = typename GemmBasicTypeConfig<PrecType>::BDataType;
|
||||
using CDataType = typename GemmBasicTypeConfig<PrecType>::CDataType;
|
||||
using AccDataType = typename GemmBasicTypeConfig<PrecType>::AccDataType;
|
||||
|
||||
ck_tile::index_t M = arg_parser.get_int("m");
|
||||
ck_tile::index_t N = arg_parser.get_int("n");
|
||||
ck_tile::index_t K = arg_parser.get_int("k");
|
||||
@@ -119,7 +133,8 @@ int run_gemm_example_with_layouts(int argc,
|
||||
c_m_n_dev_buf.SetZero();
|
||||
c_m_n_dev_result.SetZero();
|
||||
|
||||
invoke_gemm<ALayout, BLayout, CLayout>(a_m_k_dev_buf,
|
||||
invoke_gemm<ADataType, BDataType, AccDataType, CDataType,
|
||||
ALayout, BLayout, CLayout>(a_m_k_dev_buf,
|
||||
b_k_n_dev_buf,
|
||||
c_m_n_dev_buf,
|
||||
M,
|
||||
@@ -145,7 +160,8 @@ int run_gemm_example_with_layouts(int argc,
|
||||
a_m_k, b_k_n, c_m_n_host_ref);
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
|
||||
const auto rtol_atol = calculate_rtol_atol(K, kbatch, max_accumulated_value);
|
||||
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>
|
||||
(K, kbatch, max_accumulated_value);
|
||||
pass = ck_tile::check_err(c_m_n_dev_result,
|
||||
c_m_n_host_ref,
|
||||
"Error: Incorrect results!",
|
||||
@@ -202,7 +218,8 @@ int run_gemm_example_with_layouts(int argc,
|
||||
c_m_n_gpu_buf_ref.FromDevice(c_m_n_gpu_ref.data());
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(c_m_n_gpu_ref.mData.begin(), c_m_n_gpu_ref.mData.end());
|
||||
const auto rtol_atol = calculate_rtol_atol(K, kbatch, max_accumulated_value);
|
||||
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>
|
||||
(K, kbatch, max_accumulated_value);
|
||||
pass = ck_tile::check_err(c_m_n_dev_result,
|
||||
c_m_n_gpu_ref,
|
||||
"Error: Incorrect results!",
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
#!/bin/sh
|
||||
EXE="$(find . -name tile_example_gemm_basic -type f | head -n 1)"
|
||||
VALID=0
|
||||
VALID=1
|
||||
|
||||
for b_matrix_layout in "R" "C"; do
|
||||
|
||||
for b_matrix_layout in "C"; do
|
||||
for m in "64" "512" "1024" "2048"; do
|
||||
for n in "512" "1024" "2048"; do
|
||||
for k in "64" "512" "1024" "2048"; do
|
||||
$EXE -prec=fp16 -b=1 -m=$m -n=$n -k=$k -a_layout="R" -b_layout="$b_matrix_layout" -c_layout="R" -v=$VALID
|
||||
$EXE -prec=fp16 -m=$m -n=$n -k=$k -a_layout="R" -b_layout="$b_matrix_layout" -c_layout="R" -v=$VALID
|
||||
done
|
||||
done
|
||||
done
|
||||
|
||||
14
example/ck_tile/03_gemm/script/benchmark_basic_fp8.sh
Normal file
14
example/ck_tile/03_gemm/script/benchmark_basic_fp8.sh
Normal file
@@ -0,0 +1,14 @@
|
||||
#!/bin/sh
|
||||
EXE="$(find . -name tile_example_gemm_basic -type f | head -n 1)"
|
||||
VALID=1
|
||||
|
||||
|
||||
for b_matrix_layout in "C"; do
|
||||
for m in "64" "512" "1024" "2048"; do
|
||||
for n in "512" "1024" "2048"; do
|
||||
for k in "64" "512" "1024" "2048"; do
|
||||
$EXE -prec=fp8 -m=$m -n=$n -k=$k -a_layout="R" -b_layout="$b_matrix_layout" -c_layout="R" -v=$VALID
|
||||
done
|
||||
done
|
||||
done
|
||||
done
|
||||
@@ -1,12 +1,12 @@
|
||||
#!/bin/sh
|
||||
EXE="$(find . -name tile_example_gemm_universal -type f | head -n 1)"
|
||||
VALID=0
|
||||
VALID=1
|
||||
|
||||
for b_matrix_layout in "R" "C"; do
|
||||
for m in "64" "512" "1024" "2048"; do
|
||||
for b_matrix_layout in "C"; do
|
||||
for m in "512" "1024" "2048" "4096"; do
|
||||
for n in "512" "1024" "2048"; do
|
||||
for k in "64" "512" "1024" "2048"; do
|
||||
$EXE -prec=fp16 -b=1 -m=$m -n=$n -k=$k -a_layout="R" -b_layout="$b_matrix_layout" -c_layout="R" -v=$VALID
|
||||
for k in "512" "1024" "2048"; do
|
||||
$EXE -prec=fp16 -m=$m -n=$n -k=$k -a_layout="R" -b_layout="$b_matrix_layout" -c_layout="R" -v=$VALID
|
||||
done
|
||||
done
|
||||
done
|
||||
|
||||
@@ -0,0 +1,13 @@
|
||||
#!/bin/sh
|
||||
EXE="$(find . -name tile_example_gemm_universal -type f | head -n 1)"
|
||||
VALID=1
|
||||
|
||||
for b_matrix_layout in "C"; do
|
||||
for m in "512" "1024" "2048" "4096"; do
|
||||
for n in "512" "1024" "2048"; do
|
||||
for k in "512" "1024" "2048"; do
|
||||
$EXE -prec=bf16 -m=$m -n=$n -k=$k -a_layout="R" -b_layout="$b_matrix_layout" -c_layout="R" -v=$VALID
|
||||
done
|
||||
done
|
||||
done
|
||||
done
|
||||
13
example/ck_tile/03_gemm/script/benchmark_mem_pipeline_bf8.sh
Normal file
13
example/ck_tile/03_gemm/script/benchmark_mem_pipeline_bf8.sh
Normal file
@@ -0,0 +1,13 @@
|
||||
#!/bin/sh
|
||||
EXE="$(find . -name tile_example_gemm_universal -type f | head -n 1)"
|
||||
VALID=1
|
||||
|
||||
for b_matrix_layout in "C"; do
|
||||
for m in "512" "1024" "2048" "4096"; do
|
||||
for n in "512" "1024" "2048"; do
|
||||
for k in "512" "1024" "2048"; do
|
||||
$EXE -prec=bf8 -m=$m -n=$n -k=$k -a_layout="R" -b_layout="$b_matrix_layout" -c_layout="R" -v=$VALID
|
||||
done
|
||||
done
|
||||
done
|
||||
done
|
||||
13
example/ck_tile/03_gemm/script/benchmark_mem_pipeline_fp8.sh
Normal file
13
example/ck_tile/03_gemm/script/benchmark_mem_pipeline_fp8.sh
Normal file
@@ -0,0 +1,13 @@
|
||||
#!/bin/sh
|
||||
EXE="$(find . -name tile_example_gemm_universal -type f | head -n 1)"
|
||||
VALID=1
|
||||
|
||||
for b_matrix_layout in "C"; do
|
||||
for m in "512" "1024" "2048" "4096"; do
|
||||
for n in "512" "1024" "2048"; do
|
||||
for k in "512" "1024" "2048"; do
|
||||
$EXE -prec=fp8 -m=$m -n=$n -k=$k -a_layout="R" -b_layout="$b_matrix_layout" -c_layout="R" -v=$VALID
|
||||
done
|
||||
done
|
||||
done
|
||||
done
|
||||
@@ -7,22 +7,20 @@ export CK_REPEAT=1
|
||||
|
||||
COMMON_ARGS='-v=2 -warmup=0 -repeat=1'
|
||||
|
||||
run_fp16_tests() {
|
||||
for batch in 1 2; do
|
||||
for m in 128 1024; do
|
||||
for n in 128 2048; do
|
||||
for k in 32 64; do
|
||||
run_tests() {
|
||||
for m in 128 1024; do
|
||||
for n in 128 2048; do
|
||||
for k in 64 128; do
|
||||
|
||||
$EXE -b=$batch -m=$m -n=$n -k=$k -stride_a=0 -stride_b=0 -stride_c=0 -e=1e-5 -prec=fp16 $COMMON_ARGS
|
||||
if [ $? -eq 0 ]; then
|
||||
echo "Success: Test with batch=$batch, m=$m, n=$n, k=$k executed successfully."
|
||||
else
|
||||
echo "Error: Test with batch=$batch, m=$m, n=$n, k=$k failed to execute properly."
|
||||
# Optionally, exit or break if you need to halt further execution
|
||||
# exit 1
|
||||
fi
|
||||
$EXE -m=$m -n=$n -k=$k -stride_a=0 -stride_b=0 -stride_c=0 -prec=$1 $COMMON_ARGS
|
||||
if [ $? -eq 0 ]; then
|
||||
echo "Success: Test with m=$m, n=$n, k=$k executed successfully."
|
||||
else
|
||||
echo "Error: Test with m=$m, n=$n, k=$k failed to execute properly."
|
||||
# Optionally, exit or break if you need to halt further execution
|
||||
# exit 1
|
||||
fi
|
||||
|
||||
done
|
||||
done
|
||||
done
|
||||
done
|
||||
@@ -30,6 +28,9 @@ run_fp16_tests() {
|
||||
|
||||
set -x
|
||||
|
||||
run_fp16_tests
|
||||
run_tests "fp16"
|
||||
run_tests "bf16"
|
||||
run_tests "fp8"
|
||||
run_tests "bf8"
|
||||
|
||||
set +x
|
||||
|
||||
@@ -7,22 +7,20 @@ export CK_REPEAT=1
|
||||
|
||||
COMMON_ARGS='-v=2 -warmup=0 -repeat=1'
|
||||
|
||||
run_fp16_tests() {
|
||||
for batch in 1 2; do
|
||||
for m in 128 1024; do
|
||||
for n in 128 2048; do
|
||||
for k in 32 64; do
|
||||
run_tests() {
|
||||
for m in 512 1024; do
|
||||
for n in 512 2048; do
|
||||
for k in 512 1024; do
|
||||
|
||||
$EXE -b=$batch -m=$m -n=$n -k=$k -stride_a=0 -stride_b=0 -stride_c=0 -e=1e-5 -prec=fp16 $COMMON_ARGS
|
||||
if [ $? -eq 0 ]; then
|
||||
echo "Success: Test with batch=$batch, m=$m, n=$n, k=$k executed successfully."
|
||||
else
|
||||
echo "Error: Test with batch=$batch, m=$m, n=$n, k=$k failed to execute properly."
|
||||
# Optionally, exit or break if you need to halt further execution
|
||||
# exit 1
|
||||
fi
|
||||
$EXE -m=$m -n=$n -k=$k -stride_a=0 -stride_b=0 -stride_c=0 -prec=$1 $COMMON_ARGS
|
||||
if [ $? -eq 0 ]; then
|
||||
echo "Success: Test with batch=$batch, m=$m, n=$n, k=$k executed successfully."
|
||||
else
|
||||
echo "Error: Test with batch=$batch, m=$m, n=$n, k=$k failed to execute properly."
|
||||
# Optionally, exit or break if you need to halt further execution
|
||||
# exit 1
|
||||
fi
|
||||
|
||||
done
|
||||
done
|
||||
done
|
||||
done
|
||||
@@ -30,6 +28,9 @@ run_fp16_tests() {
|
||||
|
||||
set -x
|
||||
|
||||
run_fp16_tests
|
||||
run_tests "fp16"
|
||||
run_tests "bf16"
|
||||
run_tests "fp8"
|
||||
run_tests "bf8"
|
||||
|
||||
set +x
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
@@ -12,7 +12,13 @@
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "gemm_basic.hpp"
|
||||
|
||||
template <typename ALayout, typename BLayout, typename CLayout>
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout>
|
||||
float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
|
||||
{
|
||||
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
|
||||
@@ -33,7 +39,7 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
|
||||
// Compute friendly for Intrawave scheduler
|
||||
constexpr ck_tile::index_t M_Tile = 256;
|
||||
constexpr ck_tile::index_t N_Tile = 256;
|
||||
constexpr ck_tile::index_t K_Tile = 32;
|
||||
constexpr ck_tile::index_t K_Tile = 64;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp = 2;
|
||||
constexpr ck_tile::index_t N_Warp = 2;
|
||||
@@ -50,7 +56,9 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
|
||||
|
||||
constexpr bool TransposeC = false;
|
||||
|
||||
constexpr int kBlockPerCu = 1;
|
||||
constexpr int kBlockPerCu = 1;
|
||||
constexpr ck_tile::index_t TileParitionerGroupNum = 8;
|
||||
constexpr ck_tile::index_t TileParitionerM01 = 4;
|
||||
|
||||
// ===============================================
|
||||
|
||||
@@ -58,10 +66,8 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
|
||||
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
|
||||
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
|
||||
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
|
||||
using TilePartitioner = ck_tile::GemmTile2DPartitioner<GemmShape>;
|
||||
|
||||
using GemmEpilogue = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadM, kPadN>>;
|
||||
using TilePartitioner = ck_tile::
|
||||
GemmSpatiallyLocalTilePartitioner<GemmShape, TileParitionerGroupNum, TileParitionerM01>;
|
||||
|
||||
using Traits = ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
|
||||
using GemmUniversalTraits = ck_tile::
|
||||
@@ -95,6 +101,19 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
|
||||
|
||||
using GemmPipeline =
|
||||
GEMM_PIPELINE<UniversalGemmProblem, ck_tile::UniversalGemmPipelineAgBgCrPolicy>;
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<AccDataType,
|
||||
CDataType,
|
||||
CLayout,
|
||||
GemmPipelineProblem::kBlockSize,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
M_Warp,
|
||||
N_Warp,
|
||||
M_Warp_Tile,
|
||||
N_Warp_Tile,
|
||||
K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC>>;
|
||||
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
@@ -230,24 +249,101 @@ int run_gemm_example(int argc, char* argv[])
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
std::string a_layout = arg_parser.get_str("a_layout");
|
||||
std::string b_layout = arg_parser.get_str("b_layout");
|
||||
std::string data_type = arg_parser.get_str("prec");
|
||||
std::string a_layout = arg_parser.get_str("a_layout");
|
||||
std::string b_layout = arg_parser.get_str("b_layout");
|
||||
|
||||
if(a_layout == "R" && b_layout == "R")
|
||||
{
|
||||
return run_gemm_example_with_layouts(argc, argv, Row{}, Row{}, Row{});
|
||||
if(data_type == "fp16")
|
||||
{
|
||||
return run_gemm_example_with_layouts<ck_tile::half_t>(argc, argv, Row{}, Row{}, Row{});
|
||||
}
|
||||
else if(data_type == "bf16")
|
||||
{
|
||||
return run_gemm_example_with_layouts<ck_tile::bf16_t>(argc, argv, Row{}, Row{}, Row{});
|
||||
}
|
||||
else if(data_type == "fp8")
|
||||
{
|
||||
return run_gemm_example_with_layouts<ck_tile::fp8_t>(argc, argv, Row{}, Row{}, Row{});
|
||||
}
|
||||
else if(data_type == "bf8")
|
||||
{
|
||||
return run_gemm_example_with_layouts<ck_tile::bf8_t>(argc, argv, Row{}, Row{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data_type!");
|
||||
}
|
||||
}
|
||||
else if(a_layout == "R" && b_layout == "C")
|
||||
{
|
||||
return run_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{});
|
||||
if(data_type == "fp16")
|
||||
{
|
||||
return run_gemm_example_with_layouts<ck_tile::half_t>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(data_type == "bf16")
|
||||
{
|
||||
return run_gemm_example_with_layouts<ck_tile::bf16_t>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(data_type == "fp8")
|
||||
{
|
||||
return run_gemm_example_with_layouts<ck_tile::fp8_t>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(data_type == "bf8")
|
||||
{
|
||||
return run_gemm_example_with_layouts<ck_tile::bf8_t>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data_type!");
|
||||
}
|
||||
}
|
||||
else if(a_layout == "C" && b_layout == "C")
|
||||
{
|
||||
return run_gemm_example_with_layouts(argc, argv, Col{}, Col{}, Row{});
|
||||
if(data_type == "fp16")
|
||||
{
|
||||
return run_gemm_example_with_layouts<ck_tile::half_t>(argc, argv, Col{}, Col{}, Row{});
|
||||
}
|
||||
else if(data_type == "bf16")
|
||||
{
|
||||
return run_gemm_example_with_layouts<ck_tile::bf16_t>(argc, argv, Col{}, Col{}, Row{});
|
||||
}
|
||||
else if(data_type == "fp8")
|
||||
{
|
||||
return run_gemm_example_with_layouts<ck_tile::fp8_t>(argc, argv, Col{}, Col{}, Row{});
|
||||
}
|
||||
else if(data_type == "bf8")
|
||||
{
|
||||
return run_gemm_example_with_layouts<ck_tile::bf8_t>(argc, argv, Col{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data_type!");
|
||||
}
|
||||
}
|
||||
else if(a_layout == "C" && b_layout == "R")
|
||||
{
|
||||
return run_gemm_example_with_layouts(argc, argv, Col{}, Row{}, Row{});
|
||||
if(data_type == "fp16")
|
||||
{
|
||||
return run_gemm_example_with_layouts<ck_tile::half_t>(argc, argv, Col{}, Row{}, Row{});
|
||||
}
|
||||
else if(data_type == "bf16")
|
||||
{
|
||||
return run_gemm_example_with_layouts<ck_tile::bf16_t>(argc, argv, Col{}, Row{}, Row{});
|
||||
}
|
||||
else if(data_type == "fp8")
|
||||
{
|
||||
return run_gemm_example_with_layouts<ck_tile::fp8_t>(argc, argv, Col{}, Row{}, Row{});
|
||||
}
|
||||
else if(data_type == "bf8")
|
||||
{
|
||||
return run_gemm_example_with_layouts<ck_tile::bf8_t>(argc, argv, Col{}, Row{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data_type!");
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
@@ -19,12 +19,9 @@ template <typename ALayout, typename BLayout, typename CLayout>
|
||||
float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stream_config& s)
|
||||
{
|
||||
// The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part.
|
||||
constexpr bool kPadM = false;
|
||||
constexpr bool kPadN = false;
|
||||
constexpr bool kPadK = false;
|
||||
constexpr bool kTilePermute = false;
|
||||
// The rank and permutation will also be generate out by the CodeGen part.
|
||||
constexpr ck_tile::index_t kOutputRank = 2;
|
||||
constexpr bool kPadM = false;
|
||||
constexpr bool kPadN = false;
|
||||
constexpr bool kPadK = false;
|
||||
|
||||
constexpr int kBlockPerCu = 1;
|
||||
|
||||
@@ -41,38 +38,31 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre
|
||||
constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
constexpr ck_tile::index_t K_Warp_Tile = 8;
|
||||
|
||||
// Whether doing the CShuffle (transpose before the global memory), depending on the output
|
||||
// layout.
|
||||
constexpr bool CShuffleEpilogue =
|
||||
std::is_same_v<CLayout, ck_tile::tensor_layout::gemm::ColumnMajor>;
|
||||
|
||||
using CodegenGemmShape =
|
||||
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
|
||||
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
|
||||
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
|
||||
|
||||
using TilePartitioner = ck_tile::GemmTile2DPartitioner<CodegenGemmShape>;
|
||||
|
||||
using GemmEpilogue = std::conditional_t<
|
||||
CShuffleEpilogue,
|
||||
ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<AccDataType,
|
||||
CDataType,
|
||||
kPadM,
|
||||
kPadN,
|
||||
kTilePermute,
|
||||
kOutputRank,
|
||||
1,
|
||||
0,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock>>,
|
||||
ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadM, kPadN>>>;
|
||||
using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenGemmShape>;
|
||||
|
||||
using CodegenGemmTraits =
|
||||
ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
|
||||
using CodegenPipelineProblem = ck_tile::
|
||||
GemmPipelineProblem<ADataType, BDataType, AccDataType, CodegenGemmShape, CodegenGemmTraits>;
|
||||
using CodegenGemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<AccDataType,
|
||||
CDataType,
|
||||
CLayout,
|
||||
CodegenPipelineProblem::kBlockSize,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
M_Warp,
|
||||
N_Warp,
|
||||
M_Warp_Tile,
|
||||
N_Warp_Tile,
|
||||
K_Warp_Tile,
|
||||
CodegenPipelineProblem::TransposeC>>;
|
||||
// ToDo: Will add the codegen part to test different pipeline policies in GEMM.
|
||||
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
|
||||
using Kernel = ck_tile::BatchedGemmKernel<TilePartitioner, CodegenGemmPipeline, GemmEpilogue>;
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
@@ -20,12 +20,9 @@ namespace {
|
||||
|
||||
struct GroupedGemmKernelParam
|
||||
{
|
||||
static const bool kPadM = false;
|
||||
static const bool kPadN = false;
|
||||
static const bool kPadK = false;
|
||||
static const bool kTilePermute = false;
|
||||
|
||||
static const ck_tile::index_t kOutputRank = 2;
|
||||
static const bool kPadM = false;
|
||||
static const bool kPadN = false;
|
||||
static const bool kPadK = false;
|
||||
|
||||
static const int kBlockPerCu = 1;
|
||||
static const ck_tile::index_t M_Tile = 128;
|
||||
@@ -54,24 +51,6 @@ using CodegenGemmShape =
|
||||
|
||||
using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenGemmShape>;
|
||||
|
||||
template <typename CLayout>
|
||||
using GemmEpilogue = std::conditional_t<
|
||||
std::is_same_v<CLayout, ck_tile::tensor_layout::gemm::ColumnMajor>,
|
||||
ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<AccDataType,
|
||||
CDataType,
|
||||
GroupedGemmKernelParam::kPadM,
|
||||
GroupedGemmKernelParam::kPadN,
|
||||
GroupedGemmKernelParam::kTilePermute,
|
||||
GroupedGemmKernelParam::kOutputRank,
|
||||
1,
|
||||
0,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock>>,
|
||||
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<AccDataType,
|
||||
CDataType,
|
||||
GroupedGemmKernelParam::kPadM,
|
||||
GroupedGemmKernelParam::kPadN>>>;
|
||||
|
||||
template <typename ALayout, typename BLayout, typename CLayout>
|
||||
using CodegenGemmTraits = ck_tile::TileGemmTraits<GroupedGemmKernelParam::kPadM,
|
||||
GroupedGemmKernelParam::kPadN,
|
||||
@@ -92,10 +71,25 @@ template <typename ALayout, typename BLayout, typename CLayout>
|
||||
using CodegenGemmPipeline =
|
||||
ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem<ALayout, BLayout, CLayout>>;
|
||||
|
||||
template <typename ALayout, typename BLayout, typename CLayout>
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
|
||||
AccDataType,
|
||||
CDataType,
|
||||
CLayout,
|
||||
CodegenPipelineProblem<ALayout, BLayout, CLayout>::kBlockSize,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GroupedGemmKernelParam::M_Warp,
|
||||
GroupedGemmKernelParam::N_Warp,
|
||||
GroupedGemmKernelParam::M_Warp_Tile,
|
||||
GroupedGemmKernelParam::N_Warp_Tile,
|
||||
GroupedGemmKernelParam::K_Warp_Tile,
|
||||
CodegenPipelineProblem<ALayout, BLayout, CLayout>::TransposeC>>;
|
||||
|
||||
template <typename ALayout, typename BLayout, typename CLayout>
|
||||
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner,
|
||||
CodegenGemmPipeline<ALayout, BLayout, CLayout>,
|
||||
GemmEpilogue<CLayout>>;
|
||||
GemmEpilogue<ALayout, BLayout, CLayout>>;
|
||||
}; // namespace
|
||||
|
||||
std::size_t get_workspace_size(const std::vector<grouped_gemm_kargs>& gemm_descs)
|
||||
|
||||
@@ -1,19 +1,23 @@
|
||||
[Back to the main page](../../README.md)
|
||||
# Composable Kernel supported operations
|
||||
## Supported device operations
|
||||
* [Average pooling]()
|
||||
* [Batched contraction]()
|
||||
* [Batched gemm]()
|
||||
* [Batchnorm]()
|
||||
* [CGEMM]()
|
||||
* [Contraction]()
|
||||
* [Convolution]()
|
||||
* [Image to Column and Column to Image]()
|
||||
* [Elementwise]()
|
||||
* [GEMM]()
|
||||
* [Max pooling]()
|
||||
* [Reduce]()
|
||||
* [Normalization]()
|
||||
* [Permute]()
|
||||
* [Put]()
|
||||
* [Softmax]()
|
||||
<!-- * [Average pooling](../../docs/markdown/tensor_operation/average_pooling.md) -->
|
||||
<!-- * [Batched contraction](../../docs/markdown/tensor_operation/batched_contraction.md) -->
|
||||
<!-- * [Batched gemm](../../docs/markdown/tensor_operation/batched_gemm.md) -->
|
||||
<!-- * [Batchnorm](../../docs/markdown/tensor_operation/batchnorm.md) -->
|
||||
<!-- * [CGEMM](../../docs/markdown/tensor_operation/cgemm.md) -->
|
||||
<!-- * [Contraction](../../docs/markdown/tensor_operation/contraction.md) -->
|
||||
<!-- * [Convolution](../../docs/markdown/tensor_operation/convolution.md) -->
|
||||
<!-- * [Elementwise](../../docs/markdown/tensor_operation/elementwise.md) -->
|
||||
* [GEMM](../../client_example/01_gemm/README.md)
|
||||
* [Grouped Convolution Forward](../../client_example/07_grouped_convnd_fwd/README.md)
|
||||
* [Grouped Convolution Backward Data](../../client_example/10_grouped_convnd_bwd_data/README.md)
|
||||
* [Grouped Convolution Backward Weight](../../client_example/11_grouped_conv_bwd_weight/README.md)
|
||||
<!-- * [Grouped GEMM](../../docs/markdown/tensor_operation/grouped_gemm.md) -->
|
||||
<!-- * [Image to Column and Column to Image](../../docs/markdown/tensor_operation/img2col.md) -->
|
||||
<!-- * [Max pooling](../../docs/markdown/tensor_operation/max_pooling.md) -->
|
||||
<!-- * [Reduce](../../docs/markdown/tensor_operation/reduce.md) -->
|
||||
<!-- * [Normalization](../../docs/markdown/tensor_operation/normalization.md) -->
|
||||
<!-- * [Permute](../../docs/markdown/tensor_operation/permute.md) -->
|
||||
<!-- * [Put](../../docs/markdown/tensor_operation/put.md) -->
|
||||
<!-- * [Softmax](../../docs/markdown/tensor_operation/softmax.md) -->
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
|
||||
#include "ck/config.h"
|
||||
#include "ck/utility/env.hpp"
|
||||
|
||||
#ifndef CK_CODE_GEN_RTC
|
||||
#ifndef CK_DONT_USE_HIP_RUNTIME_HEADERS
|
||||
#include "hip/hip_runtime.h"
|
||||
#include "hip/hip_fp16.h"
|
||||
@@ -14,7 +14,7 @@
|
||||
// environment variable to enable logging:
|
||||
// export CK_LOGGING=ON or CK_LOGGING=1 or CK_LOGGING=ENABLED
|
||||
CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING)
|
||||
|
||||
#endif
|
||||
// to do: add various levels of logging with CK_LOG_LEVEL
|
||||
|
||||
#ifndef CK_TIME_KERNEL
|
||||
@@ -55,10 +55,10 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING)
|
||||
|
||||
// define general macros for various architectures
|
||||
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \
|
||||
defined(__gfx942__)
|
||||
defined(__gfx942__) || defined(__gfx950__)
|
||||
#define __gfx9__
|
||||
#endif
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) || defined(__gfx950__)
|
||||
#define __gfx94__
|
||||
#endif
|
||||
#if defined(__gfx1010__) || defined(__gfx1011__) || defined(__gfx1012__)
|
||||
@@ -163,6 +163,16 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING)
|
||||
// set rounding to nearest even as default for f8 conversions
|
||||
#define CK_USE_SR_F8_CONVERSION 0
|
||||
|
||||
// set rounding to nearest even as default for f6 conversions
|
||||
#define CK_USE_SR_F6_CONVERSION 0
|
||||
|
||||
// set rounding to nearest even as default for f4 conversions
|
||||
#define CK_USE_SR_F4_CONVERSION 0
|
||||
|
||||
// shuffle pk_i4 values during conversion to optimize number of binary
|
||||
// operations
|
||||
#define CK_USE_PK4_LAYOUT_SHUFFLE 1
|
||||
|
||||
// block synchronization only s_wait lgkmcnt(0), not vmcnt(0)
|
||||
#define CK_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM 1
|
||||
|
||||
|
||||
@@ -131,6 +131,10 @@
|
||||
#cmakedefine CK_USE_FP8_ON_UNSUPPORTED_ARCH @CK_USE_FP8_ON_UNSUPPORTED_ARCH@
|
||||
#endif
|
||||
|
||||
#ifndef CK_USE_NATIVE_MX_SUPPORT
|
||||
#cmakedefine CK_USE_NATIVE_MX_SUPPORT @CK_USE_NATIVE_MX_SUPPORT@
|
||||
#endif
|
||||
|
||||
// clang-format on
|
||||
|
||||
#endif // CK_CONFIG_H_IN
|
||||
|
||||
@@ -55,20 +55,21 @@ inline bool is_xdl_supported()
|
||||
{
|
||||
return ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
|
||||
ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" ||
|
||||
ck::get_device_name() == "gfx942";
|
||||
ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950";
|
||||
}
|
||||
|
||||
inline bool is_lds_direct_load_supported()
|
||||
{
|
||||
// Check if direct loads from global memory to LDS are supported.
|
||||
return ck::get_device_name() == "gfx90a" || ck::get_device_name() == "gfx940" ||
|
||||
ck::get_device_name() == "gfx941" || ck::get_device_name() == "gfx942";
|
||||
ck::get_device_name() == "gfx941" || ck::get_device_name() == "gfx942" ||
|
||||
ck::get_device_name() == "gfx950";
|
||||
}
|
||||
|
||||
inline bool is_bf16_atomic_supported()
|
||||
{
|
||||
return ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" ||
|
||||
ck::get_device_name() == "gfx942";
|
||||
ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950";
|
||||
}
|
||||
|
||||
inline bool is_gfx101_supported()
|
||||
|
||||
@@ -26,6 +26,7 @@ namespace utils {
|
||||
template <typename ComputeDataType, typename OutDataType, typename AccDataType = ComputeDataType>
|
||||
double get_relative_threshold(const int number_of_accumulations = 1)
|
||||
{
|
||||
using F4 = ck::f4_t;
|
||||
using F8 = ck::f8_t;
|
||||
using F16 = ck::half_t;
|
||||
using BF16 = ck::bhalf_t;
|
||||
@@ -33,10 +34,10 @@ double get_relative_threshold(const int number_of_accumulations = 1)
|
||||
using I8 = int8_t;
|
||||
using I32 = int32_t;
|
||||
|
||||
static_assert(is_same_v<ComputeDataType, F8> || is_same_v<ComputeDataType, F16> ||
|
||||
is_same_v<ComputeDataType, BF16> || is_same_v<ComputeDataType, F32> ||
|
||||
is_same_v<ComputeDataType, I8> || is_same_v<ComputeDataType, I32> ||
|
||||
is_same_v<ComputeDataType, int>,
|
||||
static_assert(is_same_v<ComputeDataType, F4> || is_same_v<ComputeDataType, F8> ||
|
||||
is_same_v<ComputeDataType, F16> || is_same_v<ComputeDataType, BF16> ||
|
||||
is_same_v<ComputeDataType, F32> || is_same_v<ComputeDataType, I8> ||
|
||||
is_same_v<ComputeDataType, I32> || is_same_v<ComputeDataType, int>,
|
||||
"Warning: Unhandled ComputeDataType for setting up the relative threshold!");
|
||||
double compute_error = 0;
|
||||
if constexpr(is_same_v<ComputeDataType, I8> || is_same_v<ComputeDataType, I32> ||
|
||||
@@ -49,10 +50,10 @@ double get_relative_threshold(const int number_of_accumulations = 1)
|
||||
compute_error = std::pow(2, -NumericUtils<ComputeDataType>::mant) * 0.5;
|
||||
}
|
||||
|
||||
static_assert(is_same_v<OutDataType, F8> || is_same_v<OutDataType, F16> ||
|
||||
is_same_v<OutDataType, BF16> || is_same_v<OutDataType, F32> ||
|
||||
is_same_v<OutDataType, I8> || is_same_v<OutDataType, I32> ||
|
||||
is_same_v<OutDataType, int>,
|
||||
static_assert(is_same_v<OutDataType, F4> || is_same_v<OutDataType, F8> ||
|
||||
is_same_v<OutDataType, F16> || is_same_v<OutDataType, BF16> ||
|
||||
is_same_v<OutDataType, F32> || is_same_v<OutDataType, I8> ||
|
||||
is_same_v<OutDataType, I32> || is_same_v<OutDataType, int>,
|
||||
"Warning: Unhandled OutDataType for setting up the relative threshold!");
|
||||
double output_error = 0;
|
||||
if constexpr(is_same_v<OutDataType, I8> || is_same_v<OutDataType, I32> ||
|
||||
@@ -66,10 +67,10 @@ double get_relative_threshold(const int number_of_accumulations = 1)
|
||||
}
|
||||
double midway_error = std::max(compute_error, output_error);
|
||||
|
||||
static_assert(is_same_v<AccDataType, F8> || is_same_v<AccDataType, F16> ||
|
||||
is_same_v<AccDataType, BF16> || is_same_v<AccDataType, F32> ||
|
||||
is_same_v<AccDataType, I8> || is_same_v<AccDataType, I32> ||
|
||||
is_same_v<AccDataType, int>,
|
||||
static_assert(is_same_v<AccDataType, F4> || is_same_v<AccDataType, F8> ||
|
||||
is_same_v<AccDataType, F16> || is_same_v<AccDataType, BF16> ||
|
||||
is_same_v<AccDataType, F32> || is_same_v<AccDataType, I8> ||
|
||||
is_same_v<AccDataType, I32> || is_same_v<AccDataType, int>,
|
||||
"Warning: Unhandled AccDataType for setting up the relative threshold!");
|
||||
double acc_error = 0;
|
||||
if constexpr(is_same_v<AccDataType, I8> || is_same_v<AccDataType, I32> ||
|
||||
@@ -87,6 +88,7 @@ double get_relative_threshold(const int number_of_accumulations = 1)
|
||||
template <typename ComputeDataType, typename OutDataType, typename AccDataType = ComputeDataType>
|
||||
double get_absolute_threshold(const double max_possible_num, const int number_of_accumulations = 1)
|
||||
{
|
||||
using F4 = ck::f4_t;
|
||||
using F8 = ck::f8_t;
|
||||
using F16 = ck::half_t;
|
||||
using BF16 = ck::bhalf_t;
|
||||
@@ -94,10 +96,10 @@ double get_absolute_threshold(const double max_possible_num, const int number_of
|
||||
using I8 = int8_t;
|
||||
using I32 = int32_t;
|
||||
|
||||
static_assert(is_same_v<ComputeDataType, F8> || is_same_v<ComputeDataType, F16> ||
|
||||
is_same_v<ComputeDataType, BF16> || is_same_v<ComputeDataType, F32> ||
|
||||
is_same_v<ComputeDataType, I8> || is_same_v<ComputeDataType, I32> ||
|
||||
is_same_v<ComputeDataType, int>,
|
||||
static_assert(is_same_v<ComputeDataType, F4> || is_same_v<ComputeDataType, F8> ||
|
||||
is_same_v<ComputeDataType, F16> || is_same_v<ComputeDataType, BF16> ||
|
||||
is_same_v<ComputeDataType, F32> || is_same_v<ComputeDataType, I8> ||
|
||||
is_same_v<ComputeDataType, I32> || is_same_v<ComputeDataType, int>,
|
||||
"Warning: Unhandled ComputeDataType for setting up the absolute threshold!");
|
||||
auto expo = std::log2(std::abs(max_possible_num));
|
||||
double compute_error = 0;
|
||||
@@ -111,10 +113,10 @@ double get_absolute_threshold(const double max_possible_num, const int number_of
|
||||
compute_error = std::pow(2, expo - NumericUtils<ComputeDataType>::mant) * 0.5;
|
||||
}
|
||||
|
||||
static_assert(is_same_v<OutDataType, F8> || is_same_v<OutDataType, F16> ||
|
||||
is_same_v<OutDataType, BF16> || is_same_v<OutDataType, F32> ||
|
||||
is_same_v<OutDataType, I8> || is_same_v<OutDataType, I32> ||
|
||||
is_same_v<OutDataType, int>,
|
||||
static_assert(is_same_v<OutDataType, F4> || is_same_v<OutDataType, F8> ||
|
||||
is_same_v<OutDataType, F16> || is_same_v<OutDataType, BF16> ||
|
||||
is_same_v<OutDataType, F32> || is_same_v<OutDataType, I8> ||
|
||||
is_same_v<OutDataType, I32> || is_same_v<OutDataType, int>,
|
||||
"Warning: Unhandled OutDataType for setting up the absolute threshold!");
|
||||
double output_error = 0;
|
||||
if constexpr(is_same_v<OutDataType, I8> || is_same_v<OutDataType, I32> ||
|
||||
@@ -128,10 +130,10 @@ double get_absolute_threshold(const double max_possible_num, const int number_of
|
||||
}
|
||||
double midway_error = std::max(compute_error, output_error);
|
||||
|
||||
static_assert(is_same_v<AccDataType, F8> || is_same_v<AccDataType, F16> ||
|
||||
is_same_v<AccDataType, BF16> || is_same_v<AccDataType, F32> ||
|
||||
is_same_v<AccDataType, I8> || is_same_v<AccDataType, I32> ||
|
||||
is_same_v<AccDataType, int>,
|
||||
static_assert(is_same_v<AccDataType, F4> || is_same_v<AccDataType, F8> ||
|
||||
is_same_v<AccDataType, F16> || is_same_v<AccDataType, BF16> ||
|
||||
is_same_v<AccDataType, F32> || is_same_v<AccDataType, I8> ||
|
||||
is_same_v<AccDataType, I32> || is_same_v<AccDataType, int>,
|
||||
"Warning: Unhandled AccDataType for setting up the absolute threshold!");
|
||||
double acc_error = 0;
|
||||
if constexpr(is_same_v<AccDataType, I8> || is_same_v<AccDataType, I32> ||
|
||||
@@ -450,5 +452,54 @@ check_err(const Range& out,
|
||||
return res;
|
||||
}
|
||||
|
||||
template <typename Range, typename RefRange>
|
||||
std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
|
||||
std::is_same_v<ranges::range_value_t<Range>, f4_t>),
|
||||
bool>
|
||||
check_err(const Range& out,
|
||||
const RefRange& ref,
|
||||
const std::string& msg = "Error: Incorrect results!",
|
||||
double rtol = 0.5,
|
||||
double atol = 0.5)
|
||||
{
|
||||
if(out.size() != ref.size())
|
||||
{
|
||||
std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size()
|
||||
<< std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
bool res{true};
|
||||
int err_count = 0;
|
||||
double err = 0;
|
||||
double max_err = std::numeric_limits<float>::min();
|
||||
|
||||
for(std::size_t i = 0; i < ref.size(); ++i)
|
||||
{
|
||||
const double o = type_convert<float>(*std::next(std::begin(out), i));
|
||||
const double r = type_convert<float>(*std::next(std::begin(ref), i));
|
||||
err = std::abs(o - r);
|
||||
|
||||
if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r))
|
||||
{
|
||||
max_err = err > max_err ? err : max_err;
|
||||
err_count++;
|
||||
if(err_count < 5)
|
||||
{
|
||||
std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
|
||||
<< "] != ref[" << i << "]: " << o << " != " << r << std::endl;
|
||||
}
|
||||
res = false;
|
||||
}
|
||||
}
|
||||
|
||||
if(!res)
|
||||
{
|
||||
std::cerr << std::setw(12) << std::setprecision(7) << "max err: " << max_err
|
||||
<< " number of errors: " << err_count << std::endl;
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
} // namespace utils
|
||||
} // namespace ck
|
||||
|
||||
@@ -69,6 +69,18 @@ struct GeneratorTensor_1<ck::f8_t>
|
||||
};
|
||||
#endif
|
||||
|
||||
template <>
|
||||
struct GeneratorTensor_1<ck::f4_t>
|
||||
{
|
||||
float value = 1.0;
|
||||
|
||||
template <typename... Is>
|
||||
ck::f4_t operator()(Is...)
|
||||
{
|
||||
return ck::type_convert<ck::f4_t>(value);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GeneratorTensor_1<int8_t>
|
||||
{
|
||||
@@ -183,6 +195,20 @@ struct GeneratorTensor_2<ck::bf8_t>
|
||||
};
|
||||
#endif
|
||||
|
||||
template <>
|
||||
struct GeneratorTensor_2<ck::f4_t>
|
||||
{
|
||||
int min_value = 0;
|
||||
int max_value = 1;
|
||||
|
||||
template <typename... Is>
|
||||
ck::f4_t operator()(Is...)
|
||||
{
|
||||
float tmp = (std::rand() % (max_value - min_value)) + min_value;
|
||||
return ck::type_convert<ck::f4_t>(tmp);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct GeneratorTensor_3
|
||||
{
|
||||
@@ -253,6 +279,23 @@ struct GeneratorTensor_3<ck::bf8_t>
|
||||
};
|
||||
#endif
|
||||
|
||||
template <>
|
||||
struct GeneratorTensor_3<ck::f4_t>
|
||||
{
|
||||
float min_value = 0;
|
||||
float max_value = 1;
|
||||
|
||||
template <typename... Is>
|
||||
ck::f4_t operator()(Is...)
|
||||
{
|
||||
float tmp = float(std::rand()) / float(RAND_MAX);
|
||||
|
||||
float fp32_tmp = min_value + tmp * (max_value - min_value);
|
||||
|
||||
return ck::type_convert<ck::f4_t>(fp32_tmp);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct GeneratorTensor_4
|
||||
{
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -131,7 +131,7 @@ struct ThreadGroupTensorSliceTransfer_v7r2
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
using is_tuple = decltype(std::declval<T&>().IsTuple());
|
||||
using is_tuple = decltype(ck::declval<T&>().IsTuple());
|
||||
|
||||
template <typename DstBuffers, index_t ThreadScratchId = 0>
|
||||
__device__ void RunWrite(const DstDescs& dst_descs,
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#ifndef CK_CODE_GEN_RTC
|
||||
#include <string>
|
||||
#endif
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -18,6 +20,7 @@ enum struct ConvolutionForwardSpecialization
|
||||
Filter3x3,
|
||||
};
|
||||
|
||||
#ifndef CK_CODE_GEN_RTC
|
||||
inline std::string getConvForwardSpecializationString(const ConvolutionForwardSpecialization& s)
|
||||
{
|
||||
switch(s)
|
||||
@@ -30,6 +33,7 @@ inline std::string getConvForwardSpecializationString(const ConvolutionForwardSp
|
||||
default: return "Unrecognized specialization!";
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
@@ -1,19 +1,21 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#ifndef CK_CODE_GEN_RTC
|
||||
#include <string>
|
||||
#include <sstream>
|
||||
#include <regex>
|
||||
#include <optional>
|
||||
|
||||
#include "ck/stream_config.hpp"
|
||||
#endif
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
#ifndef CK_CODE_GEN_RTC
|
||||
#define GET_OBJECT_NAME_IMLP \
|
||||
std::optional<std::string> GetObjectName() const override \
|
||||
{ \
|
||||
@@ -41,7 +43,9 @@ namespace device {
|
||||
}
|
||||
|
||||
#define REGISTER_EXTRA_PRINTING_METHODS GET_OBJECT_NAME_IMLP GET_TEMPLATE_INFO_IMPL
|
||||
#endif
|
||||
|
||||
#ifndef CK_CODE_GEN_RTC
|
||||
struct BaseArgument
|
||||
{
|
||||
BaseArgument() = default;
|
||||
@@ -66,13 +70,14 @@ struct BaseInvoker
|
||||
|
||||
virtual ~BaseInvoker() {}
|
||||
};
|
||||
#endif
|
||||
|
||||
struct BaseOperator
|
||||
{
|
||||
BaseOperator() = default;
|
||||
BaseOperator(const BaseOperator&) = default;
|
||||
BaseOperator& operator=(const BaseOperator&) = default;
|
||||
|
||||
#ifndef CK_CODE_GEN_RTC
|
||||
virtual bool IsSupportedArgument(const BaseArgument*) { return false; }
|
||||
virtual std::string GetTypeString() const { return ""; }
|
||||
|
||||
@@ -100,7 +105,7 @@ struct BaseOperator
|
||||
assert(p_arg);
|
||||
p_arg->p_workspace_ = p_workspace;
|
||||
}
|
||||
|
||||
#endif
|
||||
virtual ~BaseOperator() {}
|
||||
};
|
||||
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#ifndef CK_CODE_GEN_RTC
|
||||
#include <array>
|
||||
#endif
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
|
||||
@@ -13,8 +15,13 @@ namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
#ifdef CK_CODE_GEN_RTC
|
||||
template <typename T>
|
||||
using is_tuple = decltype(ck::declval<T&>().IsTuple());
|
||||
#else
|
||||
template <typename T>
|
||||
using is_tuple = decltype(std::declval<T&>().IsTuple());
|
||||
#endif
|
||||
|
||||
/**
|
||||
* \brief Grouped Convolution Forward
|
||||
@@ -72,12 +79,18 @@ struct DeviceGroupedConvFwdMultipleABD : public BaseOperator
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
|
||||
static_assert(NumDTensor == DsLayout::Size(), "wrong! Inconsistent NumDTensor");
|
||||
|
||||
#ifdef CK_CODE_GEN_RTC
|
||||
using APointers = ck::conditional_t<isMultiA, ck::Array<const void*, NumATensor>&, const void*>;
|
||||
using BPointers = ck::conditional_t<isMultiB, ck::Array<const void*, NumBTensor>&, const void*>;
|
||||
#else
|
||||
// If DataType is tuple, user has to pass std::array with pointers.
|
||||
using APointers =
|
||||
std::conditional_t<isMultiA, std::array<const void*, NumATensor>&, const void*>;
|
||||
ck::conditional_t<isMultiA, std::array<const void*, NumATensor>&, const void*>;
|
||||
using BPointers =
|
||||
std::conditional_t<isMultiB, std::array<const void*, NumBTensor>&, const void*>;
|
||||
ck::conditional_t<isMultiB, std::array<const void*, NumBTensor>&, const void*>;
|
||||
#endif
|
||||
|
||||
#ifndef CK_CODE_GEN_RTC
|
||||
|
||||
/**
|
||||
* \brief Make argument pointer for grouped conv fwd.
|
||||
@@ -150,6 +163,7 @@ struct DeviceGroupedConvFwdMultipleABD : public BaseOperator
|
||||
const CDEElementwiseOperation& cde_element_op) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
#endif
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -29,6 +29,7 @@ enum struct GemmSpecialization
|
||||
MNKOPadding,
|
||||
};
|
||||
|
||||
#ifndef CK_CODE_GEN_RTC
|
||||
inline std::string getGemmSpecializationString(const GemmSpecialization& s)
|
||||
{
|
||||
switch(s)
|
||||
@@ -52,6 +53,7 @@ inline std::string getGemmSpecializationString(const GemmSpecialization& s)
|
||||
default: return "Unrecognized specialization!";
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
@@ -3,11 +3,17 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#ifndef CK_CODE_GEN_RTC
|
||||
#include <functional>
|
||||
#include <iostream>
|
||||
#include <iterator>
|
||||
#include <numeric>
|
||||
#include <sstream>
|
||||
#include <stdio.h>
|
||||
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
#endif
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
@@ -15,15 +21,12 @@
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp"
|
||||
#include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
#include "ck/host_utility/io.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -91,8 +94,7 @@ __device__ void device_grouped_conv_fwd_multiple_abd_xdl_cshuffle(
|
||||
const Block2ETileMap block_2_ctile_map,
|
||||
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx94__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
|
||||
// offset base pointer for each work-group
|
||||
const index_t num_blocks_per_batch =
|
||||
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
|
||||
@@ -259,8 +261,13 @@ __global__ void
|
||||
|
||||
} // namespace
|
||||
|
||||
#ifdef CK_CODE_GEN_RTC
|
||||
template <typename T>
|
||||
using is_tuple = decltype(ck::declval<T&>().IsTuple());
|
||||
#else
|
||||
template <typename T>
|
||||
using is_tuple = decltype(std::declval<T&>().IsTuple());
|
||||
#endif
|
||||
|
||||
//
|
||||
// @brief Device Convolution operation.
|
||||
@@ -429,8 +436,8 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
|
||||
// If we are using multiAB and one of the template datatype parameters is not a tuple, convert
|
||||
// it to it
|
||||
using GemmADataType = std::conditional_t<!isMultiA && isMultiB, Tuple<ADataType>, ADataType>;
|
||||
using GemmBDataType = std::conditional_t<!isMultiB && isMultiA, Tuple<BDataType>, BDataType>;
|
||||
using GemmADataType = ck::conditional_t<!isMultiA && isMultiB, Tuple<ADataType>, ADataType>;
|
||||
using GemmBDataType = ck::conditional_t<!isMultiB && isMultiA, Tuple<BDataType>, BDataType>;
|
||||
|
||||
#define GridwiseGemmTemplateParameters \
|
||||
GemmADataType, GemmBDataType, ComputeDataType, AccDataType, CShuffleDataType, DsDataType, \
|
||||
@@ -449,15 +456,13 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
CDEBlockTransferScalarPerVector_NPerBlock, LoopSched
|
||||
// Use appropriate gridwise gemm
|
||||
using GridwiseGemm =
|
||||
std::conditional_t<isMultiA || isMultiB,
|
||||
GridwiseGemmMultipleABD_xdl_cshuffle<GridwiseGemmTemplateParameters>,
|
||||
GridwiseGemmMultipleD_xdl_cshuffle<GridwiseGemmTemplateParameters>>;
|
||||
ck::conditional_t<isMultiA || isMultiB,
|
||||
GridwiseGemmMultipleABD_xdl_cshuffle<GridwiseGemmTemplateParameters>,
|
||||
GridwiseGemmMultipleD_xdl_cshuffle<GridwiseGemmTemplateParameters>>;
|
||||
|
||||
// If ADataTypes or BDataTypes is tuple, user has to pass ck::Array with pointers.
|
||||
using APointers =
|
||||
std::conditional_t<isMultiA, ck::Array<const void*, NumATensor>&, const void*>;
|
||||
using BPointers =
|
||||
std::conditional_t<isMultiB, ck::Array<const void*, NumBTensor>&, const void*>;
|
||||
using APointers = ck::conditional_t<isMultiA, ck::Array<const void*, NumATensor>&, const void*>;
|
||||
using BPointers = ck::conditional_t<isMultiB, ck::Array<const void*, NumBTensor>&, const void*>;
|
||||
// Use Tuple for the both cases for GridPointer to initialize it in Argument constructor (not
|
||||
// in initializer list what is required for single const pointer).
|
||||
using AGridPointer = remove_cvref_t<
|
||||
@@ -812,7 +817,6 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
|
||||
|
||||
// FIXME: layout
|
||||
if constexpr(is_same_v<DLayout, ctc::G_NW_K> || is_same_v<DLayout, ctc::G_NHW_K> ||
|
||||
is_same_v<DLayout, ctc::G_NDHW_K> || is_same_v<DLayout, ctc::GNWK> ||
|
||||
@@ -965,18 +969,18 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
const BElementwiseOperation& b_element_op,
|
||||
const CDEElementwiseOperation& cde_element_op)
|
||||
{
|
||||
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_i32;
|
||||
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_strides_i32;
|
||||
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_i32;
|
||||
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_i32;
|
||||
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_lengths_i32;
|
||||
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_strides_i32;
|
||||
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_i32;
|
||||
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_strides_i32;
|
||||
std::array<index_t, NDimSpatial> conv_filter_strides_i32;
|
||||
std::array<index_t, NDimSpatial> conv_filter_dilations_i32;
|
||||
std::array<index_t, NDimSpatial> input_left_pads_i32;
|
||||
std::array<index_t, NDimSpatial> input_right_pads_i32;
|
||||
ck::Array<index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_i32;
|
||||
ck::Array<index_t, NDimSpatial + 3> a_g_n_c_wis_strides_i32;
|
||||
ck::Array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_i32;
|
||||
ck::Array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_i32;
|
||||
ck::Array<ck::Array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_lengths_i32;
|
||||
ck::Array<ck::Array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_strides_i32;
|
||||
ck::Array<index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_i32;
|
||||
ck::Array<index_t, NDimSpatial + 3> e_g_n_k_wos_strides_i32;
|
||||
ck::Array<index_t, NDimSpatial> conv_filter_strides_i32;
|
||||
ck::Array<index_t, NDimSpatial> conv_filter_dilations_i32;
|
||||
ck::Array<index_t, NDimSpatial> input_left_pads_i32;
|
||||
ck::Array<index_t, NDimSpatial> input_right_pads_i32;
|
||||
|
||||
array_convert(a_g_n_c_wis_lengths_i32, a_g_n_c_wis_lengths);
|
||||
array_convert(a_g_n_c_wis_strides_i32, a_g_n_c_wis_strides);
|
||||
|
||||
@@ -56,8 +56,7 @@ __global__ void
|
||||
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
|
||||
const Block2ETileMap block_2_etile_map)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx94__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
const index_t num_blocks_per_batch =
|
||||
|
||||
@@ -74,8 +74,7 @@ __global__ void
|
||||
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
|
||||
const Block2ETileMap block_2_etile_map)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx94__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
|
||||
const index_t num_blocks_per_batch =
|
||||
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
|
||||
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
|
||||
|
||||
@@ -60,8 +60,7 @@ __global__ void
|
||||
const index_t batch_count,
|
||||
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx94__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
const index_t num_blocks_per_batch =
|
||||
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
|
||||
@@ -108,7 +107,7 @@ __global__ void
|
||||
ignore = block_2_ctile_map;
|
||||
ignore = batch_count;
|
||||
ignore = compute_base_ptr_of_batch;
|
||||
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
|
||||
#endif // end of if (defined(__gfx9__))
|
||||
}
|
||||
|
||||
// Computes C = A * B0 * B1
|
||||
|
||||
@@ -83,8 +83,7 @@ __global__ void
|
||||
const Block2ETileMap block_2_etile_map)
|
||||
{
|
||||
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx94__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
|
||||
const index_t num_blocks_per_batch =
|
||||
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
|
||||
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
|
||||
|
||||
@@ -68,8 +68,7 @@ __global__ void
|
||||
const index_t batch_count,
|
||||
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx94__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
const index_t num_blocks_per_batch =
|
||||
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
|
||||
|
||||
@@ -59,8 +59,7 @@ __global__ void
|
||||
const ComputeBasePrtOfBatch compute_base_ptr_of_batch_,
|
||||
const Block2CTileMap block_2_ctile_map)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx94__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
|
||||
const index_t num_blocks_per_batch =
|
||||
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
|
||||
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
|
||||
|
||||
@@ -67,8 +67,7 @@ __global__ void
|
||||
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch,
|
||||
const C0MatrixMask c0_matrix_mask)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx94__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
const index_t num_blocks_per_batch =
|
||||
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
|
||||
@@ -127,7 +126,7 @@ __global__ void
|
||||
ignore = batch_count;
|
||||
ignore = compute_base_ptr_of_batch;
|
||||
ignore = c0_matrix_mask;
|
||||
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
|
||||
#endif // end of if (defined(__gfx9__))
|
||||
}
|
||||
|
||||
// Computes C = A * B0 * B1
|
||||
|
||||
@@ -62,8 +62,7 @@ __global__ void
|
||||
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch,
|
||||
const C0MatrixMask c0_matrix_mask)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx94__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
const index_t num_blocks_per_batch =
|
||||
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
|
||||
@@ -112,7 +111,7 @@ __global__ void
|
||||
ignore = batch_count;
|
||||
ignore = compute_base_ptr_of_batch;
|
||||
ignore = c0_matrix_mask;
|
||||
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
|
||||
#endif // end of if (defined(__gfx9__))
|
||||
}
|
||||
|
||||
// Computes C = A * B0 * B1
|
||||
|
||||
@@ -52,8 +52,7 @@ __global__ void
|
||||
#endif
|
||||
kernel_batched_gemm_xdlops_v2r3(const typename DeviceOp::Argument karg)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx94__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
|
||||
const index_t num_blocks_per_batch =
|
||||
__builtin_amdgcn_readfirstlane(get_grid_size() / karg.Batch);
|
||||
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/library/utility/numeric.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_conv_tensor_rearrange.hpp"
|
||||
|
||||
@@ -55,8 +55,7 @@ __global__ void
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const Block2ETileMap block_2_etile_map)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx94__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop>(p_as_grid,
|
||||
|
||||
@@ -55,8 +55,7 @@ __global__ void
|
||||
const CElementwiseOperation c_element_op,
|
||||
const Block2CTileMap block_2_ctile_map)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx94__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
|
||||
const index_t num_blocks_per_batch =
|
||||
__builtin_amdgcn_readfirstlane(get_grid_size() / num_batches);
|
||||
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
|
||||
@@ -97,7 +96,7 @@ __global__ void
|
||||
ignore = b_element_op;
|
||||
ignore = c_element_op;
|
||||
ignore = block_2_ctile_map;
|
||||
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
|
||||
#endif // end of if (defined(__gfx9__))
|
||||
}
|
||||
|
||||
// specialization for #D conv: in[n, di, hi, wi, c] * wei[k, z, y, x, c] = out[n, do, ho, wo, k]
|
||||
|
||||
@@ -50,9 +50,8 @@ __global__ void
|
||||
const CGridDesc_M0_M10_M11_N0_N10_N11 e_grid_desc_m0_m10_m11_n0_n10_n11,
|
||||
const Block2CTileMap block_2_ctile_map)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
|
||||
defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__) || \
|
||||
defined(__gfx12__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx9__) || \
|
||||
defined(__gfx103__) || defined(__gfx11__) || defined(__gfx12__))
|
||||
|
||||
constexpr index_t shared_block_size =
|
||||
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(ABDataType);
|
||||
|
||||
@@ -63,8 +63,7 @@ __global__ void
|
||||
const Block2ETileMap block_2_etile_map,
|
||||
index_t NRaw)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx94__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
|
||||
__shared__ char p_shared[GridwiseGemmWelford::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
GridwiseGemmWelford::template Run<HasMainKBlockLoop>(
|
||||
|
||||
@@ -60,8 +60,7 @@ __global__ void
|
||||
const RsGridDescriptor_MBlock_MPerBlock rs_grid_desc_mblock_mperblock,
|
||||
const Block2ETileMap block_2_etile_map)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx94__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
|
||||
|
||||
@@ -52,8 +52,7 @@ __global__ void
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const Block2ETileMap block_2_etile_map)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx94__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
|
||||
|
||||
@@ -138,6 +138,7 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
|
||||
if(stream_config.log_level_ > 0)
|
||||
{
|
||||
arg.Print();
|
||||
GridwiseGemm::BlockwiseGemmPipe::HotLoopInstList::Print();
|
||||
}
|
||||
|
||||
if(!GridwiseGemm::CheckValidity(arg))
|
||||
@@ -733,7 +734,9 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
|
||||
<< "BlkGemmPipelineVersion: "
|
||||
<< BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
|
||||
<< "BlkGemmPipelinePrefetchStages: "
|
||||
<< GridwiseGemm::BlockwiseGemmPipe::PrefetchStages;
|
||||
<< GridwiseGemm::BlockwiseGemmPipe::PrefetchStages << ", "
|
||||
<< "Kpack: "
|
||||
<< GridwiseGemm::BlockwiseGemmPipe::AMmaKStride;
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
|
||||
@@ -205,8 +205,8 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
|
||||
|
||||
const auto b2c_map = DefaultBlock2CTileMap{};
|
||||
index_t gdx, gdy, gdz;
|
||||
std::tie(gdx, gdy, gdz) = b2c_map.CalculateGridSize(karg.M, karg.N, karg.k_batch);
|
||||
const auto K0Padded = karg.K0Padded;
|
||||
ck::tie(gdx, gdy, gdz) = b2c_map.CalculateGridSize(karg.M, karg.N, karg.k_batch);
|
||||
const auto K0Padded = karg.K0Padded;
|
||||
|
||||
const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0Padded);
|
||||
|
||||
|
||||
@@ -183,8 +183,8 @@ struct DeviceGemmXdlSplitKCShuffle_LdsDirectLoad : public DeviceGemmSplitK<ALayo
|
||||
|
||||
const auto b2c_map = DefaultBlock2CTileMap{};
|
||||
index_t gdx, gdy, gdz;
|
||||
std::tie(gdx, gdy, gdz) = b2c_map.CalculateGridSize(karg.M, karg.N, karg.k_batch);
|
||||
const auto K0Padded = karg.K0Padded;
|
||||
ck::tie(gdx, gdy, gdz) = b2c_map.CalculateGridSize(karg.M, karg.N, karg.k_batch);
|
||||
const auto K0Padded = karg.K0Padded;
|
||||
|
||||
const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0Padded);
|
||||
|
||||
|
||||
@@ -47,8 +47,7 @@ __global__ void
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const Block2ETileMap block_2_etile_map)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx94__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
|
||||
|
||||
@@ -37,8 +37,7 @@ __global__ void
|
||||
const BElementwiseOperation b_element_op,
|
||||
const CDEElementwiseOperation cde_element_op)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx94__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
const index_t block_id = get_block_1d_id();
|
||||
|
||||
@@ -87,8 +87,7 @@ __global__ void
|
||||
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
|
||||
const ComputePtrOffsetOfN compute_ptr_offset_of_n)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx94__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
|
||||
// offset base pointer for each work-group
|
||||
const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z);
|
||||
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y);
|
||||
|
||||
@@ -60,8 +60,7 @@ __global__ void
|
||||
const Block2CTileMap block_2_ctile_map,
|
||||
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx94__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
|
||||
const index_t num_blocks_per_batch =
|
||||
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
|
||||
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
|
||||
@@ -103,7 +102,7 @@ __global__ void
|
||||
compute_ptr_offset_of_batch.GetAPtrOffset(0);
|
||||
compute_ptr_offset_of_batch.GetBPtrOffset(0);
|
||||
compute_ptr_offset_of_batch.GetCPtrOffset(0);
|
||||
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
|
||||
#endif // end of if (defined(__gfx9__))
|
||||
}
|
||||
|
||||
template <index_t NDimSpatial,
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user