Jing's contribution: prototype of mixed precision gemm FP16/BF16xint4 GEMM (#1762)

* add a prototype of int4

* clean

* debug

* clean

* clean

* move packed into dynamic_buffer

* fixed coord reset

* add fast pki4 to half conversion

* fix

* fixed reference and host_tensor

* fixed tensor init

* format

* debug i4_to_f16_convert

* format

* fixed splitk

* weight permute

* add b tile permute

* clean

* weight permute with splitki

* format

* improve weight layout

* add and_or_b32

* fixed splitk crush

* add permute switch as a template

* recover v3r1

* clean

* failure with intrawave v2

* fixed

* fixed

* add ckProfiler

* add bfp16 support

* add bf16 example

* fixed int4 to bhalf_t conversion

* format

* fixed int4 to bf16 conversion

* clean

* add instances for mem

* clean

* fixed host tensor size

* fixed

* debug

* fixed

* add pk_i4_t as a struct

* fix

* Update example/01_gemm/gemm_xdl_bf16_pk_i4_v3.cpp

Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com>

* Update example/01_gemm/gemm_xdl_bf16_pk_i4_v3.cpp

Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com>

* Update example/01_gemm/gemm_xdl_bf16_pk_i4_v3.cpp

Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com>

* revert

* Update example/01_gemm/gemm_xdl_bf16_pk_i4_v3.cpp

Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com>

* Update example/01_gemm/gemm_xdl_fp16_pk_i4_v3.cpp

Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com>

* Update example/01_gemm/gemm_xdl_fp16_pk_i4_v3.cpp

Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com>

* Update example/01_gemm/gemm_xdl_fp16_pk_i4_v3.cpp

Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com>

* Update example/01_gemm/gemm_xdl_fp16_pk_i4_v3.cpp

Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com>

* fixed comments

* revert

* clean

* revert

* revert

* fixed

* Update CMakeLists.txt

* Update script/cmake-ck-dev.sh

Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com>

* Update include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp

Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com>

* Update CMakeLists.txt

Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com>

* fixed

* fixed

* fixed

* revert

* revert

* add comments

* format

* fixed assert

* fixed

* Fix I4 define in ckProfiler

* Fixed example_gemm_xdl_bf16_pk_i4_v3 test failed issue

---------

Co-authored-by: Jing Zhang <jizhan@fb.com>
Co-authored-by: zjing14 <zhangjing14@gmail.com>
Co-authored-by: mtgu0705 <mtgu@amd.com>
This commit is contained in:
Adam Osewski
2025-01-02 04:48:06 +01:00
committed by GitHub
parent 159fa31946
commit 1d8e4ec2ce
37 changed files with 1582 additions and 349 deletions

View File

@@ -1,10 +1,10 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include <initializer_list>
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "profiler/profile_gemm_universal_impl.hpp"
#include "profiler_operation_registry.hpp"
@@ -27,6 +27,8 @@ enum struct GemmDataType
F16_F8_F16, // 5
F16_F16_F16_F8, // 6
F8_F8_BF16, // 7
F16_I4_F16, // 8
BF16_I4_BF16, // 9
};
#define OP_NAME "gemm_universal"
@@ -39,7 +41,7 @@ int profile_gemm_universal(int argc, char* argv[])
printf("arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n");
printf("arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8; 4: f8@f16; 5: f16@f8; 6: "
"f16->f8; 7: f8->bf16, "
"comp f8)\n");
"comp f8; 8: f16@i4; 9: bf16@i4\n");
printf("arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n");
printf(" 1: A[m, k] * B[n, k] = C[m, n];\n");
printf(" 2: A[k, m] * B[k, n] = C[m, n];\n");
@@ -103,6 +105,7 @@ int profile_gemm_universal(int argc, char* argv[])
using BF16 = ck::bhalf_t;
#if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94)
using F8 = ck::f8_t;
using I4 = ck::pk_i4_t;
#endif
using Row = ck::tensor_layout::gemm::RowMajor;
@@ -207,6 +210,14 @@ int profile_gemm_universal(int argc, char* argv[])
{
return profile(F8{}, F8{}, F8{}, F32{}, BF16{}, Row{}, Col{}, Row{});
}
else if(data_type == GemmDataType::F16_I4_F16 && layout == GemmMatrixLayout::MK_NK_MN)
{
return profile(F16{}, I4{}, F16{}, F32{}, F16{}, Row{}, Col{}, Row{});
}
else if(data_type == GemmDataType::BF16_I4_BF16 && layout == GemmMatrixLayout::MK_NK_MN)
{
return profile(BF16{}, I4{}, BF16{}, F32{}, BF16{}, Row{}, Col{}, Row{});
}
#endif
else
{