mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-13 01:36:06 +00:00
Jing's contribution: prototype of mixed precision gemm FP16/BF16xint4 GEMM (#1762)
* add a prototype of int4 * clean * debug * clean * clean * move packed into dynamic_buffer * fixed coord reset * add fast pki4 to half conversion * fix * fixed reference and host_tensor * fixed tensor init * format * debug i4_to_f16_convert * format * fixed splitk * weight permute * add b tile permute * clean * weight permute with splitki * format * improve weight layout * add and_or_b32 * fixed splitk crush * add permute switch as a template * recover v3r1 * clean * failure with intrawave v2 * fixed * fixed * add ckProfiler * add bfp16 support * add bf16 example * fixed int4 to bhalf_t conversion * format * fixed int4 to bf16 conversion * clean * add instances for mem * clean * fixed host tensor size * fixed * debug * fixed * add pk_i4_t as a struct * fix * Update example/01_gemm/gemm_xdl_bf16_pk_i4_v3.cpp Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com> * Update example/01_gemm/gemm_xdl_bf16_pk_i4_v3.cpp Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com> * Update example/01_gemm/gemm_xdl_bf16_pk_i4_v3.cpp Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com> * revert * Update example/01_gemm/gemm_xdl_bf16_pk_i4_v3.cpp Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com> * Update example/01_gemm/gemm_xdl_fp16_pk_i4_v3.cpp Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com> * Update example/01_gemm/gemm_xdl_fp16_pk_i4_v3.cpp Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com> * Update example/01_gemm/gemm_xdl_fp16_pk_i4_v3.cpp Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com> * Update example/01_gemm/gemm_xdl_fp16_pk_i4_v3.cpp Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com> * fixed comments * revert * clean * revert * revert * fixed * Update CMakeLists.txt * Update script/cmake-ck-dev.sh Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com> * Update include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com> * Update CMakeLists.txt Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com> * fixed * fixed * fixed * revert * revert * add comments * format * fixed assert * fixed * Fix I4 define in ckProfiler * Fixed example_gemm_xdl_bf16_pk_i4_v3 test failed issue --------- Co-authored-by: Jing Zhang <jizhan@fb.com> Co-authored-by: zjing14 <zhangjing14@gmail.com> Co-authored-by: mtgu0705 <mtgu@amd.com>
This commit is contained in:
@@ -1,10 +1,10 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <cstdlib>
|
||||
#include <initializer_list>
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
|
||||
#include "profiler/profile_gemm_universal_impl.hpp"
|
||||
#include "profiler_operation_registry.hpp"
|
||||
@@ -27,6 +27,8 @@ enum struct GemmDataType
|
||||
F16_F8_F16, // 5
|
||||
F16_F16_F16_F8, // 6
|
||||
F8_F8_BF16, // 7
|
||||
F16_I4_F16, // 8
|
||||
BF16_I4_BF16, // 9
|
||||
};
|
||||
|
||||
#define OP_NAME "gemm_universal"
|
||||
@@ -39,7 +41,7 @@ int profile_gemm_universal(int argc, char* argv[])
|
||||
printf("arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n");
|
||||
printf("arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8; 4: f8@f16; 5: f16@f8; 6: "
|
||||
"f16->f8; 7: f8->bf16, "
|
||||
"comp f8)\n");
|
||||
"comp f8; 8: f16@i4; 9: bf16@i4\n");
|
||||
printf("arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n");
|
||||
printf(" 1: A[m, k] * B[n, k] = C[m, n];\n");
|
||||
printf(" 2: A[k, m] * B[k, n] = C[m, n];\n");
|
||||
@@ -103,6 +105,7 @@ int profile_gemm_universal(int argc, char* argv[])
|
||||
using BF16 = ck::bhalf_t;
|
||||
#if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94)
|
||||
using F8 = ck::f8_t;
|
||||
using I4 = ck::pk_i4_t;
|
||||
#endif
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
@@ -207,6 +210,14 @@ int profile_gemm_universal(int argc, char* argv[])
|
||||
{
|
||||
return profile(F8{}, F8{}, F8{}, F32{}, BF16{}, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(data_type == GemmDataType::F16_I4_F16 && layout == GemmMatrixLayout::MK_NK_MN)
|
||||
{
|
||||
return profile(F16{}, I4{}, F16{}, F32{}, F16{}, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(data_type == GemmDataType::BF16_I4_BF16 && layout == GemmMatrixLayout::MK_NK_MN)
|
||||
{
|
||||
return profile(BF16{}, I4{}, BF16{}, F32{}, BF16{}, Row{}, Col{}, Row{});
|
||||
}
|
||||
#endif
|
||||
else
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user