diff --git a/Dockerfile.pytorch b/Dockerfile.pytorch index 2d3856fa2d..112197d207 100644 --- a/Dockerfile.pytorch +++ b/Dockerfile.pytorch @@ -22,6 +22,7 @@ RUN groupadd -g 109 render && \ chmod -R a+rwx /tmp/pytorch && \ sudo usermod -aG irc jenkins && \ #install hipblaslt + cd /tmp && \ git clone --no-checkout --filter=blob:none https://github.com/ROCm/rocm-libraries.git && \ cd rocm-libraries && \ git checkout develop && \ @@ -29,4 +30,4 @@ RUN groupadd -g 109 render && \ git sparse-checkout set projects/hipblaslt shared/origami && \ cd projects/hipblaslt && \ git show --oneline -s && \ - CPLUS_INCLUDE_PATH="/opt/amdgpu/include/" ./install.sh -idc --architecture="gfx942;gfx950" -j 128 --skip_rocroller + CPLUS_INCLUDE_PATH="/opt/amdgpu/include/" ./install.sh -idc --use-system-packages --architecture="gfx942;gfx950" -j 128 --skip_rocroller diff --git a/example/ck_tile/01_fmha/CMakeLists.txt b/example/ck_tile/01_fmha/CMakeLists.txt index 35afb1181e..02d31d2324 100644 --- a/example/ck_tile/01_fmha/CMakeLists.txt +++ b/example/ck_tile/01_fmha/CMakeLists.txt @@ -10,7 +10,7 @@ if(NOT INST_TARGETS) endif() # validate user-specified fmha_fwd API list -set(FMHA_FWD_KNOWN_APIS "fwd;fwd_splitkv;fwd_appendkv;pagedkv_prefill") +set(FMHA_FWD_KNOWN_APIS "fwd;fwd_splitkv;fwd_appendkv;pagedkv_prefill;batch_prefill") set(FMHA_FWD_ENABLE_APIS "fwd" CACHE STRING "semicolon-separated list of APIs to generate (${FMHA_FWD_KNOWN_APIS}) & link, or \"all\".") if(BUILD_TESTING) @@ -48,7 +48,6 @@ set(FMHA_FWD_CODE_GEN_COMMON_ARGS --targets ${FMHA_TARGETS_ARG} --api ${FMHA_FWD_APIS} --optdim 32,64,80,128,256 - # --filter fmha_fwd... ) set(FMHA_BWD_CODE_GEN_COMMON_ARGS ${CMAKE_CURRENT_LIST_DIR}/generate.py @@ -174,6 +173,13 @@ else() list(APPEND FMHA_FWD_INTERFACE_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_PAGEDKV_API=0) endif() +# conditionally enable call to the batch_prefill API in fmha_fwd example and tests +if("batch_prefill" IN_LIST FMHA_FWD_ENABLE_APIS) + list(APPEND FMHA_FWD_INTERFACE_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_BATCH_PREFILL_API=1) +else() + list(APPEND FMHA_FWD_INTERFACE_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_BATCH_PREFILL_API=0) +endif() + # conditionally specify the use of OCP_FP8 if(CK_USE_OCP_FP8) list(APPEND FMHA_FWD_PRIVATE_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py b/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py index 35e8c1be49..7c3efb9c18 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py @@ -84,6 +84,7 @@ using fmha_trait_{F_idx} = ck_tile::TileFmhaBatchPrefillTraits<{F_spad}, {F_qscale}, {F_occupancy}, false, + {F_sink}, {F_page_size}, {F_kv_memory_layout}, {F_kv_lookup_table}>; @@ -124,7 +125,7 @@ using fmha_kernel_{F_idx} = ck_tile::FmhaBatchPrefillWithPagedKVCacheKernel; using trait_{F_idx} = fmha_fwd_batch_prefill_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, - {F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false, false, {F_page_size}, {F_kv_memory_layout}, {F_kv_lookup_table}>; + {F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false, false, {F_sink}, {F_page_size}, {F_kv_memory_layout}, {F_kv_lookup_table}>; #include @@ -201,9 +202,9 @@ FMHA_FWD_API_PER_HDIM_CASE = """ {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v }} """ -FMHA_FWD_API_INNER_DISPATCH = """ {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.qscale_type == {F_qscale_check}) && +FMHA_FWD_API_INNER_DISPATCH = """ {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.qscale_type == {F_qscale_check}) && (t.has_sink == {F_sink}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint}) && (t.kv_memory_layout == {F_kv_memory_layout}) && (t.kv_lookup_table == {F_kv_lookup_table}) && (t.page_size == {F_page_size})) {{ - using trait_ = fmha_fwd_batch_prefill_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false, false, {F_page_size}, {F_kv_memory_layout}, {F_kv_lookup_table}>; + using trait_ = fmha_fwd_batch_prefill_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false, false, {F_sink}, {F_page_size}, {F_kv_memory_layout}, {F_kv_lookup_table}>; return fmha_batch_prefill_(s, a); }} """ @@ -247,6 +248,7 @@ class FmhaFwdApiTrait: skpad: str dpad: str dvpad: str + sink: str # t/f constraint: CppConstraint kv_memory_layout: str kv_lookup_table: str @@ -343,6 +345,7 @@ class FmhaFwdPipeline: F_dropout: str # F_qscale: str # no/pertensor F_mask: str # value from MASK_MAP + F_sink: str # t/f (StreamLLM sink tokens) F_kv_memory_layout: str # F_kv_lookup_table: str # F_constraint: CppConstraint = field(default_factory=lambda: CppConstraint()) @@ -406,6 +409,11 @@ class FmhaFwdPipeline: else: n += "_nqscale" + if self.F_sink == "t": + n += "_sink" + else: + n += "_nsink" + n += "_" + self.F_kv_memory_layout + "_" + self.F_kv_lookup_table return n @@ -472,6 +480,7 @@ class FmhaFwdApiPool: trait.kv_lookup_table ], F_page_size=trait.page_size, + F_sink=BOOL_MAP[trait.sink], ) if_j = "if" if j == 0 else "else if" per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format( @@ -578,6 +587,7 @@ class FmhaFwdKernel: F_mode=MODE_MAP[self.F_mode], F_pipeline=FMHA_BATCH_PREFILL_PIPELINE_MAP[self.F_pipeline.tag], F_page_size=self.F_page_size, + F_sink=BOOL_MAP[self.F_pipeline.F_sink], ) @property @@ -617,6 +627,7 @@ class FmhaFwdKernel: skpad=self.F_pipeline.F_skpad, dpad=self.F_pipeline.F_dpad, dvpad=self.F_pipeline.F_dvpad, + sink=self.F_pipeline.F_sink, constraint=self.F_tile.F_constraint & self.F_pipeline.F_constraint, kv_memory_layout=self.F_pipeline.F_kv_memory_layout, kv_lookup_table=self.F_pipeline.F_kv_lookup_table, @@ -655,6 +666,7 @@ class KernelComponentFactory: bias, lse, dropout, + sink, kv_memory_layout, kv_lookup_table, ) in itertools.product( @@ -663,12 +675,13 @@ class KernelComponentFactory: BIAS_MAP.keys(), ["t", "f"], ["t", "f"], + ["t", "f"], SUPPORTED_KV_MEMORY_LAYOUT, SUPPORTED_KV_LOOKUP_TABLE, ): - pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, kv_memory_layout, kv_lookup_table)) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, sink, kv_memory_layout, kv_lookup_table)) # fmt: skip elif dtype in ["fp8bf16"]: - # no need lse/dropout kernels + # no need lse/dropout/sink kernels for ( logits, qscale, @@ -684,7 +697,7 @@ class KernelComponentFactory: SUPPORTED_KV_MEMORY_LAYOUT, SUPPORTED_KV_LOOKUP_TABLE, ): - pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, kv_memory_layout, kv_lookup_table)) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, "f", kv_memory_layout, kv_lookup_table)) # fmt: skip else: assert False return pipelines @@ -701,20 +714,34 @@ class CustomFactory(KernelComponentFactory): def get_fwd_blobs( - kernel_filter: Optional[str], receipt, optdim_list, mask_impl + kernel_filter: Optional[str], receipt, optdim_list, mask_impl, + targets: Optional[List[str]] = None ) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]: + # batch_prefill pipeline uses gfx9-specific async scatter-gather buffer addressing + # (amd_buffer_addressing.hpp raw buffer loads) that is not compatible with + # non-gfx9 architectures (gfx11/gfx12/gfx10 are wave32 and use different + # buffer instruction formats). Skip all batch_prefill kernels for non-gfx9 targets. + has_non_gfx9 = targets is not None and any( + not t.startswith("gfx9") for t in targets + ) # TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad # support this in future gen = list() api_pool = FmhaFwdApiPool(mask_impl) + if has_non_gfx9: + return api_pool, gen + for dtype in FWD_DTYPE_MAP.keys(): d = CustomFactory.get_hdim_tile_size_dict(dtype) if d is None: continue # for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]): for (hdim, tiles), mode in itertools.product(d.items(), MODE_MAP.keys()): + # batch_prefill pipeline requires group mode (static_assert in pipeline problem) + if mode != "group": + continue for tile, pipeline in itertools.product( tiles, CustomFactory.get_pipelines(dtype, hdim, receipt, mask_impl) ): @@ -829,7 +856,7 @@ def write_blobs( optdim_list, mask_impl, ) -> None: - api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl) + api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl, targets) for kernel in kernels: write_single_fwd_kernel(kernel, output_dir) write_fwd_api(api_pool, output_dir) @@ -844,7 +871,7 @@ def list_blobs( mask_impl, ) -> None: with file_path.open("a") as f: - _, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl) + _, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl, targets) for kernel in kernels: f.write((file_path.parent / GEN_DIR / kernel.filename).as_posix() + "\n") f.write((file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME).as_posix() + "\n") diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index 7d7d01bd05..6c842def58 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -1452,6 +1452,7 @@ template + kHasSink_> { static constexpr auto kKVMemoryLayout = kKVMemoryLayout_; static constexpr auto kKVLookupTable = kKVLookupTable_; diff --git a/example/ck_tile/01_fmha/fmha_fwd_runner.hpp b/example/ck_tile/01_fmha/fmha_fwd_runner.hpp index 40b8006381..e01555193f 100644 --- a/example/ck_tile/01_fmha/fmha_fwd_runner.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd_runner.hpp @@ -387,7 +387,7 @@ fwd_result fmha_fwd_run(mode_enum mode, } #if(!(CK_TILE_FMHA_FWD_APPENDKV_API || CK_TILE_FMHA_FWD_SPLITKV_API || \ - CK_TILE_FMHA_FWD_PAGEDKV_API)) + CK_TILE_FMHA_FWD_PAGEDKV_API || CK_TILE_FMHA_FWD_BATCH_PREFILL_API)) if(0 < page_block_size) { std::cerr << "paged-kvcache is not supported. ignoring the 'page_block_size' option" @@ -395,7 +395,11 @@ fwd_result fmha_fwd_run(mode_enum mode, page_block_size = 0; } #endif - if(!(page_block_size % 128 == 0)) + // batch_prefill supports flexible page sizes (not just multiples of 128) + const bool need_128_aligned_page = + (CK_TILE_FMHA_FWD_APPENDKV_API || CK_TILE_FMHA_FWD_SPLITKV_API || + CK_TILE_FMHA_FWD_PAGEDKV_API); + if(need_128_aligned_page && 0 < page_block_size && !(page_block_size % 128 == 0)) { std::cerr << "only paged-kvcache block size divisible by 128 are currently supported" << std::endl; @@ -972,9 +976,10 @@ fwd_result fmha_fwd_run(mode_enum mode, ck_tile::DeviceMem seqlen_q_buf(has_group_q_padding ? seqlen_qs.size() * sizeof(int32_t) : 0); // Buffers for key/value per-sequence logical (unpadded) lengths (used in batch mode with // kvcache or group mode with padding enabled) - ck_tile::DeviceMem seqlen_k_buf((mode == mode_enum::batch && use_kvcache) || has_group_k_padding - ? seqlen_ks.size() * sizeof(int32_t) - : 0); + // batch_prefill (group+kvcache) also needs per-batch seqlen_k for VLLM_BLOCK_TABLE_2D + const bool need_seqlen_k_buf = (mode == mode_enum::batch && use_kvcache) || + has_group_k_padding || (mode == mode_enum::group && use_kvcache); + ck_tile::DeviceMem seqlen_k_buf(need_seqlen_k_buf ? seqlen_ks.size() * sizeof(int32_t) : 0); ck_tile::DeviceMem cu_seqlen_q_buf(cuq_cum.empty() ? 0 : cuq_cum.size() * sizeof(ck_tile::index_t)); ck_tile::DeviceMem cu_seqlen_kv_buf( @@ -1013,9 +1018,7 @@ fwd_result fmha_fwd_run(mode_enum mode, cu_seqlen_q_buf.ToDevice(cuq_cum.empty() ? nullptr : cuq_cum.data()); cu_seqlen_kv_buf.ToDevice(cukv_cum.empty() ? nullptr : cukv_cum.data()); seqlen_q_buf.ToDevice(has_group_q_padding ? seqlen_qs.data() : nullptr); - seqlen_k_buf.ToDevice((mode == mode_enum::batch && use_kvcache) || has_group_k_padding - ? seqlen_ks.data() - : nullptr); + seqlen_k_buf.ToDevice(need_seqlen_k_buf ? seqlen_ks.data() : nullptr); cache_seqlen_k_buf.ToDevice(need_append_kvcache ? cache_seqlen_ks.data() : nullptr); rotary_cos_buf.ToDevice(rotary_cos_host.data()); rotary_sin_buf.ToDevice(rotary_sin_host.data()); @@ -1146,6 +1149,17 @@ fwd_result fmha_fwd_run(mode_enum mode, { traits.use_pagedkv = (0 < page_block_size); } + else if constexpr(std::is_same_v>) + { + traits.has_dropout = (p_drop > 0.0f); + traits.qscale_type = qscale.type; + traits.kv_memory_layout = + ck_tile::BlockAttentionKVCacheMemoryLayoutEnum::LINEAR_LAYOUT; + traits.kv_lookup_table = + ck_tile::BlockAttentionKVCacheLookupTableEnum::VLLM_BLOCK_TABLE_2D; + traits.page_size = page_block_size; + } } }; @@ -1498,6 +1512,67 @@ fwd_result fmha_fwd_run(mode_enum mode, ? seqlen_k_buf.GetDeviceBuffer() : nullptr); } + else if constexpr(std::is_same_v>) + { + // Fields already set by the outer else block above: + // bias_ptr, lse_ptr, o_ptr, seqlen_k, max_seqlen_q, scale_s, + // logits_soft_cap, stride_bias/o, nhead/batch stride for bias/lse/o, + // window_size_left/right, sink_size, mask_type. + + // scale_p/scale_o: batch_prefill-specific fields absent from fmha_fwd_args. + args.scale_p = 1.f; + args.scale_o = 1.f; + + // Dropout fields: the outer fmha_fwd_args branch sets these; set them here + // for batch_prefill since it takes a separate inner branch. + args.rand_val_ptr = randval_buf.GetDeviceBuffer(); + args.stride_randval = stride_randval; + args.nhead_stride_randval = nhead_stride_randval; + args.batch_stride_randval = batch_stride_randval; + args.p_drop = p_drop; + args.s_randval = s_randval; + if(drop_prefs) + args.drop_seed_offset = std::make_pair(drop_seed_buf.GetDeviceBuffer(), + drop_offset_buf.GetDeviceBuffer()); + else + args.drop_seed_offset = std::make_pair(drop_seed, drop_offset); + + // Paged KV: LINEAR_LAYOUT + VLLM_BLOCK_TABLE_2D + // block_table_buf: [batch, max_blocks_per_seq] of physical page ids + // seqlen_k_buf: [batch] of per-batch seqlen_k values + args.num_total_pages = max_num_page_blocks; + args.page_block_size = page_block_size; + args.kv_memory_layout = + ck_tile::BlockAttentionKVCacheMemoryLayoutEnum::LINEAR_LAYOUT; + args.kv_lookup_table = + ck_tile::BlockAttentionKVCacheLookupTableEnum::VLLM_BLOCK_TABLE_2D; + args.kv_indptr = nullptr; + args.kv_page_indices = block_table_buf.GetDeviceBuffer(); + args.kv_last_page_lens = nullptr; + args.seqlen_k_ptr = seqlen_k_buf.GetDeviceBuffer(); + args.batch_stride_block_table = batch_stride_block_table; + + // group mode required: seqstart_q is prefix-sum of per-batch seqlen_q + args.seqstart_q_ptr = seqstart_q_buf.GetDeviceBuffer(); + + // batch_prefill LINEAR_LAYOUT strides for runner's K layout + // [max_num_page_blocks, nhead_k, page_block_size, hdim]: + // stride_k = hdim_q (token stride within one head's page slice) + // nhead_stride_k = page_block_size * hdim_q (head stride) + // batch_stride_k = nhead_k * page_block_size * hdim_q (page stride, already set) + args.stride_k = hdim_q; + args.nhead_stride_k = page_block_size * hdim_q; + // V is row-major, same layout convention + args.stride_v = hdim_v; + args.nhead_stride_v = page_block_size * hdim_v; + + // descale: not used for fp16/bf16 + args.q_descale_ptr = nullptr; + args.k_descale_ptr = nullptr; + args.v_descale_ptr = nullptr; + args.nblock_stride_kv_block_descale = 0; + args.nhead_stride_kv_block_descale = 0; + } } }; @@ -1524,6 +1599,21 @@ fwd_result fmha_fwd_run(mode_enum mode, } auto run_fwd = [&](const ck_tile::stream_config& sc) { +#if CK_TILE_FMHA_FWD_BATCH_PREFILL_API + // batch_prefill: group mode + paged KV, tested against the same CPU reference + if(1 == num_splits && use_kvcache && mode == mode_enum::group) + { + fmha_batch_prefill_traits bp_traits; + init_traits(bp_traits); + + fmha_batch_prefill_args bp_args; + init_args(bp_args); + + const float ave_time = fmha_batch_prefill(bp_traits, bp_args, sc); + if(ave_time >= 0.0f) + return ave_time; + } +#endif // CK_TILE_FMHA_FWD_BATCH_PREFILL_API #if CK_TILE_FMHA_FWD_PAGEDKV_API if(1 == num_splits && use_kvcache) { @@ -1844,7 +1934,8 @@ fwd_result fmha_fwd_run(mode_enum mode, q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host_ref_ro(i); }); } #endif -#if CK_TILE_FMHA_FWD_SPLITKV_API || CK_TILE_FMHA_FWD_PAGEDKV_API +#if CK_TILE_FMHA_FWD_SPLITKV_API || CK_TILE_FMHA_FWD_PAGEDKV_API || \ + CK_TILE_FMHA_FWD_BATCH_PREFILL_API if(0 < page_block_size) { // clang-format off @@ -1895,7 +1986,8 @@ fwd_result fmha_fwd_run(mode_enum mode, }); } #endif -#if CK_TILE_FMHA_FWD_SPLITKV_API || CK_TILE_FMHA_FWD_PAGEDKV_API +#if CK_TILE_FMHA_FWD_SPLITKV_API || CK_TILE_FMHA_FWD_PAGEDKV_API || \ + CK_TILE_FMHA_FWD_BATCH_PREFILL_API if(0 < page_block_size) { if(is_v_rowmajor) diff --git a/experimental/builder/include/ck_tile/builder/testing/conv/ck_tile.hpp b/experimental/builder/include/ck_tile/builder/testing/conv/ck_tile.hpp index 914c988d09..6eece48831 100644 --- a/experimental/builder/include/ck_tile/builder/testing/conv/ck_tile.hpp +++ b/experimental/builder/include/ck_tile/builder/testing/conv/ck_tile.hpp @@ -118,14 +118,11 @@ template ) { - if(kargs.k_batch > 1) - { - ck_tile::hip_check_error( - hipMemsetAsync(kargs.in_ptr, - 0, - zeroing_size * sizeof(typename Types::EDataType), - s_conf.stream_id_)); - } + ck_tile::hip_check_error( + hipMemsetAsync(kargs.in_ptr, + 0, + zeroing_size * sizeof(typename Types::EDataType), + s_conf.stream_id_)); } }; diff --git a/experimental/grouped_convolution_tile_instances/generate_instances.py b/experimental/grouped_convolution_tile_instances/generate_instances.py index 796e6b9158..b60e92c728 100755 --- a/experimental/grouped_convolution_tile_instances/generate_instances.py +++ b/experimental/grouped_convolution_tile_instances/generate_instances.py @@ -566,14 +566,12 @@ def parse_bwd_data_instances(instances, problem_name): if pipeline_version == "V6": print(f"Skipping instance {instance_id} with V6 since it's not supported yet.") continue - - # Check vector sizes for A and B tensors - we cannot oversubscribe. - num_tile_elements_a = m_per_xdl * k_per_xdl - num_tile_elements_b = n_per_xdl * k_per_xdl - max_vector_size_a = max(1, num_tile_elements_a // block_size) - max_vector_size_b = max(1, num_tile_elements_b // block_size) - a_scalar_per_vector = min(a_scalar_per_vector, max_vector_size_a) - b_scalar_per_vector = min(b_scalar_per_vector, max_vector_size_b) + if k_per_block > (warp_size * a_scalar_per_vector) or n_per_block > (warp_size * b_scalar_per_vector): + print(f"Skipping instance {instance_id} with multiple warps per continous tile dim since it's not supported yet.") + continue + if a_scalar_per_vector > (m_per_block * k_per_block) // block_size or b_scalar_per_vector > (n_per_block * k_per_block) // block_size: + print(f"Skipping instance {instance_id} because current scalar per vector exceedes tile size") + continue conv = ConvInstanceTemplateParams( spec, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_common.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_common.hpp index 6e047dd64a..2f9a9cd21b 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_common.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_common.hpp @@ -28,8 +28,9 @@ namespace ck { enum Activation { - gelu_and_mul = 0, - silu_and_mul = 1 + gelu_and_mul = 0, + silu_and_mul = 1, + swiglustep_and_mul = 2 }; template , pk_i4_t>) + { + gate *= 16; + up *= 16; + } + tensor_operation::element_wise::Silu{}(gate, gate); + gate = gate < 7.0f ? gate : 7.0f; + up = up < 7.0f ? (up > -7.0f ? up : -7.0f) : 7.0f; + c_thread_buf(cidx) = gate * up; + } else if(ActivationOperation == Activation::gelu_and_mul) { float gate = c_thread_buf[cidx]; @@ -2118,6 +2137,25 @@ struct GridwiseMoeGemmBlockScale tensor_operation::element_wise::Silu{}(gate, gate); c_thread_buf(cidx) = gate * up; } + else if constexpr(ActivationOperation == Activation::swiglustep_and_mul) + { + float gate = c_thread_buf[cidx]; + float up = c_thread_buf_up[cidx]; + if constexpr(MulRoutedWeight) + { + gate = gate * topk_weight; + up = up * topk_weight; + } + if constexpr(is_same_v, pk_i4_t>) + { + gate *= 16; + up *= 16; + } + tensor_operation::element_wise::Silu{}(gate, gate); + gate = gate < 7.0f ? gate : 7.0f; + up = up < 7.0f ? (up > -7.0f ? up : -7.0f) : 7.0f; + c_thread_buf(cidx) = gate * up; + } else if(ActivationOperation == Activation::gelu_and_mul) { float gate = c_thread_buf[cidx]; diff --git a/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp index c6628f66be..a523acd291 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp @@ -759,18 +759,19 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel kargs.sink_ptr != nullptr ? (*(static_cast(kargs.sink_ptr) + i_nhead)) / kargs.scale_s : -numeric::infinity(); - const index_t seqlen_k = [&]() { + // WA i_batch capture structure binding before c++20 + const index_t seqlen_k = [&, i_batch_ = i_batch]() { if constexpr(kKVLookupTable == BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D) { - const int32_t page_start = kargs.page_table.kv_indptr[i_batch]; - const int32_t page_end = kargs.page_table.kv_indptr[i_batch + 1]; + const int32_t page_start = kargs.page_table.kv_indptr[i_batch_]; + const int32_t page_end = kargs.page_table.kv_indptr[i_batch_ + 1]; const int32_t num_page_blocks = page_end - page_start; const int32_t last_page_len = [&]() { if constexpr(kPageBlockSize == 1) return static_cast(kPageBlockSize); else - return kargs.page_table.kv_last_page_lens[i_batch]; + return kargs.page_table.kv_last_page_lens[i_batch_]; }(); return num_page_blocks > 0 ? static_cast((num_page_blocks - 1) * kargs.page_block_size + @@ -780,21 +781,22 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel else // BlockAttentionKVCacheLookupTableEnum::VLLM_BLOCK_TABLE_2D { if(kargs.page_table.seqlen_k_ptr != nullptr) - return static_cast(kargs.page_table.seqlen_k_ptr[i_batch]); + return static_cast(kargs.page_table.seqlen_k_ptr[i_batch_]); else return kargs.seqlen_k; } }(); - const int32_t* page_idx = [&]() { + // WA i_batch capture structure binding before c++20 + const int32_t* page_idx = [&, i_batch_ = i_batch]() { if constexpr(kKVLookupTable == BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D) { - return kargs.page_table.kv_page_indices + kargs.page_table.kv_indptr[i_batch]; + return kargs.page_table.kv_page_indices + kargs.page_table.kv_indptr[i_batch_]; } else // BlockAttentionKVCacheLookupTableEnum::VLLM_BLOCK_TABLE_2D { return kargs.page_table.block_table_ptr + - static_cast(i_batch) * + static_cast(i_batch_) * kargs.page_table.batch_stride_block_table; } }(); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp index a8b94b6e41..4f2d3d58c2 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp @@ -291,6 +291,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync static constexpr bool kHasDropout = Problem::kHasDropout; static constexpr auto kKVMemoryLayout = Problem::kKVMemoryLayout; static constexpr auto QScaleEnum = Problem::QScaleEnum; + static constexpr bool kHasSink = Problem::kHasSink; // For KV_BLOCKSCALE: shift value for exp2(x + shift) to scale P to [0, 2^shift] // This avoids explicit P *= scale_p and v_descale /= scale_p operations @@ -546,11 +547,25 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync } __builtin_amdgcn_sched_barrier(0); - const auto q_origin = q_dram_window.get_window_origin(); - const auto [seqlen_k_start, seqlen_k_end] = - mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number{}, number{}); - - const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0); + const auto q_origin = q_dram_window.get_window_origin(); + const auto tile_range_result = [&mask, &q_origin]() { + if constexpr(kHasSink) + return mask.GetSinkTileRangeAlongX( + q_origin.at(number<0>{}), number{}, number{}); + else + { + auto [start, end] = + mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number{}, number{}); + return ck_tile::make_tuple(0, start, end); + } + }(); + const auto sink_seq_end = tile_range_result.get(ck_tile::number<0>{}); + const auto seqlen_k_start = tile_range_result.get(ck_tile::number<1>{}); + const auto seqlen_k_end = tile_range_result.get(ck_tile::number<2>{}); + const auto num_sink_loop = integer_divide_ceil(sink_seq_end, kN0); + const auto kv_load_start = (sink_seq_end == 0 && seqlen_k_start > 0) ? seqlen_k_start : 0; + const auto num_total_loop = + integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0) + num_sink_loop; // check early exit if no work to do if constexpr(FmhaMask::IsMasking || kPadSeqLenK) @@ -576,7 +591,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync auto k_dram_block_window = make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(), k_dram_block_window_tmp.get_window_lengths(), - {seqlen_k_start, 0}); + {kv_load_start, 0}); auto k_dist = Policy::template MakeKDramTileDistribution(); auto k_coord = k_dist.calculate_index(); @@ -585,7 +600,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync // kPageBlockSize >= kN0: within-page offset only (SRD rebased per page via rebase_k_window) // kPageBlockSize < kN0: global offset, must fit int32 statically_indexed_array k_offsets; - index_t current_seq_k = seqlen_k_start; + index_t current_seq_k = kv_load_start; // Load physical pages first, then compute offsets. // k_physical_pages can be reused for descale lookup later. @@ -668,11 +683,11 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync auto bias_dram_window = make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(), bias_dram_block_window_tmp.get_window_lengths(), - {bias_origin.at(number<0>{}), seqlen_k_start}, // M/N + {bias_origin.at(number<0>{}), kv_load_start}, // M/N Policy::template MakeBiasDramTileDistribution()); auto randval_dram_window = dropout.template MakeRandvalDramWindow( - randval_dram_block_window_tmp, seqlen_k_start); + randval_dram_block_window_tmp, kv_load_start); auto v_dist = Policy::template MakeVDramTileDistribution(); auto v_coord = v_dist.calculate_index(); @@ -895,7 +910,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync auto v_dram_window = make_tile_scatter_gather(v_dram_block_window_tmp.get_bottom_tensor_view(), v_dram_block_window_tmp.get_window_lengths(), - {0, seqlen_k_start}, // TODO: hdim split? + {0, kv_load_start}, // TODO: hdim split? v_dist, v_offsets, number<1>{}, // HsGatherDim @@ -1097,6 +1112,11 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync #endif } } + if constexpr(kHasSink) + { + if(i_total_loops == num_sink_loop - 1) + move_tile_window(bias_dram_window, {0, seqlen_k_start - sink_seq_end}); + } move_tile_window(bias_dram_window, {0, kN0}); if constexpr(kPadSeqLenK || FmhaMask::IsMasking) { @@ -1108,19 +1128,36 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync if(need_perpixel_check) { - set_tile_if( - s_acc, -numeric::infinity(), [&](auto tile_idx) { - const auto row = - q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); - const auto col = - k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); - return !variant.LogitsMask(variant_params, - block_indices.batch_idx, - row, - col, - block_indices.qo_head_idx, - block_indices.kv_head_idx); + auto apply_mask = [&](auto&& mask_func) { + set_tile_if(s_acc, + -numeric::infinity(), + [&](auto tile_idx) { + const auto row = + q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = + k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + return !mask_func(variant_params, + block_indices.batch_idx, + row, + col, + block_indices.qo_head_idx, + block_indices.kv_head_idx); + }); + }; + + if constexpr(kHasSink) + { + apply_mask([&](auto&&... args) { + return variant.LogitsSinkMask( + std::forward(args)...); }); + } + else + { + apply_mask([&](auto&&... args) { + return variant.LogitsMask(std::forward(args)...); + }); + } } } @@ -1297,12 +1334,23 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync { auto randval_ptr = reinterpret_cast(smem_ptr) + Policy::template GetSmemSizeKV(); + index_t seq_offset = [&]() { + if constexpr(kHasSink) + { + const bool in_sink_phase = (num_sink_loop > i_total_loops); + if(i_total_loops == num_sink_loop) + move_tile_window(randval_dram_window, + {0, seqlen_k_start - sink_seq_end}); + return in_sink_phase + ? (kv_load_start + i_total_loops * kN0) + : (seqlen_k_start + (i_total_loops - num_sink_loop) * kN0); + } + else + return seqlen_k_start + i_total_loops * kN0; + }(); dropout .template Run( - randval_ptr, - seqlen_k_start + i_total_loops * kN0, - p_compute, - randval_dram_window); + randval_ptr, seq_offset, p_compute, randval_dram_window); } #if CK_TILE_FMHA_FLOAT_TO_FLOAT16_RTN @@ -1396,9 +1444,19 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync i_total_loops++; if(i_total_loops < num_total_loop) { - current_seq_k += kN0; + // For sink: after the last sink tile, jump K/V to seqlen_k_start; + // otherwise advance by one normal tile. + const index_t k_advance = [&]() -> index_t { + if constexpr(kHasSink) + return (i_total_loops == num_sink_loop) + ? (seqlen_k_start - sink_seq_end + kN0) + : kN0; + else + return kN0; + }(); + current_seq_k += k_advance; // move K tile windows - move_tile_window(k_dram_block_window, {kN0, 0}); + move_tile_window(k_dram_block_window, {k_advance, 0}); k_dram_window.set_window_origin(k_dram_block_window.get_window_origin()); // KV_BLOCKSCALE: reload physical pages for the new tile @@ -1427,6 +1485,21 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync k_physical_pages, stride_k, page_stride_k, k_coord, k_offsets, current_seq_k); k_dram_window.update_page_idx(k_offsets); rebase_k_window(k_dram_window, k_physical_pages[number<0>{}]); + + // After sink→window transition (i_total_loops == num_sink_loop), V window + // was advanced by kN0 (one normal iter), but current_seq_k jumped by k_advance + // = seqlen_k_start - sink_seq_end + kN0 > kN0. Re-init V to current_seq_k. + if constexpr(kHasSink) + { + if(i_total_loops == num_sink_loop && num_sink_loop > 0) + { + prefetch_v_physical_pages(number<0>{}); + update_v_offsets(number<0>{}); + v_dram_window.update_page_idx(v_offsets); + rebase_v_window(v_dram_window, v_physical_pages[number<0>{}]); + } + } + if constexpr(k1_loops >= 2 && LdsSeq.at(number<0>{}) == LdsSeq.at(number{})) __builtin_amdgcn_s_barrier(); diff --git a/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp b/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp index 0670985e4f..7df39c3d11 100644 --- a/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp +++ b/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp @@ -53,6 +53,7 @@ template + kHasSink_> { static constexpr auto kKVMemoryLayout = kKVMemoryLayout_; static constexpr auto kKVLookupTable = kKVLookupTable_; diff --git a/test/grouped_convnd_bwd_data/CMakeLists.txt b/test/grouped_convnd_bwd_data/CMakeLists.txt index 514f8e9668..7a318b4c19 100644 --- a/test/grouped_convnd_bwd_data/CMakeLists.txt +++ b/test/grouped_convnd_bwd_data/CMakeLists.txt @@ -22,6 +22,17 @@ if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12") target_link_libraries(test_grouped_conv_bwd_data_scale PRIVATE gtest_main getopt::getopt utility device_grouped_conv3d_bwd_data_scale_instance) endif() +if(GPU_TARGETS MATCHES "gfx9") + if(CK_EXPERIMENTAL_BUILDER) + add_gtest_executable(test_grouped_convnd_bwd_data_tile test_grouped_convnd_bwd_data_tile.cpp) + target_compile_options(test_grouped_convnd_bwd_data_tile PRIVATE -Wno-global-constructors -Wno-undef -Wno-c++20-compat) + target_link_libraries(test_grouped_convnd_bwd_data_tile PRIVATE gtest_main getopt::getopt utility) + if(TARGET device_grouped_conv_bwd_data_tile_instances) + target_link_libraries(test_grouped_convnd_bwd_data_tile PRIVATE device_grouped_conv_bwd_data_tile_instances) + endif() + endif() +endif() + if (CK_USE_XDL OR CK_USE_WMMA) add_gtest_executable(test_grouped_convnd_bwd_data_interface test_grouped_convnd_bwd_data_interface.cpp) if(result EQUAL 0) diff --git a/test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data_tile.cpp b/test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data_tile.cpp new file mode 100644 index 0000000000..0b1c6e55f7 --- /dev/null +++ b/test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data_tile.cpp @@ -0,0 +1,258 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include +#include +#include +#include + +#include "ck_tile/builder/testing/conv/ck_tile.hpp" +#include "ck_tile/host/device_prop.hpp" +#include "profiler/grouped_convolution_backward_data_tile_algs.hpp" + +static ck::index_t args_mask = 0xffff; +static ck::index_t instance_index = -1; + +namespace ckb = ck_tile::builder; +namespace ckt = ck_tile::builder::test; +namespace ckp = ck_tile::builder::profiling; + +template +struct SignatureDetails +{ + static constexpr ck_tile::index_t num_spatial_dim = num_spatial_dim_; + static constexpr ckb::DataType data_type = data_type_; + static constexpr ckb::DataType acc_data_type = acc_data_type_; + static constexpr ckb::TensorLayout in_layout = in_layout_; + static constexpr ckb::TensorLayout wei_layout = wei_layout_; + static constexpr ckb::TensorLayout out_layout = out_layout_; +}; + +template +class TestGroupedConvndBwdDataTile : public ::testing::Test +{ + protected: + static constexpr auto SIGNATURE = + ckt::ConvSignature{.spatial_dim = SignatureDetailsType::num_spatial_dim, + .direction = ckb::ConvDirection::BACKWARD_DATA, + .data_type = SignatureDetailsType::data_type, + .accumulation_data_type = SignatureDetailsType::acc_data_type, + .input = {.config = {.layout = SignatureDetailsType::in_layout}}, + .weight = {.config = {.layout = SignatureDetailsType::wei_layout}}, + .output = {.config = {.layout = SignatureDetailsType::out_layout}}}; + + std::vector> conv_args; + std::vector split_ks{"1", "2"}; + + template + void Run() + { + ASSERT_FALSE(conv_args.empty()); + bool pass = true; + for(size_t i = 0; i < conv_args.size(); i++) + { + for(auto& split_k : split_ks) + { + if((args_mask & (1 << i)) == 0) + { + continue; + } + auto& args = conv_args[i]; + + auto inputs = alloc_inputs(args); + auto outputs = alloc_outputs(args); + ckt::init_tensor_buffer_uniform_int( + inputs.get().weight, args.make_weight_descriptor(), -5, 5); + ckt::init_tensor_buffer_uniform_int( + inputs.get().output, args.make_output_descriptor(), -5, 5); + + HIP_CHECK_ERROR( + hipMemset(outputs.get().input, + 0, + args.make_input_descriptor().get_element_space_size_in_bytes())); + + std::cout << args.make_input_descriptor() << std::endl; + std::cout << args.make_weight_descriptor() << std::endl; + std::cout << args.make_output_descriptor() << std::endl; + [[maybe_unused]] auto&& [case_passed, + avg_time, + op_name, + best_split_k, + best_instance] = + + ckp::run_grouped_conv_backward_data_tile_algs( + args, + split_k, + -1, + inputs.get(), + outputs.get(), + ck_tile::stream_config{nullptr, false /*time_kernel*/}); + + pass = pass && case_passed; + } + } + EXPECT_TRUE(pass); + } + + void conv_args_append(std::size_t, + std::size_t G, + std::size_t N, + std::size_t K, + std::size_t C, + const std::vector& filter_spatial_lengths, + const std::vector& input_spatial_lengths, + const std::vector& conv_filter_strides, + const std::vector& conv_filter_dilations, + const std::vector& input_left_pads, + const std::vector& input_right_pads) + { + ckt::Args args = { + .lengths = + { + .batch_size = N, + .groups = G, + .input_channels = C, + .output_channels = K, + .image = ckt::filter_extent_from_vector( + input_spatial_lengths), + .filter = ckt::filter_extent_from_vector( + filter_spatial_lengths), + }, + .filter_strides = ckt::filter_extent_from_vector( + conv_filter_strides), + .filter_dilation = + ckt::filter_extent_from_vector( + conv_filter_dilations), + .input_left_pad = ckt::filter_extent_from_vector( + input_left_pads), + .input_right_pad = + ckt::filter_extent_from_vector( + input_right_pads), + .a_elementwise_op = {}, + .b_elementwise_op = {}, + .cde_elementwise_op = {}, + }; + conv_args.push_back(args); + } +}; + +using KernelTypes2d = ::testing::Types, + SignatureDetails<2, + ckb::DataType::FP16, + ckb::DataType::FP32, + ckb::TensorLayout::NHWGC, + ckb::TensorLayout::GKYXC, + ckb::TensorLayout::NHWGK>, + SignatureDetails<2, + ckb::DataType::BF16, + ckb::DataType::FP32, + ckb::TensorLayout::NHWGC, + ckb::TensorLayout::GKYXC, + ckb::TensorLayout::NHWGK>>; + +using KernelTypes3d = ::testing::Types, + SignatureDetails<3, + ckb::DataType::FP16, + ckb::DataType::FP32, + ckb::TensorLayout::NDHWGC, + ckb::TensorLayout::GKZYXC, + ckb::TensorLayout::NDHWGK>, + SignatureDetails<3, + ckb::DataType::BF16, + ckb::DataType::FP32, + ckb::TensorLayout::NDHWGC, + ckb::TensorLayout::GKZYXC, + ckb::TensorLayout::NDHWGK>>; + +template +class TestGroupedConvndBwdDataTile2d : public TestGroupedConvndBwdDataTile +{ +}; + +template +class TestGroupedConvndBwdDataTile3d : public TestGroupedConvndBwdDataTile +{ +}; + +TYPED_TEST_SUITE(TestGroupedConvndBwdDataTile2d, KernelTypes2d); +TYPED_TEST_SUITE(TestGroupedConvndBwdDataTile3d, KernelTypes3d); + +TYPED_TEST(TestGroupedConvndBwdDataTile2d, Test2D) +{ + this->conv_args.clear(); + + // GroupedGemmGroupsNum = 4, ZTilde * YTilde * XTilde = 4, MaxGroupedGemmGroupsNum = 32 + this->conv_args_append(2, 2, 2, 16, 16, {3, 3}, {28, 28}, {2, 2}, {1, 1}, {1, 1}, {1, 1}); + // GroupedGemmGroupsNum = 9, ZTilde * YTilde * XTilde = 36, MaxGroupedGemmGroupsNum = 32 + this->conv_args_append(2, 2, 2, 16, 16, {3, 3}, {28, 28}, {6, 6}, {1, 1}, {1, 1}, {1, 1}); + // GroupedGemmGroupsNum = 36, ZTilde * YTilde * XTilde = 36, MaxGroupedGemmGroupsNum = 32 + this->conv_args_append(2, 2, 2, 16, 16, {6, 6}, {28, 28}, {6, 6}, {1, 1}, {1, 1}, {1, 1}); + // GroupedGemmGroupsNum = 32, ZTilde * YTilde * XTilde = 32, MaxGroupedGemmGroupsNum = 32 + this->conv_args_append(2, 2, 2, 16, 16, {4, 8}, {28, 28}, {4, 8}, {1, 1}, {1, 1}, {1, 1}); + this->conv_args_append(2, 2, 2, 192, 192, {3, 3}, {28, 28}, {1, 1}, {1, 1}, {1, 1}, {1, 1}); + this->conv_args_append(2, 2, 2, 128, 256, {3, 3}, {14, 14}, {1, 1}, {1, 1}, {1, 1}, {1, 1}); + this->conv_args_append(2, 2, 2, 128, 256, {1, 1}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0}); + this->conv_args_append(2, 2, 2, 128, 256, {1, 1}, {3, 3}, {1, 1}, {1, 1}, {0, 0}, {0, 0}); + this->conv_args_append(2, 2, 2, 32, 32, {2, 2}, {12, 12}, {3, 3}, {1, 1}, {0, 0}, {0, 0}); + this->conv_args_append(2, 2, 2, 32, 32, {2, 2}, {12, 12}, {2, 2}, {2, 2}, {0, 0}, {0, 0}); + this->conv_args_append(2, 1, 6, 448, 896, {1, 1}, {118, 182}, {2, 2}, {1, 1}, {0, 0}, {0, 0}); + this->conv_args_append(2, 1, 1, 1, 32, {8, 8}, {16, 16}, {1, 1}, {1, 1}, {1, 1}, {1, 1}); + this->conv_args_append(2, 1, 1, 64, 3, {8, 8}, {16, 16}, {1, 1}, {1, 1}, {1, 1}, {1, 1}); + this->conv_args_append(2, 1, 1, 1, 1, {8, 8}, {16, 16}, {1, 1}, {1, 1}, {1, 1}, {1, 1}); + this->template Run<2>(); +} + +TYPED_TEST(TestGroupedConvndBwdDataTile3d, Test3D) +{ + this->conv_args.clear(); + this->conv_args_append( + 3, 2, 2, 128, 256, {1, 1, 1}, {7, 7, 7}, {2, 2, 2}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}); + this->conv_args_append( + 3, 2, 2, 128, 256, {3, 3, 3}, {14, 14, 3}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}); + this->conv_args_append( + 3, 2, 2, 128, 256, {1, 1, 1}, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}); + this->conv_args_append( + 3, 2, 2, 32, 32, {1, 2, 2}, {1, 12, 12}, {1, 3, 3}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}); + this->conv_args_append( + 3, 2, 2, 32, 32, {1, 2, 2}, {1, 12, 12}, {1, 2, 2}, {1, 2, 2}, {0, 0, 0}, {0, 0, 0}); + this->conv_args_append( + 3, 1, 1, 1, 32, {3, 3, 3}, {4, 16, 16}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}); + this->conv_args_append( + 3, 1, 1, 64, 3, {3, 3, 3}, {4, 16, 16}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}); + this->conv_args_append( + 3, 1, 1, 1, 1, {3, 3, 3}, {4, 16, 16}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}); + this->template Run<3>(); +} + +int main(int argc, char** argv) +{ + testing::InitGoogleTest(&argc, argv); + if(argc == 1) {} + else if(argc == 3) + { + args_mask = strtol(argv[1], nullptr, 0); + instance_index = atoi(argv[2]); + } + else + { + std::cout << "Usage of " << argv[0] << std::endl; + std::cout << "Arg1,2: args_mask instance_index(-1 means all)" << std::endl; + } + return RUN_ALL_TESTS(); +}