Update Blockwise and Gridwise files to support both wave32 & wave64.
1. Calculate WaveSize from template parameter, instead of hard code it to 64, some "64" is also replace with WaveSize
2. Move BN0Shuffled and BK0Shuffled to device side. we can't get correct mfma inst info in host side.
3. Update b_thread_offset_n and b_thread_offset_k in gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp for gfx11. in gfx11, input data is duplicated for each 16 threads, it is different with all of others.
4. Modify a1_threadwise_copy in gridwise_batched_*gemm*gemm for gfx11. for gfx11, we need duplicate input and swizzle A if transposeC isn't enabled.
* Adding fix for the gfx908 to the GEMM MFMA implementaitons of WarpGemmMfmaBf16Bf16F32M4N64K16 WarpGemmMfmaBf16Bf16F32M64N4K16
* Adding support for offload target gfx9-4-generic
* This duplication here isn't ideal
* Start adding other layouts for gemm_ab_scale
* Add some instances
* Create tensor descriptors for A/B scales depending on A/B layout
* Fix formatting
* Revert some comments
* Revert commented instances in CMakeLists.txt
* Add some more instances for col-row gemm
* enable more row,row instances
* Use occupancy=1 for col,row layout to avoid spills
* Split-K autodeduction for DeviceGroupedConvBwdWeight_Xdl_CShuffle and DeviceGroupedConvBwdWeight_Xdl_CShuffleV3.
* Split-K autodeduction for DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle.
* Use simple best occupancy model to calculate the split-K.
* Handle split-K autodeduction in explicit gemm conv.
* Add unit tests for split-K autodeduction.
* Remove oversubscription.
* Small fixes.
* Added split-K autodeduction for DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle.
* Run clang formatting.
* Fix error handling in the conv profiler.
* Add missing documentation for the autodeducted split-K values.
* Add split-K autodeduction to DeviceGroupedConvBwdWeight_Explicit_Xdl solver.
* Fix clang formatting and split-K profiler documentation.
* Rename max_occupancy value variable.
* Calculate grid size for split-K autodeduction directly from input array shapes and template params.
---------
Co-authored-by: Ville Pietilä <>
1. Port NCHW support from ConvFwd (#2375) to conv bwd data
2. Add new instance device_grouped_conv_bwd_data_xdl_f16_nchw_instances for nchw
Co-authored-by: azhuang <anzhong.huang@amd.com>
* add template for fp16 atomic add
* add template for unsigned short atomic add
* use atomicCAS in atomic add for fp16 and unsigned short
* revrt back to atomic add using casting
1. When conv spec is 1x1 stride1 pad0, nchw is equal with matrix A + column major, we only need minor change in conv transformer to support it.
2. when out is NKHW, it is equal with matrix C with column major. we need swap A & B to get best performance.
3. Add new instance device_grouped_conv_fwd_xdl_f16_nchw_instances for nchw.
* Fix amd_ck_fp8.hpp macro definitions
1. Define CK_USE_FNUZ_FP8 and CK_USE_OCP_FP8 definitions only if they were not defined before.
2. Prefix __assert_fnuz_support and __assert_ocp_support with namespace
fp8_impl to avoid redefined error when building with rocm 6.4+
(rocm/6.4.0/include/hip/amd_detail/amd_hip_fp8.h)
Co-authored-by: Andriy Roshchenko <andriy.roshchenko@amd.com>
* Avoid passing indices (std::vector) by value to host tensor's operator()
Each access requires 2 allocations and copies of the vector.
* Remove 1 unneeded vector copy from the slowest part of fmha_bwd's verification
* Compute ds_hp_host_ref in parallel
This sequntial ForEach is the slowest part of validation and it benefits
from parallel computation.
* Do not use ForEach for simple copy and conversion of large tensors
These tensors all have the same shape {nhead, real_seqlen_q, real_seqlen_k} and
can be copied/converted without complex computations of linear indices.
* Some prep work for adding batched_gemm_wmma_universal. Moved batched_gemm in general to gfx11 and gfx12 categories, and split existing batched_gemm test into xdl and wmma versions. Updated profiler and instance factory. For now only adding f16-row-row-row-GemmDefault. For now actual device instance list is empty.
* Add DeviceBatchedGemm_Wmma_CShuffleV3 based on DeviceGemm_Wmma_CShuffleV3 and make sure it's used in the instance factory and tests. Currently the new batched device level struct cannot actually handle batching, but it does pass tests with a trivial batch size of 1, meaning that the overall structure is good.
* Add custom kernel and Argument type to DeviceBatchedGemm_Wmma_CShuffleV3. Batching arguments not passed to kernel yet.
* Implement kernel-level batching logic for DeviceBatchedGemm_Wmma_CShuffleV3. In principle the whole thing works now, just need to add other data types and perhaps do some cleanup.
* Add other layouts for batched gemm wmma chufflev3 f16 f16 f16. Now matching XDL (for f16).
* Add bf16 bf16 bf16 support for batched gemm wmma cshuffle v3 for all layouts.
* Fixup comments and TODOs
* Expand test cases for batched gemm wmma cshuffle v3 with more unusual shapes. Some of the original test cases for batched gemm do not work based on cshuffle v3 because the dimensions are too small.
* Fix argument order for calls to profile_batched_gemm_impl() ONLY in wmma tests.
* Take batching into account when using rotating memory or clearing the C tensor.
* Implement small refactors / comments etc. from review.
* Port recent gemm wmma updates to batched gemm wmma: V1 pipeline, non-main-k-block-loop, check compute type, packed buffer size calc. Ported new instance lists.
* Add MNKPadding instances to batched gemm wmma cshuffle v3, remove incompatible test problems.
* Put clearing the C matrix in a pre-process lambda for the non-flush case + small fixups.
* Once again switch order of strides and batch strides in calls to profile_batched_gemm_impl() from test_batched_gemm_wmma to match latest definition of that function.
---------
Co-authored-by: kiefer <kiefer.van.teutem@streamhpc.com>