diff --git a/CMakeLists.txt b/CMakeLists.txt index 37962c14e3..8f31267b64 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -246,13 +246,6 @@ if(NOT WIN32 AND ${hip_VERSION_FLAT} GREATER 500500000) add_compile_options("SHELL: -mllvm --lsr-drop-solution=1") endif() endif() -if(NOT WIN32 AND ${hip_VERSION_FLAT} GREATER 600140090) - check_cxx_compiler_flag("-mllvm -enable-post-misched=0" HAS_ENABLE_POST_MISCHED) - if(HAS_ENABLE_POST_MISCHED) - message("Adding the enable-post-misched=0 compiler flag") - add_compile_options("SHELL: -mllvm -enable-post-misched=0") - endif() -endif() set(check-coerce) check_cxx_compiler_flag(" -mllvm -amdgpu-coerce-illegal-types=1" check-coerce) if(NOT WIN32 AND check-coerce AND ${hip_VERSION_FLAT} GREATER 600241132) @@ -534,7 +527,6 @@ include_directories(BEFORE ${HIP_INCLUDE_DIRS} ) -SET(BUILD_DEV ON CACHE BOOL "BUILD_DEV") message("CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}") if("${CMAKE_CXX_COMPILER_ID}" MATCHES "Clang") diff --git a/Jenkinsfile b/Jenkinsfile index a3a637666f..80392bfbed 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -722,6 +722,9 @@ CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;ROCM pipeline { agent none + triggers { + parameterizedCron(CRON_SETTINGS) + } options { parallelsAlwaysFailFast() } diff --git a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_ab_scale.cpp b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_ab_scale.cpp index 9b7849a654..b54ba5ddfb 100644 --- a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_ab_scale.cpp +++ b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_ab_scale.cpp @@ -55,7 +55,7 @@ using CDEElementOp = PassThrough; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; -static constexpr ck::index_t Scale_Block_M = 128; +static constexpr ck::index_t Scale_Block_M = 1; static constexpr ck::index_t Scale_Block_N = 128; static constexpr ck::index_t Scale_Block_K = 128; @@ -65,14 +65,14 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_ABScale_ A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, Scale_Block_M, Scale_Block_N, Scale_Block_K, - 128, 128, - 128, 16, 16, + 16, 128, + 256, 16, 16, 16, 16, - 4, 4, - S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, - S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, - 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, - ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, FP8>; + 1, 2, + S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, + S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, + 1, 2, S<1, 16, 1, 16>, S<8>, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, FP8>; // clang-format on int main(int argc, char* argv[]) @@ -80,11 +80,12 @@ int main(int argc, char* argv[]) bool do_verification = true; int init_method = 1; bool time_kernel = false; + bool flush_cache = true; // GEMM shape - ck::index_t M = 3840; - ck::index_t N = 4096; - ck::index_t K = 4096; + ck::index_t M = 128; + ck::index_t N = 1024; + ck::index_t K = 1024; ck::index_t StrideA = K; ck::index_t StrideB = K; @@ -100,7 +101,7 @@ int main(int argc, char* argv[]) init_method = std::stoi(argv[2]); time_kernel = std::stoi(argv[3]); } - else if(argc == 10) + else if(argc == 8) { do_verification = std::stoi(argv[1]); init_method = std::stoi(argv[2]); @@ -110,16 +111,19 @@ int main(int argc, char* argv[]) N = std::stoi(argv[5]); K = std::stoi(argv[6]); - StrideA = std::stoi(argv[7]); - StrideB = std::stoi(argv[8]); - StrideE = std::stoi(argv[9]); + flush_cache = std::stoi(argv[7]); + + StrideA = K; + StrideB = K; + StrideE = N; } else { printf("arg1: verification (0=no, 1=yes)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); printf("arg3: time kernel (0=no, 1=yes)\n"); - printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideE\n"); + printf("arg4 to 6: M, N, K\n"); + printf("arg7: flush both I$ and L2$ (0=no, 1=yes)\n"); exit(0); } @@ -182,9 +186,15 @@ int main(int argc, char* argv[]) b1_k_n.GenerateTensorValue(GeneratorTensor_1{}); break; case 4: - a0_m_k.GenerateTensorValue(GeneratorTensor_1{}); - b0_k_n.GenerateTensorValue(GeneratorTensor_1{}); + a0_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); a1_m_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + b1_k_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 5: + a0_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a1_m_k.GenerateTensorValue(GeneratorTensor_1{}); b1_k_n.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); break; default: @@ -194,6 +204,16 @@ int main(int argc, char* argv[]) b1_k_n.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); } #endif +#if 0 + for(int im =0; im< (M + Scale_Block_M - 1) / Scale_Block_M; im++){ + float row_sum = .0; + for(int ik =0; ik< (K + Scale_Block_K - 1) / Scale_Block_K; ik++){ + printf("%lf ",a1_m_k(im, ik)); + row_sum += a1_m_k(im, ik); + } + printf("sum: %lf\n", row_sum * 128); + } +#endif DeviceMem a0_device_buf(sizeof(A0DataType) * a0_m_k.mDesc.GetElementSpaceSize()); DeviceMem a1_device_buf(sizeof(A1DataType) * a1_m_k.mDesc.GetElementSpaceSize()); @@ -239,12 +259,24 @@ int main(int argc, char* argv[]) "not support this GEMM problem"); } - float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel, 20, 50}); - std::size_t flop = std::size_t(2) * M * N * K; std::size_t num_btype = sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + sizeof(EDataType) * M * N; + float ave_time = .0; + + if(flush_cache) + { + int rotating_buf = (512 * 1024 * 1024 + num_btype - 1) / num_btype; + + ave_time = invoker.Run(argument, + StreamConfig{nullptr, time_kernel, 0, 50, 100, true, rotating_buf}); + } + else + { + ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel, 0, 50, 100}); + } + float tflops = static_cast(flop) / 1.E9 / ave_time; float gb_per_sec = num_btype / 1.E6 / ave_time; diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py index 4c23250d05..6326a97f8e 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py @@ -176,7 +176,8 @@ float fmha_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a) ); }} -float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& s){{ +template <> +float fmha_bwd<2>(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& s){{ float r = -1; {F_dispatch} return r; @@ -412,14 +413,26 @@ class FmhaBwdDQDKDVKernel: pn = pad_name() n = f"fmha_bwd_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + self.F_tile.name + f'_{self.F_pipeline}' if pn != '' : n += f'_{pn}' + else: n += '_npad' + if self.F_bias != 'no' : n += f'_{self.F_bias}' + else: n += '_nbias' + if self.F_dbias == 't' : n += '_dbias' + else: n += '_ndbias' + if self.F_mask[0:2] == 's_': if self.F_mask == 's_mask': n += f'_mask' + else: n += '_nmask' else: if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}' + else: n += '_nmask' + if self.F_dropout != 'no' : n += f'_{self.F_dropout}' + else: n += '_ndropout' + if self.F_deterministic == 't' : n += '_deterministic' + else: n += '_ndeterministic' return n @property @@ -489,7 +502,7 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> F_spad=spad, F_skpad=skpad, F_dpad=dpad, F_dvpad=dvpad, F_bias=bias, F_dbias=dbias, F_dropout=dropout, F_mask=mask, F_mode=mode, F_pipeline=ppl, mask_impl=mask_impl, F_deterministic=deterministic) - if kernel_filter != None: + if kernel_filter != '': if not fnmatch.fnmatch(k.name, kernel_filter): continue # Flash attention integration @@ -517,23 +530,19 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> if not cond: continue # Aiter (mha_bwd) integration - elif receipt == 10: + elif receipt == 300: cond = dtype in ['fp16', 'bf16'] cond &= mode == "batch" - cond &= bias in ['no', 'alibi'] cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16'] cond &= dpad == dvpad - cond &= deterministic == "t" if not cond: continue # Aiter (mha_varlen_bwd) integration - elif receipt == 11: + elif receipt == 400: cond = dtype in ['fp16', 'bf16'] cond &= mode == "group" - cond &= bias in ['no', 'alibi'] cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16'] cond &= dpad == dvpad - cond &= deterministic == "t" if not cond: continue api_pool.register_dq_dk_dv_traits(k.api_trait()) @@ -632,13 +641,14 @@ class FmhaBwdOGradDotOKernel: pn = pad_name() n = f"fmha_bwd_dot_do_o_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_o{self.F_occupancy}" if pn != '' : n += f'_{pn}' + else: n += '_npad' return n @property def filename(self) -> str: return self.name + ".cpp" -def get_bwd_dot_do_o_blobs() -> List[FmhaBwdOGradDotOKernel]: +def get_bwd_dot_do_o_blobs(kernel_filter : Optional[str], receipt) -> List[FmhaBwdOGradDotOKernel]: # TODO: we don't support tuning yet, so pick up one value for pad/occupancy # support this in future def get_occupancy(dtype, hdim): @@ -657,6 +667,21 @@ def get_bwd_dot_do_o_blobs() -> List[FmhaBwdOGradDotOKernel]: k = FmhaBwdOGradDotOKernel(F_idx=0, F_hdim=hdim, F_dtype=dtype, F_spad=spad, F_dvpad=dvpad, F_mode=mode, F_occupancy=get_occupancy(dtype, hdim)) + if kernel_filter != '': + if not fnmatch.fnmatch(k.name, kernel_filter): + continue + # Aiter (mha_bwd) integration + if receipt == 300: + cond = dtype in ['fp16', 'bf16'] + cond &= mode == "batch" + if not cond: + continue + # Aiter (mha_varlen_bwd) integration + elif receipt == 400: + cond = dtype in ['fp16', 'bf16'] + cond &= mode == "group" + if not cond: + continue gen.append(k) return gen @@ -766,14 +791,16 @@ class FmhaBwdConvertQGradKernel: pn = pad_name() n = f"fmha_bwd_convert_dq_d{self.F_hdim}_{self.F_dtype}_b{self.F_bm0}x{self.F_bn0}_{self.F_mode}_o{self.F_occupancy}" if pn != '' : n += f'_{pn}' - if self.F_deterministic == 't' : n += f'_deterministic' + else: n += '_npad' + if self.F_deterministic == 't' : n += '_deterministic' + else: n += '_ndeterministic' return n @property def filename(self) -> str: return self.name + ".cpp" -def get_bwd_convert_dq_blobs() -> List[FmhaBwdConvertQGradKernel]: +def get_bwd_convert_dq_blobs(kernel_filter : Optional[str], receipt) -> List[FmhaBwdConvertQGradKernel]: # TODO: we don't support tuning yet, so pick up one value for pad/occupancy # support this in future def get_occupancy(dtype, hdim): @@ -792,6 +819,21 @@ def get_bwd_convert_dq_blobs() -> List[FmhaBwdConvertQGradKernel]: continue k = FmhaBwdConvertQGradKernel(F_idx=0, F_hdim=hdim, F_dtype=dtype, F_bm0=64, F_bn0=tile.F_bn0, F_spad=spad, F_dpad=dpad, F_mode=mode, F_occupancy=get_occupancy(dtype, hdim), F_deterministic=deterministic) + if kernel_filter != '': + if not fnmatch.fnmatch(k.name, kernel_filter): + continue + # Aiter (mha_bwd) integration + if receipt == 300: + cond = dtype in ['fp16', 'bf16'] + cond &= mode == "batch" + if not cond: + continue + # Aiter (mha_varlen_bwd) integration + elif receipt == 400: + cond = dtype in ['fp16', 'bf16'] + cond &= mode == "group" + if not cond: + continue gen.append(k) return gen @@ -808,27 +850,33 @@ def write_single_bwd_convert_dq_kernel(kernel: FmhaBwdConvertQGradKernel, autoge def write_bwd_api(api_pool : FmhaBwdApiPool, autogen_dir: Path) -> None: (autogen_dir / FMHA_BWD_API_FILENAME).write_text(api_pool.api) -def write_blobs(output_dir : Path, kernel_filter : Optional[str], receipt, mask_impl) -> None: - kernels = get_bwd_dot_do_o_blobs() +def write_blobs(output_dir : Path, filter_list : str, receipt, mask_impl) -> None: + filter_list = filter_list.split('@') + filter_list.extend([''] * (3 - len(filter_list))) + + kernels = get_bwd_dot_do_o_blobs(filter_list[0], receipt) for kernel in kernels: write_single_bwd_dot_do_o_kernel(kernel, output_dir) - kernels = get_bwd_convert_dq_blobs() + kernels = get_bwd_convert_dq_blobs(filter_list[1], receipt) for kernel in kernels: write_single_bwd_convert_dq_kernel(kernel, output_dir) - api_pool, kernels = get_bwd_dq_dk_dv_blobs(kernel_filter, receipt, mask_impl) + api_pool, kernels = get_bwd_dq_dk_dv_blobs(filter_list[2], receipt, mask_impl) for kernel in kernels: write_single_bwd_dq_dk_dv_kernel(kernel, output_dir) write_bwd_api(api_pool, output_dir) -def list_blobs(file_path : Path, kernel_filter : Optional[str], receipt, mask_impl) -> None: +def list_blobs(file_path : Path, filter_list : str, receipt, mask_impl) -> None: + filter_list = filter_list.split('@') + filter_list.extend([''] * (3 - len(filter_list))) + with file_path.open('a') as f: - kernels = get_bwd_dot_do_o_blobs() + kernels = get_bwd_dot_do_o_blobs(filter_list[0], receipt) for kernel in kernels: f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") - kernels = get_bwd_convert_dq_blobs() + kernels = get_bwd_convert_dq_blobs(filter_list[1], receipt) for kernel in kernels: f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") - _, kernels = get_bwd_dq_dk_dv_blobs(kernel_filter, receipt, mask_impl) + _, kernels = get_bwd_dq_dk_dv_blobs(filter_list[2], receipt, mask_impl) for kernel in kernels: f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") f.write(str(file_path.parent / GEN_DIR / FMHA_BWD_API_FILENAME) + "\n") diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index b72627ed5d..f2d9216696 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -233,14 +233,26 @@ class FmhaFwdPipeline: pn = pad_name() n = f'{self.tag}_v{self.F_vlayout[0]}' if pn != '' : n += f'_{pn}' + else: n += '_npad' + if self.F_bias != 'no' : n += f'_{self.F_bias}' + else: n += '_nbias' + if self.F_mask[0:2] == 's_': if self.F_mask == 's_mask': n += f'_mask' + else: n += '_nmask' else: if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}' + else: n += '_nmask' + if self.F_lse == 't' : n += '_lse' + else: n += '_nlse' + if self.F_dropout == 't' : n += '_dropout' + else: n += '_ndropout' + if self.F_squant == 't' : n += '_squant' + else: n += '_nsquant' return n class FmhaFwdApiPool: @@ -484,7 +496,7 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm F_tile=tile, F_pipeline=pipeline, mask_impl=mask_impl) - if kernel_filter != None: + if kernel_filter != '': if not fnmatch.fnmatch(k.name, kernel_filter): continue # 2 - Flash attention integration @@ -504,20 +516,18 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm if not cond: continue # Aiter(mha_fwd) integration - elif receipt == 10: + elif receipt == 100: cond = dtype in ['fp16', 'bf16'] - cond &= mode == "batch" + cond &= mode == 'batch' cond &= pipeline.F_vlayout == 'row' - cond &= pipeline.F_bias in ['no', 'alibi'] cond &= pipeline.F_squant == 'f' if not cond: continue # Aiter(mha_varlen_fwd) integration - elif receipt == 11: + elif receipt == 200: cond = dtype in ['fp16', 'bf16'] - cond &= mode == "group" + cond &= mode == 'group' cond &= pipeline.F_vlayout == 'row' - cond &= pipeline.F_bias in ['no', 'alibi'] cond &= pipeline.F_squant == 'f' if not cond: continue @@ -532,13 +542,13 @@ def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None: def write_fwd_api(api_pool : FmhaFwdApiPool, autogen_dir: Path) -> None: (autogen_dir / FMHA_FWD_API_FILENAME).write_text(api_pool.api) -def write_blobs(output_dir : Path, kernel_filter : Optional[str], receipt, mask_impl) -> None: +def write_blobs(output_dir : Path, kernel_filter : str, receipt, mask_impl) -> None: api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, mask_impl) for kernel in kernels: write_single_fwd_kernel(kernel, output_dir) write_fwd_api(api_pool, output_dir) -def list_blobs(file_path : Path, kernel_filter : Optional[str], receipt, mask_impl) -> None: +def list_blobs(file_path : Path, kernel_filter : str, receipt, mask_impl) -> None: with file_path.open('a') as f: _, kernels = get_fwd_blobs(kernel_filter, receipt, mask_impl) for kernel in kernels: diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py index f8a89448ba..16048e3fb6 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py @@ -323,12 +323,11 @@ def get_fwd_appendkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> F_tile=tile, F_pipeline=pipeline, mask_impl=mask_impl) - if kernel_filter != None: + if kernel_filter != '': if not fnmatch.fnmatch(k.name, kernel_filter): continue # 2 - Flash attention integration - # 12 - Aiter(mha_fwd_kvcache) integration - if receipt in (2, 12): + if receipt == 2: cond = dtype in ['fp16', 'bf16'] cond &= pipeline.F_vlayout == 'row' if not cond: diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py index c0ca666b11..ba555df88d 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py @@ -397,14 +397,26 @@ class FmhaFwdSplitKVPipeline: pn = pad_name() n = f'{self.tag}_v{self.F_vlayout[0]}' if pn != '' : n += f'_{pn}' + else: n += '_npad' + if self.F_bias != 'no' : n += f'_{self.F_bias}' + else: n += '_nbias' + if self.F_mask[0:2] == 's_': if self.F_mask == 's_mask': n += f'_mask' + else: n += '_nmask' else: if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}' + else: n += '_nmask' + if self.F_lse == 't' : n += '_lse' + else: n += '_nlse' + if self.F_squant == 't' : n += '_squant' + else: n += '_nsquant' + if self.F_pagedkv == 't' : n += '_pagedkv' + else: n += '_npagedkv' return n @dataclass @@ -702,7 +714,7 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> F_tile=tile, F_pipeline=pipeline, mask_impl=mask_impl) - if kernel_filter != None: + if kernel_filter != '': if not fnmatch.fnmatch(k.name, kernel_filter): continue # Flash attention integration @@ -714,20 +726,10 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> if not cond: continue # Aiter(mha_varlen_fwd) integration - elif receipt == 11: + elif receipt == 200: cond = dtype in ['fp16', 'bf16'] cond &= mode == "group" cond &= pipeline.F_vlayout == 'row' - cond &= pipeline.F_bias in ['no', 'alibi'] - cond &= pipeline.F_squant == 'f' - if not cond: - continue - # Aiter(mha_fwd_kvcache) integration - elif receipt == 12: - cond = dtype in ['fp16', 'bf16'] - cond &= mode == "batch" - cond &= pipeline.F_vlayout == 'row' - cond &= pipeline.F_bias in ['no', 'alibi'] cond &= pipeline.F_squant == 'f' if not cond: continue @@ -780,9 +782,15 @@ def get_fwd_splitkv_combine_blobs(kernel_filter : Optional[str], receipt) -> Lis F_mode=mode, F_tile=tile, F_pipeline=pipeline) - if kernel_filter != None: + if kernel_filter != '': if not fnmatch.fnmatch(k.name, kernel_filter): continue + # Aiter(mha_varlen_fwd) integration + if receipt == 200: + cond = dtype in ['fp16', 'bf16'] + cond &= mode == "group" + if not cond: + continue gen.append(k) return gen @@ -794,21 +802,27 @@ def write_fwd_splitkv_api(api_pool : FmhaFwdSplitKVApiPool, autogen_dir: Path) - file_path = autogen_dir / FMHA_FWD_SPLITKV_API_FILENAME file_path.write_text(api_pool.api) -def write_blobs(output_dir : Path, kernel_filter : Optional[str], receipt, mask_impl) -> None: - kernels = get_fwd_splitkv_combine_blobs(kernel_filter, receipt) +def write_blobs(output_dir : Path, filter_list : str, receipt, mask_impl) -> None: + filter_list = filter_list.split('@') + filter_list.extend([''] * (2 - len(filter_list))) + + kernels = get_fwd_splitkv_combine_blobs(filter_list[0], receipt) for kernel in kernels: write_single_kernel(kernel, output_dir) - api_pool, kernels = get_fwd_splitkv_blobs(kernel_filter, receipt, mask_impl) + api_pool, kernels = get_fwd_splitkv_blobs(filter_list[1], receipt, mask_impl) for kernel in kernels: write_single_kernel(kernel, output_dir) write_fwd_splitkv_api(api_pool, output_dir) -def list_blobs(file_path : Path, kernel_filter : Optional[str], receipt, mask_impl) -> None: +def list_blobs(file_path : Path, filter_list : str, receipt, mask_impl) -> None: + filter_list = filter_list.split('@') + filter_list.extend([''] * (2 - len(filter_list))) + with file_path.open('a') as f: - kernels = get_fwd_splitkv_combine_blobs(kernel_filter, receipt) + kernels = get_fwd_splitkv_combine_blobs(filter_list[0], receipt) for kernel in kernels: f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") - _, kernels = get_fwd_splitkv_blobs(kernel_filter, receipt, mask_impl) + _, kernels = get_fwd_splitkv_blobs(filter_list[1], receipt, mask_impl) for kernel in kernels: f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_SPLITKV_API_FILENAME) + "\n") diff --git a/example/ck_tile/01_fmha/fmha_bwd.hpp b/example/ck_tile/01_fmha/fmha_bwd.hpp index 6204cbcfa8..9179dbd9be 100644 --- a/example/ck_tile/01_fmha/fmha_bwd.hpp +++ b/example/ck_tile/01_fmha/fmha_bwd.hpp @@ -452,4 +452,5 @@ struct fmha_bwd_traits bool is_deterministic; // TODO: padding check is inside this api }; +template float fmha_bwd(fmha_bwd_traits, fmha_bwd_args, const ck_tile::stream_config&); diff --git a/example/ck_tile/01_fmha/generate.py b/example/ck_tile/01_fmha/generate.py index 0c2cef1ce7..0d35db14d4 100644 --- a/example/ck_tile/01_fmha/generate.py +++ b/example/ck_tile/01_fmha/generate.py @@ -30,7 +30,7 @@ handlers = dict( ) assert 0 < len(handlers) -def write_blobs(output_dir: Optional[str], api_list : List[str], kernel_filter : Optional[str], receipt, mask_impl) -> None: +def write_blobs(output_dir: Optional[str], api_list : List[str], filters_list : List[str], receipt, mask_impl) -> None: if output_dir is None: output_dir = Path(__file__).parent else: @@ -38,19 +38,19 @@ def write_blobs(output_dir: Optional[str], api_list : List[str], kernel_filter : output_dir.mkdir(parents=True, exist_ok=True) - for api in api_list: + for api, kernel_filter in zip(api_list, filters_list): handler = handlers[api][HandlerId.WRITE_BLOBS] handler(output_dir, kernel_filter, receipt, mask_impl) # list all the files that will be generated -def list_blobs(output_file : Optional[str], api_list : List[str], kernel_filter : Optional[str], receipt, mask_impl) -> None: +def list_blobs(output_file : Optional[str], api_list : List[str], filters_list : List[str], receipt, mask_impl) -> None: assert output_file is not None file_path = Path(output_file) # create an empty file / drop its contents if it exists open(file_path, "w").close() - for api in api_list: + for api, kernel_filter in zip(api_list, filters_list): handler = handlers[api][HandlerId.LIST_BLOBS] handler(file_path, kernel_filter, receipt, mask_impl) @@ -84,6 +84,7 @@ if __name__ == "__main__": parser.add_argument( "-f", "--filter", + default='', required=False, help="filter out kernels that need to generate, using fnmatch module" ) @@ -105,15 +106,19 @@ if __name__ == "__main__": " 1: generate more instance to cover all hdim\n" + \ " 2: Only generate instance for Flash attention integration\n" + \ " 4: Only generate instance for PyTorch integration\n" + \ - " 10: Only generate instance for Aiter(mha_fwd, mha_bwd) integration\n" + \ - " 11: Only generate instance for Aiter(mha_varlen_fwd, mha_varlen_bwd) integration\n" + \ - " 12: Only generate instance for Aiter(mha_fwd_kvcache) integration" - + " 100-199: Only generate instance for Aiter(mha_fwd) integration\n" + \ + " 200-299: Only generate instance for Aiter(mha_varlen_fwd) integration\n" + \ + " 300-399: Only generate instance for Aiter(mha_bwd) integration\n" + \ + " 400-499: Only generate instance for Aiter(mha_varlen_bwd) integration" + ) args = parser.parse_args() api_list = args.direction.split(',') + filter_list = args.filter.split(',') + filter_list.extend([''] * (len(api_list) - len(filter_list))) + if args.list_blobs is not None: - list_blobs(args.list_blobs, api_list, args.filter, int(args.receipt), mask_impl=args.mask) + list_blobs(args.list_blobs, api_list, filter_list, int(args.receipt), mask_impl=args.mask) else: - write_blobs(args.output_dir, api_list, args.filter, int(args.receipt), mask_impl=args.mask) + write_blobs(args.output_dir, api_list, filter_list, int(args.receipt), mask_impl=args.mask) diff --git a/example/ck_tile/03_gemm/gemm_basic.cpp b/example/ck_tile/03_gemm/gemm_basic.cpp index 5dc7b9cd0b..57298b68dc 100644 --- a/example/ck_tile/03_gemm/gemm_basic.cpp +++ b/example/ck_tile/03_gemm/gemm_basic.cpp @@ -10,7 +10,7 @@ #include #include "ck_tile/host.hpp" -#include "gemm_basic.hpp" +#include "gemm_utils.hpp" template -struct GemmBasicTypeConfig; +struct GemmTypeConfig; template <> -struct GemmBasicTypeConfig +struct GemmTypeConfig { using ADataType = ck_tile::half_t; using BDataType = ck_tile::half_t; @@ -49,7 +114,7 @@ struct GemmBasicTypeConfig }; template <> -struct GemmBasicTypeConfig +struct GemmTypeConfig { using ADataType = ck_tile::bf16_t; using BDataType = ck_tile::bf16_t; @@ -58,7 +123,7 @@ struct GemmBasicTypeConfig }; template <> -struct GemmBasicTypeConfig +struct GemmTypeConfig { using ADataType = ck_tile::fp8_t; using BDataType = ck_tile::fp8_t; @@ -67,7 +132,7 @@ struct GemmBasicTypeConfig }; template <> -struct GemmBasicTypeConfig +struct GemmTypeConfig { using ADataType = ck_tile::bf8_t; using BDataType = ck_tile::bf8_t; @@ -76,7 +141,7 @@ struct GemmBasicTypeConfig }; template <> -struct GemmBasicTypeConfig +struct GemmTypeConfig { using ADataType = ck_tile::half_t; using BDataType = ck_tile::pk_int4_t; diff --git a/example/ck_tile/03_gemm/run_gemm_example.inc b/example/ck_tile/03_gemm/run_gemm_example.inc index f068cbc1da..6cb40e45d1 100644 --- a/example/ck_tile/03_gemm/run_gemm_example.inc +++ b/example/ck_tile/03_gemm/run_gemm_example.inc @@ -29,8 +29,67 @@ auto calculate_rtol_atol(const ck_tile::index_t K, // Use higher threshold return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); } -template + +template void permute_tensor_b(Tensor& tensor) +{ + using GemmShape = ck_tile::TileGemmShape< + ck_tile::sequence, + ck_tile::sequence, + ck_tile:: + sequence, + GemmConfig::PermuteA, + GemmConfig::PermuteB>; + + using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits; + + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + + using GemmPipeline = GEMM_PIPELINE; + + const ck_tile::index_t K = tensor.get_length(0); + const ck_tile::index_t N = tensor.get_length(1); + const ck_tile::index_t K1 = GemmPipeline::GetSmemPackB(); + const ck_tile::index_t K0 = K / K1; + + Tensor tensor_copy = tensor; + + // int K0, N, K1 + for(int j = 0; j < K0; j++) + { + for(int i = 0; i < N; i++) + { + for(int jj = 0; jj < K1; jj++) + { + tensor(j * N * K1 + i * K1 + jj) = tensor_copy(i * K + (j * K1 + jj)); + } + } + } +} + +template +void permute_vectors_i4x4_b(Tensor& tensor) { const ck_tile::index_t K = tensor.get_length(0); const ck_tile::index_t N = tensor.get_length(1); @@ -153,7 +212,7 @@ int run_gemm_example_with_layouts(int argc, if(!result) return -1; - using AccDataType = typename GemmBasicTypeConfig::AccDataType; + using AccDataType = typename GemmTypeConfig::AccDataType; ck_tile::index_t M = arg_parser.get_int("m"); ck_tile::index_t N = arg_parser.get_int("n"); @@ -181,8 +240,8 @@ int run_gemm_example_with_layouts(int argc, if(init_method == 0) { - ck_tile::FillUniformDistribution{-1.f, 1.f}(a_m_k); - ck_tile::FillUniformDistribution{-1.f, 1.f}(b_k_n); + ck_tile::FillUniformDistribution{-5.f, 5.f}(a_m_k); + ck_tile::FillUniformDistribution{-5.f, 5.f}(b_k_n); } else if(init_method == 1) { @@ -204,18 +263,36 @@ int run_gemm_example_with_layouts(int argc, ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes()); ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes()); - a_m_k_dev_buf.ToDevice(a_m_k.data()); + static_assert(!GemmConfig::PermuteA, "Not implemented"); if constexpr(std::is_same_v) { - // Permute data for device implementation + // Permute vector pk_i4x4 data for device implementation ck_tile::HostTensor b_k_n_dev = b_k_n; - permute_tensor_b(b_k_n_dev); + if constexpr(GemmConfig::PermuteB) + { + permute_tensor_b(b_k_n_dev); + } + permute_vectors_i4x4_b(b_k_n_dev); b_k_n_dev_buf.ToDevice(b_k_n_dev.data()); } else { + if constexpr(GemmConfig::PermuteB) + { + std::cout << "Permute for this DataType is not implemented." << std::endl; + return false; + } b_k_n_dev_buf.ToDevice(b_k_n.data()); } + + a_m_k_dev_buf.ToDevice(a_m_k.data()); c_m_n_dev_buf.SetZero(); c_m_n_dev_result.SetZero(); diff --git a/example/ck_tile/03_gemm/universal_gemm.cpp b/example/ck_tile/03_gemm/universal_gemm.cpp index ab763437e5..8c04066b20 100644 --- a/example/ck_tile/03_gemm/universal_gemm.cpp +++ b/example/ck_tile/03_gemm/universal_gemm.cpp @@ -10,7 +10,7 @@ #include #include "ck_tile/host.hpp" -#include "gemm_basic.hpp" +#include "gemm_utils.hpp" template float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) { -#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY) - // Memory friendly for Interwave scheduler - constexpr ck_tile::index_t M_Tile = 128; - constexpr ck_tile::index_t N_Tile = 32; - constexpr ck_tile::index_t K_Tile = 64; + using GemmShape = ck_tile::TileGemmShape< + ck_tile::sequence, + ck_tile::sequence, + ck_tile:: + sequence, + GemmConfig::PermuteA, + GemmConfig::PermuteB>; + using TilePartitioner = + ck_tile::GemmSpatiallyLocalTilePartitioner; - constexpr ck_tile::index_t M_Warp = 4; - constexpr ck_tile::index_t N_Warp = 1; - constexpr ck_tile::index_t K_Warp = 1; - - constexpr ck_tile::index_t M_Warp_Tile = 32; - constexpr ck_tile::index_t N_Warp_Tile = 32; - constexpr ck_tile::index_t K_Warp_Tile = 8; - - constexpr bool DoubleSmemBuffer = false; -#endif -#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3) - // Compute friendly for Intrawave scheduler - constexpr ck_tile::index_t M_Tile = 256; - constexpr ck_tile::index_t N_Tile = 256; - constexpr ck_tile::index_t K_Tile = 64; - - constexpr ck_tile::index_t M_Warp = 2; - constexpr ck_tile::index_t N_Warp = 2; - constexpr ck_tile::index_t K_Warp = 1; - - constexpr ck_tile::index_t M_Warp_Tile = 32; - constexpr ck_tile::index_t N_Warp_Tile = 32; - constexpr ck_tile::index_t K_Warp_Tile = 16; - - constexpr bool DoubleSmemBuffer = false; -#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4) - // Compute friendly for Intrawave scheduler - // Using the ping pong reader in the lds level - constexpr ck_tile::index_t M_Tile = 256; - constexpr ck_tile::index_t N_Tile = 256; - constexpr ck_tile::index_t K_Tile = 32; - - constexpr ck_tile::index_t M_Warp = 2; - constexpr ck_tile::index_t N_Warp = 2; - constexpr ck_tile::index_t K_Warp = 1; - - constexpr ck_tile::index_t M_Warp_Tile = 32; - constexpr ck_tile::index_t N_Warp_Tile = 32; - constexpr ck_tile::index_t K_Warp_Tile = 16; - - constexpr bool DoubleSmemBuffer = true; -#endif - - constexpr bool kPadM = false; - constexpr bool kPadN = false; - constexpr bool kPadK = false; - - constexpr bool TransposeC = false; - - constexpr int kBlockPerCu = 1; - constexpr ck_tile::index_t TileParitionerGroupNum = 8; - constexpr ck_tile::index_t TileParitionerM01 = 4; - - // =============================================== - - using GemmShape = - ck_tile::TileGemmShape, - ck_tile::sequence, - ck_tile::sequence>; - using TilePartitioner = ck_tile:: - GemmSpatiallyLocalTilePartitioner; - - using Traits = ck_tile::TileGemmTraits; - using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits; + using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits; + GemmConfig::TransposeC>; using GemmPipelineProblem = ck_tile::GemmPipelineProblem; using BaseGemmPipeline = UNIVERSAL_GEMM_PIPELINE; - const ck_tile::index_t k_grain = args.k_batch * K_Tile; - const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * K_Tile; + const ck_tile::index_t k_grain = args.k_batch * GemmConfig::K_Tile; + const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * GemmConfig::K_Tile; const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); @@ -133,11 +82,11 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& GemmPipelineProblem::kBlockSize, TilePartitioner::MPerBlock, TilePartitioner::NPerBlock, - M_Warp, - N_Warp, - M_Warp_Tile, - N_Warp_Tile, - K_Warp_Tile, + GemmConfig::M_Warp, + GemmConfig::N_Warp, + GemmConfig::M_Warp_Tile, + GemmConfig::N_Warp_Tile, + GemmConfig::K_Warp_Tile, UniversalGemmProblem::TransposeC>>; using Kernel = ck_tile::GemmKernel; auto kargs = Kernel::MakeKernelArgs(args); @@ -158,8 +107,9 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& << std::endl; } - ave_time = ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + ave_time = ck_tile::launch_kernel(s, + ck_tile::make_kernel( + Kernel{}, grids, blocks, 0, kargs)); return ave_time; }; diff --git a/example/ck_tile/15_fused_moe/fused_moe.hpp b/example/ck_tile/15_fused_moe/fused_moe.hpp index 1f2246fa4a..b354d1d347 100644 --- a/example/ck_tile/15_fused_moe/fused_moe.hpp +++ b/example/ck_tile/15_fused_moe/fused_moe.hpp @@ -17,6 +17,9 @@ struct fused_moe_args const void* y_smooth_scale_ptr; // [e, 1, n], smooth-quant-scale for 2nd gemm input const void* local_expert_mask_ptr; // [e], local_expert_mask_ptr for EP void* o_ptr; // [m, k], output token (no need to do zeroing) + void* ws_ptr; // size is moe_sorting_get_workspace_size() + // if return zero, then could be nullptr + // must be cleard before use const void* topk_ids_ptr; // [tokens, topk] const void* topk_weight_ptr; // [tokens, topk] diff --git a/example/ck_tile/15_fused_moe/instances/fused_moe_api.cpp b/example/ck_tile/15_fused_moe/instances/fused_moe_api.cpp index cf9ff2edba..466420f066 100644 --- a/example/ck_tile/15_fused_moe/instances/fused_moe_api.cpp +++ b/example/ck_tile/15_fused_moe/instances/fused_moe_api.cpp @@ -27,6 +27,7 @@ float fused_moe(fused_moe_traits t, fused_moe_args a, const ck_tile::stream_conf a.sorted_expert_ids_ptr, // void* p_sorted_expert_ids; a.num_sorted_tiles_ptr, // void* p_total_tokens_post_pad; a.o_ptr, // void* p_moe_buf; + a.ws_ptr, // void* p_ws; a.num_tokens, // index_t tokens; a.block_m, // index_t unit_size; a.num_experts, // index_t num_experts; diff --git a/example/ck_tile/15_fused_moe/main.cpp b/example/ck_tile/15_fused_moe/main.cpp index 95adcd684b..cb93ce8907 100644 --- a/example/ck_tile/15_fused_moe/main.cpp +++ b/example/ck_tile/15_fused_moe/main.cpp @@ -371,6 +371,12 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::DeviceMem num_sorted_tiles_buf( num_sorted_tiles_host.get_element_space_size_in_bytes()); + // if return zero, means no need workspace, can set moe_sorting_args.p_ws to nullptr + ck_tile::index_t workspace_size = ck_tile::moe_sorting_get_workspace_size(tokens, experts); + ck_tile::DeviceMem moe_sorting_ws(workspace_size != 0 ? workspace_size : 0); + if(workspace_size != 0) + moe_sorting_ws.SetZero(); // note, clear here!!!! + fused_moe_traits traits{prec_i, prec_w, prec_o, @@ -394,6 +400,7 @@ bool run(const ck_tile::ArgParser& arg_parser) local_expert_masking ? local_expert_mask_buf.GetDeviceBuffer() : nullptr, o_buf.GetDeviceBuffer(), + workspace_size != 0 ? moe_sorting_ws.GetDeviceBuffer() : nullptr, topk_ids_buf.GetDeviceBuffer(), topk_weight_buf.GetDeviceBuffer(), sorted_token_ids_buf.GetDeviceBuffer(), diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp b/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp index 2ffef95196..14d450034d 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -10,10 +10,10 @@ #include "ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp" template -struct GemmBasicTypeConfig; +struct GemmTypeConfig; template <> -struct GemmBasicTypeConfig +struct GemmTypeConfig { using ADataType = ck_tile::half_t; using BDataType = ck_tile::half_t; @@ -21,7 +21,7 @@ struct GemmBasicTypeConfig using AccDataType = float; }; -using Types = GemmBasicTypeConfig; +using Types = GemmTypeConfig; // Specific type aliases for easy access using ADataType = Types::ADataType; diff --git a/include/ck/host_utility/device_prop.hpp b/include/ck/host_utility/device_prop.hpp index e04e27b761..402d924cbd 100644 --- a/include/ck/host_utility/device_prop.hpp +++ b/include/ck/host_utility/device_prop.hpp @@ -5,11 +5,17 @@ #ifndef __HIPCC_RTC__ #include -#include +#include #include namespace ck { +constexpr unsigned int fnv1a_hash(std::string_view str, unsigned int h = 2166136261u) +{ + return str.empty() ? h + : fnv1a_hash(str.substr(1), + (h ^ static_cast(str.front())) * 16777619u); +} inline std::string get_device_name() { hipDeviceProp_t props{}; @@ -19,37 +25,31 @@ inline std::string get_device_name() { return std::string(); } - status = hipGetDeviceProperties(&props, device); if(status != hipSuccess) { return std::string(); } const std::string raw_name(props.gcnArchName); - - // https://github.com/ROCm/MIOpen/blob/8498875aef84878e04c1eabefdf6571514891086/src/target_properties.cpp#L40 - static std::map device_name_map = { - {"Ellesmere", "gfx803"}, - {"Baffin", "gfx803"}, - {"RacerX", "gfx803"}, - {"Polaris10", "gfx803"}, - {"Polaris11", "gfx803"}, - {"Tonga", "gfx803"}, - {"Fiji", "gfx803"}, - {"gfx800", "gfx803"}, - {"gfx802", "gfx803"}, - {"gfx804", "gfx803"}, - {"Vega10", "gfx900"}, - {"gfx901", "gfx900"}, - {"10.3.0 Sienna_Cichlid 18", "gfx1030"}, - }; - const auto name = raw_name.substr(0, raw_name.find(':')); // str.substr(0, npos) returns str. - - auto match = device_name_map.find(name); - if(match != device_name_map.end()) - return match->second; - return name; + switch(fnv1a_hash(name)) + { + // https://github.com/ROCm/MIOpen/blob/8498875aef84878e04c1eabefdf6571514891086/src/target_properties.cpp#L40 + case fnv1a_hash("Ellesmere"): + case fnv1a_hash("Baffin"): + case fnv1a_hash("RacerX"): + case fnv1a_hash("Polaris10"): + case fnv1a_hash("Polaris11"): + case fnv1a_hash("Tonga"): + case fnv1a_hash("Fiji"): + case fnv1a_hash("gfx800"): + case fnv1a_hash("gfx802"): + case fnv1a_hash("gfx804"): return "gfx803"; + case fnv1a_hash("Vega10"): + case fnv1a_hash("gfx901"): return "gfx900"; + case fnv1a_hash("10.3.0 Sienna_Cichlid 18"): return "gfx1030"; + default: return name; + } } inline bool is_xdl_supported() diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_selector.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_selector.hpp index a2a552b521..526c4216fa 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_selector.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_selector.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v1.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v1.hpp index 8ee80a1d76..9781caf28c 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v1.hpp @@ -306,9 +306,6 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1(a_thread_buf[I0]), - // type_convert(b_thread_bufs[mfma_reg_buf][I0])); static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, NRepeat, 1>{}([&](auto n0) { static_for<0, KRepeat, 1>{}([&](auto k0) { diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_ab_scale.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_ab_scale.hpp index 821bbb0051..8375e81fa0 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_ab_scale.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_ab_scale.hpp @@ -7,10 +7,10 @@ namespace ck { -// Naive pipeline with lowest resource request per WGP -// GlobalPrefetchStages: 1 +// Compute optimized pipeline +// GlobalPrefetchStages: 2 // LocalPreFillStages: 1 -// LocalPreFetchStages: 0 +// LocalPreFetchStages: 1 // LocalSharedMemoryBuffer: 1 template + KPack, + true> { using Base = BlockwiseGemmXdlops_pipeline_base; + KPack, + true>; + using Base::A_K1; + using Base::B_K1; using Base::I0; + using Base::I1; using Base::KRepeat; using Base::xdlops_gemm; + using typename Base::HotLoopInstList; using Base::CalculateCThreadOriginDataIndex; using Base::CalculateCThreadOriginDataIndex8D; @@ -131,19 +137,43 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale PrefetchStages; @@ -151,11 +181,116 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale + // sizeof(ComputeDataType) / sizeof(BDataType) + // ? sizeof(ComputeDataType) / sizeof(ADataType) + // : sizeof(ComputeDataType) / sizeof(BDataType); + constexpr auto num_mfma_stage1 = num_mfma_inst - (num_dsread_a_mfma + num_dsread_b_mfma); + constexpr auto num_mfma_per_issue = + num_mfma_stage1 / (num_buffer_load_inst_a + num_buffer_load_inst_b); + constexpr auto num_dswrite_per_issue_a = num_ds_write_inst_a / num_buffer_load_inst_a; + constexpr auto num_dswrite_per_issue_b = num_ds_write_inst_b / num_buffer_load_inst_b; + + static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) { + ignore = i; + static_for<0, num_dswrite_per_issue_a, 1>{}([&](auto idswrite) { + ignore = idswrite; + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier( + 0x008, num_mfma_per_issue - num_dswrite_per_issue_a, 0); // MFMA + }); + static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) { + ignore = i; + static_for<0, num_dswrite_per_issue_b, 1>{}([&](auto idswrite) { + ignore = idswrite; + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier( + 0x008, num_mfma_per_issue - num_dswrite_per_issue_b, 0); // MFMA + }); + + // stage 2 + static_for<0, num_dsread_a_mfma, 1>{}([&](auto i) { + if constexpr((num_ds_read_inst_a - (i + 1) * ds_read_a_mfma_rate) >= + ds_read_a_mfma_rate) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read + } + else + { + __builtin_amdgcn_sched_group_barrier(0x100, + num_ds_read_inst_a - (num_dsread_a_mfma - 1) * + ds_read_a_mfma_rate, + 0); // DS read + } + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + + static_for<0, num_dsread_b_mfma, 1>{}([&](auto i) { + if constexpr((num_ds_read_inst_b - (i + 1) * ds_read_b_mfma_rate) >= + ds_read_b_mfma_rate) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); // DS read + } + else + { + __builtin_amdgcn_sched_group_barrier(0x100, + num_ds_read_inst_b - (num_dsread_b_mfma - 1) * + ds_read_b_mfma_rate, + 0); // DS read + } + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); } template ( a_thread_desc_.GetElementSpaceSize()); auto b_thread_buf = make_static_buffer( @@ -223,6 +359,8 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale( b_scale_thread_desc.GetElementSpaceSize()); + auto c_scale_thread_buf = make_static_buffer( + c_scale_thread_desc.GetElementSpaceSize()); // Global prefetch 1 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); @@ -231,11 +369,26 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale{}([&](auto m0) { + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(m0, I0), + a_scale_thread_buf); + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + a_scale_thread_copy_step.At(Number<0>{})); + }); + + if constexpr(NumKBlockPerScale == 1) + { + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + a_scale_thread_copy_step.At(Number<2>{})); + } + else + { + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + a_scale_thread_copy_step.At(Number<1>{})); + } b_scale_thread_copy.Run(b_scale_grid_desc, b_scale_grid_buf, @@ -243,17 +396,101 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale{}); + constexpr auto num_scale_m_block = CScaleThreadDesc{}.GetLength(Number<1>{}); + constexpr auto num_scale_n_block = CScaleThreadDesc{}.GetLength(Number<2>{}); + + static_for<0, num_scale_m_block, 1>{}([&](auto m0) { + static_for<0, num_scale_n_block, 1>{}([&](auto n0) { + static_for<0, num_scale_k_block, 1>{}([&](auto k0) { + constexpr index_t c_offset = + CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0)); + constexpr index_t a_offset = + AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0)); + constexpr index_t b_offset = + BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0)); + + c_scale_thread_buf(Number{}) = + a_scale_thread_buf[Number{}] * + b_scale_thread_buf[Number{}]; + }); + }); + }); + // Local prefill 1 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + // Global prefetch 2 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(m0, I0), + a_scale_thread_buf); + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + a_scale_thread_copy_step.At(Number<0>{})); + }); + + if constexpr(NumKBlockPerScale == 1) + { + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + a_scale_thread_copy_step.At(Number<2>{})); + } + else + { + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + a_scale_thread_copy_step.At(Number<1>{})); + } + + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(I0, I0), + b_scale_thread_buf); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step); + // Initialize C c_thread_buf.Clear(); - auto c_thread_buf_per_scale = remove_cvref_t(); + StaticBufferTupleOfVector + c_thread_buf_per_scale; + + // Local prefetch 1 + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k0, I0), + a_thread_buf); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k0, I0), + b_thread_buf); + }); + }); + + __builtin_amdgcn_sched_barrier(0); // main body if constexpr(HasMainLoop) @@ -261,13 +498,85 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) { + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + }); + static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); + }); + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); + constexpr index_t cscale_offset = + CScaleThreadDesc{}.CalculateOffset( + make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat)); + + c_thread_buf(Number{}) += + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()[Number{}] * + type_convert( + c_scale_thread_buf[Number{}]); + }); + }); + }); + }); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, num_scale_n_block, 1>{}([&](auto n0) { + static_for<0, num_scale_k_block, 1>{}([&](auto k0) { + constexpr index_t c_offset = + CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0)); + constexpr index_t a_offset = + AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0)); + constexpr index_t b_offset = + BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0)); + + c_scale_thread_buf(Number{}) = + a_scale_thread_buf[Number{}] * + b_scale_thread_buf[Number{}]; + }); + }); + }); + block_sync_lds(); static_for<0, KRepeat, 1>{}([&](auto k) { static_for<0, MRepeat, 1>{}([&](auto m0) { @@ -289,19 +598,70 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - c_thread_buf_per_scale.Clear(); - static_for<0, KRepeat, 1>{}([&](auto k0) { + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(m0, I0), + a_scale_thread_buf); + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, a_scale_thread_copy_step.At(Number<0>{})); + }); + + if constexpr(NumKBlockPerScale == 1) + { + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, a_scale_thread_copy_step.At(Number<2>{})); + } + else + { + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, a_scale_thread_copy_step.At(Number<1>{})); + } + + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(I0, I0), + b_scale_thread_buf); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step); + HotLoopScheduler(); + __builtin_amdgcn_sched_barrier(0); + i += 1; + } while(i < (num_loop - 2)); + } + + // tail + if constexpr(TailNum == TailNumber::Full) + { + block_sync_lds(); + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) { + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + }); + static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) { vector_type a_thread_vec; vector_type b_thread_vec; static_for<0, KPack, 1>{}([&](auto ik) { a_thread_vec.template AsType()(ik) = a_thread_buf[Number{}]; + make_tuple(m0, + I0, + kscale0 * KRepeat / num_scale_k_block + k0, + ik))>{}]; b_thread_vec.template AsType()(ik) = b_thread_buf[Number{}]; + make_tuple(n0, + I0, + kscale0 * KRepeat / num_scale_k_block + k0, + ik))>{}]; }); using mfma_input_type = @@ -311,46 +671,41 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale( a_thread_vec.template AsType(), b_thread_vec.template AsType(), - c_thread_buf_per_scale.GetVectorTypeReference(I0)); + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); }); static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); + constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset( + make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat)); + c_thread_buf(Number{}) += - c_thread_buf_per_scale[Number{}] * - type_convert(a_scale_thread_buf[I0]) * - type_convert(b_scale_thread_buf[I0]); + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()[Number{}] * + type_convert( + c_scale_thread_buf[Number{}]); }); }); }); + }); - a_scale_thread_copy.Run(a_scale_grid_desc, - a_scale_grid_buf, - a_scale_thread_desc, - make_tuple(I0, I0), - a_scale_thread_buf); + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, num_scale_n_block, 1>{}([&](auto n0) { + static_for<0, num_scale_k_block, 1>{}([&](auto k0) { + constexpr index_t c_offset = + CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0)); + constexpr index_t a_offset = + AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0)); + constexpr index_t b_offset = + BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0)); - b_scale_thread_copy.Run(b_scale_grid_desc, - b_scale_grid_buf, - b_scale_thread_desc, - make_tuple(I0, I0), - b_scale_thread_buf); + c_scale_thread_buf(Number{}) = + a_scale_thread_buf[Number{}] * + b_scale_thread_buf[Number{}]; + }); + }); + }); - a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, a_scale_thread_copy_step); - b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step); - - block_sync_lds(); - a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); - b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); - - i += 1; - - } while(i < (num_loop - 1)); - } - - // tail - if constexpr(TailNum == TailNumber::Full) - { block_sync_lds(); static_for<0, KRepeat, 1>{}([&](auto k) { static_for<0, MRepeat, 1>{}([&](auto m0) { @@ -371,49 +726,143 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale{}([&](auto m0) { static_for<0, NRepeat, 1>{}([&](auto n0) { - c_thread_buf_per_scale.Clear(); - static_for<0, KRepeat, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; - - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; + static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) { + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; }); + static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; - using mfma_input_type = - typename vector_type::type; + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); - xdlops_gemm.template Run<>( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf_per_scale.GetVectorTypeReference(I0)); - }); - static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); - c_thread_buf(Number{}) += - c_thread_buf_per_scale[Number{}] * - type_convert(a_scale_thread_buf[I0]) * - type_convert(b_scale_thread_buf[I0]); + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); + }); + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); + constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset( + make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat)); + + c_thread_buf(Number{}) += + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()[Number{}] * + type_convert( + c_scale_thread_buf[Number{}]); + }); }); }); }); + __builtin_amdgcn_sched_barrier(0); + } + else if constexpr(TailNum == TailNumber::Odd) + { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) { + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + }); + static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); + }); + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); + constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset( + make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat)); + + c_thread_buf(Number{}) += + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()[Number{}] * + type_convert( + c_scale_thread_buf[Number{}]); + }); + }); + }); + }); + __builtin_amdgcn_sched_barrier(0); } } protected: - using Base::a_thread_copy_; using Base::a_thread_desc_; - using Base::b_thread_copy_; using Base::b_thread_desc_; using Base::c_thread_desc_; + using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3>, + 3, + A_K1, + A_K1>; + + using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3>, + 3, + B_K1, + B_K1>; + + AThreadCopy a_thread_copy_{CalculateAThreadOriginDataIndex()}; + BThreadCopy b_thread_copy_{CalculateBThreadOriginDataIndex()}; }; } // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_ab_scale.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_ab_scale.hpp index 40fa776484..c8ad9c5b02 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_ab_scale.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_ab_scale.hpp @@ -96,7 +96,8 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale + KPack, + true> { using Base = BlockwiseGemmXdlops_pipeline_base; + KPack, + true>; using Base::I0; using Base::KRepeat; using Base::xdlops_gemm; @@ -270,11 +272,26 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale{}([&](auto m0) { + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(m0, I0), + a_scale_thread_buf); + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + a_scale_thread_copy_step.At(Number<0>{})); + }); + + if(num_loop_per_scale == 1) + { + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + a_scale_thread_copy_step.At(Number<2>{})); + } + else + { + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + a_scale_thread_copy_step.At(Number<1>{})); + } b_scale_thread_copy.Run(b_scale_grid_desc, b_scale_grid_buf, @@ -282,7 +299,6 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale{}) += c_thread_buf_per_scale[Number{}] * - type_convert(a_scale_thread_buf[I0]) * + type_convert(a_scale_thread_buf[m0]) * type_convert(b_scale_thread_buf[I0]); }); }); }); - a_scale_thread_copy.Run(a_scale_grid_desc, - a_scale_grid_buf, - a_scale_thread_desc, - make_tuple(I0, I0), - a_scale_thread_buf); + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(m0, I0), + a_scale_thread_buf); + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, a_scale_thread_copy_step.At(Number<0>{})); + }); + + if(num_loop_per_scale == 1) + { + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, a_scale_thread_copy_step.At(Number<2>{})); + } + else + { + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, a_scale_thread_copy_step.At(Number<1>{})); + } b_scale_thread_copy.Run(b_scale_grid_desc, b_scale_grid_buf, @@ -378,8 +409,6 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale{}) += c_thread_buf_per_scale[Number{}] * - type_convert(a_scale_thread_buf[I0]) * + type_convert(a_scale_thread_buf[m0]) * type_convert(b_scale_thread_buf[I0]); }); }); }); - a_scale_thread_copy.Run(a_scale_grid_desc, - a_scale_grid_buf, - a_scale_thread_desc, - make_tuple(I0, I0), - a_scale_thread_buf); + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(m0, I0), + a_scale_thread_buf); + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, a_scale_thread_copy_step.At(Number<0>{})); + }); + + if(num_loop_per_scale == 1) + { + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, a_scale_thread_copy_step.At(Number<2>{})); + } + else + { + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, a_scale_thread_copy_step.At(Number<1>{})); + } b_scale_thread_copy.Run(b_scale_grid_desc, b_scale_grid_buf, @@ -471,7 +515,6 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale{}) += c_thread_buf_per_scale[Number{}] * - type_convert(a_scale_thread_buf[I0]) * + type_convert(a_scale_thread_buf[m0]) * type_convert(b_scale_thread_buf[I0]); }); }); @@ -586,7 +629,7 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale{}) += c_thread_buf_per_scale[Number{}] * - type_convert(a_scale_thread_buf[I0]) * + type_convert(a_scale_thread_buf[m0]) * type_convert(b_scale_thread_buf[I0]); }); }); diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_ab_scale.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_ab_scale.hpp index de542866a6..fc0075b196 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_ab_scale.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_ab_scale.hpp @@ -96,7 +96,8 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale + KPack, + true> { using Base = BlockwiseGemmXdlops_pipeline_base; + KPack, + true>; using Base::I0; using Base::KRepeat; using Base::xdlops_gemm; @@ -177,11 +179,11 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale{}) == 1, + "Pipeline v3 only support scaleblocksliceK=1"); + static_assert(CScaleThreadDesc{}.GetLength(Number<2>{}) == 1, + "Pipeline v3 only support scaleblocksliceN=1"); // assume kperblock = scaleblockk - ignore = num_loop_per_scale; auto a_thread_buf = make_static_buffer( a_thread_desc_.GetElementSpaceSize()); auto b_thread_buf = make_static_buffer( @@ -330,6 +337,8 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale( b_scale_thread_desc.GetElementSpaceSize()); + auto c_scale_thread_buf = make_static_buffer( + c_scale_thread_desc.GetElementSpaceSize()); // Global prefetch 1 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); @@ -338,11 +347,26 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale{}([&](auto m0) { + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(m0, I0), + a_scale_thread_buf); + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + a_scale_thread_copy_step.At(Number<0>{})); + }); + + if constexpr(NumKBlockPerScale == 1) + { + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + a_scale_thread_copy_step.At(Number<2>{})); + } + else + { + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + a_scale_thread_copy_step.At(Number<1>{})); + } b_scale_thread_copy.Run(b_scale_grid_desc, b_scale_grid_buf, @@ -350,8 +374,12 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale{}([&](auto m0) { + c_scale_thread_buf(m0) = a_scale_thread_buf[m0] * b_scale_thread_buf[I0]; + }); + // Local prefill 1 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); @@ -363,10 +391,44 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale{}([&](auto m0) { + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(m0, I0), + a_scale_thread_buf); + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + a_scale_thread_copy_step.At(Number<0>{})); + }); + + if constexpr(NumKBlockPerScale == 1) + { + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + a_scale_thread_copy_step.At(Number<2>{})); + } + else + { + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + a_scale_thread_copy_step.At(Number<1>{})); + } + + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(I0, I0), + b_scale_thread_buf); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step); + // Initialize C c_thread_buf.Clear(); - auto c_thread_buf_per_scale = remove_cvref_t(); + StaticBufferTupleOfVector + c_thread_buf_per_scale; // Local prefetch 1 block_sync_lds(); @@ -409,7 +471,10 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale{}([&](auto m0) { static_for<0, NRepeat, 1>{}([&](auto n0) { - c_thread_buf_per_scale.Clear(); + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + }); static_for<0, KRepeat, 1>{}([&](auto k0) { vector_type a_thread_vec; vector_type b_thread_vec; @@ -430,19 +495,23 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale( a_thread_vec.template AsType(), b_thread_vec.template AsType(), - c_thread_buf_per_scale.GetVectorTypeReference(I0)); + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); }); static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); c_thread_buf(Number{}) += - c_thread_buf_per_scale[Number{}] * - type_convert(a_scale_thread_buf[I0]) * - type_convert(b_scale_thread_buf[I0]); + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()[Number{}] * + type_convert(c_scale_thread_buf[m0]); }); }); }); + static_for<0, MRepeat, 1>{}([&](auto m0) { + c_scale_thread_buf(m0) = a_scale_thread_buf[m0] * b_scale_thread_buf[I0]; + }); + block_sync_lds(); static_for<0, KRepeat, 1>{}([&](auto k) { static_for<0, MRepeat, 1>{}([&](auto m0) { @@ -462,11 +531,27 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale{}([&](auto m0) { + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(m0, I0), + a_scale_thread_buf); + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, a_scale_thread_copy_step.At(Number<0>{})); + }); + + if constexpr(NumKBlockPerScale == 1) + { + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, a_scale_thread_copy_step.At(Number<2>{})); + } + else + { + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, a_scale_thread_copy_step.At(Number<1>{})); + } b_scale_thread_copy.Run(b_scale_grid_desc, b_scale_grid_buf, @@ -474,7 +559,6 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale{}([&](auto m0) { static_for<0, NRepeat, 1>{}([&](auto n0) { - c_thread_buf_per_scale.Clear(); + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + }); static_for<0, KRepeat, 1>{}([&](auto k0) { vector_type a_thread_vec; vector_type b_thread_vec; @@ -507,15 +594,15 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale( a_thread_vec.template AsType(), b_thread_vec.template AsType(), - c_thread_buf_per_scale.GetVectorTypeReference(I0)); + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); }); static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); c_thread_buf(Number{}) += - c_thread_buf_per_scale[Number{}] * - type_convert(a_scale_thread_buf[I0]) * - type_convert(b_scale_thread_buf[I0]); + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()[Number{}] * + type_convert(c_scale_thread_buf[m0]); }); }); }); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_ab_scale.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_ab_scale.hpp index 480402b7e1..d5fec7201a 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_ab_scale.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_ab_scale.hpp @@ -15,6 +15,7 @@ #include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp" #include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/flush_cache.hpp" namespace ck { namespace tensor_operation { @@ -177,14 +178,57 @@ struct DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3 const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); const auto Run = [&](const auto& kernel) { - if(arg.KBatch > 1) - hipGetErrorString(hipMemsetAsync(arg.p_c_grid, - 0, - arg.M * arg.N * sizeof(CDataType), - stream_config.stream_id_)); + if(stream_config.flush_cache) + { + Argument arg_ = arg; - ave_time = launch_and_time_kernel( - stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg); + const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1( + arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0); + const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1( + arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0); + + auto size_a_buffer = + a_grid_desc_ak0_m_ak1.GetElementSpaceSize() * sizeof(ADataType); + auto size_b_buffer = + b_grid_desc_bk0_n_bk1.GetElementSpaceSize() * sizeof(BDataType); + + ck::utility::RotatingMemWrapper rotating_mem( + arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer); + rotating_mem.Print(); + + auto run_flush_cache = [&]() { + // flush icache + ck::utility::flush_icache(); + // rotating mem + rotating_mem.Next(); + // clear c mem + if(arg_.KBatch > 1) + hipGetErrorString(hipMemsetAsync(arg_.p_c_grid, + 0, + arg_.M * arg_.N * sizeof(CDataType), + stream_config.stream_id_)); + }; + + ave_time = ck::utility::launch_and_time_kernel_with_preprocess( + stream_config, + run_flush_cache, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + arg_); + } + else + { + if(arg.KBatch > 1) + hipGetErrorString(hipMemsetAsync(arg.p_c_grid, + 0, + arg.M * arg.N * sizeof(CDataType), + stream_config.stream_id_)); + + ave_time = launch_and_time_kernel( + stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg); + } }; constexpr index_t minimum_occupancy = @@ -195,7 +239,7 @@ struct DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3 if(has_main_k_block_loop) { - // Tail number always 1 + // Tail number always full if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 || BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) { @@ -208,127 +252,13 @@ struct DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3 Run(kernel); } } - // Tail number could be One to Seven - else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2) - { - { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One) - { - const auto kernel = - kernel_gemm_xdl_cshuffle_v3; - Run(kernel); - } - else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == - TailNumber::Full) - { - const auto kernel = - kernel_gemm_xdl_cshuffle_v3; - Run(kernel); - } - - if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2) - { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two) - { - const auto kernel = - kernel_gemm_xdl_cshuffle_v3; - Run(kernel); - } - } - - if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3) - { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == - TailNumber::Three) - { - const auto kernel = - kernel_gemm_xdl_cshuffle_v3; - Run(kernel); - } - } - - if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4) - { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == - TailNumber::Four) - { - const auto kernel = - kernel_gemm_xdl_cshuffle_v3; - Run(kernel); - } - } - - if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5) - { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == - TailNumber::Five) - { - const auto kernel = - kernel_gemm_xdl_cshuffle_v3; - Run(kernel); - } - } - - if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6) - { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six) - { - const auto kernel = - kernel_gemm_xdl_cshuffle_v3; - Run(kernel); - } - } - - if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7) - { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == - TailNumber::Seven) - { - const auto kernel = - kernel_gemm_xdl_cshuffle_v3; - Run(kernel); - } - } - } - } } else { // Tail number always 1 if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Full) { const auto kernel = kernel_gemm_xdl_cshuffle_v3; Run(kernel); } + else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } } } return ave_time; @@ -363,10 +303,11 @@ struct DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3 return false; } - if(ScaleBlockM % MPerBlock != 0 || ScaleBlockN % NPerBlock != 0 || ScaleBlockK != KPerBlock) - { - return false; - } + // if(ScaleBlockM % MPerBlock != 0 || ScaleBlockN % NPerBlock != 0 || ScaleBlockK != + // KPerBlock) + // { + // return false; + // } if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding || GemmSpec == GemmSpecialization::NKPadding || diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp index 813acfa656..25be9bebb7 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp @@ -225,7 +225,7 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3 make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{})); } - __device__ static auto MakeAGridDescriptor_AK0_M_AK1( + __host__ __device__ static auto MakeAGridDescriptor_AK0_M_AK1( index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0) { const auto a_grid_desc_mraw_kraw = [&]() { @@ -307,7 +307,7 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3 } } - __device__ static auto MakeBGridDescriptor_BK0_N_BK1( + __host__ __device__ static auto MakeBGridDescriptor_BK0_N_BK1( index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0) { const auto b_grid_desc_nraw_kraw = [&]() { @@ -422,6 +422,13 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3 } }(); + // pad M and N + return transform_tensor_descriptor(c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); +#if 0 using GemmSpecialization = tensor_operation::device::GemmSpecialization; if constexpr(GemmSpec == GemmSpecialization::MNPadding || @@ -459,6 +466,7 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3 // not pad M or N return c_grid_desc_mraw_nraw; } +#endif } __host__ __device__ static auto MakeDsGridDescriptor_M_N( @@ -656,40 +664,19 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3 // in some cases. else if constexpr(is_same::value) { - constexpr auto MLdsLayer = 32 * 4 / KPerBlock / sizeof(LDSTypeA) < 1 - ? 1 - : 32 * 4 / KPerBlock / sizeof(LDSTypeA); - constexpr auto a_lds_block_desc = make_naive_tensor_descriptor( - make_tuple( - AK0Number * Number{}, Number{}, AK1Number), - make_tuple(AK1Number, Number{}, I1)); + constexpr auto a_lds_block_desc = + make_naive_tensor_descriptor(make_tuple(AK0Number, Number{}, AK1Number), + make_tuple(AK1Number, Number{}, I1)); constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( a_lds_block_desc, - make_tuple(make_xor_with_modulo_transform(make_tuple( - Number{}, Number{})), + make_tuple(make_xor_with_modulo_transform( + make_tuple(Number{}, Number{})), make_pass_through_transform(AK1Number)), make_tuple(Sequence<1, 0>{}, Sequence<2>{}), make_tuple(Sequence<1, 0>{}, Sequence<2>{})); - constexpr auto a_lds_block_desc_ak0_mldslayer_m_ak1 = transform_tensor_descriptor( - a_lds_block_desc_permuted, - make_tuple(make_unmerge_transform(make_tuple(AK0Number, Number{})), - make_pass_through_transform(Number{}), - make_pass_through_transform(AK1Number)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{})); - - constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor( - a_lds_block_desc_ak0_mldslayer_m_ak1, - make_tuple(make_pass_through_transform(AK0Number), - make_merge_transform_v3_division_mod( - make_tuple(Number{}, Number{})), - make_pass_through_transform(AK1Number)), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); - - return a_lds_block_desc_ak0_m_ak1; + return a_lds_block_desc_permuted; } else // ColumnMajor A { @@ -791,42 +778,19 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3 } else if constexpr(is_same::value) { - // NLdsLayer * K0 as logical Bank - constexpr auto NLdsLayer = 32 * 4 / KPerBlock / sizeof(LDSTypeB) < 1 - ? 1 - : 32 * 4 / KPerBlock / sizeof(LDSTypeB); - ; - constexpr auto b_lds_block_desc = make_naive_tensor_descriptor( - make_tuple( - BK0Number * Number{}, Number{}, BK1Number), - make_tuple(BK1Number, Number{}, I1)); + constexpr auto b_lds_block_desc = + make_naive_tensor_descriptor(make_tuple(BK0Number, Number{}, BK1Number), + make_tuple(BK1Number, Number{}, I1)); constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( b_lds_block_desc, - make_tuple(make_xor_with_modulo_transform(make_tuple( - Number{}, Number{})), + make_tuple(make_xor_with_modulo_transform( + make_tuple(Number{}, Number{})), make_pass_through_transform(BK1Number)), make_tuple(Sequence<1, 0>{}, Sequence<2>{}), make_tuple(Sequence<1, 0>{}, Sequence<2>{})); - constexpr auto b_lds_block_desc_bk0_nldslayer_n_bk1 = transform_tensor_descriptor( - b_lds_block_desc_permuted, - make_tuple(make_unmerge_transform(make_tuple(BK0Number, Number{})), - make_pass_through_transform(Number{}), - make_pass_through_transform(BK1Number)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{})); - - constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor( - b_lds_block_desc_bk0_nldslayer_n_bk1, - make_tuple(make_pass_through_transform(BK0Number), - make_merge_transform_v3_division_mod( - make_tuple(Number{}, Number{})), - make_pass_through_transform(BK1Number)), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); - - return b_lds_block_desc_bk0_n_bk1; + return b_lds_block_desc_permuted; } else // RowMajor B { @@ -992,7 +956,8 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3 if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || - GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) && + !(is_same::value)) { if(!(karg.M % MPerBlock == 0)) { @@ -1009,7 +974,8 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3 if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || - GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) && + (is_same::value)) { if(!(karg.N % NPerBlock == 0)) { @@ -1357,28 +1323,39 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3 (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / KPerBlock); - const index_t ScaleSliceSizeM = 1; - const index_t ScaleSliceSizeN = 1; - const index_t ScaleSliceSizeK = 1; + constexpr index_t ScaleSliceSizeM = MXdlPerWave; + constexpr index_t ScaleSliceSizeN = math::integer_divide_ceil(NPerBlock, ScaleBlockN); + constexpr index_t ScaleSliceSizeK = math::integer_divide_ceil(KPerBlock, ScaleBlockK); + // ScaleSliceSizeK is last dimension in A/B scale for vector memory access + // ScaleSliceSizeK is first dimension in C scale for packed math constexpr auto a_scale_thread_desc = make_naive_tensor_descriptor_packed( make_tuple(Number{}, Number{})); + constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl); + auto a_thread_offset = + get_thread_local_1d_id() % MPerXdl + (get_thread_local_1d_id() / 64) / NWaves * MPerXdl; + constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, Number{})); + make_tuple(Number{}, Number{})); + + constexpr auto c_scale_thread_desc = make_naive_tensor_descriptor_packed(make_tuple( + Number{}, Number{}, Number{})); auto a_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2, + Sequence<1, ScaleSliceSizeK>, Sequence<0, 1>, 1, - 1, + ScaleSliceSizeK, 1, false>( - a_scale_grid_desc_am_ak, make_multi_index(block_m_id * MPerBlock / ScaleBlockM, 0)); + a_scale_grid_desc_am_ak, + make_multi_index(block_m_id * MPerBlock / ScaleBlockM + a_thread_offset, 0)); auto b_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2, Sequence<0, 1>, 1, - 1, + ScaleSliceSizeK, 1, false>( b_scale_grid_desc_bn_ak, make_multi_index(block_n_id * NPerBlock / ScaleBlockN, 0)); - constexpr auto a_scale_thread_slice_copy_step = make_multi_index(0, 1); - constexpr auto b_scale_thread_slice_copy_step = make_multi_index(0, 1); + // constexpr auto a_scale_thread_slice_copy_step = make_multi_index(0, 1); + constexpr auto a_scale_thread_slice_copy_step = + make_tuple(make_multi_index(MWaves * MPerXdl, 0), + make_multi_index(-MPerBlock, 0), + make_multi_index(-MPerBlock, ScaleSliceSizeK)); + constexpr auto b_scale_thread_slice_copy_step = make_multi_index(0, ScaleSliceSizeK); - const index_t num_k_block_per_scale = ScaleBlockK / KPerBlock; + constexpr auto NumKBlockPerScale = math::integer_divide_ceil(ScaleBlockK, KPerBlock); - blockwise_gemm_pipeline.template Run( + blockwise_gemm_pipeline.template Run( a_grid_desc_ak0_m_ak1, a_block_desc_ak0_m_ak1, a_blockwise_copy, @@ -1411,6 +1392,8 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3 b_grid_buf, b_block_buf, b_block_slice_copy_step, + + c_scale_thread_desc, c_thread_buf, a_scale_grid_desc_am_ak, @@ -1425,8 +1408,7 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3 b_scale_grid_buf, b_scale_thread_slice_copy_step, - num_k_block_main_loop, - num_k_block_per_scale); + num_k_block_main_loop); // shuffle C and write out { @@ -1437,23 +1419,24 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); - // TODO: hacky, fix it! - constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = - blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + // transposed XDL + // // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 = + blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(); - // TODO: hacky, fix it! - // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths - constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = - blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + // // TODO: hacky, fix it! + // only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp = + blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(); - constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); - constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); - constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); - constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); - constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); - constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); - constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); - constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I4); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I5); + constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I6); + constexpr auto N4 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I7); constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); @@ -1462,24 +1445,24 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3 static_cast(p_shared), c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); - constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4 = transform_tensor_descriptor( c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, make_tuple( make_freeze_transform(I0), make_unmerge_transform(make_tuple( Number{}, // M0 (MXdlPerWave) per shuffle M1, // M1 = MWave - M2, // M2 * M3 * M4 = MPerXdl - M3, - M4)), + M2)), // M2 = MPerXdl make_freeze_transform(I0), make_unmerge_transform(make_tuple( Number{}, // N0 (NXdlPerWave) per shuffle N1, // N1 = NWave - N2))), // N2 = NPerXdl + N2, // N2 * N3 * N4 = NPerXdl + N3, + N4))), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple( - Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + Sequence<>{}, Sequence<0, 2, 4>{}, Sequence<>{}, Sequence<1, 3, 5, 6, 7>{})); // calculate origin of thread output tensor on global memory // blockwise GEMM c matrix starting index @@ -1489,57 +1472,57 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3 const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; - const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + const auto m_thread_data_on_block_to_m0_m1_m2_adaptor = make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), - make_tuple(Sequence<0, 1, 2, 3, 4>{}), - make_tuple(Sequence<0>{})); - - const auto m_thread_data_on_block_idx = - m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( - make_multi_index(m_thread_data_on_block)); - - const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(make_merge_transform(make_tuple(M0, M1, M2))), make_tuple(Sequence<0, 1, 2>{}), make_tuple(Sequence<0>{})); + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2, N3, N4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + const auto n_thread_data_on_block_idx = - n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor.CalculateBottomIndex( make_multi_index(n_thread_data_on_block)); // shuffle: threadwise copy C from VGPR to LDS auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3, + N2, + I1, + N4>, Sequence<0, 1, 2, 3, 4, 5, 6, 7>, 7, 1, InMemoryDataOperationEnum::Set, 1, true>{ - c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4, make_multi_index(0, 0, m_thread_data_on_block_idx[I1], n_thread_data_on_block_idx[I1], m_thread_data_on_block_idx[I2], - m_thread_data_on_block_idx[I3], - m_thread_data_on_block_idx[I4], - n_thread_data_on_block_idx[I2]), - ck::tensor_operation::element_wise::PassThrough{}}; + n_thread_data_on_block_idx[I2], + n_thread_data_on_block_idx[I3], + n_thread_data_on_block_idx[I4]), + tensor_operation::element_wise::PassThrough{}}; using EDataType = CDataType; @@ -1621,18 +1604,17 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3 make_tuple(make_multi_index(block_m_id, 0, block_n_id, 0)), c_element_op}; - // space filling curve for threadwise C in VGPR constexpr auto sfc_c_vgpr = - SpaceFillingCurve, + SpaceFillingCurve, Sequence<0, 1, 2, 3, 4, 5, 6, 7>, Sequence>{}; + N2, + 1, + N4>>{}; constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); @@ -1652,10 +1634,10 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3 block_sync_lds(); // each thread write its data from VGPR to LDS - c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4, sfc_c_vgpr.GetIndexTupleOfNumber(access_id), c_thread_buf, - c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4, c_shuffle_block_buf); // make sure it's safe to read from LDS diff --git a/include/ck_tile/core.hpp b/include/ck_tile/core.hpp index a8c95b9c38..25f600d68d 100644 --- a/include/ck_tile/core.hpp +++ b/include/ck_tile/core.hpp @@ -58,6 +58,7 @@ #include "ck_tile/core/tensor/transpose_tile.hpp" #include "ck_tile/core/tensor/update_tile.hpp" #include "ck_tile/core/utility/bit_cast.hpp" +#include "ck_tile/core/utility/env.hpp" #include "ck_tile/core/utility/functional.hpp" #include "ck_tile/core/utility/functional_with_tuple.hpp" #include "ck_tile/core/utility/ignore.hpp" diff --git a/include/ck_tile/core/config.hpp b/include/ck_tile/core/config.hpp index c761fcb8c3..090b2bf797 100644 --- a/include/ck_tile/core/config.hpp +++ b/include/ck_tile/core/config.hpp @@ -29,6 +29,12 @@ #include "hip/hip_fp16.h" #endif +#include "ck_tile/core/utility/env.hpp" + +// environment variable to enable logging: +// export CK_TILE_LOGGING=ON or CK_TILE_LOGGING=1 or CK_TILE_LOGGING=ENABLED +CK_TILE_DECLARE_ENV_VAR_BOOL(CK_TILE_LOGGING) + #ifdef __HIPCC__ #define CK_TILE_HOST inline __host__ #define CK_TILE_DEVICE inline __device__ diff --git a/include/ck_tile/core/utility/env.hpp b/include/ck_tile/core/utility/env.hpp new file mode 100644 index 0000000000..5b0b7a9071 --- /dev/null +++ b/include/ck_tile/core/utility/env.hpp @@ -0,0 +1,204 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +namespace ck_tile { + +template +void CK_TILE_ERROR(Args&&... args) noexcept +{ + std::ostringstream oss; + (oss << ... << args); + std::cerr << "[ERROR] " << oss.str() << std::endl; +} + +namespace internal { + +template +bool is_any_of(const char* const (&names)[N], const std::string& str) +{ + return std::any_of(std::begin(names), std::end(names), [&](const char* inner_str) { + return str == inner_str; + }); +}; + +template +struct ParseEnvVal +{ +}; +template <> +struct ParseEnvVal +{ + static bool parse_env_var_value(const char* vp) + { + std::string value_env_str{vp}; + + for(auto& c : value_env_str) + { + if(std::isalpha(c) != 0) + { + c = std::tolower(static_cast(c)); + } + } + + if(is_any_of(enabled_names, value_env_str)) + { + return true; + } + else if(is_any_of(disabled_names, value_env_str)) + { + return false; + } + else + { + throw std::runtime_error("Invalid value for env variable"); + } + + return false; + } + + private: + static constexpr const char* enabled_names[] = {"enable", "enabled", "1", "yes", "on", "true"}; + static constexpr const char* disabled_names[] = { + "disable", "disabled", "0", "no", "off", "false"}; +}; + +// Supports hexadecimals (with leading "0x"), octals (if prefix is "0") and decimals (default). +// Returns 0 if environment variable is in wrong format (strtoull fails to parse the string). +template <> +struct ParseEnvVal +{ + static uint64_t parse_env_var_value(const char* vp) { return std::strtoull(vp, nullptr, 0); } +}; + +template <> +struct ParseEnvVal +{ + static std::string parse_env_var_value(const char* vp) { return std::string{vp}; } +}; + +template +struct EnvVar +{ + private: + T value{}; + bool is_unset = true; + + public: + const T& GetValue() const { return value; } + + bool IsUnset() const { return is_unset; } + + void Unset() { is_unset = true; } + + void UpdateValue(const T& val) + { + is_unset = false; + value = val; + } + + explicit EnvVar(const char* const name, const T& def_val) + { + // NOLINTNEXTLINE (concurrency-mt-unsafe) + const char* vp = std::getenv(name); + if(vp != nullptr) // a value was provided + { + is_unset = false; + value = ParseEnvVal::parse_env_var_value(vp); + } + else // no value provided, use default value + { + value = def_val; + } + } +}; +} // end namespace internal + +// Static inside function hides the variable and provides +// thread-safety/locking +// Used in global namespace +#define CK_TILE_DECLARE_ENV_VAR(name, type, default_val) \ + namespace ck_tile::env { \ + struct name \ + { \ + static_assert(std::is_same_v, \ + "CK_TILE_DECLARE_ENV* must be used in the global namespace"); \ + using value_type = type; \ + static ck_tile::internal::EnvVar& Ref() \ + { \ + static ck_tile::internal::EnvVar var{#name, default_val}; \ + return var; \ + } \ + }; \ + } + +#define CK_TILE_DECLARE_ENV_VAR_BOOL(name) CK_TILE_DECLARE_ENV_VAR(name, bool, false) + +#define CK_TILE_DECLARE_ENV_VAR_UINT64(name) CK_TILE_DECLARE_ENV_VAR(name, uint64_t, 0) + +#define CK_TILE_DECLARE_ENV_VAR_STR(name) CK_TILE_DECLARE_ENV_VAR(name, std::string, "") + +#define CK_TILE_ENV(name) \ + ck_tile::env::name {} + +template +inline const std::string& EnvGetString(EnvVar) +{ + static_assert(std::is_same_v); + return EnvVar::Ref().GetValue(); +} + +template +inline bool EnvIsEnabled(EnvVar) +{ + static_assert(std::is_same_v); + return !EnvVar::Ref().IsUnset() && EnvVar::Ref().GetValue(); +} + +template +inline bool EnvIsDisabled(EnvVar) +{ + static_assert(std::is_same_v); + return !EnvVar::Ref().IsUnset() && !EnvVar::Ref().GetValue(); +} + +template +inline uint64_t EnvValue(EnvVar) +{ + static_assert(std::is_same_v); + return EnvVar::Ref().GetValue(); +} + +template +inline bool EnvIsUnset(EnvVar) +{ + return EnvVar::Ref().IsUnset(); +} + +template +void EnvUnset(EnvVar) +{ + EnvVar::Ref().Unset(); +} + +/// Updates the cached value of an environment variable +template +void UpdateEnvVar(EnvVar, const ValueType& val) +{ + static_assert(std::is_same_v); + EnvVar::Ref().UpdateValue(val); +} + +template +void UpdateEnvVar(EnvVar, const std::string_view& val) +{ + EnvVar::Ref().UpdateValue( + ck_tile::internal::ParseEnvVal::parse_env_var_value( + val.data())); +} + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp b/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp index d9d6739fb5..6024e00419 100644 --- a/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp +++ b/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp @@ -68,16 +68,6 @@ struct BlockUniversalGemmAsBsCr static constexpr index_t NPerBlockPerIter = NWarp * WarpGemm::kN; static constexpr index_t KPerBlockPerIter = WarpGemm::kK; - using AWarpTileDistr = remove_cvref_t; - using BWarpTileDistr = remove_cvref_t; - - using AWarpTile = remove_cvref_t( - AWarpTileDistr{}))>; - using BWarpTile = remove_cvref_t( - BWarpTileDistr{}))>; - // TODO: Should we have two policies? Interwave & Intrawave ?? static constexpr index_t InterWaveSchedulingMacClusters = 1; @@ -108,6 +98,25 @@ struct BlockUniversalGemmAsBsCr static constexpr auto Scheduler = Traits::Scheduler; + using AWarpDstr = typename WarpGemm::AWarpDstr; + using BWarpDstr = typename WarpGemm::BWarpDstr; + using CWarpDstr = typename WarpGemm::CWarpDstr; + + using AWarpTensor = typename WarpGemm::AWarpTensor; + using BWarpTensor = typename WarpGemm::BWarpTensor; + using CWarpTensor = typename WarpGemm::CWarpTensor; + + static constexpr auto a_warp_y_lengths = + to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + static constexpr auto b_warp_y_lengths = + to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + static constexpr auto c_warp_y_lengths = + to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + + static constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t{}; + static constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t{}; + static constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; + static constexpr index_t APackedSize = ck_tile::numeric_traits>::PackedSize; static constexpr index_t BPackedSize = @@ -116,18 +125,65 @@ struct BlockUniversalGemmAsBsCr using I0 = number<0>; using I1 = number<1>; + CK_TILE_DEVICE static constexpr auto MakeABlockDistributionEncode() + { + constexpr index_t KPerThread = Traits::KPerThread; + constexpr index_t NumMacClusters = Traits::InterWaveSchedulingMacClusters; + constexpr index_t KPerInnerLoop = ck_tile::max(KPerThread / NumMacClusters, Traits::KPack); + constexpr index_t KIterInterWave = KPerInnerLoop / WarpGemm::kK; + + using KIterSeq = std::conditional_t, + sequence>; + + constexpr auto a_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, KIterSeq>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{}); + + return a_block_dstr_encode; + } + + CK_TILE_DEVICE static constexpr auto MakeBBlockDistributionEncode() + { + constexpr index_t KPerThread = Traits::KPerThread; + constexpr index_t NumMacClusters = Traits::InterWaveSchedulingMacClusters; + constexpr index_t KPerInnerLoop = ck_tile::max(KPerThread / NumMacClusters, Traits::KPack); + constexpr index_t KIterInterWave = KPerInnerLoop / WarpGemm::kK; + + using KIterSeq = std::conditional_t, + sequence>; + + constexpr auto b_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, KIterSeq>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{}); + + return b_block_dstr_encode; + } + private: template - CK_TILE_DEVICE static void load_interleaved_pk_type(const WarpWindow& warp_window, - WarpTile& warp_tile) + CK_TILE_DEVICE static void load_interleaved_pk_type(WarpTile& warp_tile, + const WarpWindow& warp_window) { constexpr index_t UnaryOpSize = 8; const element_wise::PassThroughPack8 elementwise_op{}; - constexpr index_t thread_buffer_size = - Traits::AWarpTile::get_thread_buffer_size() / UnaryOpSize; - const auto in_dstr_tensors = load_tile(warp_window); + constexpr index_t thread_buffer_size = WarpTile::get_thread_buffer_size() / UnaryOpSize; + const auto in_dstr_tensors = load_tile(warp_window); - static_assert(Traits::AWarpTile::get_thread_buffer_size() % UnaryOpSize == 0); + static_assert(WarpTile::get_thread_buffer_size() % UnaryOpSize == 0); using ComputeVectorType = ComputeDataType __attribute__((ext_vector_type(UnaryOpSize))); static_for<0, thread_buffer_size, 1>{}([&](auto i) { @@ -144,6 +200,17 @@ struct BlockUniversalGemmAsBsCr template struct BlockGemmImpl { + static constexpr auto ALdsTileDistr = + decltype(make_static_tile_distribution(MakeABlockDistributionEncode())){}; + static constexpr auto BLdsTileDistr = + decltype(make_static_tile_distribution(MakeBBlockDistributionEncode())){}; + + using ALdsTile = decltype(make_static_distributed_tensor(ALdsTileDistr)); + using BLdsTile = decltype(make_static_distributed_tensor(BLdsTileDistr)); + + ALdsTile a_warp_tile_; + ALdsTile b_warp_tile_; + // C += A * B template CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor, @@ -158,114 +225,39 @@ struct BlockUniversalGemmAsBsCr "The ADataType and BDataType as defined in " "traits should be the same as correspoinding block window data type!"); - static_assert( - GemmTraits::MPerBlock == ASmemBlockWindow{}.get_window_lengths()[I0{}] && - GemmTraits::NPerBlock == BSmemBlockWindow{}.get_window_lengths()[I0{}] && - GemmTraits::KPerBlock == ASmemBlockWindow{}.get_window_lengths()[I1{}], - "MPerBlock, NPerBlock, KPerBlock defined in " - " BlockGemmShape are different from A/B block smem windows apropriate dims!"); - - const index_t iMWarp = get_warp_id() / NWarp; - const index_t iNWarp = get_warp_id() - (iMWarp * NWarp); - - // TODO: refactor warp_window tile type to class member as it should be - // compile-time known information. - auto a_warp_window_tmp = make_tile_window( - a_block_window.get_bottom_tensor_view(), - make_tuple(number{}, number{}), - a_block_window.get_window_origin() + multi_index<2>{iMWarp * WarpGemm::kM, 0}, - make_static_tile_distribution(typename WarpGemm::AWarpDstrEncoding{})); - - using AWarpWindow = remove_cvref_t; - - static_assert(GemmTraits::AWarpTile::get_num_of_dimension() == - AWarpWindow::get_num_of_dimension(), - "AWarpWindow number of dimensions must be equal to " - "AWarpTile number of dimensions!"); - static_assert(GemmTraits::AWarpTile::get_lengths() == - AWarpWindow{}.get_window_lengths(), - "AWarpWindow lengths must be equal to AWarpTile lengths!"); - - statically_indexed_array< - statically_indexed_array, - MIterPerWarp> - a_warp_windows; - - // construct B-warp-window - auto b_warp_window_tmp = make_tile_window( - b_block_window.get_bottom_tensor_view(), - make_tuple(number{}, number{}), - b_block_window.get_window_origin() + multi_index<2>{iNWarp * WarpGemm::kN, 0}, - make_static_tile_distribution(typename WarpGemm::BWarpDstrEncoding{})); - - using BWarpWindow = remove_cvref_t; - - static_assert(GemmTraits::BWarpTile::get_num_of_dimension() == - BWarpWindow::get_num_of_dimension(), - "BWarpWindow number of dimensions must be equal to " - "BWarpTile number of dimensions!"); - static_assert(GemmTraits::BWarpTile::get_lengths() == - BWarpWindow{}.get_window_lengths(), - "BWarpWindow lengths must be equal to BWarpTile lengths!"); - - statically_indexed_array< - statically_indexed_array, - NIterPerWarp> - b_warp_windows; - - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - static_for<0, GemmTraits::KIterPerWarp, 1>{}([&](auto kIter) { - a_warp_windows(mIter)(kIter) = a_warp_window_tmp; - - // TODO: I don't have to move 0,0 window! - move_tile_window(a_warp_windows(mIter)(kIter), - {mIter * GemmTraits::MPerBlockPerIter, - kIter * GemmTraits::KPerBlockPerIter}); - }); - }); - - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - static_for<0, GemmTraits::KIterPerWarp, 1>{}([&](auto kIter) { - b_warp_windows(nIter)(kIter) = b_warp_window_tmp; - - move_tile_window(b_warp_windows(nIter)(kIter), - {nIter * GemmTraits::NPerBlockPerIter, - kIter * GemmTraits::KPerBlockPerIter}); - }); - }); - - using CWarpDstr = typename WarpGemm::CWarpDstr; - using AWarpTensor = typename WarpGemm::AWarpTensor; - using BWarpTensor = typename WarpGemm::BWarpTensor; - using CWarpTensor = typename WarpGemm::CWarpTensor; - - constexpr auto c_warp_y_lengths = - to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); - constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; - + if constexpr(std::is_same_v) + { + load_interleaved_pk_type(a_warp_tile_, a_block_window); + } + else + { + load_tile(a_warp_tile_, a_block_window); + } + if constexpr(std::is_same_v) + { + load_interleaved_pk_type(b_warp_tile_, b_block_window); + } + else + { + load_tile(b_warp_tile_, b_block_window); + } // hot loop: static_for<0, GemmTraits::KIterPerWarp, 1>{}([&](auto kIter) { static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - AWarpTensor a_warp_tile; - if constexpr(std::is_same_v) - { - load_interleaved_pk_type(a_warp_windows(mIter)(kIter), a_warp_tile); - } - else - { - a_warp_tile = load_tile(a_warp_windows(mIter)(kIter)); - } + // read A warp tensor from A block tensor + AWarpTensor a_warp_tensor; + + a_warp_tensor.get_thread_buffer() = a_warp_tile_.get_y_sliced_thread_data( + merge_sequences(sequence{}, a_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - BWarpTensor b_warp_tile; - if constexpr(std::is_same_v) - { - load_interleaved_pk_type(b_warp_windows(nIter)(kIter), b_warp_tile); - } - else - { - b_warp_tile = load_tile(b_warp_windows(nIter)(kIter)); - } + // read B warp tensor from B block tensor + BWarpTensor b_warp_tensor; + + b_warp_tensor.get_thread_buffer() = b_warp_tile_.get_y_sliced_thread_data( + merge_sequences(sequence{}, b_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); // read C warp tensor from C block tensor- CWarpTensor c_warp_tensor; @@ -275,7 +267,7 @@ struct BlockUniversalGemmAsBsCr merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); // warp GEMM - WarpGemm{}(c_warp_tensor, a_warp_tile, b_warp_tile); + WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); // write C warp tensor into C block tensor c_block_tensor.set_y_sliced_thread_data( @@ -291,149 +283,68 @@ struct BlockUniversalGemmAsBsCr template struct BlockGemmImpl { - statically_indexed_array< - statically_indexed_array, - MIterPerWarp> - a_warp_tiles_; + static constexpr auto ALdsTileDistr = + decltype(make_static_tile_distribution(MakeABlockDistributionEncode())){}; + static constexpr auto BLdsTileDistr = + decltype(make_static_tile_distribution(MakeBBlockDistributionEncode())){}; - statically_indexed_array< - statically_indexed_array, - NIterPerWarp> - b_warp_tiles_; + using ALdsTile = decltype(make_static_distributed_tensor(ALdsTileDistr)); + using BLdsTile = decltype(make_static_distributed_tensor(BLdsTileDistr)); + + ALdsTile a_warp_tile_; + ALdsTile b_warp_tile_; template CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window, const BSmemBlockWindow& b_block_window) { - static_assert( - GemmTraits::MPerBlock == ASmemBlockWindow{}.get_window_lengths()[I0{}] && - GemmTraits::NPerBlock == BSmemBlockWindow{}.get_window_lengths()[I0{}] && - GemmTraits::KPerBlock == ASmemBlockWindow{}.get_window_lengths()[I1{}], - "MPerBlock, NPerBlock, KPerBlock defined in " - " BlockGemmShape are different from A/B block smem windows apropriate dims!"); - - static_assert(std::is_same_v && - std::is_same_v, - "The ADataType and BDataType as defined in " - "traits should be the same as correspoinding block window data type!"); - - const index_t iMWarp = get_warp_id() / NWarp; - const index_t iNWarp = get_warp_id() - (iMWarp * NWarp); - - // TODO: refactor warp_window tile type to class member as it should be - // compile-time known information. - auto a_warp_window_tmp = make_tile_window( - a_block_window.get_bottom_tensor_view(), - make_tuple(number{}, number{}), - a_block_window.get_window_origin() + multi_index<2>{iMWarp * WarpGemm::kM, 0}, - make_static_tile_distribution(typename WarpGemm::AWarpDstrEncoding{})); - - using AWarpWindow = remove_cvref_t; - - static_assert(GemmTraits::AWarpTile::get_num_of_dimension() == - AWarpWindow::get_num_of_dimension(), - "AWarpWindow number of dimensions must be equal to " - "AWarpTile number of dimensions!"); - static_assert(GemmTraits::AWarpTile::get_lengths() == - AWarpWindow{}.get_window_lengths(), - "AWarpWindow lengths must be equal to AWarpTile lengths!"); - - statically_indexed_array, - MIterPerWarp> - a_warp_windows; - - // construct B-warp-window - auto b_warp_window_tmp = make_tile_window( - b_block_window.get_bottom_tensor_view(), - make_tuple(number{}, number{}), - b_block_window.get_window_origin() + multi_index<2>{iNWarp * WarpGemm::kN, 0}, - make_static_tile_distribution(typename WarpGemm::BWarpDstrEncoding{})); - - using BWarpWindow = remove_cvref_t; - - static_assert(GemmTraits::BWarpTile::get_num_of_dimension() == - BWarpWindow::get_num_of_dimension(), - "BWarpWindow number of dimensions must be equal to " - "BWarpTile number of dimensions!"); - static_assert(GemmTraits::BWarpTile::get_lengths() == - BWarpWindow{}.get_window_lengths(), - "BWarpWindow lengths must be equal to BWarpTile lengths!"); - - statically_indexed_array, - NIterPerWarp> - b_warp_windows; - - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - a_warp_windows(mIter)(kIter) = a_warp_window_tmp; - - // TODO: I don't have to move 0,0 window! - move_tile_window(a_warp_windows(mIter)(kIter), - {mIter * GemmTraits::MPerBlockPerIter, - kIter * GemmTraits::KPerBlockPerIter}); - }); - }); - - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - b_warp_windows(nIter)(kIter) = b_warp_window_tmp; - - move_tile_window(b_warp_windows(nIter)(kIter), - {nIter * GemmTraits::NPerBlockPerIter, - kIter * GemmTraits::KPerBlockPerIter}); - }); - }); - - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - // read A warp tensor from A block window - if constexpr(std::is_same_v) - { - load_interleaved_pk_type(a_warp_windows(mIter)(kIter), - a_warp_tiles_(mIter)(kIter)); - } - else - { - a_warp_tiles_(mIter)(kIter) = load_tile(a_warp_windows(mIter)(kIter)); - } - }); - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - // read B warp tensor from B Block window - if constexpr(std::is_same_v) - { - load_interleaved_pk_type(b_warp_windows(nIter)(kIter), - b_warp_tiles_(nIter)(kIter)); - } - else - { - b_warp_tiles_(nIter)(kIter) = load_tile(b_warp_windows(nIter)(kIter)); - } - }); - }); + if constexpr(std::is_same_v) + { + load_interleaved_pk_type(a_warp_tile_, a_block_window); + } + else + { + load_tile(a_warp_tile_, a_block_window); + } + if constexpr(std::is_same_v) + { + load_interleaved_pk_type(b_warp_tile_, b_block_window); + } + else + { + load_tile(b_warp_tile_, b_block_window); + } } // C += A * B template CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor, - [[maybe_unused]] const ASmemBlockWindow& a_block_window, - [[maybe_unused]] const BSmemBlockWindow& b_block_window) + [[maybe_unused]] ASmemBlockWindow& a_block_window, + [[maybe_unused]] BSmemBlockWindow& b_block_window) { static_assert(std::is_same_v, "The CDataType as defined in traits should be the same as correspoinding " "C block tensor data type!"); - using CWarpDstr = typename WarpGemm::CWarpDstr; - using CWarpTensor = typename WarpGemm::CWarpTensor; - - constexpr auto c_warp_y_lengths = - to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); - constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; - // hot loop: static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + // read A warp tensor from A block tensor + AWarpTensor a_warp_tensor; + + a_warp_tensor.get_thread_buffer() = a_warp_tile_.get_y_sliced_thread_data( + merge_sequences(sequence{}, a_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - // read C warp tensor from C block tensor- + // read B warp tensor from B block tensor + BWarpTensor b_warp_tensor; + + b_warp_tensor.get_thread_buffer() = b_warp_tile_.get_y_sliced_thread_data( + merge_sequences(sequence{}, b_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); + + // read C warp tensor from C block tensor CWarpTensor c_warp_tensor; c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( @@ -441,9 +352,7 @@ struct BlockUniversalGemmAsBsCr merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); // warp GEMM - WarpGemm{}(c_warp_tensor, - a_warp_tiles_[mIter][kIter], - b_warp_tiles_[nIter][kIter]); + WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); // write C warp tensor into C block tensor c_block_tensor.set_y_sliced_thread_data( @@ -468,126 +377,53 @@ struct BlockUniversalGemmAsBsCr static constexpr index_t KRepeat = KPerThread / KPerInnerLoop; static constexpr index_t KInnerLoopIter = KPerInnerLoop / GemmTraits::KPack; - statically_indexed_array< - statically_indexed_array, - MIterPerWarp> - a_warp_tiles_; + static constexpr auto ALdsTileDistr = + decltype(make_static_tile_distribution(MakeABlockDistributionEncode())){}; + static constexpr auto BLdsTileDistr = + decltype(make_static_tile_distribution(MakeBBlockDistributionEncode())){}; - statically_indexed_array< - statically_indexed_array, - NIterPerWarp> - b_warp_tiles_; + using ALdsTile = decltype(make_static_distributed_tensor(ALdsTileDistr)); + using BLdsTile = decltype(make_static_distributed_tensor(BLdsTileDistr)); + + ALdsTile a_warp_tile_; + ALdsTile b_warp_tile_; template CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window, const BSmemBlockWindow& b_block_window) { - static_assert( - GemmTraits::MPerBlock == ASmemBlockWindow{}.get_window_lengths()[I0{}] && - GemmTraits::NPerBlock == BSmemBlockWindow{}.get_window_lengths()[I0{}] && - GemmTraits::KPerBlock == ASmemBlockWindow{}.get_window_lengths()[I1{}], - "MPerBlock, NPerBlock, KPerBlock defined in " - " BlockGemmShape are different from A/B block smem windows apropriate dims!"); + constexpr auto a_lds_load_tile_distr = + make_static_tile_distribution(MakeABlockDistributionEncode()); + constexpr auto b_lds_load_tile_distr = + make_static_tile_distribution(MakeBBlockDistributionEncode()); - static_assert(std::is_same_v && - std::is_same_v, - "The ADataType and BDataType as defined in " - "traits should be the same as correspoinding block window data type!"); - - const index_t iMWarp = get_warp_id() / NWarp; - const index_t iNWarp = get_warp_id() - (iMWarp * NWarp); - - // TODO: refactor warp_window tile type to class member as it should be - // compile-time known information. - auto a_warp_window_tmp = make_tile_window( + auto a_lds_gemm_window = make_tile_window( a_block_window.get_bottom_tensor_view(), - make_tuple(number{}, number{}), - a_block_window.get_window_origin() + - multi_index<2>{iMWarp * WarpGemm::kM, KIdx * KPerInnerLoop}, - make_static_tile_distribution(typename WarpGemm::AWarpDstrEncoding{})); - - using AWarpWindow = remove_cvref_t; - - static_assert(GemmTraits::AWarpTile::get_num_of_dimension() == - AWarpWindow::get_num_of_dimension(), - "AWarpWindow number of dimensions must be equal to " - "AWarpTile number of dimensions!"); - static_assert(GemmTraits::AWarpTile::get_lengths() == - AWarpWindow{}.get_window_lengths(), - "AWarpWindow lengths must be equal to AWarpTile lengths!"); - - statically_indexed_array, - MIterPerWarp> - a_warp_windows; - - // construct B-warp-window - auto b_warp_window_tmp = make_tile_window( + make_tuple(number{}, number{}), + {0, KIdx * KPerInnerLoop}, + a_lds_load_tile_distr); + auto b_lds_gemm_window = make_tile_window( b_block_window.get_bottom_tensor_view(), - make_tuple(number{}, number{}), - b_block_window.get_window_origin() + - multi_index<2>{iNWarp * WarpGemm::kN, KIdx * KPerInnerLoop}, - make_static_tile_distribution(typename WarpGemm::BWarpDstrEncoding{})); + make_tuple(number{}, number{}), + {0, KIdx * KPerInnerLoop}, + b_lds_load_tile_distr); - using BWarpWindow = remove_cvref_t; - - static_assert(GemmTraits::BWarpTile::get_num_of_dimension() == - BWarpWindow::get_num_of_dimension(), - "BWarpWindow number of dimensions must be equal to " - "BWarpTile number of dimensions!"); - static_assert(GemmTraits::BWarpTile::get_lengths() == - BWarpWindow{}.get_window_lengths(), - "BWarpWindow lengths must be equal to BWarpTile lengths!"); - - statically_indexed_array, - NIterPerWarp> - b_warp_windows; - - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - static_for<0, KInnerLoopIter, 1>{}([&](auto kIter) { - a_warp_windows(mIter)(kIter) = a_warp_window_tmp; - - move_tile_window(a_warp_windows(mIter)(kIter), - {mIter * GemmTraits::MPerBlockPerIter, - kIter * GemmTraits::KPerBlockPerIter}); - }); - }); - - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - static_for<0, KInnerLoopIter, 1>{}([&](auto kIter) { - b_warp_windows(nIter)(kIter) = b_warp_window_tmp; - - move_tile_window(b_warp_windows(nIter)(kIter), - {nIter * GemmTraits::NPerBlockPerIter, - kIter * GemmTraits::KPerBlockPerIter}); - }); - }); - - // TODO check if a_warp_tiles has same desc as a_warp_window - static_for<0, KInnerLoopIter, 1>{}([&](auto kIter) { - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - if constexpr(std::is_same_v) - { - load_interleaved_pk_type(a_warp_windows(mIter)(kIter), - a_warp_tiles_(mIter)(kIter)); - } - else - { - a_warp_tiles_(mIter)(kIter) = load_tile(a_warp_windows(mIter)(kIter)); - } - }); - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - // read B warp tensor from B Block window - if constexpr(std::is_same_v) - { - load_interleaved_pk_type(b_warp_windows(nIter)(kIter), - b_warp_tiles_(nIter)(kIter)); - } - else - { - b_warp_tiles_(nIter)(kIter) = load_tile(b_warp_windows(nIter)(kIter)); - } - }); - }); + if constexpr(std::is_same_v) + { + load_interleaved_pk_type(a_warp_tile_, a_block_window); + } + else + { + load_tile(a_warp_tile_, a_lds_gemm_window); + } + if constexpr(std::is_same_v) + { + load_interleaved_pk_type(b_warp_tile_, b_block_window); + } + else + { + load_tile(b_warp_tile_, b_lds_gemm_window); + } } // C += A * B @@ -600,13 +436,6 @@ struct BlockUniversalGemmAsBsCr "The CDataType as defined in traits should be the same as correspoinding " "C block tensor data type!"); - using CWarpDstr = typename WarpGemm::CWarpDstr; - using CWarpTensor = typename WarpGemm::CWarpTensor; - - constexpr auto c_warp_y_lengths = - to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); - constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; - // hot loop: static_for<0, KRepeat, 1>{}([&](auto kIter) { LocalPrefetch(a_block_window, b_block_window); @@ -626,7 +455,21 @@ struct BlockUniversalGemmAsBsCr static_for<0, KInnerLoopIter, 1>{}([&](auto kInnerIter) { static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + // read A warp tensor from A block tensor + AWarpTensor a_warp_tensor; + + a_warp_tensor.get_thread_buffer() = a_warp_tile_.get_y_sliced_thread_data( + merge_sequences(sequence{}, a_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // read B warp tensor from B block tensor + BWarpTensor b_warp_tensor; + + b_warp_tensor.get_thread_buffer() = + b_warp_tile_.get_y_sliced_thread_data( + merge_sequences(sequence{}, + b_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); // read C warp tensor from C block tensor- CWarpTensor c_warp_tensor; @@ -651,9 +494,7 @@ struct BlockUniversalGemmAsBsCr __builtin_amdgcn_sched_barrier(0); } // warp GEMM - WarpGemm{}(c_warp_tensor, - a_warp_tiles_[mIter][kInnerIter], - b_warp_tiles_[nIter][kInnerIter]); + WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); // write C warp tensor into C block tensor c_block_tensor.set_y_sliced_thread_data( diff --git a/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp index 3107d07bc9..972c71e93b 100644 --- a/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp @@ -129,34 +129,34 @@ struct GemmKernel const std::size_t k_id = blockIdx.z) { constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{}); - const index_t K_t = kargs.k_batch * K1; - const index_t KRead = (kargs.K + K_t - 1) / K_t * K1; + const index_t K_t = __builtin_amdgcn_readfirstlane(kargs.k_batch * K1); + const index_t KRead = __builtin_amdgcn_readfirstlane((kargs.K + K_t - 1) / K_t * K1); if constexpr(std::is_same_v) { - a_k_split_offset = k_id * KRead; + a_k_split_offset = __builtin_amdgcn_readfirstlane(k_id * KRead); } else if constexpr(std::is_same_v) { - a_k_split_offset = k_id * KRead * kargs.stride_A; + a_k_split_offset = __builtin_amdgcn_readfirstlane(k_id * KRead * kargs.stride_A); } if constexpr(std::is_same_v) { - b_k_split_offset = k_id * KRead * kargs.stride_B; + b_k_split_offset = __builtin_amdgcn_readfirstlane(k_id * KRead * kargs.stride_B); } else if constexpr(std::is_same_v) { - b_k_split_offset = k_id * KRead; + b_k_split_offset = __builtin_amdgcn_readfirstlane(k_id * KRead); } if(k_id < static_cast(kargs.k_batch - 1)) { - splitted_k = KRead; + splitted_k = __builtin_amdgcn_readfirstlane(KRead); } else { - splitted_k = kargs.K - KRead * (kargs.k_batch - 1); + splitted_k = __builtin_amdgcn_readfirstlane(kargs.K - KRead * (kargs.k_batch - 1)); } } @@ -172,23 +172,32 @@ struct GemmKernel { if(kargs.k_batch != 1) { - std::cerr << "Conditions not met for Kbatch >1 !" << std::endl; + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("Conditions not met for Kbatch >1 !"); + } return false; } } if constexpr(std::is_same_v) { - if(kargs.K % TilePartitioner::KPerBlock != 0 && GemmPipeline::kPadK == false) + if(kargs.K % (TilePartitioner::KPerBlock * kargs.k_batch) != 0 && + GemmPipeline::kPadK == false) { - std::cerr << "Can't support K that is not a multiple of KPerBlock" - " without padding!" - << std::endl; + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("Can't support K that is not a multiple of k_batch * KPerBlock " + "without padding!"); + } return false; } if(kargs.K % GemmPipeline::GetVectorSizeA() != 0) { - std::cerr << "K is not a multiple of vector load size for A tensor!" << std::endl; + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("K is not a multiple of vector load size for A tensor!"); + } return false; } } @@ -196,14 +205,19 @@ struct GemmKernel { if(kargs.M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM == false) { - std::cerr << "Can't support M that is not a multiple of MPerBlock" - " without padding!" - << std::endl; + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR( + "Can't support M that is not a multiple of MPerBlock without padding!"); + } return false; } if(kargs.M % GemmPipeline::GetVectorSizeA() != 0) { - std::cerr << "M is not a multiple of vector load size for A tensor!" << std::endl; + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("M is not a multiple of vector load size for A tensor!"); + } return false; } } @@ -212,29 +226,40 @@ struct GemmKernel { if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false) { - std::cerr << "Can't support N that is not a multiple of NPerBlock" - " without padding!" - << std::endl; + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR( + "Can't support N that is not a multiple of NPerBlock without padding!"); + } return false; } if(kargs.N % GemmPipeline::GetVectorSizeB() != 0) { - std::cerr << "N is not a multiple of vector load size for B tensor!" << std::endl; + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("N is not a multiple of vector load size for B tensor!"); + } return false; } } else { - if(kargs.K % TilePartitioner::KPerBlock != 0 && GemmPipeline::kPadK == false) + if(kargs.K % (TilePartitioner::KPerBlock * kargs.k_batch) != 0 && + GemmPipeline::kPadK == false) { - std::cerr << "Can't support K that is not a multiple of KPerBlock" - " without padding!" - << std::endl; + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("Can't support K that is not a multiple of k_batch * KPerBlock " + "without padding!"); + } return false; } if(kargs.K % GemmPipeline::GetVectorSizeB() != 0) { - std::cerr << "K is not a multiple of vector load size for B tensor!" << std::endl; + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("K is not a multiple of vector load size for B tensor!"); + } return false; } } @@ -243,14 +268,19 @@ struct GemmKernel { if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false) { - std::cerr << "Can't support N that is not a multiple of NPerBlock" - " without padding!" - << std::endl; + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR( + "Can't support N that is not a multiple of NPerBlock without padding!"); + } return false; } if(kargs.N % EpiloguePipeline::template GetVectorSizeC() != 0) { - std::cerr << "N is not a multiple of vector load size for C tensor!" << std::endl; + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("N is not a multiple of vector load size for C tensor!"); + } return false; } } @@ -258,14 +288,19 @@ struct GemmKernel { if(kargs.M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM == false) { - std::cerr << "Can't support M that is not a multiple of MPerBlock" - " without padding!" - << std::endl; + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR( + "Can't support M that is not a multiple of MPerBlock without padding!"); + } return false; } if(kargs.M % EpiloguePipeline::template GetVectorSizeC() != 0) { - std::cerr << "M is not a multiple of vector load size for C tensor!" << std::endl; + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("M is not a multiple of vector load size for C tensor!"); + } return false; } } @@ -279,6 +314,7 @@ struct GemmKernel const GemmKernelArgs& kargs, const SplitKBatchOffset& splitk_batch_offset) { + static_assert(!TilePartitioner::BlockGemmShape::PermuteA, "Not implemented!"); const auto& a_tensor_view = [&]() { if constexpr(std::is_same_v) { @@ -303,21 +339,63 @@ struct GemmKernel const auto& b_tensor_view = [&]() { if constexpr(std::is_same_v) { - return make_naive_tensor_view( - b_ptr, - make_tuple(splitk_batch_offset.splitted_k, kargs.N), - make_tuple(kargs.stride_B, 1), - number{}, - number<1>{}); + if constexpr(TilePartitioner::BlockGemmShape::PermuteB) + { + constexpr index_t K1 = GemmPipeline::GetSmemPackB(); + const index_t K0 = splitk_batch_offset.splitted_k / K1; + constexpr index_t VectorSizeB = std::min(K1, GemmPipeline::GetVectorSizeB()); + const auto b_k0_n_k1_desc = + make_naive_tensor_descriptor(make_tuple(K0, kargs.N, K1), + make_tuple(kargs.N * K1, K1, I1), + number{}, + number<1>{}); + const auto b_n_k_desc = transform_tensor_descriptor( + b_k0_n_k1_desc, + make_tuple(make_merge_transform(make_tuple(K0, K1)), + make_pass_through_transform(kargs.N)), + make_tuple(sequence<0, 2>{}, sequence<1>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + return make_tensor_view(b_ptr, b_n_k_desc); + } + else + { + return make_naive_tensor_view( + b_ptr, + make_tuple(splitk_batch_offset.splitted_k, kargs.N), + make_tuple(kargs.stride_B, 1), + number{}, + number<1>{}); + } } else { - return make_naive_tensor_view( - b_ptr, - make_tuple(kargs.N, splitk_batch_offset.splitted_k), - make_tuple(kargs.stride_B, 1), - number{}, - number<1>{}); + if constexpr(TilePartitioner::BlockGemmShape::PermuteB) + { + constexpr index_t K1 = GemmPipeline::GetSmemPackB(); + const index_t K0 = splitk_batch_offset.splitted_k / K1; + constexpr index_t VectorSizeB = std::min(K1, GemmPipeline::GetVectorSizeB()); + const auto b_k0_n_k1_desc = + make_naive_tensor_descriptor(make_tuple(K0, kargs.N, K1), + make_tuple(kargs.N * K1, K1, I1), + number{}, + number<1>{}); + const auto b_n_k_desc = transform_tensor_descriptor( + b_k0_n_k1_desc, + make_tuple(make_merge_transform(make_tuple(K0, K1)), + make_pass_through_transform(kargs.N)), + make_tuple(sequence<0, 2>{}, sequence<1>{}), + make_tuple(sequence<1>{}, sequence<0>{})); + return make_tensor_view(b_ptr, b_n_k_desc); + } + else + { + return make_naive_tensor_view( + b_ptr, + make_tuple(kargs.N, splitk_batch_offset.splitted_k), + make_tuple(kargs.stride_B, 1), + number{}, + number<1>{}); + } } }(); @@ -488,7 +566,8 @@ struct GemmKernel const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); - const index_t num_loop = TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k); + const index_t num_loop = __builtin_amdgcn_readfirstlane( + TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k)); // Run GEMM cooperatively by whole workgroup. const auto& a_block_window = gemm_tile_windows.at(I0); @@ -539,7 +618,8 @@ struct GemmKernel const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); - const index_t num_loop = TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k); + const index_t num_loop = __builtin_amdgcn_readfirstlane( + TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k)); // Run GEMM cooperatively by whole workgroup. const auto& a_block_window = gemm_tile_windows.at(I0); @@ -558,7 +638,8 @@ struct GemmKernel CK_TILE_DEVICE void operator()(GemmKernelArgs kargs) const { - const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(blockIdx.x); + const auto blockId = __builtin_amdgcn_readfirstlane(blockIdx.x); + const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(blockId); const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock); const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock); @@ -572,11 +653,11 @@ struct GemmKernel // allocate LDS __shared__ char smem_ptr_0[GetSmemSize()]; - __shared__ char smem_ptr_1[GetSmemSize()]; - if(kargs.k_batch == 1) + if constexpr(GemmPipeline::DoubleSmemBuffer == true) { - if constexpr(GemmPipeline::DoubleSmemBuffer == true) + __shared__ char smem_ptr_1[GetSmemSize()]; + if(kargs.k_batch == 1) { RunGemm2LDS(a_ptr, b_ptr, @@ -590,17 +671,8 @@ struct GemmKernel } else { - RunGemm(a_ptr, b_ptr, c_ptr, smem_ptr_0, kargs, splitk_batch_offset, i_m, i_n); - } - } - else - { - // Do not compile in case where we have unsupported - // VectorSizeC & data type configuration. - if constexpr(!(EpiloguePipeline::template GetVectorSizeC() % 2 != 0 && - is_any_of::value)) - { - if constexpr(GemmPipeline::DoubleSmemBuffer == true) + if constexpr(!(EpiloguePipeline::template GetVectorSizeC() % 2 != 0 && + is_any_of::value)) { RunGemm2LDS(a_ptr, b_ptr, @@ -612,7 +684,18 @@ struct GemmKernel i_m, i_n); } - else + } + } + else + { + if(kargs.k_batch == 1) + { + RunGemm(a_ptr, b_ptr, c_ptr, smem_ptr_0, kargs, splitk_batch_offset, i_m, i_n); + } + else + { + if constexpr(!(EpiloguePipeline::template GetVectorSizeC() % 2 != 0 && + is_any_of::value)) { RunGemm( a_ptr, b_ptr, c_ptr, smem_ptr_0, kargs, splitk_batch_offset, i_m, i_n); diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp index 4855df0e0e..24bd66a59e 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp @@ -68,9 +68,10 @@ struct GemmPipelineAgBgCrImplBase return make_tuple(std::move(a_lds_block), std::move(b_lds_block)); } - template - CK_TILE_DEVICE auto GetAWindows(const ADramBlockWindowTmp& a_dram_block_window_tmp, - const ALdsTensorView& a_lds_block_view) const + template + CK_TILE_DEVICE constexpr auto GetAWindows(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const ALdsTensorView& a_lds_block_view, + const ALdsLoadTileDistr&) const { constexpr bool is_col_major = std::is_same_v; @@ -88,17 +89,21 @@ struct GemmPipelineAgBgCrImplBase auto a_copy_lds_window = make_tile_window( a_lds_block_view, make_tuple(number{}, number{}), {0, 0}); - auto a_lds_gemm_window = make_tile_window( - a_lds_block_view, make_tuple(number{}, number{}), {0, 0}); + auto a_lds_gemm_window = + make_tile_window(a_lds_block_view, + make_tuple(number{}, number{}), + {0, 0}, + ALdsLoadTileDistr{}); return make_tuple(std::move(a_copy_dram_window), std::move(a_copy_lds_window), std::move(a_lds_gemm_window)); } - template - CK_TILE_DEVICE auto GetBWindows(const BDramBlockWindowTmp& b_dram_block_window_tmp, - const BLdsTensorView& b_lds_block_view) const + template + CK_TILE_DEVICE constexpr auto GetBWindows(const BDramBlockWindowTmp& b_dram_block_window_tmp, + const BLdsTensorView& b_lds_block_view, + const BLdsLoadTileDistr&) const { constexpr bool is_row_major = std::is_same_v; @@ -117,8 +122,11 @@ struct GemmPipelineAgBgCrImplBase auto b_copy_lds_window = make_tile_window( b_lds_block_view, make_tuple(number{}, number{}), {0, 0}); - auto b_lds_gemm_window = make_tile_window( - b_lds_block_view, make_tuple(number{}, number{}), {0, 0}); + auto b_lds_gemm_window = + make_tile_window(b_lds_block_view, + make_tuple(number{}, number{}), + {0, 0}, + BLdsLoadTileDistr{}); return make_tuple(std::move(b_copy_dram_window), std::move(b_copy_lds_window), diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp index 73d5ce8f81..1e3694d24c 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp @@ -77,6 +77,9 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB(); } static constexpr index_t GetVectorSizeC() { return Policy::template GetVectorSizeC(); } + static constexpr index_t GetSmemPackA() { return Policy::template GetSmemPackA(); } + static constexpr index_t GetSmemPackB() { return Policy::template GetSmemPackB(); } + static constexpr bool kPadM = Problem::kPadM; static constexpr bool kPadN = Problem::kPadN; static constexpr bool kPadK = Problem::kPadK; @@ -114,11 +117,11 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1{}); // Below should be equal to AK1|BK1 - constexpr index_t A_LDS_Read_Width = Policy::template GetSmemPackA(); - constexpr index_t B_LDS_Read_Width = Policy::template GetSmemPackB(); + constexpr index_t A_LDS_Read_Width = GetSmemPackA(); + constexpr index_t B_LDS_Read_Width = GetSmemPackB(); - constexpr index_t A_LDS_Write_Width = Policy::template GetSmemPackA(); - constexpr index_t B_LDS_Write_Width = Policy::template GetSmemPackB(); + constexpr index_t A_LDS_Write_Width = GetSmemPackA(); + constexpr index_t B_LDS_Write_Width = GetSmemPackB(); constexpr index_t A_Buffer_Load_Inst_Num = MPerBlock * KPerBlock / (BlockSize * GetVectorSizeA()); @@ -174,11 +177,11 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1{}); // Below should be equal to AK1|BK1 - constexpr index_t A_LDS_Read_Width = Policy::template GetSmemPackA(); - constexpr index_t B_LDS_Read_Width = Policy::template GetSmemPackB(); + constexpr index_t A_LDS_Read_Width = GetSmemPackA(); + constexpr index_t B_LDS_Read_Width = GetSmemPackB(); - constexpr index_t A_LDS_Write_Width = Policy::template GetSmemPackA(); - constexpr index_t B_LDS_Write_Width = Policy::template GetSmemPackB(); + constexpr index_t A_LDS_Write_Width = GetSmemPackA(); + constexpr index_t B_LDS_Write_Width = GetSmemPackB(); constexpr index_t A_Buffer_Load_Inst_Num = MPerBlock * KPerBlock / (BlockSize * GetVectorSizeA()); @@ -346,17 +349,23 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 // A/B tiles in LDS auto&& [a_lds_block, b_lds_block] = Base::GetABLdsTensorViews(p_smem); + // Tile distribution for load from lds + constexpr auto a_lds_load_tile_distr = + make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode()); + constexpr auto b_lds_load_tile_distr = + make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode()); + // A DRAM tile window for load // A LDS tile window for store // A LDS tile for block GEMM auto&& [a_copy_dram_window, a_copy_lds_window, a_lds_gemm_window] = - Base::GetAWindows(a_dram_block_window_tmp, a_lds_block); + Base::GetAWindows(a_dram_block_window_tmp, a_lds_block, a_lds_load_tile_distr); // B DRAM tile window for load // B LDS tile window for store // B LDS tile for block GEMM auto&& [b_copy_dram_window, b_copy_lds_window, b_lds_gemm_window] = - Base::GetBWindows(b_dram_block_window_tmp, b_lds_block); + Base::GetBWindows(b_dram_block_window_tmp, b_lds_block, b_lds_load_tile_distr); // Block GEMM auto block_gemm = BlockGemm(); diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp index b679f8c8aa..f95d80a6f5 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp @@ -86,6 +86,9 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB(); } static constexpr index_t GetVectorSizeC() { return Policy::template GetVectorSizeC(); } + static constexpr index_t GetSmemPackA() { return Policy::template GetSmemPackA(); } + static constexpr index_t GetSmemPackB() { return Policy::template GetSmemPackB(); } + static constexpr bool kPadM = Problem::kPadM; static constexpr bool kPadN = Problem::kPadN; static constexpr bool kPadK = Problem::kPadK; diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp index b8b2d5b1c9..abf5b617ee 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp @@ -129,6 +129,9 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB(); } static constexpr index_t GetVectorSizeC() { return Policy::template GetVectorSizeC(); } + static constexpr index_t GetSmemPackA() { return Policy::template GetSmemPackA(); } + static constexpr index_t GetSmemPackB() { return Policy::template GetSmemPackB(); } + static constexpr bool kPadM = Problem::kPadM; static constexpr bool kPadN = Problem::kPadN; static constexpr bool kPadK = Problem::kPadK; @@ -215,10 +218,17 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem auto& a_lds_block = ab_lds_blocks.at(I0{}); auto& b_lds_block = ab_lds_blocks.at(I1{}); + // Tile distribution for load from lds + constexpr auto a_lds_load_tile_distr = decltype(make_static_tile_distribution( + BlockGemm::MakeABlockDistributionEncode())){}; + constexpr auto b_lds_load_tile_distr = decltype(make_static_tile_distribution( + BlockGemm::MakeBBlockDistributionEncode())){}; + // A DRAM tile window for load // A LDS tile window for store // A LDS tile for block GEMM - auto a_windows = Base::GetAWindows(a_dram_block_window_tmp, a_lds_block); + auto a_windows = + Base::GetAWindows(a_dram_block_window_tmp, a_lds_block, a_lds_load_tile_distr); auto& a_copy_dram_window = a_windows.at(I0{}); auto& a_copy_lds_window = a_windows.at(I1{}); auto& a_lds_gemm_window = a_windows.at(I2{}); @@ -226,7 +236,8 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem // B DRAM tile window for load // B LDS tile window for store // B LDS tile for block GEMM - auto b_windows = Base::GetBWindows(b_dram_block_window_tmp, b_lds_block); + auto b_windows = + Base::GetBWindows(b_dram_block_window_tmp, b_lds_block, b_lds_load_tile_distr); auto& b_copy_dram_window = b_windows.at(I0{}); auto& b_copy_lds_window = b_windows.at(I1{}); auto& b_lds_gemm_window = b_windows.at(I2{}); @@ -493,10 +504,17 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem auto& a_lds_block = ab_lds_blocks.at(I0{}); auto& b_lds_block = ab_lds_blocks.at(I1{}); + // Tile distribution for load from lds + constexpr auto a_lds_load_tile_distr = decltype(make_static_tile_distribution( + BlockGemm::MakeABlockDistributionEncode())){}; + constexpr auto b_lds_load_tile_distr = decltype(make_static_tile_distribution( + BlockGemm::MakeBBlockDistributionEncode())){}; + // A DRAM tile window for load // A LDS tile window for store // A LDS tile for block GEMM - auto a_windows = Base::GetAWindows(a_dram_block_window_tmp, a_lds_block); + auto a_windows = + Base::GetAWindows(a_dram_block_window_tmp, a_lds_block, a_lds_load_tile_distr); auto& a_copy_dram_window = a_windows.at(I0{}); auto& a_copy_lds_window = a_windows.at(I1{}); auto& a_lds_gemm_window = a_windows.at(I2{}); @@ -504,7 +522,8 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem // B DRAM tile window for load // B LDS tile window for store // B LDS tile for block GEMM - auto b_windows = Base::GetBWindows(b_dram_block_window_tmp, b_lds_block); + auto b_windows = + Base::GetBWindows(b_dram_block_window_tmp, b_lds_block, b_lds_load_tile_distr); auto& b_copy_dram_window = b_windows.at(I0{}); auto& b_copy_lds_window = b_windows.at(I1{}); auto& b_lds_gemm_window = b_windows.at(I2{}); diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp index 33945651ae..41ea89b2bd 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp @@ -36,6 +36,9 @@ struct GemmPipelineAGmemBGmemCRegV1 static constexpr index_t GetVectorSizeB() { return Problem::VectorSizeB; } static constexpr index_t GetVectorSizeC() { return Problem::VectorSizeC; } + static constexpr index_t GetSmemPackA() { return Policy::template GetSmemPackA(); } + static constexpr index_t GetSmemPackB() { return Policy::template GetSmemPackB(); } + static constexpr bool kPadM = Problem::kPadM; static constexpr bool kPadN = Problem::kPadN; static constexpr bool kPadK = Problem::kPadK; @@ -125,13 +128,25 @@ struct GemmPipelineAGmemBGmemCRegV1 auto b_copy_lds_window = make_tile_window( b_lds_block, make_tuple(number{}, number{}), {0, 0}); + // Tile distribution for load from lds + constexpr auto a_lds_load_tile_distr = + make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode()); + constexpr auto b_lds_load_tile_distr = + make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode()); + // A LDS tile for block GEMM - auto a_lds_gemm_window = make_tile_window( - a_lds_block, make_tuple(number{}, number{}), {0, 0}); + auto a_lds_gemm_window = + make_tile_window(a_lds_block, + make_tuple(number{}, number{}), + {0, 0}, + a_lds_load_tile_distr); // B LDS tile for block GEMM - auto b_lds_gemm_window = make_tile_window( - b_lds_block, make_tuple(number{}, number{}), {0, 0}); + auto b_lds_gemm_window = + make_tile_window(b_lds_block, + make_tuple(number{}, number{}), + {0, 0}, + b_lds_load_tile_distr); // Block GEMM auto block_gemm = BlockGemm(); diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp index fe706113ae..95b7618b11 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp @@ -31,6 +31,9 @@ struct GemmPipelineAGmemBGmemCRegV2 static constexpr index_t kNPerBlock = BlockGemmShape::kN; static constexpr index_t kKPerBlock = BlockGemmShape::kK; + static constexpr index_t GetSmemPackA() { return Policy::template GetSmemPackA(); } + static constexpr index_t GetSmemPackB() { return Policy::template GetSmemPackB(); } + [[nodiscard]] CK_TILE_HOST static const std::string GetName() { // clang-format off @@ -122,17 +125,29 @@ struct GemmPipelineAGmemBGmemCRegV2 {0, 0}, b_copy_dram_window.get_tile_distribution()); - // A LDS tile for block GEMM - auto a_lds_gemm_window = make_tile_window( - a_lds_block, make_tuple(number{}, number{}), {0, 0}); - - // B LDS tile for block GEMM - auto b_lds_gemm_window = make_tile_window( - b_lds_block, make_tuple(number{}, number{}), {0, 0}); - // Block GEMM constexpr auto block_gemm = Policy::template GetBlockGemm(); + // Tile distribution for load from lds + constexpr auto a_lds_load_tile_distr = + make_static_tile_distribution(decltype(block_gemm)::MakeABlockDistributionEncode()); + constexpr auto b_lds_load_tile_distr = + make_static_tile_distribution(decltype(block_gemm)::MakeBBlockDistributionEncode()); + + // A LDS tile for block GEMM + auto a_lds_gemm_window = + make_tile_window(a_lds_block, + make_tuple(number{}, number{}), + {0, 0}, + a_lds_load_tile_distr); + + // B LDS tile for block GEMM + auto b_lds_gemm_window = + make_tile_window(b_lds_block, + make_tuple(number{}, number{}), + {0, 0}, + b_lds_load_tile_distr); + // Acc register tile auto c_block_tile = decltype(block_gemm(a_lds_gemm_window, b_lds_gemm_window)){}; diff --git a/include/ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp b/include/ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp index 24a399f18d..f0aa4472e1 100644 --- a/include/ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp +++ b/include/ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp @@ -8,7 +8,11 @@ namespace ck_tile { -template +template struct TileGemmShape { using BlockTile = remove_cvref_t; @@ -21,6 +25,9 @@ struct TileGemmShape static constexpr index_t kN = BlockTile::at(number<1>{}); static constexpr index_t kK = BlockTile::at(number<2>{}); + static constexpr bool PermuteA = PermuteA_; + static constexpr bool PermuteB = PermuteB_; + CK_TILE_HOST static std::string GetName() { // clang-format off diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_ab_scale.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm_ab_scale.hpp index 7553d5e76e..3fa82ae53a 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/gemm_ab_scale.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_ab_scale.hpp @@ -17,7 +17,7 @@ namespace tensor_operation { namespace device { namespace instance { #if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_FP8)) -void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_instances( +void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_comp_default_instances( std::vector, @@ -28,14 +28,14 @@ void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_i F32, Tuple<>, BF16, - 128, + 1, 128, 128, PassThrough, PassThrough, PassThrough>>>& instances); -void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_instances( +void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_comp_kpadding_instances( std::vector, @@ -46,14 +46,14 @@ void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_ F32, Tuple<>, BF16, - 128, + 1, 128, 128, PassThrough, PassThrough, PassThrough>>>& instances); -void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnpadding_instances( +void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_mem_v1_default_instances( std::vector, @@ -64,14 +64,14 @@ void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnpadding F32, Tuple<>, BF16, - 128, + 1, 128, 128, PassThrough, PassThrough, PassThrough>>>& instances); -void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnkpadding_instances( +void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_mem_v1_kpadding_instances( std::vector, @@ -82,61 +82,7 @@ void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnkpaddin F32, Tuple<>, BF16, - 128, - 128, - 128, - PassThrough, - PassThrough, - PassThrough>>>& instances); - -void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_default_instances( - std::vector, - Row, - F8, - F32, - F8, - F32, - Tuple<>, - BF16, - 128, - 128, - 128, - PassThrough, - PassThrough, - PassThrough>>>& instances); - -void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_kpadding_instances( - std::vector, - Row, - F8, - F32, - F8, - F32, - Tuple<>, - BF16, - 128, - 128, - 128, - PassThrough, - PassThrough, - PassThrough>>>& instances); - -void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_mnkpadding_instances( - std::vector, - Row, - F8, - F32, - F8, - F32, - Tuple<>, - BF16, - 128, + 1, 128, 128, PassThrough, @@ -163,7 +109,7 @@ struct DeviceOperationInstanceFactory, CDataType, - 128, + 1, 128, 128, ck::tensor_operation::element_wise::PassThrough, @@ -180,7 +126,7 @@ struct DeviceOperationInstanceFactory, CDataType, - 128, + 1, 128, 128, ck::tensor_operation::element_wise::PassThrough, @@ -198,20 +144,14 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { - add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_instances( + add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_comp_default_instances( op_ptrs); - add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_instances( - op_ptrs); - add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnpadding_instances( - op_ptrs); - add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnkpadding_instances( + add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_comp_kpadding_instances( op_ptrs); - add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_default_instances( + add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_mem_v1_default_instances( op_ptrs); - add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_kpadding_instances( - op_ptrs); - add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_mnkpadding_instances( + add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_mem_v1_kpadding_instances( op_ptrs); } } diff --git a/library/src/tensor_operation_instance/gpu/gemm_ab_scale/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/CMakeLists.txt index aab1c4e86e..d572862884 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_ab_scale/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/CMakeLists.txt @@ -4,16 +4,13 @@ set(GEMM_AB_SCALE_INSTANCES) list(APPEND GEMM_AB_SCALE_INSTANCES device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_instance.cpp device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_instance.cpp - device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnpadding_instance.cpp - device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnkpadding_instance.cpp device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_default_instance.cpp device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_kpadding_instance.cpp - device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_mnkpadding_instance.cpp ) set_source_files_properties(device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") set_source_files_properties(device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") -set_source_files_properties(device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") -set_source_files_properties(device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnkpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") add_instance_library(device_gemm_ab_scale_instance ${GEMM_AB_SCALE_INSTANCES}) diff --git a/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128.hpp b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128.hpp index 3a7df8d974..eba9cfcb7c 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128.hpp @@ -34,49 +34,50 @@ static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; template -using device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_instances = std::tuple< +using device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_comp_instances = std::tuple< // clang-format off - //################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| Scale| Scale| Scale| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| - //################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| - //################################| | | | | | | | | | | Operation| Operation| Operation| | | M| N| K| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| - //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + //################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| Scale| Scale| Scale| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| + //################################| | | | | | | | | | | Operation| Operation| Operation| | | M| N| K| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| + //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // Compute friendly - // Spill in current compiler - // DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 256, 128, 16, 16, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, - // DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 224, 128, 16, 16, 16, 16, 8, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, - DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 128, 128, 128, 128, 16, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, - DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 128, 128, 64, 128, 16, 16, 32, 32, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, - DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 128, 64, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, - DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 128, 64, 64, 128, 16, 16, 32, 32, 1, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8> + DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 128, 128, 16, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 64, 128, 16, 16, 32, 32, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 64, 128, 16, 16, 32, 32, 1, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8> // clang-format on >; template -using device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_instances = std::tuple< +using device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_mem_instances = std::tuple< // clang-format off - //################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| Scale| Scale| Scale| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| - //################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| - //################################| | | | | | | | | | | Operation| Operation| Operation| | | M| N| K| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| - //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + //################################| ALayout| BLayout| DsLayout| ELayout|AData | BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| Scale| Scale| Scale| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //################################| | | | | Type | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| + //################################| | | | | | | | | | | Operation| Operation| Operation| | | M| N| K| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| + //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - // Latency friendly - DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 128, 32, 16, 128, 16, 16, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<2, 2, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, - DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 128, 128, 128, 16, 16, 128, 16, 16, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, S<4, 4, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, - DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 128, 16, 32, 128, 16, 16, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<4, 4, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, - // Memory friendly - DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 128, 128, 32, 128, 16, 16, 32, 32, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<4, 4, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>, - DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 128, 128, 16, 128, 16, 16, 16, 16, 4, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<2, 2, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>, - DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 128, 64, 32, 128, 16, 16, 32, 32, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<4, 4, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>, - DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 128, 64, 16, 128, 16, 16, 16, 16, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<2, 2, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>, - DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 128, 32, 16, 128, 16, 16, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<2, 2, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>, - DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 128, 128, 128, 16, 16, 64, 16, 16, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, S<4, 4, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>, - DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 128, 128, 128, 16, 16, 128, 16, 16, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, S<4, 4, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>, - DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 128, 16, 32, 128, 16, 16, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<4, 4, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>, - DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 128, 16, 64, 128, 16, 16, 16, 16, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<4, 4, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>, - DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 128, 32, 64, 128, 16, 16, 32, 32, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>, - DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 128, 16, 128, 128, 16, 16, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<4, 4, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>, - DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8> + // Memory friendly + DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 16, 256, 128, 8, 16, 16, 16, 1, 4, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 16, 1, 16>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 16, 128, 128, 8, 16, 16, 16, 1, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 16, 1, 16>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 16, 64, 128, 8, 16, 16, 16, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<4>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + + DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 16, 128, 256, 16, 16, 16, 16, 1, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 16, 1, 16>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 16, 64, 256, 16, 16, 16, 16, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<4>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + + DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 32, 256, 128, 16, 16, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 32, 128, 128, 16, 16, 32, 32, 1, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 32, 64, 128, 16, 16, 16, 16, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + + DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 32, 128, 256, 16, 16, 32, 32, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 32, 64, 256, 16, 16, 16, 16, 2, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + + DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 256, 128, 16, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 128, 128, 16, 16, 32, 32, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 64, 128, 16, 16, 32, 32, 1, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + + DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 128, 256, 16, 16, 32, 32, 2, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 64, 256, 16, 16, 32, 32, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8> // clang-format on >; } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_instance.cpp index ab83c7eb3e..aebffc01f2 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_instance.cpp @@ -8,7 +8,7 @@ namespace tensor_operation { namespace device { namespace instance { -void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_instances( +void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_comp_default_instances( std::vector, @@ -19,7 +19,7 @@ void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_i F32, Tuple<>, BF16, - 128, + 1, 128, 128, PassThrough, @@ -28,7 +28,7 @@ void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_i { add_device_operation_instances( instances, - device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_instances{}); + device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_comp_instances{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_instance.cpp index dfb1bb6e2d..31fffae080 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_instance.cpp @@ -8,7 +8,7 @@ namespace tensor_operation { namespace device { namespace instance { -void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_instances( +void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_comp_kpadding_instances( std::vector, @@ -19,7 +19,7 @@ void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_ F32, Tuple<>, BF16, - 128, + 1, 128, 128, PassThrough, @@ -28,7 +28,7 @@ void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_ { add_device_operation_instances( instances, - device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_instances{}); + device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_comp_instances{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnkpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnkpadding_instance.cpp deleted file mode 100644 index d2d3ebe81e..0000000000 --- a/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnkpadding_instance.cpp +++ /dev/null @@ -1,37 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. - -#include "device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnkpadding_instances( - std::vector, - Row, - F8, - F32, - F8, - F32, - Tuple<>, - BF16, - 128, - 128, - 128, - PassThrough, - PassThrough, - PassThrough>>>& instances) -{ - add_device_operation_instances( - instances, - device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_instances{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnpadding_instance.cpp deleted file mode 100644 index f6ce77a751..0000000000 --- a/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnpadding_instance.cpp +++ /dev/null @@ -1,37 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. - -#include "device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnpadding_instances( - std::vector, - Row, - F8, - F32, - F8, - F32, - Tuple<>, - BF16, - 128, - 128, - 128, - PassThrough, - PassThrough, - PassThrough>>>& instances) -{ - add_device_operation_instances( - instances, - device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_instances{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_default_instance.cpp index e2205ad728..569911e3de 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_default_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_default_instance.cpp @@ -8,7 +8,7 @@ namespace tensor_operation { namespace device { namespace instance { -void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_default_instances( +void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_mem_v1_default_instances( std::vector, @@ -19,7 +19,7 @@ void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_default F32, Tuple<>, BF16, - 128, + 1, 128, 128, PassThrough, @@ -28,8 +28,8 @@ void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_default { add_device_operation_instances( instances, - device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_instances{}); + device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_mem_instances{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_kpadding_instance.cpp index 5c0a6eb00d..d1e5b6b535 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_kpadding_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_kpadding_instance.cpp @@ -8,7 +8,7 @@ namespace tensor_operation { namespace device { namespace instance { -void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_kpadding_instances( +void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_mem_v1_kpadding_instances( std::vector, @@ -19,7 +19,7 @@ void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_kpaddin F32, Tuple<>, BF16, - 128, + 1, 128, 128, PassThrough, @@ -28,8 +28,8 @@ void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_kpaddin { add_device_operation_instances( instances, - device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_instances{}); + device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_mem_instances{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_mnkpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_mnkpadding_instance.cpp deleted file mode 100644 index cc1a03b060..0000000000 --- a/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_mnkpadding_instance.cpp +++ /dev/null @@ -1,38 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. - -#include "device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_mnkpadding_instances( - std::vector, - Row, - F8, - F32, - F8, - F32, - Tuple<>, - BF16, - 128, - 128, - 128, - PassThrough, - PassThrough, - PassThrough>>>& instances) -{ - add_device_operation_instances( - instances, - device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_instances{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/profiler/src/profile_gemm_ab_scale.cpp b/profiler/src/profile_gemm_ab_scale.cpp index 56c8b5e7a1..3956038a30 100644 --- a/profiler/src/profile_gemm_ab_scale.cpp +++ b/profiler/src/profile_gemm_ab_scale.cpp @@ -32,6 +32,7 @@ enum struct GemmDataType enum struct ScaleBlockTile { Tile_128_128_128, // 0 + Tile_1_128_128, // 1 }; #define OP_NAME "gemm_ab_scale" @@ -49,7 +50,8 @@ int profile_gemm_ab_scale(int argc, char* argv[]) printf(" 1: A[m, k] * B[n, k] = C[m, n];\n"); printf(" 2: A[k, m] * B[k, n] = C[m, n];\n"); printf(" 3: A[k, m] * B[n, k] = C[m, n])\n"); - printf("arg4: scale block tile (0: ScaleBlockM/N/K = [128, 128, 128];\n"); + printf("arg4: scale block tile (0: ScaleBlockM/N/K = [128, 128, 128]; 1: ScaleBlockM/N/K = " + "[1, 128, 128];\n"); printf("arg5: verification (0: no; 1: yes)\n"); printf("arg6: initialization (0: no init; 1: integer value; 2: decimal value)\n"); printf("arg7: print tensor value (0: no; 1: yes)\n"); @@ -155,7 +157,7 @@ int profile_gemm_ab_scale(int argc, char* argv[]) }; if(data_type == GemmDataType::F8_F8_BF16 && layout == GemmMatrixLayout::MK_NK_MN && - scale_block_tile == ScaleBlockTile::Tile_128_128_128) + scale_block_tile == ScaleBlockTile::Tile_1_128_128) { return profile(F8{}, F32{}, @@ -164,7 +166,7 @@ int profile_gemm_ab_scale(int argc, char* argv[]) F8{}, F32{}, BF16{}, - ck::Number<128>{}, + ck::Number<1>{}, ck::Number<128>{}, ck::Number<128>{}, Row{}, diff --git a/test/ck_tile/gemm/test_gemm_pipeline_util.hpp b/test/ck_tile/gemm/test_gemm_pipeline_util.hpp index 155234cddc..3a9203a5bf 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_util.hpp +++ b/test/ck_tile/gemm/test_gemm_pipeline_util.hpp @@ -71,7 +71,9 @@ class TestCkTileGemmPipeline : public ::testing::Test constexpr ck_tile::index_t M_Warp_Tile = 32; constexpr ck_tile::index_t N_Warp_Tile = 32; - constexpr ck_tile::index_t K_Warp_Tile = 8; + // TODO: Restore to 8. At now after changes in block_universal_gemm_as_bs_cr it return wrong + // values. + constexpr ck_tile::index_t K_Warp_Tile = 16; constexpr bool kPadM = PadM; constexpr bool kPadN = PadN;