From 9fed0adea86022a8ff1dfb437f6bda7335eff323 Mon Sep 17 00:00:00 2001 From: Jing Zhang Date: Tue, 22 Oct 2024 14:29:01 -0700 Subject: [PATCH] weight permute with splitki --- example/01_gemm/gemm_xdl_fp16_pk_i4_v3.cpp | 35 +++++++++++++------ .../grid/gridwise_gemm_xdl_cshuffle_v3.hpp | 10 +++--- 2 files changed, 30 insertions(+), 15 deletions(-) diff --git a/example/01_gemm/gemm_xdl_fp16_pk_i4_v3.cpp b/example/01_gemm/gemm_xdl_fp16_pk_i4_v3.cpp index c41be1e62a..6bda3cb511 100644 --- a/example/01_gemm/gemm_xdl_fp16_pk_i4_v3.cpp +++ b/example/01_gemm/gemm_xdl_fp16_pk_i4_v3.cpp @@ -39,6 +39,9 @@ using DeviceGemmV2Instance = 2, 32, 32, 1, 1, 1, S<1, 16, 1, 4>, 4, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1>; + + static int NPerBlock = 16; + static int KPerBlock = 256; #else 128, 16, 32, @@ -51,8 +54,11 @@ using DeviceGemmV2Instance = 2, 32, 32, 0, 1, 1, S<1, 16, 1, 8>, 4, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1>; + + static int NPerBlock = 32; + static int KPerBlock = 128; #endif - // clang-format on +// clang-format on using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm) { -#if 0 +#ifndef WEIGHT_PERMUTE b_k_split_offset = blockIdx.z * karg.KRead / BPackedSize; #else - const int k0_offset = karg.KRead * NPerBlock; + const int k0_offset = karg.KRead * karg.N; b_k_split_offset = blockIdx.z * k0_offset / BPackedSize; #endif }