diff --git a/example/01_gemm/gemm_xdl_fp16_pk_i4_v3.cpp b/example/01_gemm/gemm_xdl_fp16_pk_i4_v3.cpp index c41be1e62a..6bda3cb511 100644 --- a/example/01_gemm/gemm_xdl_fp16_pk_i4_v3.cpp +++ b/example/01_gemm/gemm_xdl_fp16_pk_i4_v3.cpp @@ -39,6 +39,9 @@ using DeviceGemmV2Instance = 2, 32, 32, 1, 1, 1, S<1, 16, 1, 4>, 4, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1>; + + static int NPerBlock = 16; + static int KPerBlock = 256; #else 128, 16, 32, @@ -51,8 +54,11 @@ using DeviceGemmV2Instance = 2, 32, 32, 0, 1, 1, S<1, 16, 1, 8>, 4, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1>; + + static int NPerBlock = 32; + static int KPerBlock = 128; #endif - // clang-format on +// clang-format on using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm) { -#if 0 +#ifndef WEIGHT_PERMUTE b_k_split_offset = blockIdx.z * karg.KRead / BPackedSize; #else - const int k0_offset = karg.KRead * NPerBlock; + const int k0_offset = karg.KRead * karg.N; b_k_split_offset = blockIdx.z * k0_offset / BPackedSize; #endif }