mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 21:21:22 +00:00
Add fp16/fp8 support into Grouped gemm FixedNK (#874)
* move all arguments into device * add b2c_tile_map * add examples * add SetDeviceKernelArgs * dedicated fixed_nk solution * init client api * add grouped_gemm_bias example * add a instance * add instances * formatting * fixed cmake * Update EnableCompilerWarnings.cmake * Update cmake-ck-dev.sh * clean; fixed comments * fixed comment * add instances for fp32 output * add instances for fp32 output * add fp32 out client example * fixed CI * init commit for kbatch * add splitk gridwise * format * fixed * clean deviceop * clean code * finish splitk * fixed instances * change m_loops to tile_loops * add setkbatch * clean code * add splitK+bias * add instances * opt mk_nk instances * clean examples * fixed CI * remove zero * finished non-zero * clean * clean code * optimized global_barrier * fixed ci * fixed CI * instance and client * removed AddBias * format * fixed CI * fixed CI * move 20_grouped_gemm to 21_grouped_gemm * clean * formatting * clean * clean * fixed computeType --------- Co-authored-by: Jing Zhang <jizha@amd.com>
This commit is contained in:
@@ -29,7 +29,9 @@ namespace ck {
|
||||
// E = cde_op(C, D0, D1, ...)
|
||||
// Assume:
|
||||
// D0, D1, ... and E have the same layout
|
||||
template <typename ABDataType, // FIXME: don't assume A/B have same datatype
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename ComputeType,
|
||||
typename AccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename DsDataType,
|
||||
@@ -96,17 +98,6 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
|
||||
using GridwiseGemmPipe = remove_cvref_t<
|
||||
decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
|
||||
|
||||
// denorm test fix, required to work around fp16 mfma issue
|
||||
// we convert fp16->fp32->bf16 and execute bf16 mfma instruction
|
||||
// when mfma if fixed, remove this section and update
|
||||
// ABDataTypeAdjusted -> ABDataType throughout this file
|
||||
#if CK_WORKAROUND_DENORM_FIX
|
||||
using ABDataTypeAdjusted =
|
||||
conditional_t<is_same_v<ABDataType, ck::half_t>, ck::bhalf_t, ABDataType>;
|
||||
#else
|
||||
using ABDataTypeAdjusted = ABDataType;
|
||||
#endif
|
||||
|
||||
__host__ __device__ static constexpr auto GetABlockDescriptor_KBatch_AK0PerBlock_MPerBlock_AK1()
|
||||
{
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
@@ -196,7 +187,7 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
|
||||
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
|
||||
|
||||
return math::max((a_block_space_size_aligned + b_block_space_size_aligned) *
|
||||
sizeof(ABDataType),
|
||||
sizeof(ComputeType),
|
||||
c_block_size * sizeof(CShuffleDataType));
|
||||
}
|
||||
|
||||
@@ -401,8 +392,8 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
|
||||
// check tensor size: cannot be larger than 2GB each
|
||||
constexpr long_index_t TwoGB = (long_index_t{1} << 31);
|
||||
|
||||
if(!(a_grid_desc_kbatch_ak0_m_ak1.GetElementSpaceSize() * sizeof(ABDataType) <= TwoGB &&
|
||||
b_grid_desc_kbatch_bk0_n_bk1.GetElementSpaceSize() * sizeof(ABDataType) <= TwoGB &&
|
||||
if(!(a_grid_desc_kbatch_ak0_m_ak1.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB &&
|
||||
b_grid_desc_kbatch_bk0_n_bk1.GetElementSpaceSize() * sizeof(BDataType) <= TwoGB &&
|
||||
e_grid_desc_m_n.GetElementSpaceSize() * sizeof(EDataType) <= TwoGB))
|
||||
{
|
||||
return false;
|
||||
@@ -470,8 +461,8 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
|
||||
typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename CDEElementwiseOperation_,
|
||||
typename Block2ETileMap>
|
||||
__device__ static void Run(const ABDataType* __restrict__ p_a_grid,
|
||||
const ABDataType* __restrict__ p_b_grid,
|
||||
__device__ static void Run(const ADataType* __restrict__ p_a_grid,
|
||||
const BDataType* __restrict__ p_b_grid,
|
||||
DsGridPointer p_ds_grid,
|
||||
EDataType* __restrict__ p_e_grid,
|
||||
void* __restrict__ p_shared,
|
||||
@@ -538,8 +529,8 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
|
||||
Sequence<1, AK0PerBlock, MPerBlock, AK1>,
|
||||
ABlockTransferThreadClusterLengths_KBatch_AK0_M_AK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABDataType,
|
||||
ABDataTypeAdjusted,
|
||||
ADataType,
|
||||
ComputeType,
|
||||
decltype(a_grid_desc_kbatch_ak0_m_ak1),
|
||||
decltype(a_block_desc_kbatch_ak0_m_ak1),
|
||||
ABlockTransferSrcAccessOrder,
|
||||
@@ -569,8 +560,8 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
|
||||
Sequence<1, BK0PerBlock, NPerBlock, BK1>,
|
||||
BBlockTransferThreadClusterLengths_KBatch_BK0_N_BK1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
ABDataType,
|
||||
ABDataTypeAdjusted,
|
||||
BDataType,
|
||||
ComputeType,
|
||||
decltype(b_grid_desc_kbatch_bk0_n_bk1),
|
||||
decltype(b_block_desc_kbatch_bk0_n_bk1),
|
||||
BBlockTransferSrcAccessOrder,
|
||||
@@ -606,11 +597,11 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
|
||||
// sanity check
|
||||
constexpr index_t KPack =
|
||||
math::max(math::lcm(AK1, BK1),
|
||||
MfmaSelector<ABDataTypeAdjusted, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
|
||||
MfmaSelector<ComputeType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
|
||||
|
||||
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
|
||||
BlockSize,
|
||||
ABDataTypeAdjusted,
|
||||
ComputeType,
|
||||
AccDataType,
|
||||
decltype(a_block_desc_ak0_m_ak1),
|
||||
decltype(b_block_desc_bk0_n_bk1),
|
||||
@@ -683,11 +674,10 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
|
||||
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<ABDataTypeAdjusted*>(p_shared),
|
||||
a_block_desc_ak0_m_ak1.GetElementSpaceSize());
|
||||
static_cast<ComputeType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
|
||||
|
||||
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<ABDataTypeAdjusted*>(p_shared) + a_block_space_size_aligned,
|
||||
static_cast<ComputeType*>(p_shared) + a_block_space_size_aligned,
|
||||
b_block_desc_bk0_n_bk1.GetElementSpaceSize());
|
||||
|
||||
constexpr auto a_block_slice_copy_step = make_multi_index(0, KPerBlock / AK1, 0, 0);
|
||||
@@ -999,8 +989,8 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
|
||||
const index_t KBatch,
|
||||
const Block2ETileMap& block_2_etile_map)
|
||||
{
|
||||
const auto p_a_grid = reinterpret_cast<const ABDataType*>(p_a_grid_);
|
||||
const auto p_b_grid = reinterpret_cast<const ABDataType*>(p_b_grid_);
|
||||
const auto p_a_grid = reinterpret_cast<const ADataType*>(p_a_grid_);
|
||||
const auto p_b_grid = reinterpret_cast<const BDataType*>(p_b_grid_);
|
||||
const auto p_e_grid = reinterpret_cast<EDataType*>(p_e_grid_);
|
||||
|
||||
using DsGridDesc_M_N =
|
||||
|
||||
Reference in New Issue
Block a user