mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-05 06:01:23 +00:00
Jing's contribution: prototype of mixed precision gemm FP16/BF16xint4 GEMM (#1762)
* add a prototype of int4 * clean * debug * clean * clean * move packed into dynamic_buffer * fixed coord reset * add fast pki4 to half conversion * fix * fixed reference and host_tensor * fixed tensor init * format * debug i4_to_f16_convert * format * fixed splitk * weight permute * add b tile permute * clean * weight permute with splitki * format * improve weight layout * add and_or_b32 * fixed splitk crush * add permute switch as a template * recover v3r1 * clean * failure with intrawave v2 * fixed * fixed * add ckProfiler * add bfp16 support * add bf16 example * fixed int4 to bhalf_t conversion * format * fixed int4 to bf16 conversion * clean * add instances for mem * clean * fixed host tensor size * fixed * debug * fixed * add pk_i4_t as a struct * fix * Update example/01_gemm/gemm_xdl_bf16_pk_i4_v3.cpp Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com> * Update example/01_gemm/gemm_xdl_bf16_pk_i4_v3.cpp Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com> * Update example/01_gemm/gemm_xdl_bf16_pk_i4_v3.cpp Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com> * revert * Update example/01_gemm/gemm_xdl_bf16_pk_i4_v3.cpp Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com> * Update example/01_gemm/gemm_xdl_fp16_pk_i4_v3.cpp Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com> * Update example/01_gemm/gemm_xdl_fp16_pk_i4_v3.cpp Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com> * Update example/01_gemm/gemm_xdl_fp16_pk_i4_v3.cpp Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com> * Update example/01_gemm/gemm_xdl_fp16_pk_i4_v3.cpp Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com> * fixed comments * revert * clean * revert * revert * fixed * Update CMakeLists.txt * Update script/cmake-ck-dev.sh Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com> * Update include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com> * Update CMakeLists.txt Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com> * fixed * fixed * fixed * revert * revert * add comments * format * fixed assert * fixed * Fix I4 define in ckProfiler * Fixed example_gemm_xdl_bf16_pk_i4_v3 test failed issue --------- Co-authored-by: Jing Zhang <jizhan@fb.com> Co-authored-by: zjing14 <zhangjing14@gmail.com> Co-authored-by: mtgu0705 <mtgu@amd.com>
This commit is contained in:
@@ -127,7 +127,9 @@ template <typename ALayout,
|
||||
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v4,
|
||||
typename ComputeTypeA = CDataType,
|
||||
typename ComputeTypeB = ComputeTypeA>
|
||||
typename ComputeTypeB = ComputeTypeA,
|
||||
bool PermuteA = false,
|
||||
bool PermuteB = false>
|
||||
struct GridwiseGemm_xdl_cshuffle_v3
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
@@ -151,6 +153,20 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
|
||||
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
|
||||
|
||||
static constexpr index_t APackedSize = []() {
|
||||
if constexpr(is_same_v<remove_cvref_t<ADataType>, pk_i4_t>)
|
||||
return 2;
|
||||
else
|
||||
return 1;
|
||||
}();
|
||||
|
||||
static constexpr index_t BPackedSize = []() {
|
||||
if constexpr(is_same_v<remove_cvref_t<BDataType>, pk_i4_t>)
|
||||
return 2;
|
||||
else
|
||||
return 1;
|
||||
}();
|
||||
|
||||
__host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch)
|
||||
{
|
||||
return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, KBatch);
|
||||
@@ -319,6 +335,10 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
|
||||
using GemmSpecialization = tensor_operation::device::GemmSpecialization;
|
||||
|
||||
static_assert(!(is_same_v<remove_cvref_t<ADataType>, pk_i4_t> &&
|
||||
GemmSpec != GemmSpecialization::Default),
|
||||
"pk_i4_t does not support padding");
|
||||
|
||||
if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
|
||||
GemmSpec == GemmSpecialization::MNKPadding)
|
||||
{
|
||||
@@ -373,15 +393,39 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
}
|
||||
else
|
||||
{
|
||||
// not pad N or K
|
||||
const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
|
||||
b_grid_desc_nraw_kraw,
|
||||
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)),
|
||||
make_pass_through_transform(N)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
if constexpr(!PermuteB)
|
||||
{
|
||||
// not pad N or K
|
||||
const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
|
||||
b_grid_desc_nraw_kraw,
|
||||
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)),
|
||||
make_pass_through_transform(N)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
return b_grid_desc_bk0_n_bk1;
|
||||
return b_grid_desc_bk0_n_bk1;
|
||||
}
|
||||
else
|
||||
{
|
||||
// Pre-shuffled Weight
|
||||
// BGlobal[K / KPerBlock, N, KPerBlock / K1, K1] -> BTile[K / K1, N, K1]
|
||||
constexpr index_t BK01 = KPerBlock / BK1Value;
|
||||
const index_t BK0_ = StrideB / BK1Value;
|
||||
const index_t BK00 = BK0_ / BK01;
|
||||
|
||||
const auto b_grid_desc_bk00_n_bk01_bk1_permute =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(BK00, N, BK01, BK1Value));
|
||||
|
||||
const auto b_grid_desc_bk0_n_bk1_permute = transform_tensor_descriptor(
|
||||
b_grid_desc_bk00_n_bk01_bk1_permute,
|
||||
make_tuple(make_merge_transform(make_tuple(BK00, BK01)),
|
||||
make_pass_through_transform(make_tuple(N)),
|
||||
make_pass_through_transform(BK1Value)),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
|
||||
return b_grid_desc_bk0_n_bk1_permute;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -572,7 +616,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
{
|
||||
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
|
||||
{
|
||||
a_k_split_offset = blockIdx.z * karg.KRead;
|
||||
a_k_split_offset = blockIdx.z * karg.KRead / APackedSize;
|
||||
}
|
||||
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
|
||||
{
|
||||
@@ -585,7 +629,15 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
}
|
||||
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
|
||||
{
|
||||
b_k_split_offset = blockIdx.z * karg.KRead;
|
||||
if constexpr(!PermuteB)
|
||||
{
|
||||
b_k_split_offset = blockIdx.z * karg.KRead / BPackedSize;
|
||||
}
|
||||
else
|
||||
{
|
||||
const int k0_offset = karg.KRead * karg.N;
|
||||
b_k_split_offset = blockIdx.z * k0_offset / BPackedSize;
|
||||
}
|
||||
}
|
||||
|
||||
if(blockIdx.z < static_cast<uint32_t>(karg.KBatch - 1))
|
||||
@@ -625,9 +677,8 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
// in some cases.
|
||||
else if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
|
||||
{
|
||||
constexpr auto MLdsLayer = 32 * 4 / KPerBlock / sizeof(ADataType) < 1
|
||||
? 1
|
||||
: 32 * 4 / KPerBlock / sizeof(ADataType);
|
||||
constexpr index_t LdsSize = 32 * 4 / KPerBlock / sizeof(ADataType) / APackedSize;
|
||||
constexpr auto MLdsLayer = LdsSize < 1 ? 1 : LdsSize;
|
||||
constexpr auto a_lds_block_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(
|
||||
AK0Number * Number<MLdsLayer>{}, Number<MPerBlock / MLdsLayer>{}, AK1Number),
|
||||
@@ -761,10 +812,8 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
|
||||
{
|
||||
// NLdsLayer * K0 as logical Bank
|
||||
constexpr auto NLdsLayer = 32 * 4 / KPerBlock / sizeof(BDataType) < 1
|
||||
? 1
|
||||
: 32 * 4 / KPerBlock / sizeof(BDataType);
|
||||
;
|
||||
constexpr index_t LdsSize = 32 * 4 / KPerBlock / sizeof(BDataType) / BPackedSize;
|
||||
constexpr index_t NLdsLayer = LdsSize < 1 ? 1 : LdsSize;
|
||||
constexpr auto b_lds_block_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(
|
||||
BK0Number * Number<NLdsLayer>{}, Number<NPerBlock / NLdsLayer>{}, BK1Number),
|
||||
@@ -946,8 +995,8 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
constexpr auto c_block_size =
|
||||
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
|
||||
|
||||
return math::max((a_block_space_size_aligned * sizeof(ADataType) +
|
||||
b_block_space_size_aligned * sizeof(BDataType)),
|
||||
return math::max((a_block_space_size_aligned * sizeof(ADataType) / APackedSize +
|
||||
b_block_space_size_aligned * sizeof(BDataType) / BPackedSize),
|
||||
c_block_size * sizeof(CShuffleDataType));
|
||||
}
|
||||
|
||||
@@ -1312,8 +1361,9 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
static_cast<ADataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
|
||||
|
||||
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<BDataType*>(p_shared) +
|
||||
a_block_space_size_aligned * sizeof(ADataType) / sizeof(BDataType),
|
||||
reinterpret_cast<BDataType*>(static_cast<char*>(p_shared) + a_block_space_size_aligned *
|
||||
sizeof(ADataType) /
|
||||
APackedSize),
|
||||
b_block_desc_bk0_n_bk1.GetElementSpaceSize());
|
||||
|
||||
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
|
||||
@@ -1706,16 +1756,16 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
static_cast<ADataType*>(p_shared_0), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
|
||||
|
||||
auto b_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<BDataType*>(p_shared_0) +
|
||||
a_block_space_size_aligned * sizeof(ADataType) / sizeof(BDataType),
|
||||
bit_cast<BDataType*>(static_cast<char*>(p_shared_0) +
|
||||
a_block_space_size_aligned * sizeof(ADataType)),
|
||||
b_block_desc_bk0_n_bk1.GetElementSpaceSize());
|
||||
|
||||
auto a_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<ADataType*>(p_shared_1), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
|
||||
|
||||
auto b_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<BDataType*>(p_shared_1) +
|
||||
a_block_space_size_aligned * sizeof(ADataType) / sizeof(BDataType),
|
||||
bit_cast<BDataType*>(bit_cast<char*>(p_shared_1) +
|
||||
a_block_space_size_aligned * sizeof(ADataType)),
|
||||
b_block_desc_bk0_n_bk1.GetElementSpaceSize());
|
||||
|
||||
auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong);
|
||||
|
||||
Reference in New Issue
Block a user