diff --git a/CMakeLists.txt b/CMakeLists.txt index bb0c254e06..b7337a7f83 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -533,11 +533,6 @@ include_directories(BEFORE ${HIP_INCLUDE_DIRS} ) -SET(BUILD_DEV ON CACHE BOOL "BUILD_DEV") -if(BUILD_DEV) - add_compile_options(-Werror) - add_compile_options(-Weverything) -endif() message("CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}") if("${CMAKE_CXX_COMPILER_ID}" MATCHES "Clang") diff --git a/cmake/EnableCompilerWarnings.cmake b/cmake/EnableCompilerWarnings.cmake index fb2b38d688..d5bcd6f978 100644 --- a/cmake/EnableCompilerWarnings.cmake +++ b/cmake/EnableCompilerWarnings.cmake @@ -66,7 +66,6 @@ else() -Wunreachable-code -Wunused -Wno-reserved-identifier - -Werror -Wno-option-ignored -Wsign-compare -Wno-extra-semi-stmt diff --git a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_bpreshuffle.cpp b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_bpreshuffle.cpp index e4e6a4f1a7..7cb0b0c523 100644 --- a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_bpreshuffle.cpp +++ b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_bpreshuffle.cpp @@ -333,14 +333,8 @@ int main(int argc, char* argv[]) "not support this GEMM problem"); } - size_t total_size = - (M * K * sizeof(A0DataType) + N * K * sizeof(B0DataType) + M * sizeof(D0DataType) + - N * sizeof(D1DataType) + M * N * sizeof(EDataType)); - int rotate_buf_num = - ck::math::min(size_t(Repeat), ck::math::integer_divide_ceil(512 * 1024 * 1024, total_size)); - float ave_time = invoker.Run( - argument, StreamConfig{nullptr, time_kernel, 0, Warmup, Repeat, true, rotate_buf_num}); + argument, StreamConfig{nullptr, time_kernel, 0, Warmup, Repeat, false, 1}); std::size_t flop = std::size_t(2) * M * N * K; std::size_t num_btype = diff --git a/example/ck_tile/01_fmha/CMakeLists.txt b/example/ck_tile/01_fmha/CMakeLists.txt index 9ba3a453fc..16f8463b73 100644 --- a/example/ck_tile/01_fmha/CMakeLists.txt +++ b/example/ck_tile/01_fmha/CMakeLists.txt @@ -21,7 +21,7 @@ string(REPLACE ";" "," FMHA_FWD_APIS "${FMHA_FWD_ENABLE_APIS}") # generate a list of kernels, but not actually emit files at config sta execute_process( COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py - --api ${FMHA_FWD_APIS} --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/fwd_blob_list.txt + --api ${FMHA_FWD_APIS} --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/fwd_blob_list.txt --receipt 200 RESULT_VARIABLE ret ) if(ret AND NOT ret EQUAL 0) @@ -45,7 +45,7 @@ file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt FMHA_BWD_GEN_BLOBS) add_custom_command( OUTPUT ${FMHA_FWD_GEN_BLOBS} COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py - --api ${FMHA_FWD_APIS} --output_dir ${CMAKE_CURRENT_BINARY_DIR} + --api ${FMHA_FWD_APIS} --output_dir ${CMAKE_CURRENT_BINARY_DIR} --receipt 200 ) add_custom_command( diff --git a/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py b/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py index 332707eafd..fb9c9ab951 100644 --- a/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py +++ b/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py @@ -3,11 +3,11 @@ # generate kernel instances to speed up compilation FWD_DTYPE_MAP = { - "fp16" : "FmhaFwdFp16", + # "fp16" : "FmhaFwdFp16", "bf16" : "FmhaFwdBf16", - "fp8" : "FmhaFwdFp8", - "fp8fp16": "FmhaFwdFp8Fp16", - "fp8bf16": "FmhaFwdFp8Bf16" + # "fp8" : "FmhaFwdFp8", + # "fp8fp16": "FmhaFwdFp8Fp16", + # "fp8bf16": "FmhaFwdFp8Bf16" } BWD_DTYPE_MAP = { diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index 4ff7ede765..aef12a6e1f 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -118,7 +118,7 @@ FMHA_FWD_API_PER_DTYPE=""" {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{ {F_hdim_case} }} """ -FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim_v}) {{ +FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim}) {{ {F_inner_dispatch} }} """ @@ -288,7 +288,7 @@ class FmhaFwdApiPool: F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max, F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype]) if_j = 'if' if j == 0 else 'else if' - per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_hdim_v=trait.bn1, F_inner_dispatch=inners) + per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners) if_i = 'if' if i == 0 else 'else if' per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case) if not per_dtypes: @@ -413,18 +413,17 @@ class FmhaFwdKernel: def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]: if dtype == 'fp16' or dtype == 'bf16': return { - '32' : FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16, -1), - '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), + # '32' : FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16, -1), + # '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), ### '96' : FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), - '192' : FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), - '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), + # '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), } elif dtype == 'fp8' or dtype == 'bf8': return { - '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1), - '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1), - '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1), + # '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1), + # '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1), + # '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1), } else: return None @@ -441,7 +440,7 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm pipelines = [] if dtype in ['fp16', 'bf16']: for mask, bias, lse, dropout in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"]): - if hdim == 256: + if False: # if True: pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask)) pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask)) @@ -449,20 +448,24 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) else: - if bias == "bias": - # TODO: rocm 6.2 compiler problem if using qr_async for bias case - pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) - else: - pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) - if receipt == 1 and bias != "bias": - pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim - pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim + # if bias == "bias": + # # TODO: rocm 6.2 compiler problem if using qr_async for bias case + # pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask)) + # pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) + # pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask)) + # pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) + # else: + pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 'f', 'f', bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 'f', 'f', bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 'f', 'f', bias, lse, dropout, squant, mask)) + # pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) + # pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) + # pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) + # pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) + # if receipt == 1 and bias != "bias": + # pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim + # pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim elif dtype in ['fp8', 'bf8']: # no need lse/dropout kernels for mask, bias in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys()): @@ -490,10 +493,6 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm if pipeline.F_spad != 't' or pipeline.F_skpad != 't': # in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not continue - if hdim == 192 and tile.F_bn1 == 128: - # NOTE: this is used to speedup deepseek prefill case, we don't gen training - if pipeline.F_bias != 'no' or pipeline.F_lse == 't' or pipeline.F_dropout == 't' or (pipeline.F_mask not in ['no', 's_no']): - continue k = FmhaFwdKernel(F_idx=0, F_hdim=hdim, F_dtype=dtype, @@ -534,6 +533,9 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm cond &= mode == 'group' cond &= pipeline.F_vlayout == 'row' cond &= pipeline.F_squant == 'f' + cond &= pipeline.F_bias == 'no' + cond &= pipeline.F_lse == 'f' + cond &= pipeline.F_dropout == 'f' if not cond: continue api_pool.register_traits(k.api_trait()) diff --git a/example/ck_tile/01_fmha/fmha_fwd.cpp b/example/ck_tile/01_fmha/fmha_fwd.cpp index b3855e59df..423bb5c782 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/fmha_fwd.cpp @@ -480,6 +480,14 @@ bool run(const ck_tile::ArgParser& arg_parser) const auto seqstart_q_host = to_seqstarts(seqlen_qs); const auto seqstart_k_host = to_seqstarts(seqlen_ks); const auto seqstart_k_with_padding_host = to_seqstarts(seqlen_kpads); + // std::vector page_idx_host(seqstart_k_host.back(), 0); + ck_tile::HostTensor page_idx_host({seqstart_k_host.back()}); + // std::iota(page_idx_host.begin(), page_idx_host.end(), 0); + iota_shuffle(page_idx_host.mData.begin(), page_idx_host.mData.end(), 0); + // for (int i = 0; i < page_idx_host.get_element_space_size(); i++) { + // page_idx_host(i) = (i + 19) % page_idx_host.size(); + // } + page_idx_host.savetxt("page_idx_host.txt", "int"); using TypeConfig = FmhaFwdTypeConfig; @@ -585,7 +593,9 @@ bool run(const ck_tile::ArgParser& arg_parser) }; bool is_v_rowmajor = vlayout == std::string("r"); - + assert(is_v_rowmajor); + assert(!i_perm); + // host memory for storing all the tensor elements const ck_tile::index_t shape_batch = (mode == mode_enum::batch ? batch : 1); const ck_tile::index_t shape_seqlen_q = @@ -601,6 +611,8 @@ bool run(const ck_tile::ArgParser& arg_parser) 0 < page_block_size ? get_lengths(i_perm, max_num_page_blocks, nhead_k, page_block_size, hdim_q) : get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_q)); + ck_tile::HostTensor k_host_sgl({seqstart_k_host.back(), nhead_k, hdim_q}); + /// NOTICE: always use same shape for knew_host & vnew_host in batch/group mode ck_tile::HostTensor knew_host( 0 < seqlen_knew @@ -613,6 +625,7 @@ bool run(const ck_tile::ArgParser& arg_parser) : get_lengths(i_perm, max_num_page_blocks, nhead_k, hdim_v, page_block_size)) : (is_v_rowmajor ? get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_v) : get_lengths(i_perm, shape_batch, nhead_k, hdim_v, shape_seqlen_k))); + ck_tile::HostTensor v_host_sgl({seqstart_k_host.back(), nhead_k, hdim_v}); ck_tile::HostTensor vnew_host( 0 < seqlen_knew ? (is_v_rowmajor ? get_lengths(i_perm, batch, nhead_k, seqlen_knew, hdim_v) @@ -742,13 +755,21 @@ bool run(const ck_tile::ArgParser& arg_parser) } } } + k_host.ForEach([&](auto& self, auto i) { + k_host_sgl(page_idx_host(i[1]), i[2], i[3]) = self(i); + }); + v_host.ForEach([&](auto& self, auto i) { + v_host_sgl(page_idx_host(i[1]), i[2], i[3]) = self(i); + }); + // k_host.savetxt("k_host.txt"); + // k_host_sgl.savetxt("k_host_sgl.txt"); + // v_host_sgl.savetxt("v_host_sgl.txt"); iota_shuffle(block_table_host.begin(), block_table_host.end(), 0); iota_shuffle(cache_batch_idx_host.begin(), cache_batch_idx_host.end(), 0); - ck_tile::DeviceMem q_buf(q_host.get_element_space_size_in_bytes()); - ck_tile::DeviceMem k_buf(k_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem k_buf(k_host_sgl.get_element_space_size_in_bytes()); ck_tile::DeviceMem knew_buf(knew_host.get_element_space_size_in_bytes()); - ck_tile::DeviceMem v_buf(v_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem v_buf(v_host_sgl.get_element_space_size_in_bytes()); ck_tile::DeviceMem vnew_buf(vnew_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem bias_buf(bias_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem lse_acc_buf(lse_acc_host.get_element_space_size_in_bytes()); @@ -757,6 +778,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::DeviceMem o_buf(o_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t)); ck_tile::DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t)); + ck_tile::DeviceMem page_idx(page_idx_host.size() * sizeof(int32_t)); ck_tile::DeviceMem seqlen_k_buf((mode == mode_enum::batch && use_kvcache) || 0 <= seqlen_kpads[0] ? seqlen_ks.size() * sizeof(int32_t) @@ -773,14 +795,16 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::DeviceMem cache_batch_idx_buf(cache_batch_idx_host.get_element_space_size_in_bytes()); q_buf.ToDevice(q_host.data()); - k_buf.ToDevice(k_host.data()); + k_buf.ToDevice(k_host_sgl.data()); knew_buf.ToDevice(knew_host.data()); - v_buf.ToDevice(v_host.data()); + v_buf.ToDevice(v_host_sgl.data()); vnew_buf.ToDevice(vnew_host.data()); bias_buf.ToDevice(bias_host.data()); seqstart_q.ToDevice(seqstart_q_host.data()); seqstart_k.ToDevice(seqlen_kpads[0] < 0 ? seqstart_k_host.data() : seqstart_k_with_padding_host.data()); + page_idx.ToDevice(page_idx_host.data()); + seqlen_k_buf.ToDevice((mode == mode_enum::batch && use_kvcache) || 0 <= seqlen_kpads[0] ? seqlen_ks.data() : nullptr); @@ -996,6 +1020,8 @@ bool run(const ck_tile::ArgParser& arg_parser) (mode == mode_enum::group ? seqstart_q.GetDeviceBuffer() : nullptr); args.seqstart_k_ptr = (mode == mode_enum::group ? seqstart_k.GetDeviceBuffer() : nullptr); + args.page_idx_ptr = + (mode == mode_enum::group ? page_idx.GetDeviceBuffer() : nullptr); args.seqlen_k_ptr = ((mode == mode_enum::batch && use_kvcache) || 0 <= k_paddings_[0] ? seqlen_k_buf.GetDeviceBuffer() : nullptr); @@ -1171,7 +1197,6 @@ bool run(const ck_tile::ArgParser& arg_parser) auto o_naive_ref = o_naive_buf.ToHost(); o_buf.FromDevice(o_host.data()); // TODO: ugly - auto [rtol_, atol_] = get_elimit(init_method); bool pass_ = ck_tile::check_err( o_host, o_naive_ref, std::string("OUT Error: Incorrect results!"), rtol_, atol_); @@ -1513,6 +1538,8 @@ bool run(const ck_tile::ArgParser& arg_parser) else o_host_result.ForEach([&](auto& self, auto idx) { self(idx) = o_host(b_idx, idx[1] + query_offset, idx[0], idx[2]); }); // clang-format on + // o_host_result.savetxt("o_host_result.txt"); + // o_host_ref.savetxt("o_host_ref.txt"); auto [rtol, atol] = get_elimit(init_method); bool cur_pass = ck_tile::check_err( o_host_result, o_host_ref, std::string("OUT Error: Incorrect results!"), rtol, atol); diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index 765c221a7b..1f02e9f729 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -129,7 +129,8 @@ struct fmha_fwd_args const void* seqstart_k_ptr; const void* seqlen_k_ptr; // only used if both 'seqstart_q_ptr' & 'seqstart_k_ptr' are not nullptr - + const void* page_idx_ptr; + ck_tile::index_t seqlen_q; ck_tile::index_t seqlen_k; ck_tile::index_t batch; @@ -326,6 +327,7 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.seqstart_q_ptr, args.seqstart_k_ptr, args.seqlen_k_ptr, + args.page_idx_ptr, args.hdim_q, args.hdim_v, args.nhead_q, diff --git a/example/ck_tile/XX_moe_gemm/run_moe_gemm_example.inc b/example/ck_tile/XX_moe_gemm/run_moe_gemm_example.inc new file mode 100644 index 0000000000..5b1e838fd1 --- /dev/null +++ b/example/ck_tile/XX_moe_gemm/run_moe_gemm_example.inc @@ -0,0 +1,316 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +using F16 = ck_tile::half_t; +// using BF16 = ck::bhalf_t; +using F8 = ck_tile::fp8_t; +using F32 = float; + +template +static constexpr inline auto is_row_major(Layout layout_) +{ + return ck_tile::bool_constant, + ck_tile::tensor_layout::gemm::RowMajor>>{}; +} + +template +auto calculate_rtol_atol(const ck_tile::index_t K, + const ck_tile::index_t kbatch, + const float max_accumulated_value) +{ + using ComputeType = + std::conditional_t; + // Calculate thresholds + const auto rtol = ck_tile::get_relative_threshold( + ck_tile::integer_divide_ceil(K, kbatch)); + const auto atol = ck_tile::get_absolute_threshold( + max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch)); + // Calculate error due to split_k accumulation + const auto rtol_split_k = + ck_tile::get_relative_threshold(kbatch); + const auto atol_split_k = ck_tile::get_absolute_threshold( + max_accumulated_value, kbatch); + // Use higher threshold + return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); +} + +template +float invoke_gemm(int n_warmup, int n_repeat, const moe_gemm_kargs& args) +{ + float ave_time = moe_gemm( + args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}); + + std::string op_name{"Moe Gemm"}; + + std::size_t flop = 0, num_btype = 0; + + flop += std::size_t(2) * args.M * args.N * args.K; + + num_btype += sizeof(ADataType) * args.M * args.K + sizeof(BDataType) * args.K * args.N + + sizeof(CDataType) * args.M * args.N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + return ave_time; +} + +template +int run_moe_gemm_example_with_layouts(int argc, + char* argv[], + const ALayout a_layout = ALayout{}, + const BLayout b_layout = BLayout{}, + [[maybe_unused]] const CLayout c_layout = CLayout{}) +{ + auto [result, arg_parser] = create_args(argc, argv); + + if(!result) + { + return -1; + }; + + // auto valid_input_data = [&](int group_count, const auto&... args) { + // return !(args.empty() || ...) && group_count == (args.size() == ...); + // }; + + // ck_tile::index_t N = 4096; + // ck_tile::index_t K = 4096; + // ck_tile::index_t experts = 8; + // ck_tile::index_t sorted_tile_num = 8; + // ck_tile::index_t valid_tile_num = 8; + // ck_tile::index_t tokens = 128; + // ck_tile::index_t topk = 2; + + const ck_tile::index_t N = arg_parser.get_int("N"); + const ck_tile::index_t K = arg_parser.get_int("K"); + ck_tile::index_t stride_A = arg_parser.get_int("stride_A"); + ck_tile::index_t stride_B = arg_parser.get_int("stride_B"); + ck_tile::index_t stride_C = arg_parser.get_int("stride_C"); + const ck_tile::index_t num_tokens = arg_parser.get_int("NumTokens"); + const ck_tile::index_t topk = arg_parser.get_int("TopK"); + const ck_tile::index_t repeat = arg_parser.get_int("repeat"); + const ck_tile::index_t experts = arg_parser.get_int("experts"); + + // TODO: replace the magic declaration + const ck_tile::index_t MPerBlock = 128; + ck_tile::index_t sorted_tile_num = 8; + ck_tile::index_t valid_tile_num = sorted_tile_num; + + const ck_tile::index_t M = sorted_tile_num * MPerBlock; + + std::unique_ptr a_m_k_dev_buf; + std::unique_ptr b_k_n_dev_buf; + std::unique_ptr c_m_n_dev_buf; + + stride_A = ck_tile::get_default_stride(M, N, stride_A, is_row_major(a_layout)); + stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout)); + stride_C = ck_tile::get_default_stride(M, N, stride_C, is_row_major(CLayout{})); + + auto a_m_k_tensor = ck_tile::HostTensor( + ck_tile::host_tensor_descriptor(M, K, stride_A, is_row_major(a_layout))); + + // TODO: add the experts' weights in b + auto b_k_n_tensor = ck_tile::HostTensor( + is_row_major(b_layout) + ? ck_tile::host_tensor_descriptor(experts * K, N, stride_B, is_row_major(b_layout)) + : ck_tile::host_tensor_descriptor(K, experts * N, stride_B, is_row_major(b_layout))); + auto c_m_n_tensor = ck_tile::HostTensor( + ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); + + std::cout << "gemm" + << " a_m_k: " << a_m_k_tensor.mDesc << " b_k_n: " << b_k_n_tensor.mDesc + << " c_m_n: " << c_m_n_tensor.mDesc << std::endl; + + ck_tile::FillUniformDistribution{-1.f, 1.f}(a_m_k_tensor); + ck_tile::FillUniformDistribution{-1.f, 1.f}(b_k_n_tensor); + + a_m_k_dev_buf = + std::make_unique(a_m_k_tensor.get_element_space_size_in_bytes()); + b_k_n_dev_buf = + std::make_unique(b_k_n_tensor.get_element_space_size_in_bytes()); + c_m_n_dev_buf = + std::make_unique(c_m_n_tensor.get_element_space_size_in_bytes()); + + a_m_k_dev_buf->ToDevice(a_m_k_tensor.data()); + b_k_n_dev_buf->ToDevice(b_k_n_tensor.data()); + c_m_n_dev_buf->SetZero(); + c_m_n_tensor.SetZero(); + + const void* p_a = a_m_k_dev_buf->GetDeviceBuffer(); + const void* p_b = b_k_n_dev_buf->GetDeviceBuffer(); + void* p_c = c_m_n_dev_buf->GetDeviceBuffer(); + + // TODO: malloc and init sorted tokens and max tokens buffer + + ck_tile::HostTensor expert_ids( + ck_tile::HostTensorDescriptor({sorted_tile_num}, {1})); + ck_tile::HostTensor sorted_token_ids( + ck_tile::HostTensorDescriptor({sorted_tile_num * MPerBlock}, {1})); + ck_tile::HostTensor max_token_id( + ck_tile::HostTensorDescriptor({1 + sorted_tile_num})); + + std::unique_ptr sorted_token_ids_dev = std::make_unique( + sizeof(ck_tile::index_t) * sorted_token_ids.get_element_space_size_in_bytes()); + std::unique_ptr expert_ids_dev = std::make_unique( + sizeof(ck_tile::index_t) * expert_ids.get_element_space_size_in_bytes()); + std::unique_ptr max_token_id_dev = std::make_unique( + sizeof(ck_tile::index_t) * max_token_id.get_element_space_size_in_bytes()); + + max_token_id.mData = {valid_tile_num * MPerBlock, 0, 1, 2, 3, 4, 6, 7, 8, 8}; + int eids[] = {0, 1, 2, 3, 4, 4, 5, 6, 3, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2} + for(int i = 0; i < sorted_tile_num; i++) + { + expert_ids.mData[i] = eids[i]; + } + int token_per_tile = (num_tokens * topk + valid_tile_num - 1) / valid_tile_num; + int tokenid = 0; + // sorted_token_ids.mData[0] = 0; + for(int i = 0; i < sorted_tile_num * MPerBlock; i++) + { + int tile_off = i % MPerBlock; + if(tile_off < token_per_tile && tokenid < num_tokens * topk) + { + sorted_token_ids.mData[i] = (tokenid % num_tokens) | ((tokenid / num_tokens) << 24); + tokenid++; + } + else + { + sorted_token_ids.mData[i] = num_tokens - 1; + } + } + + sorted_token_ids_dev->ToDevice(sorted_token_ids.data()); + expert_ids_dev->ToDevice(expert_ids.data()); + max_token_id_dev->ToDevice(max_token_id.data()); + + const ck_tile::index_t* p_sorted_token_ids_dev = + static_cast(sorted_token_ids_dev->GetDeviceBuffer()); + const ck_tile::index_t* p_expert_ids_dev = + static_cast(expert_ids_dev->GetDeviceBuffer()); + const ck_tile::index_t* p_max_token_id_dev = + static_cast(max_token_id_dev->GetDeviceBuffer()); + + moe_gemm_kargs gemm_desc{p_sorted_token_ids_dev, + p_expert_ids_dev, + p_max_token_id_dev, + p_a, + p_b, + p_c, + num_tokens, + topk, + M, + N, + K, + stride_A, + stride_B, + stride_C}; + + invoke_gemm(3, repeat, gemm_desc); + + c_m_n_dev_buf->FromDevice(c_m_n_tensor.data()); + + bool pass{true}; + if(arg_parser.get_int("validate")) + { + ck_tile::HostTensor c_m_n_host_ref( + ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); + + c_m_n_host_ref.SetZero(); + + std::unique_ptr c_m_n_ref_buf = + std::make_unique(c_m_n_tensor.get_element_space_size_in_bytes()); + + c_m_n_ref_buf->SetZero(); + + ck_tile::reference_moe_gemm_gpu( + p_sorted_token_ids_dev, + p_expert_ids_dev, + p_max_token_id_dev, + static_cast(p_a), + static_cast(p_b), + static_cast(c_m_n_ref_buf->GetDeviceBuffer()), + num_tokens, + topk, + M, + N, + K, + stride_A, + stride_B, + stride_C); + const float max_accumulated_value = + *std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end()); + const auto rtol_atol = calculate_rtol_atol( + K, 1 /*kbatch*/, max_accumulated_value); + c_m_n_ref_buf->FromDevice(c_m_n_host_ref.data()); + + // for(int im = 0; im < M; im++) + // { + // for(int in = 0; in < N; in++) + // { + // // if (static_cast(static_cast(p_c)[im * N + in]) != 0) + // printf("c[%d][%d]: %f ", + // im, + // in, + // static_cast(static_cast(p_c)[im * N + in])); + // printf("ref[%d][%d]: %f \n", + // im, + // in, + // static_cast( + // static_cast(c_m_n_host_ref.data())[im * N + in])); + // } + // } + + pass = ck_tile::check_err(c_m_n_tensor, + c_m_n_host_ref, + "Error: Incorrect results!", + rtol_atol.at(ck_tile::number<0>{}), + rtol_atol.at(ck_tile::number<1>{})); + + std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{}) + << " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) + << std::endl; + + std::cout << "The CPU verification result is:" << (pass ? "correct" : "fail") << std::endl; + } + + return pass; +} + +int run_moe_gemm_example(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + { + return -1; + } + + const std::string a_layout = arg_parser.get_str("a_layout"); + const std::string b_layout = arg_parser.get_str("b_layout"); + + using Row = ck_tile::tensor_layout::gemm::RowMajor; + using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + + if(a_layout == "R" && b_layout == "C") + { + return run_moe_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); + } + // else if(a_layout == "R" && b_layout == "R") + // { + // return run_grouped_gemm_example_with_layouts(argc, argv, Row{}, Row{}, Row{}); + // } + else + { + throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!"); + } +} diff --git a/include/ck_tile/core/tensor/load_tile.hpp b/include/ck_tile/core/tensor/load_tile.hpp index b280a1725d..173f23441b 100644 --- a/include/ck_tile/core/tensor/load_tile.hpp +++ b/include/ck_tile/core/tensor/load_tile.hpp @@ -18,32 +18,10 @@ namespace ck_tile { -template -CK_TILE_DEVICE auto load_tile(const tile_window_with_static_distribution& tile_window, - number = {}, - bool_constant = {}) -{ - return tile_window.load(number{}, bool_constant{}); -} - -template -CK_TILE_DEVICE auto load_tile(const tile_window_linear& tile_window, +CK_TILE_DEVICE auto load_tile(const TileWindow_& tile_window, number = {}, bool_constant = {}) { @@ -51,35 +29,11 @@ CK_TILE_DEVICE auto load_tile(const tile_window_linear CK_TILE_DEVICE auto load_tile(DistributedTensor_& dst_tile, - const tile_window_with_static_distribution& tile_window, - number = {}, - bool_constant = {}) -{ - return tile_window.load(dst_tile, number{}, bool_constant{}); -} - -template -CK_TILE_DEVICE auto load_tile(DistributedTensor_& dst_tile, - const tile_window_linear& tile_window, + const TileWindow_& tile_window, number = {}, bool_constant = {}) { diff --git a/include/ck_tile/core/tensor/store_tile.hpp b/include/ck_tile/core/tensor/store_tile.hpp index d5a716664d..fa7eb7b089 100644 --- a/include/ck_tile/core/tensor/store_tile.hpp +++ b/include/ck_tile/core/tensor/store_tile.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -104,6 +104,31 @@ CK_TILE_DEVICE void store_tile( tile_window.store(dstr_tensor, number<-1>{}); } +// template +// CK_TILE_DEVICE void +// store_tile(tile_window_with_static_lengths& tile_window_tmp, +// const static_distributed_tensor& dstr_tensor, +// const T& offsets) +// { +// using DataType = remove_cvref_t; +// using TileDstr = remove_cvref_t; + +// static_assert(std::is_same_v, DataType>, "wrong!"); + +// constexpr auto tile_dstr = TileDstr{}; + +// auto tile_window = make_tile_window(tile_window_tmp.get_bottom_tensor_view(), +// tile_window_tmp.get_window_lengths(), +// tile_window_tmp.get_window_origin(), +// tile_dstr); + +// tile_window.store(dstr_tensor, offsets); +// } + template +struct tile_scatter_gather +{ + using BottomTensorView = remove_reference_t; + using WindowLengths = remove_cvref_t; + using TileDstr = remove_cvref_t; + using PageIdxArray = remove_cvref_t; + using WindowAdaptor = typename TileDstr::PsYs2XsAdaptor; + using BottomTensorDesc = typename BottomTensorView::TensorDesc; + + using DataType = remove_cvref_t; + + static constexpr index_t NDimWindowAdaptorTop = WindowAdaptor::get_num_of_top_dimension(); + static constexpr index_t NDimBottomTensor = BottomTensorDesc::get_num_of_dimension(); + + static constexpr index_t NDimP = TileDstr::get_num_of_dimension_p(); + static constexpr index_t NDimY = TileDstr::get_num_of_dimension_y(); + + static constexpr auto I0 = number<0>{}; + static constexpr auto I1 = number<1>{}; + static_assert(NumCoord == 1); + + // TODO: check WindowLengths and StaticTileDistribution are consistent + + static_assert(ck_tile::is_known_at_compile_time::value, + "wrong! lengths should be static"); + static_assert(TileDstr::is_static(), "wrong!"); + + static_assert(NDimBottomTensor == WindowAdaptor::get_num_of_bottom_dimension(), + "wrong! inconsistent # of diemsnions"); + + using AdaptorTopIndex = array; + using BottomTensorIndex = array; + + using WindowAdaptorCoord = + decltype(make_tensor_adaptor_coordinate(WindowAdaptor{}, AdaptorTopIndex{})); + + using BottomTensorCoord = + decltype(make_tensor_coordinate(BottomTensorDesc{}, BottomTensorIndex{})); + + struct load_store_traits + { + private: + static constexpr auto get_vector_dim_y_scalar_per_vector() + { + const auto [ys_vector_lengths, ys_vector_strides] = + tile_scatter_gather:: + get_window_adaptor_ys_safe_vector_length_strides(); + + index_t VectorDimY_ = 0; + index_t ScalarPerVector_ = 1; + + for(index_t i = 0; i < NDimY; ++i) + { + if(ys_vector_strides[i] == 1 && ys_vector_lengths[i] > ScalarPerVector_) + { + ScalarPerVector_ = ys_vector_lengths[i]; + VectorDimY_ = i; + } + } + + return make_tuple(VectorDimY_, ScalarPerVector_); + } + + public: + static constexpr index_t PackedSize = + ck_tile::numeric_traits>::PackedSize; + static constexpr index_t VectorDimY = get_vector_dim_y_scalar_per_vector().template at<0>(); + static constexpr index_t ScalarPerVector = + get_vector_dim_y_scalar_per_vector().template at<1>(); + + // using vector_type_t = vector_type_maker_t; + // using vector_t = typename vector_type_t::type; + using vector_t = thread_buffer; + + private: + static constexpr auto scalars_per_access_ = [] { + constexpr auto scalars_per_access_arr = generate_array( + [&](auto i) { return (i == VectorDimY) ? ScalarPerVector : 1; }, number{}); + + /// TODO: add non-automatic storage argument support to macro TO_SEQUENCE() + constexpr auto NDimY_ = NDimY; + + return TO_SEQUENCE(scalars_per_access_arr, NDimY_); + }(); + + static constexpr auto get_space_filling_curve() + { + constexpr auto tile_dstr = TileDstr{}; + + constexpr auto thread_tensor_lengths_ys = + to_sequence(tile_dstr.get_ys_to_d_descriptor().get_lengths()); + + // FIXME: need logic to judge dim access order + using DimAccessOrder = typename arithmetic_sequence_gen<0, NDimY, 1>::type; + + return space_filling_curve{}; + } + + public: + using SFC_Ys = decltype(get_space_filling_curve()); + + static constexpr index_t NumAccess = SFC_Ys::get_num_of_access(); + + static_assert(0 < NumAccess, "Wrong! NumAccess should be larger than 0"); + static_assert(NumAccess % NumCoord == 0, "wrong! # of access is not divisible by NumCoord"); + }; + + static constexpr index_t NumAccessPerCoord = load_store_traits::NumAccess / NumCoord; + + CK_TILE_DEVICE constexpr tile_scatter_gather() = default; + + CK_TILE_DEVICE constexpr tile_scatter_gather( + const BottomTensorView& bottom_tensor_view, + const WindowLengths& window_lengths, + const BottomTensorIndex& window_origin, + const TileDstr& tile_distribution, + const PageIdxArray& page_idx) + : bottom_tensor_view_{bottom_tensor_view}, + window_lengths_{window_lengths}, + window_origin_{window_origin}, + tile_dstr_{tile_distribution}, + page_idx_{page_idx}, + pre_computed_coords_{} + { +#if 0 // debug + // TODO: this use more register for FA, but less register for GEMM + // need investigation + // only support warp-tile and block-tile + static_assert(NDimP == 1 or NDimP == 2, "wrong!"); + + WindowAdaptorCoord window_adaptor_thread_coord_tmp; + + if constexpr(NDimP == 1) + { + window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate( + tile_distribution.get_ps_ys_to_xs_adaptor(), AdaptorTopIndex{get_lane_id(), 0}); + } + else if constexpr(NDimP == 2) + { + window_adaptor_thread_coord_tmp = + make_tensor_adaptor_coordinate(tile_distribution.get_ps_ys_to_xs_adaptor(), + AdaptorTopIndex{get_warp_id(), get_lane_id(), 0}); + } +#else + // TODO: this use less register for FA, but more register for GEMM + // need investigation + const auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate( + tile_distribution.get_ps_ys_to_xs_adaptor(), + container_concat(detail::get_partition_index(tile_distribution), + array{0})); +#endif + + BottomTensorIndex bottom_tensor_thread_origin_idx_tmp = + window_origin + window_adaptor_thread_coord_tmp.get_bottom_index(); + bottom_tensor_thread_origin_idx_tmp(HsGatherDim) = 0; + // BottomTensorIndex bottom_tensor_thread_origin_idx_tmp = + // tuple(0, window_adaptor_thread_coord_tmp.get_bottom_index()[1]); + const auto bottom_tensor_thread_coord_tmp = make_tensor_coordinate( + bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp); + + // pre-compute NumCoord (WindowAdaptorCoord, BottomTensorCoord) bundles to speed up + // future load/store() calls (might allocate more registers) + using Traits = load_store_traits; + using SFC_Ys = typename Traits::SFC_Ys; + + static_for<0, NumCoord, 1>{}([&](auto iCoord) { + auto window_adaptor_thread_coord = window_adaptor_thread_coord_tmp; + auto bottom_tensor_thread_coord = bottom_tensor_thread_coord_tmp; + + constexpr auto idx_diff_ys = + SFC_Ys::get_step_between(number<0>{}, number{}); + + constexpr auto idx_diff_ps_ys = container_concat( + generate_tuple([&](auto) { return number<0>{}; }, number{}), idx_diff_ys); + + move_window_adaptor_and_bottom_tensor_thread_coordinate( + window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys); + + pre_computed_coords_(iCoord) = + make_tuple(window_adaptor_thread_coord, bottom_tensor_thread_coord); + }); + } + + CK_TILE_DEVICE static constexpr index_t get_num_of_dimension() { return NDimBottomTensor; } + + CK_TILE_DEVICE static constexpr bool has_static_tile_distribution() + { + return TileDstr::is_static(); + } + + CK_TILE_DEVICE constexpr auto get_window_lengths() const { return window_lengths_; } + + CK_TILE_DEVICE constexpr auto get_tile_distribution() const { return tile_dstr_; } + + CK_TILE_DEVICE constexpr auto get_bottom_tensor_view() const { return bottom_tensor_view_; } + + CK_TILE_DEVICE constexpr auto get_window_origin() const { return window_origin_; } + + CK_TILE_DEVICE constexpr void + set_bottom_tensor_view_data_ptr(typename BottomTensorView::DataType* data) + { + bottom_tensor_view_.buf_.p_data_ = data; + } + + // move thread's window adaptor coordinate and bottom tensor coordinate + // [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...] ==> [x0', x1', ...] ==> [offset] + template + CK_TILE_DEVICE void move_window_adaptor_and_bottom_tensor_thread_coordinate( + WindowAdaptorCoord& window_adaptor_thread_coord, + BottomTensorCoord& bottom_tensor_thread_coord, + const ATopIndex& idx_diff_adaptor_top) const + { + array idx_diff_adaptor_bottom; + + move_tensor_adaptor_coordinate(tile_dstr_.get_ps_ys_to_xs_adaptor(), + window_adaptor_thread_coord, + idx_diff_adaptor_top, + idx_diff_adaptor_bottom); + + move_tensor_coordinate(bottom_tensor_view_.get_tensor_descriptor(), + bottom_tensor_thread_coord, + idx_diff_adaptor_bottom); + } + + // return vector dimension among [y0, y1, ...] + CK_TILE_DEVICE static constexpr auto get_window_adaptor_ys_safe_vector_length_strides() + { + // bottom tensor top dimension vector lengths and strides + const auto [bottom_tensor_top_dim_vector_lengths, bottom_tensor_top_dim_vector_strides] = + BottomTensorDesc::get_top_dimension_safe_vector_length_strides(); + + // window vector lengths/strides + const auto window_adaptor_bottom_dim_vector_lengths = bottom_tensor_top_dim_vector_lengths; + const auto window_adaptor_bottom_dim_vector_strides = bottom_tensor_top_dim_vector_strides; + + // window adaptor [p0, p1, ..., y0, y1, ...] + array window_adaptor_vector_lengths{ + -1}; + array window_adaptor_vector_strides{ + -1}; + + constexpr auto window_adaptor_bottom_dims = + WindowAdaptor::get_bottom_dimension_hidden_ids(); + + set_container_subset(window_adaptor_vector_lengths, + window_adaptor_bottom_dims, + window_adaptor_bottom_dim_vector_lengths); + set_container_subset(window_adaptor_vector_strides, + window_adaptor_bottom_dims, + window_adaptor_bottom_dim_vector_strides); + + const auto [window_adaptor_ps_ys_vector_lengths, window_adaptor_ps_ys_vector_strides] = + WindowAdaptor{}.get_top_dimension_safe_vector_length_strides( + window_adaptor_vector_lengths, window_adaptor_vector_strides); + + // [y0, y1, ...] + constexpr auto y_dims = typename arithmetic_sequence_gen::type{}; + + return make_tuple(get_container_subset(window_adaptor_ps_ys_vector_lengths, y_dims), + get_container_subset(window_adaptor_ps_ys_vector_strides, y_dims)); + } + + CK_TILE_DEVICE constexpr auto get_num_of_access() const { return load_store_traits::NumAccess; } + + template + CK_TILE_DEVICE auto load(number = {}, + bool_constant = {}) const + { + constexpr auto tile_dstr = TileDstr{}; + auto dst_tensor = make_static_distributed_tensor(tile_dstr); + load(dst_tensor, number{}, bool_constant{}); + return dst_tensor; + } + + template + CK_TILE_DEVICE auto load(DistributedTensor& dst_tensor, + number = {}, + bool_constant = {}) const + { + using Traits = load_store_traits; + using vector_t = typename Traits::vector_t; + using SFC_Ys = typename Traits::SFC_Ys; + + constexpr auto tile_dstr = TileDstr{}; + + // loop over thread tensor space [y0, y1, ...] + static_for<0, NumCoord, 1>{}([&](auto iCoord) { + /// TODO: use structure binding (to be captured later) if compiled in C++20 + auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0]; + auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1]; + + static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) { + constexpr auto iAccess = number{}; + + // data index [y0, y1, ...] + constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess); + constexpr auto idx_m = idx_ys_start[number{}]; + const auto page_offset = page_idx_[idx_m]; + + // read from bottom tensor + const vector_t vec_value = + get_bottom_tensor_view().template get_vectorized_elements( + bottom_tensor_thread_coord, page_offset, bool_constant{}); +#if 1 + // write into distributed tensor + static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) { + constexpr auto idx_ys = generate_tuple( + [&](auto jj) { + return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j) + : idx_ys_start[jj]; + }, + number{}); + + constexpr index_t d = + tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) / + Traits::PackedSize; + + dst_tensor.get_thread_buffer().template at() = + vec_value.template get_as()[j / Traits::PackedSize]; + }); +#else + constexpr index_t d = + tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start); + static_assert(d % Traits::ScalarPerVector == 0); + + dst_tensor.get_thread_buffer().template get_as()( + number{}) = bit_cast(vec_value); +#endif + // move thread coordinate + if constexpr(iCoordAccess != (NumAccessPerCoord - 1)) + { + constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess); + + constexpr auto forward_step_scatter = generate_tuple( + [&](auto i) { return i == YsGatherDim ? 0 : idx_diff_ys[i]; }, number{}); + + constexpr auto idx_diff_ps_ys = container_concat( + generate_tuple([&](auto) { return number<0>{}; }, number{}), + forward_step_scatter); + + move_window_adaptor_and_bottom_tensor_thread_coordinate( + window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys); + } + }); + }); + } + + template + CK_TILE_DEVICE void store(const static_distributed_tensor& dstr_tensor, + number = {}, + bool_constant = {}) const + { + using Traits = load_store_traits; + + // using vector_type_t = typename Traits::vector_type_t; + using vector_t = typename Traits::vector_t; + using SFC_Ys = typename Traits::SFC_Ys; + + constexpr auto tile_dstr = TileDstr{}; + // printf("off %d\n", page_idx_[I0]); + // loop over thread tensor space [y0, y1, ...] + static_for<0, NumCoord, 1>{}([&](auto iCoord) { + auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0]; + auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1]; + + // BottomTensorIndex bottom_tensor_thread_origin_idx_tmp = + // window_origin_ + + // tuple(0, window_adaptor_thread_coord.get_bottom_index()[1]); + + // auto bottom_tensor_thread_coord = make_tensor_coordinate( + // bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp); + + static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) { + constexpr auto iAccess = number{}; + + // data index [y0, y1, ...] + constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess); + constexpr auto idx_m = idx_ys_start[number<0>{}]; + const auto page_offset = page_idx_[idx_m]; + + // printf("idx_ys_start[0], idx_ys_start[1](%d, %d) \n", + // idx_ys_start[number<0>{}]+0, idx_ys_start[number<1>{}]+0); + + // read from distributed tensor + // vector_type_t vec; + vector_t vec_value; + + static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) { + constexpr auto idx_ys = generate_tuple( + [&](auto jj) { + return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j) + : idx_ys_start[jj]; + }, + number{}); + + constexpr index_t d = + tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) / + Traits::PackedSize; + // printf("thread_idx_m: %d j: %d\n", idx_ys[number<0>{}] + 0, 0+j); + vec_value.template get_as()(j / Traits::PackedSize) = + dstr_tensor.get_thread_buffer().template at(); + }); + + // const vector_t vec_value = vec.template get_as().template at<0>(); + + // write into bottom tensor + get_bottom_tensor_view().template set_vectorized_elements( + bottom_tensor_thread_coord, + page_offset, + vec_value, + bool_constant{}); + // printf("coord_offset:%d, scatter_offset:%d \n", + // bottom_tensor_thread_coord.get_offset(), offset); move thread coordinate + if constexpr(iCoordAccess != (NumAccessPerCoord - 1)) + { + constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess); + + constexpr auto forward_step_scatter = generate_tuple( + [&](auto i) { return i == YsGatherDim ? 0 : idx_diff_ys[i]; }, number{}); + + constexpr auto idx_diff_ps_ys = container_concat( + generate_tuple([&](auto) { return number<0>{}; }, number{}), + forward_step_scatter); + + move_window_adaptor_and_bottom_tensor_thread_coordinate( + window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys); + } + }); + }); + } + + // move thread's botom tensor coordiante + // [x0', x1', ... ] ==> [offset] + // also move window-origin + CK_TILE_DEVICE void move(const BottomTensorIndex& step) + { + window_origin_ += step; + BottomTensorIndex step_new = step; + step_new(HsGatherDim) = 0; + static_for<0, NumCoord, 1>{}([&](auto iCoord) { + move_tensor_coordinate(bottom_tensor_view_.get_tensor_descriptor(), + pre_computed_coords_(iCoord)(I1), + step_new); + }); + } + + CK_TILE_DEVICE void update_page_idx(const PageIdxArray& new_idx) + { + page_idx_ = new_idx; + + // static_for<0, 2, 1>{}([&](auto k0) { + // printf("update tid %d %d \n", threadIdx.x, page_idx_[k0]); + // }); + } +// CK_TILE_DEVICE void set_window_origin(const BottomTensorIndex& new_window_origin) +// { +// window_origin_ = new_window_origin; + +// #if 0 // debug +// // TODO: this use more register for FA, but less register for GEMM +// // need investigation +// // only support warp-tile and block-tile +// static_assert(NDimP == 1 or NDimP == 2, "wrong!"); + +// WindowAdaptorCoord window_adaptor_thread_coord_tmp; + +// if constexpr(NDimP == 1) +// { +// window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate( +// tile_dstr_.get_ps_ys_to_xs_adaptor(), AdaptorTopIndex{get_lane_id(), 0}); +// } +// else if constexpr(NDimP == 2) +// { +// window_adaptor_thread_coord_tmp = +// make_tensor_adaptor_coordinate(tile_dstr_.get_ps_ys_to_xs_adaptor(), +// AdaptorTopIndex{get_warp_id(), get_lane_id(), 0}); +// } +// #else +// // TODO: this use less register for FA, but more register for GEMM +// // need investigation +// const auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate( +// tile_dstr_.get_ps_ys_to_xs_adaptor(), +// container_concat(detail::get_partition_index(tile_dstr_), array{0})); +// #endif + +// BottomTensorIndex bottom_tensor_thread_origin_idx_tmp = +// window_origin_ + window_adaptor_thread_coord_tmp.get_bottom_index(); + +// const auto bottom_tensor_thread_coord_tmp = make_tensor_coordinate( +// bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp); + +// // pre-compute NumCoord (WindowAdaptorCoord, BottomTensorCoord) bundles to speed up +// // future load/store() calls (might allocate more registers) +// using Traits = load_store_traits; +// using SFC_Ys = typename Traits::SFC_Ys; + +// static_for<0, NumCoord, 1>{}([&](auto iCoord) { +// auto window_adaptor_thread_coord = window_adaptor_thread_coord_tmp; +// auto bottom_tensor_thread_coord = bottom_tensor_thread_coord_tmp; + +// constexpr auto idx_diff_ys = +// SFC_Ys::get_step_between(number<0>{}, number{}); + +// constexpr auto idx_diff_ps_ys = container_concat( +// generate_tuple([&](auto) { return number<0>{}; }, number{}), idx_diff_ys); + +// move_window_adaptor_and_bottom_tensor_thread_coordinate( +// window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys); + +// pre_computed_coords_(iCoord) = +// make_tuple(window_adaptor_thread_coord, bottom_tensor_thread_coord); +// }); +// } + + CK_TILE_HOST_DEVICE void init_raw() { bottom_tensor_view_.init_raw(); } + + // this is the bottom tensor view + // [x0', x1', ...] ==> [offset] + BottomTensorView bottom_tensor_view_; + + // + WindowLengths window_lengths_; + + // origin ([x0', x1', ...]) of window on bottom tensor + BottomTensorIndex window_origin_; + + // Tile tensor distribution, which contains: + // 1. adaptor for window: [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...] + // 2. thread descriptor for thread tensor in register: [y0, y1, ...] ==> [d] + TileDstr tile_dstr_; + + PageIdxArray page_idx_; + + // this contains: + // per-thread coordinate for window adaptor + // per-thread coordinate for bottom tensor + array, NumCoord> pre_computed_coords_; +}; + +// TODO: use strategy +template +CK_TILE_DEVICE constexpr auto +make_tile_scatter_gather(const TensorView_& tensor_view, + const WindowLengths_& window_lengths, + const multi_index& origin, + const StaticTileDistribution_& tile_distribution, + const StaticPageIndexArray_& page_idx, + number = {}, + number = {}) +{ + return tile_scatter_gather, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + HsGatherDim, + NumCoord>{ + tensor_view, window_lengths, origin, tile_distribution, page_idx}; +} + +template +CK_TILE_DEVICE constexpr auto +make_tile_scatter_gather(const tile_window_with_static_lengths& tile_window, + const multi_index& origin, + const StaticTileDistribution& tile_distribution, + const StaticPageIndexArray& page_idx, + number = {}) +{ + return make_tile_scatter_gather(tile_window.get_bottom_tensor_view(), + tile_window.get_window_lengths(), + origin, + tile_distribution, + page_idx, + number{}); +} + +template +CK_TILE_DEVICE constexpr auto +make_tile_scatter_gather(const tile_window_with_static_lengths& tile_window, + const StaticTileDistribution& tile_distribution, const StaticPageIndexArray& page_idx, + number = {}) +{ + return make_tile_scatter_gather(tile_window.get_bottom_tensor_view(), + tile_window.get_window_lengths(), + tile_window.get_window_origin(), + tile_distribution, + page_idx, + number{}); +} + +// template +// CK_TILE_DEVICE constexpr auto +// make_tile_window_raw(const tile_window_with_static_lengths& tile_window, +// const StaticTileDistribution& tile_distribution) +// { +// auto w = make_tile_scatter_gather(tile_window.get_bottom_tensor_view(), +// tile_window.get_window_lengths(), +// tile_window.get_window_origin(), +// tile_distribution); +// w.init_raw(); +// return w; +// } + + +} // namespace ck_tile diff --git a/include/ck_tile/core/tensor/tile_window.hpp b/include/ck_tile/core/tensor/tile_window.hpp index 3bb728df23..2a3c9fdc64 100644 --- a/include/ck_tile/core/tensor/tile_window.hpp +++ b/include/ck_tile/core/tensor/tile_window.hpp @@ -609,6 +609,93 @@ struct tile_window_with_static_distribution }); } + // template + // CK_TILE_DEVICE void store(const static_distributed_tensor& dstr_tensor, + // const statically_indexed_array& offsets, + // number = {}, + // bool_constant = {}) const + // { + // using Traits = load_store_traits; + + // // using vector_type_t = typename Traits::vector_type_t; + // using vector_t = typename Traits::vector_t; + // using SFC_Ys = typename Traits::SFC_Ys; + + // constexpr auto tile_dstr = TileDstr{}; + + // // loop over thread tensor space [y0, y1, ...] + // static_for<0, NumCoord, 1>{}([&](auto iCoord) { + // auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0]; + // // auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1]; + + // BottomTensorIndex bottom_tensor_thread_origin_idx_tmp = + // window_origin_ + + // tuple(0, window_adaptor_thread_coord.get_bottom_index()[1]); + + // auto bottom_tensor_thread_coord = make_tensor_coordinate( + // bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp); + + // static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) { + // constexpr auto iAccess = number{}; + + // // data index [y0, y1, ...] + // constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess); + // constexpr auto idx_m = idx_ys_start[number<0>{}]; + // const auto offset = offsets[idx_m]; + + // // printf("idx_ys_start[0], idx_ys_start[1](%d, %d) \n", + // // idx_ys_start[number<0>{}]+0, idx_ys_start[number<1>{}]+0); + + // // read from distributed tensor + // // vector_type_t vec; + // vector_t vec_value; + + // static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) { + // constexpr auto idx_ys = generate_tuple( + // [&](auto jj) { + // return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j) + // : idx_ys_start[jj]; + // }, + // number{}); + + // constexpr index_t d = + // tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) / + // Traits::PackedSize; + // // printf("thread_idx_m: %d j: %d\n", idx_ys[number<0>{}] + 0, 0+j); + // vec_value.template get_as()(j / Traits::PackedSize) = + // dstr_tensor.get_thread_buffer().template at(); + // }); + + // // const vector_t vec_value = vec.template get_as().template at<0>(); + + // // write into bottom tensor + // get_bottom_tensor_view().template set_vectorized_elements( + // bottom_tensor_thread_coord, + // offset, + // vec_value, + // bool_constant{}); + // // printf("coord_offset:%d, scatter_offset:%d \n", + // // bottom_tensor_thread_coord.get_offset(), offset); move thread coordinate + // if constexpr(iCoordAccess != (NumAccessPerCoord - 1)) + // { + // constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess); + + // constexpr auto forward_step_scatter = generate_tuple( + // [&](auto i) { return i == 0 ? 0 : idx_diff_ys[i]; }, number{}); + + // constexpr auto idx_diff_ps_ys = container_concat( + // generate_tuple([&](auto) { return number<0>{}; }, number{}), + // forward_step_scatter); + + // move_window_adaptor_and_bottom_tensor_thread_coordinate( + // window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys); + // } + // }); + // }); + // } + template CK_TILE_DEVICE void store(const static_distributed_tensor& dstr_tensor, number = {}, @@ -1010,23 +1097,6 @@ make_tile_window_raw(const TensorView_& tensor_view, return w; } -template -CK_TILE_DEVICE void move_tile_window( - tile_window_with_static_distribution& window, - const typename tile_window_with_static_distribution::BottomTensorIndex& step) -{ - window.move(step); -} - /** * @brief This class provides description of tile windowed view on the device memory. * @@ -1155,13 +1225,4 @@ make_tile_window_raw(const tile_window_with_static_lengths -CK_TILE_DEVICE void move_tile_window( - tile_window_with_static_lengths& window, - const typename tile_window_with_static_lengths::BottomTensorIndex& - step) -{ - window.move(step); -} - } // namespace ck_tile diff --git a/include/ck_tile/core/tensor/tile_window_linear.hpp b/include/ck_tile/core/tensor/tile_window_linear.hpp index 1e24e660f6..6af6813e0c 100644 --- a/include/ck_tile/core/tensor/tile_window_linear.hpp +++ b/include/ck_tile/core/tensor/tile_window_linear.hpp @@ -1200,19 +1200,4 @@ make_tile_window_linear_raw(const TileWindow_& tile_window, LinearBottomDims_{}); } -template -CK_TILE_DEVICE void move_tile_window( - tile_window_linear& - window, - const typename tile_window_linear::BottomTensorIndex& step) -{ - window.move(step); -} - } // namespace ck_tile diff --git a/include/ck_tile/core/tensor/tile_window_utils.hpp b/include/ck_tile/core/tensor/tile_window_utils.hpp index 71a72329f8..a6d4fcde36 100644 --- a/include/ck_tile/core/tensor/tile_window_utils.hpp +++ b/include/ck_tile/core/tensor/tile_window_utils.hpp @@ -18,6 +18,14 @@ #pragma once namespace ck_tile { +template +CK_TILE_DEVICE void move_tile_window( + TileWindow_& window, + const typename TileWindow_::BottomTensorIndex& step) +{ + window.move(step); +} + // input a lds store tile, extract some information from it // used to set m0 value for gfx9 serious template diff --git a/include/ck_tile/host/reference/reference_fused_single_moe_gemm.hpp b/include/ck_tile/host/reference/reference_fused_single_moe_gemm.hpp new file mode 100644 index 0000000000..d98a1e899a --- /dev/null +++ b/include/ck_tile/host/reference/reference_fused_single_moe_gemm.hpp @@ -0,0 +1,236 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/host/host_tensor.hpp" + +namespace ck_tile { + +// template +// CK_TILE_HOST void reference_gemm(const HostTensor& a_m_k, +// const HostTensor& b_k_n, +// HostTensor& c_m_n, +// const AElementOp& a_element_op = {}, +// const BElementOp& b_element_op = {}, +// const ACCElementOp& acc_element_op = {}) +// { +// const std::size_t M = a_m_k.get_length(0); +// const std::size_t N = b_k_n.get_length(1); +// const std::size_t K = a_m_k.get_length(1); + +// auto f_mn = [&](auto m, auto n) { +// AccDataType v_acc = 0; + +// for(std::size_t k = 0; k < K; ++k) +// { +// AccDataType v_a; +// AccDataType v_b; +// if constexpr(std::is_same_v) +// { +// const pk_int4_t pk_val = a_element_op(a_m_k(m, k)); +// const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(pk_val); +// if(k % 2 == 1) +// v_a = fp32_val.hi; +// else +// v_a = fp32_val.lo; +// } +// else +// { +// v_a = ck_tile::type_convert(a_element_op(a_m_k(m, k))); +// } +// if constexpr(std::is_same_v) +// { +// const pk_int4_t pk_val = b_element_op(b_k_n(k, n)); +// const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(pk_val); +// if(k % 2 == 1) +// v_b = fp32_val.hi; +// else +// v_b = fp32_val.lo; +// } +// else +// { +// v_b = ck_tile::type_convert(b_element_op(b_k_n(k, n))); +// } +// v_acc += v_a * v_b; +// } + +// c_m_n(m, n) = ck_tile::type_convert(acc_element_op(v_acc)); +// }; + +// make_ParallelTensorFunctor(f_mn, M, N)(std::thread::hardware_concurrency()); +// } + +template +__global__ void naive_gemm_kernel(const ck_tile::index_t* p_sorted_token_ids_, + const ck_tile::index_t* p_sorted_expert_ids_, + const ck_tile::index_t* p_max_token_id_, + const ADataType* A, + const BDataType* B, + CDataType* C, + ck_tile::index_t Num_tokens, + ck_tile::index_t TopK, + ck_tile::index_t M, + ck_tile::index_t N, + ck_tile::index_t K, + ck_tile::index_t strideA, + ck_tile::index_t strideB, + ck_tile::index_t strideC) +{ + int idx = blockIdx.x * blockDim.x + threadIdx.x; + int row = idx / N; // Compute row index + int col = idx % N; // Compute column index + (void)Num_tokens; + // assert(p_sorted_expert_ids_ != nullptr); + // assert(TopK == 1); + // assert(Num_tokens == 128); + // if(Num_tokens == 128 && TopK == 1 && p_sorted_expert_ids_ != nullptr) {} + + // index_t max_tokens = p_max_token_id_[0]; + index_t gather_token_id = 0; + index_t scatter_token_id = 0; + index_t expert_id = 0; + + if(row < p_max_token_id_[0]) + { + expert_id = p_sorted_expert_ids_[row / 128]; + gather_token_id = p_sorted_token_ids_[row] & 0xffffff; + scatter_token_id = p_sorted_token_ids_[row] & 0xffffff; + if(!IsInputGemm) + { + gather_token_id = gather_token_id * TopK + (p_sorted_token_ids_[row] >> 24); + } + else + { + scatter_token_id = scatter_token_id * TopK + (p_sorted_token_ids_[row] >> 24); + } + } + else + { + return; + } + + if(row < M && col < N) + { + AccDataType acc = 0.0; + for(int k = 0; k < K; ++k) + { + constexpr index_t packed_size_a = ck_tile::numeric_traits::PackedSize; + constexpr index_t packed_size_b = ck_tile::numeric_traits::PackedSize; + // Adjust indexing based on matrix layout + int a_index = (std::is_same_v) + ? gather_token_id * strideA + k + : k * strideA + gather_token_id; + + // TODO: add experts weights dispatch + int b_index = + expert_id * N * K + ((std::is_same_v) + ? col * strideB + k + : k * strideB + col); + + AccDataType v_a; + AccDataType v_b; + if constexpr(std::is_same_v) + { + const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(A[a_index / packed_size_a]); + if(k % 2 == 1) + v_a = fp32_val.hi; + else + v_a = fp32_val.lo; + } + else + { + v_a = ck_tile::type_convert(A[a_index]); + } + if constexpr(std::is_same_v) + { + const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(B[b_index / packed_size_b]); + if(k % 2 == 1) + v_b = fp32_val.hi; + else + v_b = fp32_val.lo; + } + else + { + v_b = ck_tile::type_convert(B[b_index]); + } + acc += v_a * v_b; + } + + int c_index = (std::is_same_v) + ? scatter_token_id * strideC + col + : col * strideC + scatter_token_id; + C[c_index] = ck_tile::type_convert(acc); + } +} + +template +void reference_moe_gemm_gpu(const index_t* p_sorted_token_ids_, + const index_t* p_sorted_expert_ids_, + const index_t* p_max_token_id_, + const ADataType* a_ptr, + const BDataType* b_ptr, + CDataType* c_ptr, + index_t Num_tokens, + index_t TopK, + index_t M, + index_t N, + index_t K, + index_t stride_a, + index_t stride_b, + index_t stride_c) +{ + int totalElements = M * N; + int numThreadsPerBlock = 256; // Common choice for threads per block + int numBlocks = (totalElements + numThreadsPerBlock - 1) / numThreadsPerBlock; + + naive_gemm_kernel<<>>(p_sorted_token_ids_, + p_sorted_expert_ids_, + p_max_token_id_, + a_ptr, + b_ptr, + c_ptr, + Num_tokens, + TopK, + M, + N, + K, + stride_a, + stride_b, + stride_c); + + return; +} + +} // namespace ck_tile diff --git a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp index 0081edcb2e..575bc55d1b 100644 --- a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp @@ -6,7 +6,7 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" - +#include "ck_tile/core/tensor/tile_scatter_gather.hpp" namespace ck_tile { template + CK_TILE_DEVICE auto operator()(ODramWindow& out_dram_window, + const OAccTile& o_acc_tile, + void* p_smem, + const index_t* p_sorted_tokens_id, + index_t token_pos) + { + + const index_t iMWarp = get_warp_id() / kNWave; + const index_t iNWarp = get_warp_id() - iMWarp * kNWave; + + constexpr auto lds_block_desc = MakeLdsBlockDescriptor(); + auto o_lds_block = make_tensor_view( + static_cast(p_smem), lds_block_desc); + auto in_lds_window = + make_tile_window(o_lds_block, + make_tuple(number{}, number{}), + {number{} * iMWarp, number{} * iNWarp}); + auto out_lds_window = + make_tile_window(o_lds_block, + make_tuple(number{}, number{}), + {0, 0}); + + using SFC = space_filling_curve, + sequence<0, 1>, + sequence>; + constexpr index_t num_access = SFC::get_num_of_access(); + + using TileEncodingPattern = + TileDistributionEncodingPattern2D; + constexpr auto dram_tile_distribution = TileEncodingPattern::Make2DStaticTileDistribution(); + // auto coord = dram_tile_distribution.calculate_index(); + + // const auto& view = out_dram_window.get_bottom_tensor_view(); + + constexpr auto c_warp_y_lengths = + to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; + auto c_coord = dram_tile_distribution.calculate_index(); + + // printf("c_coord[0]:%d \n", c_coord[0]); + + CWarpTensor c_warp_in_tensor; + static_for<0, num_access, 1>{}([&](auto iAccess) { + constexpr auto idx_y_start = SFC::get_index(iAccess); + + // auto idx_m = number{})>{} + 0; + // printf("idx_y_start:%d \n", idx_m); + constexpr auto mIter = number{}) / (kMPerXdl * kMWave)>{}; + + statically_indexed_array offsets; + static_for<0, 2 /*CMrepeats*/, 1>{}([&](auto m0) { + auto token_id = token_pos + m0 + c_coord[0] + mIter * kMPerXdl * kMWave; + auto fused_token = p_sorted_tokens_id[token_id]; + + index_t token_offset = fused_token & 0xffffff; + + if constexpr(IsInputGemm) + { + token_offset = token_offset * 3 /*TopK*/ + (fused_token >> 24); + } + + offsets[m0] = token_offset * 4096; // Problem::kN_; + }); + // printf("c_coord[number<0>{}]: %d \n", coord[number<0>{}]); + // printf("mIter: %d", mIter+0); + + constexpr auto nIter = number{}) / (kNPerXdl * kNWave)>{}; + + // printf("mIter, nIter(%d, %d) \n", mIter+0, nIter+0); + + c_warp_in_tensor.get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + + const auto c_warp_in_tensor_casted = cast_tile(c_warp_in_tensor); + + block_sync_lds(); + store_tile(in_lds_window, c_warp_in_tensor_casted); + block_sync_lds(); + + const auto c_out_tensor = + load_tile(make_tile_window(out_lds_window, dram_tile_distribution)); + + if constexpr(out_memory_data_op == memory_operation_enum::set) + { + + + auto tile_window = make_tile_scatter_gather(out_dram_window.get_bottom_tensor_view(), + out_dram_window.get_window_lengths(), + out_dram_window.get_window_origin(), + dram_tile_distribution, + offsets); + + tile_window.store(c_out_tensor); + // store_tile(out_dram_window, c_out_tensor, offsets); + } + else + { + update_tile(out_dram_window, c_out_tensor); + } + if constexpr(iAccess != num_access - 1) + { + constexpr auto step = SFC::get_forward_step(iAccess); + move_tile_window(out_dram_window, {0, step.at(number<1>{})}); + // printf("step.at(number<0>{}), step.at(number<1>{}):,%d, %d", + // step.at(number<0>{})+0, step.at(number<1>{})+0); + } + }); + } + template diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index c671463252..46a013f5db 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -265,6 +265,7 @@ struct FmhaFwdKernel const int32_t* seqstart_q_ptr; const int32_t* seqstart_k_ptr; const int32_t* seqlen_k_ptr; + const int32_t* page_idx; }; using Kargs = std::conditional_t; @@ -596,6 +597,7 @@ struct FmhaFwdKernel const void* seqstart_q_ptr, const void* seqstart_k_ptr, const void* seqlen_k_ptr, + const void* page_idx_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, @@ -654,7 +656,8 @@ struct FmhaFwdKernel {}, // placeholder for dropout reinterpret_cast(seqstart_q_ptr), reinterpret_cast(seqstart_k_ptr), - reinterpret_cast(seqlen_k_ptr)}; + reinterpret_cast(seqlen_k_ptr), + reinterpret_cast(page_idx_ptr)}; if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { @@ -720,6 +723,7 @@ struct FmhaFwdKernel const void* seqstart_q_ptr, const void* seqstart_k_ptr, const void* seqlen_k_ptr, + const void* page_idx_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, @@ -758,6 +762,7 @@ struct FmhaFwdKernel seqstart_q_ptr, seqstart_k_ptr, seqlen_k_ptr, + page_idx_ptr, hdim_q, hdim_v, num_head_q, @@ -799,6 +804,7 @@ struct FmhaFwdKernel const void* seqstart_q_ptr, const void* seqstart_k_ptr, const void* seqlen_k_ptr, + const void* page_idx_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, @@ -837,6 +843,7 @@ struct FmhaFwdKernel seqstart_q_ptr, seqstart_k_ptr, seqlen_k_ptr, + page_idx_ptr, hdim_q, hdim_v, num_head_q, @@ -958,7 +965,7 @@ struct FmhaFwdKernel const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1); long_index_t batch_offset_q = 0; - long_index_t batch_offset_k = 0; + // long_index_t batch_offset_k = 0; long_index_t batch_offset_v = 0; long_index_t batch_offset_bias = 0; long_index_t batch_offset_randval = 0; @@ -972,7 +979,7 @@ struct FmhaFwdKernel const long_index_t key_start = kargs.seqstart_k_ptr[i_batch]; batch_offset_q = query_start * kargs.stride_q; - batch_offset_k = key_start * kargs.stride_k; + // batch_offset_k = key_start * kargs.stride_k; if constexpr(std::is_same_v) { batch_offset_v = key_start * kargs.stride_v; @@ -981,6 +988,8 @@ struct FmhaFwdKernel { batch_offset_v = key_start; } + + kargs.page_idx += key_start; if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { batch_offset_bias = query_start * kargs.stride_bias + key_start; @@ -1016,35 +1025,38 @@ struct FmhaFwdKernel kargs.seqlen_k = adjusted_seqstart_k_ptr[1] - adjusted_seqstart_k_ptr[0]; } } - else - { - batch_offset_q = static_cast(i_batch) * kargs.batch_stride_q; - batch_offset_k = static_cast(i_batch) * kargs.batch_stride_k; - batch_offset_v = static_cast(i_batch) * kargs.batch_stride_v; - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) - { - batch_offset_bias = static_cast(i_batch) * kargs.batch_stride_bias; - } - if constexpr(kStoreLSE) - { - batch_offset_lse = static_cast(i_batch) * kargs.batch_stride_lse; - } - if constexpr(kHasDropout) - { - batch_offset_randval = - static_cast(i_batch) * kargs.batch_stride_randval; - } - batch_offset_o = static_cast(i_batch) * kargs.batch_stride_o; - } + // else + // { + // batch_offset_q = static_cast(i_batch) * kargs.batch_stride_q; + // batch_offset_k = static_cast(i_batch) * kargs.batch_stride_k; + // batch_offset_v = static_cast(i_batch) * kargs.batch_stride_v; + // if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + // { + // batch_offset_bias = static_cast(i_batch) * kargs.batch_stride_bias; + // } + // if constexpr(kStoreLSE) + // { + // batch_offset_lse = static_cast(i_batch) * kargs.batch_stride_lse; + // } + // if constexpr(kHasDropout) + // { + // batch_offset_randval = + // static_cast(i_batch) * kargs.batch_stride_randval; + // } + // batch_offset_o = static_cast(i_batch) * kargs.batch_stride_o; + // } // for simplicity, batch stride we just modify the pointer const QDataType* q_ptr = reinterpret_cast(kargs.q_ptr) + static_cast(i_nhead) * kargs.nhead_stride_q + batch_offset_q; + // const KDataType* k_ptr = + // reinterpret_cast(kargs.k_ptr) + + // static_cast(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_k + + // batch_offset_k; const KDataType* k_ptr = reinterpret_cast(kargs.k_ptr) + - static_cast(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_k + - batch_offset_k; + static_cast(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_k; const VDataType* v_ptr = reinterpret_cast(kargs.v_ptr) + static_cast(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_v + @@ -1329,6 +1341,9 @@ struct FmhaFwdKernel position_encoding, kargs.scale_s, smem_ptr, + kargs.page_idx, + kargs.stride_k, + kargs.stride_v, dropout); } else @@ -1343,6 +1358,9 @@ struct FmhaFwdKernel position_encoding, kargs.scale_s, smem_ptr, + kargs.page_idx, + kargs.stride_k, + kargs.stride_v, dropout); } }(); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp index 8a4a925b81..75364f4138 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp @@ -8,6 +8,9 @@ #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp" #include "ck_tile/ops/fmha/block/block_dropout.hpp" #include "ck_tile/ops/reduce/block/block_reduce.hpp" +#include "ck_tile/core/tensor/tile_distribution.hpp" +#include "ck_tile/core/tensor/tile_scatter_gather.hpp" +// #include "ck_tile/core/tensor/tile_scatter_gather_debug.hpp" namespace ck_tile { @@ -45,6 +48,10 @@ struct BlockFmhaPipelineQRKSVS static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim; static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim; + static constexpr auto I0 = number<0>{}; + static constexpr auto I1 = number<1>{}; + static constexpr auto I2 = number<2>{}; + static constexpr auto I3 = number<3>{}; static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!"); static constexpr bool kIsGroupMode = Problem::kIsGroupMode; @@ -148,6 +155,9 @@ struct BlockFmhaPipelineQRKSVS PositionEncoding position_encoding, float scale_s, void* smem_ptr, + const index_t* page_idx, + const index_t stride_k, + const index_t stride_v, DropoutType& dropout) const { static_assert( @@ -241,11 +251,10 @@ struct BlockFmhaPipelineQRKSVS return o_acc; } } - auto k_dram_block_window = make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(), k_dram_block_window_tmp.get_window_lengths(), - {seqlen_k_start, 0}); + {seqlen_k_start, 0}); //todo fixme felix const auto bias_origin = bias_dram_block_window_tmp.get_window_origin(); auto bias_dram_window = @@ -257,11 +266,28 @@ struct BlockFmhaPipelineQRKSVS auto randval_dram_window = dropout.template MakeRandvalDramWindow( randval_dram_block_window_tmp, seqlen_k_start); + auto v_dist = Policy::template MakeVDramTileDistribution(); + auto v_coord = v_dist.calculate_index(); + const auto VPageIndexDim = I1; + using VDstrEncode = typename decltype(v_dist)::DstrEncode; + constexpr index_t V_KRepeat = VDstrEncode::hs_lengthss_[I1][I3]; + statically_indexed_array v_offsets; + static_for<0, V_KRepeat, 1>{}([&](auto k0) { + v_offsets[k0] = page_idx[v_coord[VPageIndexDim] + k0.value] * stride_v; + // printf("1tid %d %d %d %d %d\n", threadIdx.x, v_coord[VPageIndexDim], k0.value, page_idx[v_coord[VPageIndexDim] + k0.value], stride_v); + }); auto v_dram_window = - make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(), + make_tile_scatter_gather(v_dram_block_window_tmp.get_bottom_tensor_view(), v_dram_block_window_tmp.get_window_lengths(), {0, seqlen_k_start}, // TODO: hdim split? - Policy::template MakeVDramTileDistribution()); + v_dist, + v_offsets, + VPageIndexDim); + // auto v_dram_window = + // make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(), + // v_dram_block_window_tmp.get_window_lengths(), + // {0, seqlen_k_start}, // TODO: hdim split? + // v_dist); auto q_tile = tile_elementwise_in(q_element_func, q); @@ -275,12 +301,20 @@ struct BlockFmhaPipelineQRKSVS do { // STAGE 1, QK gemm - auto k_dram_window = make_tile_window( + auto k_dist = Policy::template MakeKDramTileDistribution(); + auto k_coord = k_dist.calculate_index(); + using KDstrEncode = typename decltype(k_dist)::DstrEncode; + constexpr index_t NRepeat = KDstrEncode::hs_lengthss_[I0][I0]; + statically_indexed_array k_offsets; + static_for<0, NRepeat, 1>{}([&](auto n0) { + k_offsets[n0] = page_idx[k_coord[0] + kN0 / NRepeat * n0.value] * stride_k; + }); + auto k_dram_window = make_tile_scatter_gather( k_dram_block_window.get_bottom_tensor_view(), k_dram_block_window.get_window_lengths(), k_dram_block_window.get_window_origin(), - Policy::template MakeKDramTileDistribution()); // K DRAM tile window for - // load + k_dist, + k_offsets); // K DRAM tile window for auto k_block_tile = load_tile(k_dram_window); { @@ -321,7 +355,13 @@ struct BlockFmhaPipelineQRKSVS }); } - const auto v_prefetch = load_tile(v_dram_window); // prefetch load v tile + const auto v_prefetch = v_dram_window.load(); // prefetch load v tile + // const auto v_prefetch = load_tile(v_dram_window); // prefetch load v tile + static_for<0, V_KRepeat, 1>{}([&](auto k0) { + v_offsets[k0] = page_idx[kK1 + v_coord[VPageIndexDim] + k0.value] * stride_v; + // printf("2tid %d %d %d %d\n", threadIdx.x, v_coord[VPageIndexDim], kK1 + v_coord[VPageIndexDim] + k0.value, page_idx[kK1 + v_coord[VPageIndexDim] + k0.value]); + }); + v_dram_window.update_page_idx(v_offsets); { // tail block_sync_lds(); gemm_0(s_acc, @@ -523,6 +563,12 @@ struct BlockFmhaPipelineQRKSVS { static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) { const auto v = load_tile(v_dram_window); // load next v + + static_for<0, V_KRepeat, 1>{}([&](auto k0) { + v_offsets[k0] = page_idx[kK1 * 2 + i_k1.value * kK1 + v_coord[VPageIndexDim] + k0.value] * stride_v; + // printf("3tid %d %d %d %d\n", threadIdx.x, v_coord[VPageIndexDim], kK1 * 2 + i_k1.value * kK1 + v_coord[VPageIndexDim] + k0.value, page_idx[kK1 + i_k1.value * kK1 + v_coord[VPageIndexDim] + k0.value]); + }); + v_dram_window.update_page_idx(v_offsets); block_sync_lds(); gemm_1(o_acc, get_slice_tile( @@ -556,6 +602,7 @@ struct BlockFmhaPipelineQRKSVS v_lds_window); block_sync_lds(); } + page_idx += kN0; } while(++i_total_loops < num_total_loop); // store lse @@ -626,6 +673,9 @@ struct BlockFmhaPipelineQRKSVS PositionEncoding position_encoding, float scale_s, void* smem_ptr, + const index_t* page_idx, + const index_t stride_k, + const index_t stride_v, DropoutType& dropout) const { return operator()(q_dram_block_window_tmp, @@ -646,6 +696,9 @@ struct BlockFmhaPipelineQRKSVS position_encoding, scale_s, smem_ptr, + page_idx, + stride_k, + stride_v, dropout); } }; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp index 26f7e46f9f..05e85cba83 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp @@ -623,7 +623,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy; constexpr index_t Banks = 32; // TODO: need change based on arch constexpr index_t PixelsPerRow = Banks * 4 / sizeof(VDataType); - constexpr index_t kKPack = GetSmemKPackV(); + constexpr index_t kKPack = GetSmemKPackV(); // static_assert(PixelsPerRow % kKPack == 0); constexpr index_t NPerRow = PixelsPerRow / kKPack; constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; @@ -783,19 +783,19 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy) { - constexpr index_t N1 = GetAlignmentV(); - constexpr index_t N0 = kNPerBlock / N1; // P + constexpr index_t N1 = GetAlignmentV(); // 8 + constexpr index_t N0 = kNPerBlock / N1; // P // 16 constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; static_assert(total_pixels % N1 == 0); // TODO: this is not always true? - constexpr index_t K3 = total_pixels / N1; - constexpr index_t kKPack = GetSmemKPackV(); + constexpr index_t K3 = total_pixels / N1; //2 + constexpr index_t kKPack = GetSmemKPackV(); // 8 static_assert(kKPack % K3 == 0); - constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave + constexpr index_t K2 = kKPack / K3; // //4 TODO: this dimention could be outside single wave if constexpr(get_warp_size() % (K2 * N0) == 0) { - constexpr index_t K1 = get_warp_size() / (K2 * N0); - constexpr index_t K0 = kBlockSize / get_warp_size(); + constexpr index_t K1 = get_warp_size() / (K2 * N0); // 2 + constexpr index_t K0 = kBlockSize / get_warp_size(); // 2 static_assert(kKPerBlock == K0 * K1 * K2 * K3); return make_static_tile_distribution( tile_distribution_encoding, diff --git a/script/cmake-ck-dev.sh b/script/cmake-ck-dev.sh index 0e57af7aef..3d5fc84743 100755 --- a/script/cmake-ck-dev.sh +++ b/script/cmake-ck-dev.sh @@ -19,7 +19,6 @@ cmake -D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ -D CMAKE_CXX_FLAGS="-std=c++17 -O3 -ftemplate-backtrace-limit=0 -fPIE -Wno-gnu-line-marker" \ -D CMAKE_BUILD_TYPE=Release \ --D BUILD_DEV=ON \ -D GPU_TARGETS=$GPU_TARGETS \ -D CMAKE_VERBOSE_MAKEFILE:BOOL=ON \ -D USE_BITINT_EXTENSION_INT4=OFF \