Add support for mixed precision bf16&int8 grouped gemm (#1166)

* add support for mixed precision bf16&int8 grouped gemm

* fix gfx versions and add bf16 kbatch condition

* added reviewers comments
This commit is contained in:
jakpiase
2024-02-21 10:35:35 +01:00
committed by GitHub
parent 66736edb95
commit 32d4be3d09
10 changed files with 1159 additions and 19 deletions

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -650,22 +650,9 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
constexpr auto AtomicAdd = InMemoryDataOperationEnum::AtomicAdd;
constexpr auto Set = InMemoryDataOperationEnum::Set;
if(arg.k_batch_ > 1)
{
if(has_main_k_block_loop)
{
ave_time =
launch_kernel(integral_constant<bool, true>{},
integral_constant<InMemoryDataOperationEnum, AtomicAdd>{});
}
else
{
ave_time =
launch_kernel(integral_constant<bool, false>{},
integral_constant<InMemoryDataOperationEnum, AtomicAdd>{});
}
}
else
// For bf16 datatype only kbatch = 1 scenario is supported. This condition is enforced
// in IsSupportedArgument function
if constexpr(std::is_same<ADataType, ck::bhalf_t>::value)
{
if(has_main_k_block_loop)
{
@@ -678,6 +665,39 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
integral_constant<InMemoryDataOperationEnum, Set>{});
}
}
else
{
if(arg.k_batch_ > 1)
{
if(has_main_k_block_loop)
{
ave_time = launch_kernel(
integral_constant<bool, true>{},
integral_constant<InMemoryDataOperationEnum, AtomicAdd>{});
}
else
{
ave_time = launch_kernel(
integral_constant<bool, false>{},
integral_constant<InMemoryDataOperationEnum, AtomicAdd>{});
}
}
else
{
if(has_main_k_block_loop)
{
ave_time =
launch_kernel(integral_constant<bool, true>{},
integral_constant<InMemoryDataOperationEnum, Set>{});
}
else
{
ave_time =
launch_kernel(integral_constant<bool, false>{},
integral_constant<InMemoryDataOperationEnum, Set>{});
}
}
}
return ave_time;
}
@@ -718,6 +738,13 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
}
}
// For bf16 datatype only kbatch = 1 is supported since there is no AtomicAdd
// instruction that supports bf16 and we cannot use splitk because of that
if constexpr(std::is_same<ADataType, ck::bhalf_t>::value)
{
supported = supported & (arg.k_batch_ == 1);
}
return supported;
}