Jing's contribution: prototype of mixed precision gemm FP16/BF16xint4 GEMM (#1762)

* add a prototype of int4

* clean

* debug

* clean

* clean

* move packed into dynamic_buffer

* fixed coord reset

* add fast pki4 to half conversion

* fix

* fixed reference and host_tensor

* fixed tensor init

* format

* debug i4_to_f16_convert

* format

* fixed splitk

* weight permute

* add b tile permute

* clean

* weight permute with splitki

* format

* improve weight layout

* add and_or_b32

* fixed splitk crush

* add permute switch as a template

* recover v3r1

* clean

* failure with intrawave v2

* fixed

* fixed

* add ckProfiler

* add bfp16 support

* add bf16 example

* fixed int4 to bhalf_t conversion

* format

* fixed int4 to bf16 conversion

* clean

* add instances for mem

* clean

* fixed host tensor size

* fixed

* debug

* fixed

* add pk_i4_t as a struct

* fix

* Update example/01_gemm/gemm_xdl_bf16_pk_i4_v3.cpp

Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com>

* Update example/01_gemm/gemm_xdl_bf16_pk_i4_v3.cpp

Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com>

* Update example/01_gemm/gemm_xdl_bf16_pk_i4_v3.cpp

Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com>

* revert

* Update example/01_gemm/gemm_xdl_bf16_pk_i4_v3.cpp

Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com>

* Update example/01_gemm/gemm_xdl_fp16_pk_i4_v3.cpp

Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com>

* Update example/01_gemm/gemm_xdl_fp16_pk_i4_v3.cpp

Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com>

* Update example/01_gemm/gemm_xdl_fp16_pk_i4_v3.cpp

Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com>

* Update example/01_gemm/gemm_xdl_fp16_pk_i4_v3.cpp

Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com>

* fixed comments

* revert

* clean

* revert

* revert

* fixed

* Update CMakeLists.txt

* Update script/cmake-ck-dev.sh

Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com>

* Update include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp

Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com>

* Update CMakeLists.txt

Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com>

* fixed

* fixed

* fixed

* revert

* revert

* add comments

* format

* fixed assert

* fixed

* Fix I4 define in ckProfiler

* Fixed example_gemm_xdl_bf16_pk_i4_v3 test failed issue

---------

Co-authored-by: Jing Zhang <jizhan@fb.com>
Co-authored-by: zjing14 <zhangjing14@gmail.com>
Co-authored-by: mtgu0705 <mtgu@amd.com>
This commit is contained in:
Adam Osewski
2025-01-02 04:48:06 +01:00
committed by GitHub
parent 159fa31946
commit 1d8e4ec2ce
37 changed files with 1582 additions and 349 deletions

View File

@@ -287,3 +287,85 @@ bool parse_cmd_args<ProblemSizeSplitK>(int argc,
return true;
}
template <typename DataType>
inline __host__ __device__ constexpr double get_rtol()
{
if constexpr(std::is_same_v<DataType, float>)
{
return 1e-3;
}
else if constexpr(std::is_same_v<DataType, double>)
{
return 1e-6;
}
else if constexpr(std::is_same_v<DataType, ck::half_t>)
{
return 1e-3;
}
else if constexpr(std::is_same_v<DataType, ck::bhalf_t>)
{
return 5e-2;
}
else if constexpr(std::is_same_v<DataType, int32_t>)
{
return 1e-1;
}
else if constexpr(std::is_same_v<DataType, int8_t>)
{
return 1e-1;
}
else if constexpr(std::is_same_v<DataType, ck::f8_t>)
{
return 1e-1; // 240 and 224 are acceptable
}
else if constexpr(std::is_same_v<DataType, ck::bf8_t>)
{
return 1.5e-1; // 57344 and 49152 are acceptable
}
else
{
return 1e-3;
}
}
template <typename DataType>
inline __host__ __device__ constexpr double get_atol()
{
if constexpr(std::is_same_v<DataType, float>)
{
return 1e-3;
}
else if constexpr(std::is_same_v<DataType, double>)
{
return 1e-6;
}
else if constexpr(std::is_same_v<DataType, ck::half_t>)
{
return 1e-3;
}
else if constexpr(std::is_same_v<DataType, ck::bhalf_t>)
{
return 5e-2;
}
else if constexpr(std::is_same_v<DataType, int32_t>)
{
return 1e-1;
}
else if constexpr(std::is_same_v<DataType, int8_t>)
{
return 1e-1;
}
else if constexpr(std::is_same_v<DataType, ck::f8_t>)
{
return 16.1; // 240 and 224 are acceptable
}
else if constexpr(std::is_same_v<DataType, ck::bf8_t>)
{
return 8192.1; // 57344 and 49152 are acceptable
}
else
{
return 1e-3;
}
}