mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 21:21:22 +00:00
Optimizing fp8_fp16 mixedprec gemm (#1150)
* add delayed cvt * extend fp16 gemm_splitk instances for fp8_fp16 gemm * add f8 example * add 128 kperblk instances for fp8 * add kpb128 instance * added more instances into kpb128 * clean code * clean code * fix * fix * fixed * Update example/35_splitK_gemm/splitK_gemm_xdl_fp16_fp8.cpp Co-authored-by: Bartłomiej Kocot <barkocot@amd.com> * Update include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp Co-authored-by: Bartłomiej Kocot <barkocot@amd.com> * Update library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_fp8_f16_mk_nk_mn_kpb128_instance.cpp Co-authored-by: Bartłomiej Kocot <barkocot@amd.com> --------- Co-authored-by: Jing Zhang <jizha@amd.com> Co-authored-by: Bartłomiej Kocot <barkocot@amd.com>
This commit is contained in:
@@ -9,7 +9,6 @@
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
|
||||
@@ -96,7 +95,10 @@ template <index_t BlockSize,
|
||||
typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
LoopScheduler LoopSched = make_default_loop_scheduler(),
|
||||
PipelineVersion PipelineVer = PipelineVersion::v1,
|
||||
typename ComputeType = FloatC>
|
||||
typename ComputeTypeA = FloatC,
|
||||
typename ComputeTypeB = ComputeTypeA,
|
||||
typename LDSTypeA = ComputeTypeA,
|
||||
typename LDSTypeB = ComputeTypeB>
|
||||
struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
@@ -430,7 +432,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
|
||||
constexpr auto c_block_size =
|
||||
GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock().GetElementSpaceSize();
|
||||
|
||||
return math::max((a_block_space_size + b_block_space_size) * sizeof(ComputeType),
|
||||
return math::max(a_block_space_size * sizeof(LDSTypeA) +
|
||||
b_block_space_size * sizeof(LDSTypeB),
|
||||
c_block_size * sizeof(FloatC));
|
||||
}
|
||||
|
||||
@@ -785,7 +788,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
FloatA,
|
||||
ComputeType,
|
||||
LDSTypeA,
|
||||
decltype(a_b_k0_m_k1_grid_desc),
|
||||
decltype(a_b_k0_m_k1_block_desc),
|
||||
ABlockTransferSrcAccessOrder,
|
||||
@@ -815,7 +818,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
FloatB,
|
||||
ComputeType,
|
||||
LDSTypeB,
|
||||
decltype(b_b_k0_n_k1_grid_desc),
|
||||
decltype(b_b_k0_n_k1_block_desc),
|
||||
BBlockTransferSrcAccessOrder,
|
||||
@@ -845,8 +848,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
|
||||
|
||||
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
|
||||
BlockSize,
|
||||
ComputeType, // ComputeType A
|
||||
ComputeType, // ComputeType B
|
||||
LDSTypeA,
|
||||
LDSTypeB,
|
||||
FloatAcc,
|
||||
decltype(a_k0_m_k1_block_desc),
|
||||
decltype(b_k0_n_k1_block_desc),
|
||||
@@ -855,7 +858,9 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
K1,
|
||||
LoopSched>();
|
||||
LoopSched,
|
||||
ComputeTypeA,
|
||||
ComputeTypeB>();
|
||||
|
||||
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
|
||||
|
||||
@@ -863,8 +868,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
|
||||
constexpr auto a_block_space_size =
|
||||
math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
ComputeType* p_a_block = static_cast<ComputeType*>(p_shared_block);
|
||||
ComputeType* p_b_block = static_cast<ComputeType*>(p_shared_block) + a_block_space_size;
|
||||
auto p_a_block = reinterpret_cast<LDSTypeA*>(p_shared_block);
|
||||
auto p_b_block = reinterpret_cast<LDSTypeB*>(p_a_block + a_block_space_size);
|
||||
|
||||
constexpr auto a_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0);
|
||||
constexpr auto b_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0);
|
||||
|
||||
Reference in New Issue
Block a user