Optimizing fp8_fp16 mixedprec gemm (#1150)

* add delayed cvt

* extend fp16 gemm_splitk instances for fp8_fp16 gemm

* add f8 example

* add 128 kperblk instances for fp8

* add kpb128 instance

* added more instances into kpb128

* clean code

* clean code

* fix

* fix

* fixed

* Update example/35_splitK_gemm/splitK_gemm_xdl_fp16_fp8.cpp

Co-authored-by: Bartłomiej Kocot <barkocot@amd.com>

* Update include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp

Co-authored-by: Bartłomiej Kocot <barkocot@amd.com>

* Update library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_fp8_f16_mk_nk_mn_kpb128_instance.cpp

Co-authored-by: Bartłomiej Kocot <barkocot@amd.com>

---------

Co-authored-by: Jing Zhang <jizha@amd.com>
Co-authored-by: Bartłomiej Kocot <barkocot@amd.com>
This commit is contained in:
zjing14
2024-02-12 11:45:42 -06:00
committed by GitHub
parent 94fbaac002
commit 602c4cc0d9
10 changed files with 370 additions and 131 deletions

View File

@@ -60,7 +60,9 @@ template <typename ADataType,
index_t CBlockTransferScalarPerVector_NWaveNPerXDL,
typename ComputeType = CDataType,
PipelineVersion PipelineVer = PipelineVersion::v1,
LoopScheduler LoopSched = make_default_loop_scheduler()>
LoopScheduler LoopSched = make_default_loop_scheduler(),
typename LDSTypeA = ComputeType,
typename LDSTypeB = ComputeType>
struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
BLayout,
@@ -81,6 +83,9 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
// TODO: should be exposed as Tparams.
static constexpr index_t NumGemmKPrefetchStage = 1;
using ComputeTypeA = ComputeType;
using ComputeTypeB = ComputeType;
using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2<
BlockSize,
ADataType,
@@ -125,7 +130,10 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
LoopSched,
PipelineVer,
ComputeType>;
ComputeTypeA,
ComputeTypeB,
LDSTypeA,
LDSTypeB>;
struct Argument : public GridwiseGemm::Argument
{