* Fixed cmake errors related to gemm_bilinear. Previously, if the above flags are set, cmake build fails: GPU_TARGETS="gfx1100;gfx1201" -D DTYPES="fp16;bf16;fp8"
* Fixed cmake build errors related to test_fp8
* Updates to support mixed precision
(cherry picked from commit e65d71180393e7b66169c56565a6bac740427de6)
Co-authored-by: Anca Hamuraru <anca@streamhpc.com>
* Adding support for RRR, F8xF16xF16 gemm_universal_wmma - wip
(cherry picked from commit f8c06322df0abcbd5945a56cdf5bffe56480f9f0)
Co-authored-by: Anca Hamuraru <anca@streamhpc.com>
* Added support for F8xF16xF16 to gemm_wmma_universal
(cherry picked from commit 15c851de6daa513a12c2e3af299bab0176175fb5)
Co-authored-by: Anca Hamuraru <anca@streamhpc.com>
* Added support for F16xF8xF16 to gemm_wmma_universal
* Added support for BF16xI4xBF16 to gemm_wmma_universal
(cherry picked from commit c6a4a69d2d43d59bae8bdabfae80d648646f217e)
Co-authored-by: Anca Hamuraru <anca@streamhpc.com>
* Added support for F16xI4xF16 to gemm_wmma_universal
* Fixed IsSupportedArgument to check ComputeTypeA, ComputeTypeB instead of ADataType, BDataType
* Added missing test class for FP16_KM_NK
* Pre-commit hooks fixes
* Added padding instances for f16xf16xf16
* Fixed cmake errors related to gemm_bilinear. Previously, if the above flags are set, cmake build fails: GPU_TARGETS="gfx1100;gfx1201" -D DTYPES="fp16;bf16;fp8"
(cherry picked from commit 5bdc993dbf)
Co-authored-by: Anca Hamuraru <anca@streamhpc.com>
* Fixed cmake build errors related to test_fp8
(cherry picked from commit 12176616b6)
Co-authored-by: Anca Hamuraru <anca@streamhpc.com>
* Ammending changes for adding support for padding instances for f16xf16xf16
* Fixes for padding instances for f16xf16xf16
* Added padding instances for bf16xbf16, f8xf8
* Added packed instances for bf16xi4xbf16
* Added padding instances for f8xf16xf16
* Added padding instances for f16xf8xf16, f16xi4xf16
* Fixed typos for bf16xbf16xbf16 padding instances
* Fixed typos for padded instances
* Added tests for fp16, KM_KN and KM_NK
* Padding not supported for when BDataType is pk_i4_t. Added fix for correct check and removed padding instances.
* Fixed typos
* Updated the set of tests for FP16
* Updated the set of tests for FP16
* Fix typo
* Moved f16xi4 test under the correct data layout group
* example for gemm_universal_bf16
* Adding examples for gemm_wmma instances
* Added the missing parameters
* Fixed review comments and added executable to cmakeLists
* Fixing clang format
* Fixing build erros
* Fixed compilation failure.
* Modified some code as per gemm_universal_examples
* Fixed the gemm specialization error
* Fixed the build errors.
* Fix strides of a/b_thread_desc
The descriptors are larger than needed (even though the compiler don't alloc registers for unused values).
* Load in M/NRepeat dims with thread copy's slice instead of a loop
* Clone BlockwiseGemmXdlops_pipeline_v1 for WMMA implementation
* Implement Intrawave and Interwave variants of pipeline v1
* Add instances for Interwave and Intrawave v1
* Add instances with ABlockLdsExtraM and BBlockLdsExtraN = 0
* Remove instances that are too slow (mostly because of register spilling)
* Add a workaround for fp8/bf8->f32 packed conversion issue
* Add instances for Interwave and Intrawave v1
* Enable profiling of mixed precision with f8 and int4 on WMMA
* Fix segfault in profiler when B is pk_i4_t
b_device_buf's size in bytes is larger than b_k_n_permute so b_device_buf.ToDevice reads out-of-bounds.
* Remove instances that are too slow (mostly because of register spilling)
* Add missing add_device_gemm_wmma_universal_f8_f8_bf16 declarations
* Add test case for bf16_i4
* Add missing Regular tests
* Add test_gemm_universal_xdl/wmma_fp16 to REGRESSION_TESTS
They take more than 30 seconds
* Fix a bug that fp16_i4 validation passes only with PermuteB
A permutation required by conversion from pk_i4_t to half_t does not
depend on PermuteB, they can be used independently.
* Use PermuteB with f16_i4 in most instances (as xdl)
Some instances use PermuteB = false for checking correctness.
See also the previous commit.
* Fix cache flushing for pk_i4
* Add mixed precision examples
* Disable all tests and instances with f8 on gfx11
Even though f8_f16 and f16_f8 don't require f8 WMMA instructions,
gfx11 still lacks hardware instructions for fast f8->f32 conversion.
* Add FP16 KM_NK and KM_KN test suites for XDL
These tests were added to common .inc for better testing of WMMA instances
* Support multiple D in GridwiseGemm_wmma_cshuffle_v3
DeviceGemm_Wmma_CShuffleV3 is changed for new template parameters.
* Use ThreadGroupTensorSliceTransfer_v7r3
* Clone for device_gemm_wmma_cshuffle_v3.hpp for future Multiple D support
* Clone example/65_gemm_multiply_multiply/gemm_add_add_xdl_fp16.cpp for wmma
* Implement DeviceGemmMultipleD_Wmma_CShuffleV3
* Make gemm_add_add_wmma to work with DeviceGemmMultipleD_Wmma_CShuffleV3
* Prepare gemma_add tests for adding wmma
* Add gemm_add_fastgelu instances and test
* Add a special wrapper to use DeviceGemmMultipleD_Wmma_CShuffleV3 with old API
ckProfiler uses DeviceGemmMultipleD (tests also call its functions), the wrapper allows to use
DeviceGemmMultipleDSplitK instances there.
* removed unnecessary ck parts from compilation
* initial gemm_add_multiply instance implementations
* fixed profiler help message for gemm_add_multiply
* improved multiply_add profiler layout help
* fixed template arguments for test instances
* added test for gemm_add_multiply
* Support multiple D in GridwiseGemm_wmma_cshuffle_v3
DeviceGemm_Wmma_CShuffleV3 is changed for new template parameters.
* Use ThreadGroupTensorSliceTransfer_v7r3
* Clone for device_gemm_wmma_cshuffle_v3.hpp for future Multiple D support
* Clone example/65_gemm_multiply_multiply/gemm_add_add_xdl_fp16.cpp for wmma
* Implement DeviceGemmMultipleD_Wmma_CShuffleV3
* Make gemm_add_add_wmma to work with DeviceGemmMultipleD_Wmma_CShuffleV3
* Prepare gemma_add tests for adding wmma
* Add gemm_add_fastgelu instances and test
* Add a special wrapper to use DeviceGemmMultipleD_Wmma_CShuffleV3 with old API
ckProfiler uses DeviceGemmMultipleD (tests also call its functions), the wrapper allows to use
DeviceGemmMultipleDSplitK instances there.
* switched to splitK interface
* log print added to splitk benchmarks
* revert main cmake comments
* newline change reverted
* added add_fastgelu instances
* revert unintended change in xdl add_fastgelu
* created gemm_add_add_fastgelu instances
* created fastegelu instances
* added tests for all splitk fastgelus
* Added tests.
* multiply_add instances created
* updates to add_multiply splitk instances
* splitk xdl test fixes
* added wmma multiply_multiply instances
* fixed ONLY_XDL_AND_WMMA_KERNELS tag
* Added gemm_add examples for wmma v1 and v3
* fixed / workarounded i8 instances
* Modified the v3 code to added one fp16 bxdl instance.
* added bf16 xdl instance.
* adding gemm_add wmma_cshuffle and other support
(cherry picked from commit ec447e7f564095ea969eddc39ec77b843aa52976)
Co-authored-by: Cenxuan <cenxuan@streamhpc.com>
* add instances into camkelists
(cherry picked from commit 23bf2d2771c939ea3ca7f493433c55255bffd08e)
Co-authored-by: Cenxuan <cenxuan@streamhpc.com>
* This is work in progress, edited the template parameters in order to build
(cherry picked from commit b4fde8a3314cb44659c4bbda35f1a0133c63dc41)
Co-authored-by: Cenxuan <cenxuan@streamhpc.com>
* temp work saved, changed the BDataType to f16 or bf16 since wmma currently not support non-equal A and B datatype
(cherry picked from commit 22fbd68f1db458ab50780a394ee2544c7a1484d1)
Co-authored-by: Cenxuan <cenxuan@streamhpc.com>
* added datatype and use clang-format-12
(cherry picked from commit ae4e853682ef1bb27784b2f965b4a66b3751ceec)
Co-authored-by: Cenxuan <cenxuan@streamhpc.com>
* Fixing build errors
* Added instances for v3
* Adding instances and executables
* Code update of template parameters modified.
* Renamed file.
* Added tests.
* resolved error tests.
* Fixing build errors
* Updated comments
* removed the changes as per the MR review comment.
* Updated tests.
* fp8 instances - not tested
* Restored the Cmake file that was reverted by mistake during rebase.
* fixed wmma_op test
* Updated comments.
* Updated the template parameter description
* fixed rdna4 instances
* fixed back compatibility on gfx11
* cleanups
* fix ckProfiler
* one more cmake fix
* added fp8 instances
* Updated tests to ad BF16 instances as per review comment
* Added include file and cleaned up(as per review comment)
* Updated and optimized the example code for all types.
* Fixed clang format
* Resolve "Implement `device_gemm_bilinear` for RDNA4"
* test generalization to handle FP16 shuffle better
* added missing changes
* Added bf16 wmma instance for add_relu
* Added f16 wmma instance and corrected bf16 instance errors.
* Added instances to Cmake
* Modified the template parameters to make the instances work.
* Fixed typo in profiler
* Added v3 instances for gemm_add_relu
* addressed core review comments
* Added test for gemm_add_relu wmma instance
* Cleaned up the code.
* Added examples for gemm_add_relu
* Fixing typo to resolve build errors.
* Fixes applied to fix the precision loss.
* fix billinear test after merge
* Removed the old wmma instances.
* Added wrapper and renamed the wmma_v3 instances
* Updated copyrights and added wrappers.
* Fixes applied according to review comments
* Apply 1 suggestion(s) to 1 file(s)
Co-authored-by: Robin Voetter <robin@streamhpc.com>
* Removed the old wmma instances.
* Updated wrapper for the v3 instances
* removed the old wmma examples
* Renamed the v3 instances
* Deleted the gtest file added by mistake.
* Updated thge profiler with wrapper
* Fixed test errors.
* Fixed the review comments
* Fixed the if condition MACROS.
* REVERTED THE PROFILER CHANGES
* Revert "REVERTED THE PROFILER CHANGES"
This reverts commit 21cb98546c.
* Revert "Fixed test errors."
This reverts commit 13efcc6fe1.
* Revert "Updated thge profiler with wrapper"
This reverts commit 536f86661d.
* Added missing wrapper instances
* Updated copyrights.
* Fixed typo.
* Fixed copyrights.
* Updated copyrights.
* updated copyrights.
* comments on the atomics workaround
* fixed cmake comment
* Fix bug from merge
* clang-format-18
* Fix compilation error
* Fix linking error
* Fix bug in add and add_relu examples
* Fix error including file (typo)
* Quick fix to compile examples for different targets
* Fix for multi target
* implemented f16 and bf16 instances for gemm_silu
* addressed review comments
* addressed review comments
* Fix clang format
* Fix clang format
---------
Co-authored-by: Anca Hamuraru <anca@streamhpc.com>
Co-authored-by: apoorva <apoorva@streamhpc.com>
Co-authored-by: Anton Gorenko <anton@streamhpc.com>
Co-authored-by: Zoltan Lakatos <zoltan.lakatos@streamhpc.com>
Co-authored-by: Cenxuan <cenxuan@streamhpc.com>
Co-authored-by: Robin Voetter <robin@streamhpc.com>
Co-authored-by: Kiefer van Teutem <kiefer.van.teutem@streamhpc.com>
Co-authored-by: Kevin Abraham <kevin.abraham@streamhpc.com>
Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
* Adding RapidJson Library
* Adding Json Dumps in all CK_Tile Examples
Not verified yet
* Adding json to cktile Batched Transpose
* adding json dumps to layernorm2d_fwd
* Adding json dump to flatmm_basic
* Adding RapidJson Library
* Adding Json Dumps in all CK_Tile Examples
Not verified yet
* Adding json to cktile Batched Transpose
* adding json dumps to layernorm2d_fwd
* Adding json dump to flatmm_basic
* Adding json in 03_gemm
* Add json dump to 16_batched_gemm
* Add json dump to gemm_multi_d_fp16
* Add json dump to grouped_gemm
* fix fmha_bwd/fwd
* Fix clang-format errors
exclude include/rapidjson in jenkins as its a third-party library
* Saparating function and defination.
* Update Documentation of 03_gemm
* Refactoring as per code review
* Disable fp8 instances on unsupported targets (#2592)
* Restrict building of gemm_universal_preshuffle_f8 instances to specific targets in CMakeLists.txt
* Add condition to skip gemm_xdl_universal_preshuffle_f8 instances for unsupported targets in CMakeLists.txt
* Add conditions to skip unsupported targets for gemm_universal_preshuffle_f8 and gemm_xdl_universal_preshuffle_f8 instances in CMakeLists.txt
* Refine conditions to exclude gemm_universal_preshuffle_f8 instances for unsupported targets in CMakeLists.txt
---------
Co-authored-by: AviralGoelAMD <aviralgoel@amd.com>
* fix clang format
* remove duplicate lines of code from library/src/tensor_operation_instance/gpu/CMakeLists.txt
* Fixing Readme and unifying jsondumps
* adding moe_smoothquant
* adding fused_moe
* Fixing Readme for batched_gemm
* Fixing Readme for grouped_gemm
* adding flatmm
* adding gemm_multi_d_fp16
* adding elementwise
* adding File name when json is dumped
* Fixing Reduce after merge
* adding batched_transpose
* Adding Warptile in Gemm
* Fixing Clang Format
---------
Co-authored-by: Aviral Goel <aviral.goel@amd.com>
Co-authored-by: AviralGoelAMD <aviralgoel@amd.com>
Co-authored-by: illsilin_amdeng <Illia.Silin@amd.com>
* uncomment all the headdim, use optdim to control
* change default back to -1
* uncomment splitkv instance
* Fix typo in receipt 4 for appendkv
* support optdim for bwd, splitkv and appendkv
* Fix 192 key error
---------
Co-authored-by: Max Podkorytov <4273004+tenpercent@users.noreply.github.com>
Co-authored-by: Andy Lugo <Andy.LugoReyes@amd.com>
* Wrap tile size mapping as class method
* Warp pipeline generating as class method
* Add constraint as kernel dispatching criteria
* Support mutltiple tile size for a (hdim, hdim_v) combination
* Use smaller tile size if CU utilization is low
* Use integar as the key of the tile size map
* Fix type error
* Simply override parent class method return value
* Add attribute to eliminate warnging
* Allow using environment variables to turn on/off custom factory
* Unify param naming style
* Add missing HIP runtime include directive
* Fix os.environ.get() usage
* add prefetching physical block id for pagedkv
* start add pagedkv prefill
* rename pipeline
* add kernel for pagedkv
* add an init version pagedkv prefill
* fix redefine issue
* add struct BlockFmhaFwdPagedKVPipelineProblem and fmha_fwd_pagedkv_args
* generate dispatch code
* add body generating code
* comipling pass
* remove dropout from pagedkv
* set lse to false in generating code
* start changing qr kernel to pagedkv
* init version of kernerl with pagedkv
* change names of file that are generated
* chang host validation for pagedkv prefill
* using iglp to change blockgemm
* add kernel files to op head file
* show parameters
* rewrite print parameter fun
* add fwd
* remove default parameter of GridSize
* format
* fix nhead issue and add seqlen_k_ptr to batch mode
* format code
* remove no-longer used code
* format
* fix some comments
---------
Co-authored-by: ltqin <letaoqin@amd.com>
Co-authored-by: Po Yen Chen <PoYen.Chen@amd.com>
* Avoid passing indices (std::vector) by value to host tensor's operator()
Each access requires 2 allocations and copies of the vector.
* Remove 1 unneeded vector copy from the slowest part of fmha_bwd's verification
* Compute ds_hp_host_ref in parallel
This sequntial ForEach is the slowest part of validation and it benefits
from parallel computation.
* Do not use ForEach for simple copy and conversion of large tensors
These tensors all have the same shape {nhead, real_seqlen_q, real_seqlen_k} and
can be copied/converted without complex computations of linear indices.
* - elevate important build messages to log level STATUS
- comment out the rest (temporarily)
* - marked all low importance build messages as log_level=DEBUG
* Add constraint on traits/tile/pipeline
* Use kM0=128 if max_seqlen_q == 8192
* Re-format codegen script
* Remove redundant attr name postix
* Fix import error: default field in dataclass
* Use kK0=64 & kK1=64 to hide latency
* Use CU utilization to decide tile size
* add ck tile examples to package
* Update jenkinsfile
* fix for jenkinsfile
* fix for building ck tile code on non gfx9
* compile ck tile examples only for gfx94
* include ck tile examples in all target
* fix for basic gemm UseStructuredSparsity
* Update CMakeLists.txt
* Update gemm_pipeline_problem.hpp
* add targets to rocm install
---------
Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>