mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 13:11:25 +00:00
Add support for mixed precision bf16&int8 grouped gemm (#1166)
* add support for mixed precision bf16&int8 grouped gemm * fix gfx versions and add bf16 kbatch condition * added reviewers comments
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -650,22 +650,9 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
|
||||
constexpr auto AtomicAdd = InMemoryDataOperationEnum::AtomicAdd;
|
||||
constexpr auto Set = InMemoryDataOperationEnum::Set;
|
||||
|
||||
if(arg.k_batch_ > 1)
|
||||
{
|
||||
if(has_main_k_block_loop)
|
||||
{
|
||||
ave_time =
|
||||
launch_kernel(integral_constant<bool, true>{},
|
||||
integral_constant<InMemoryDataOperationEnum, AtomicAdd>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
ave_time =
|
||||
launch_kernel(integral_constant<bool, false>{},
|
||||
integral_constant<InMemoryDataOperationEnum, AtomicAdd>{});
|
||||
}
|
||||
}
|
||||
else
|
||||
// For bf16 datatype only kbatch = 1 scenario is supported. This condition is enforced
|
||||
// in IsSupportedArgument function
|
||||
if constexpr(std::is_same<ADataType, ck::bhalf_t>::value)
|
||||
{
|
||||
if(has_main_k_block_loop)
|
||||
{
|
||||
@@ -678,6 +665,39 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
|
||||
integral_constant<InMemoryDataOperationEnum, Set>{});
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(arg.k_batch_ > 1)
|
||||
{
|
||||
if(has_main_k_block_loop)
|
||||
{
|
||||
ave_time = launch_kernel(
|
||||
integral_constant<bool, true>{},
|
||||
integral_constant<InMemoryDataOperationEnum, AtomicAdd>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
ave_time = launch_kernel(
|
||||
integral_constant<bool, false>{},
|
||||
integral_constant<InMemoryDataOperationEnum, AtomicAdd>{});
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(has_main_k_block_loop)
|
||||
{
|
||||
ave_time =
|
||||
launch_kernel(integral_constant<bool, true>{},
|
||||
integral_constant<InMemoryDataOperationEnum, Set>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
ave_time =
|
||||
launch_kernel(integral_constant<bool, false>{},
|
||||
integral_constant<InMemoryDataOperationEnum, Set>{});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
@@ -718,6 +738,13 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
|
||||
}
|
||||
}
|
||||
|
||||
// For bf16 datatype only kbatch = 1 is supported since there is no AtomicAdd
|
||||
// instruction that supports bf16 and we cannot use splitk because of that
|
||||
if constexpr(std::is_same<ADataType, ck::bhalf_t>::value)
|
||||
{
|
||||
supported = supported & (arg.k_batch_ == 1);
|
||||
}
|
||||
|
||||
return supported;
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user