mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
Update include path to break the remod's cyclic dep issue (#2978)
* Update include path to break the cyclic dep issue
* Use ck_tile::permute_vectors_i4x4_b in tile engine
---------
Co-authored-by: Damien Lejeune <damien.lejeune@amd.com>
Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com>
[ROCm/composable_kernel commit: 46c10c316d]
This commit is contained in:
@@ -16,6 +16,7 @@
|
||||
#include "ck_tile/host/host_tensor.hpp"
|
||||
#include "ck_tile/host/joinable_thread.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/host/permute_pk_int4.hpp"
|
||||
#include "ck_tile/host/ranges.hpp"
|
||||
#include "ck_tile/host/reference/reference_batched_dropout.hpp"
|
||||
#include "ck_tile/host/reference/reference_batched_dropout_randval.hpp"
|
||||
|
||||
@@ -9,5 +9,7 @@
|
||||
#include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_problem.hpp"
|
||||
#include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp"
|
||||
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
|
||||
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
|
||||
#include "ck_tile/ops/common/streamk_common.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/common/utils.hpp"
|
||||
|
||||
@@ -12,5 +12,7 @@
|
||||
#include "ck_tile/ops/batched_transpose/pipeline/batched_transpose_policy.hpp"
|
||||
#include "ck_tile/ops/batched_transpose/pipeline/batched_transpose_problem.hpp"
|
||||
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
|
||||
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
|
||||
#include "ck_tile/ops/common/streamk_common.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/common/utils.hpp"
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
|
||||
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
|
||||
#include "ck_tile/ops/common/streamk_common.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/common/utils.hpp"
|
||||
#include "ck_tile/ops/common/streamk_common.hpp"
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/ops/elementwise.hpp"
|
||||
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
|
||||
@@ -10,5 +10,7 @@
|
||||
#include "ck_tile/ops/elementwise/pipeline/elementwise_shape.hpp"
|
||||
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
|
||||
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
|
||||
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
|
||||
#include "ck_tile/ops/common/streamk_common.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/common/utils.hpp"
|
||||
|
||||
@@ -8,5 +8,7 @@
|
||||
#include "ck_tile/ops/epilogue/default_2d_epilogue.hpp"
|
||||
#include "ck_tile/ops/epilogue/dynamic_quant_epilogue.hpp"
|
||||
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
|
||||
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
|
||||
#include "ck_tile/ops/common/streamk_common.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/common/utils.hpp"
|
||||
|
||||
@@ -14,5 +14,7 @@
|
||||
#include "ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp"
|
||||
#include "ck_tile/ops/flatmm/pipeline/tile_flatmm_shape.hpp"
|
||||
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
|
||||
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
|
||||
#include "ck_tile/ops/common/streamk_common.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/common/utils.hpp"
|
||||
|
||||
@@ -60,5 +60,7 @@
|
||||
#include "ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp"
|
||||
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
|
||||
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
|
||||
#include "ck_tile/ops/common/streamk_common.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/common/utils.hpp"
|
||||
|
||||
@@ -16,5 +16,7 @@
|
||||
#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_pipeline.hpp"
|
||||
#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_policy.hpp"
|
||||
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
|
||||
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
|
||||
#include "ck_tile/ops/common/streamk_common.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/common/utils.hpp"
|
||||
|
||||
@@ -30,18 +30,18 @@
|
||||
#include "ck_tile/ops/gemm/block/block_wp_asmem_bsmem_creg_v1_custom_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/gemm_multi_d_kernel.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/gemm_multi_abd_kernel.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/gemm_multi_d_kernel.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/streamk_gemm_kernel.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async_default_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4_default_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async_default_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5_default_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp"
|
||||
@@ -72,5 +72,7 @@
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_smfmac_impl.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_wmma_gemm.hpp"
|
||||
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
|
||||
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
|
||||
#include "ck_tile/ops/common/streamk_common.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/common/utils.hpp"
|
||||
|
||||
@@ -3,8 +3,8 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp"
|
||||
#include "ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_flatbr_bquant_cr.hpp"
|
||||
#include "ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp"
|
||||
#include "ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp"
|
||||
#include "ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp"
|
||||
#include "ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp"
|
||||
@@ -15,11 +15,13 @@
|
||||
#include "ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_base.hpp"
|
||||
#include "ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp"
|
||||
#include "ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp"
|
||||
#include "ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_base_policy.hpp"
|
||||
#include "ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp"
|
||||
#include "ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp"
|
||||
#include "ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp"
|
||||
#include "ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_base_policy.hpp"
|
||||
#include "ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp"
|
||||
#include "ck_tile/ops/gemm_quant/pipeline/tile_gemm_quant_traits.hpp"
|
||||
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
|
||||
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
|
||||
#include "ck_tile/ops/common/streamk_common.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/common/utils.hpp"
|
||||
|
||||
@@ -12,5 +12,7 @@
|
||||
#include "ck_tile/ops/grouped_convolution/utils/transform_conv_bwd_weight_to_gemm.hpp"
|
||||
#include "ck_tile/ops/grouped_convolution/utils/transform_conv_fwd_to_gemm.hpp"
|
||||
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
|
||||
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
|
||||
#include "ck_tile/ops/common/streamk_common.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/common/utils.hpp"
|
||||
|
||||
@@ -7,5 +7,7 @@
|
||||
#include "ck_tile/ops/image_to_column/pipeline/block_image_to_column_problem.hpp"
|
||||
#include "ck_tile/ops/image_to_column/pipeline/tile_image_to_column_shape.hpp"
|
||||
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
|
||||
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
|
||||
#include "ck_tile/ops/common/streamk_common.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/common/utils.hpp"
|
||||
|
||||
@@ -10,5 +10,7 @@
|
||||
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp"
|
||||
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_traits.hpp"
|
||||
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
|
||||
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
|
||||
#include "ck_tile/ops/common/streamk_common.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/common/utils.hpp"
|
||||
|
||||
@@ -7,5 +7,7 @@
|
||||
#include "ck_tile/ops/norm_reduce/block/block_norm_reduce_problem.hpp"
|
||||
#include "ck_tile/ops/norm_reduce/thread/thread_welford.hpp"
|
||||
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
|
||||
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
|
||||
#include "ck_tile/ops/common/streamk_common.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/common/utils.hpp"
|
||||
|
||||
@@ -6,5 +6,7 @@
|
||||
#include "ck_tile/ops/permute/kernel/generic_permute_kernel.hpp"
|
||||
#include "ck_tile/ops/permute/pipeline/generic_petmute_problem.hpp"
|
||||
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
|
||||
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
|
||||
#include "ck_tile/ops/common/streamk_common.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/common/utils.hpp"
|
||||
|
||||
@@ -11,5 +11,7 @@
|
||||
#include "ck_tile/ops/reduce/pipeline/reduce2d_problem.hpp"
|
||||
#include "ck_tile/ops/reduce/pipeline/reduce2d_shape.hpp"
|
||||
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
|
||||
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
|
||||
#include "ck_tile/ops/common/streamk_common.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/common/utils.hpp"
|
||||
|
||||
@@ -11,5 +11,7 @@
|
||||
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_two_pass.hpp"
|
||||
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_traits.hpp"
|
||||
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
|
||||
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
|
||||
#include "ck_tile/ops/common/streamk_common.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/common/utils.hpp"
|
||||
|
||||
@@ -10,5 +10,7 @@
|
||||
#include "ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_problem.hpp"
|
||||
#include "ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_two_pass.hpp"
|
||||
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
|
||||
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
|
||||
#include "ck_tile/ops/common/streamk_common.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/common/utils.hpp"
|
||||
|
||||
@@ -6,5 +6,7 @@
|
||||
#include "ck_tile/ops/softmax/block/block_softmax_2d.hpp"
|
||||
#include "ck_tile/ops/softmax/block/block_softmax_2d_problem.hpp"
|
||||
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
|
||||
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
|
||||
#include "ck_tile/ops/common/streamk_common.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/common/utils.hpp"
|
||||
|
||||
@@ -6,5 +6,7 @@
|
||||
#include "ck_tile/ops/topk/block/block_topk_stream_2d.hpp"
|
||||
#include "ck_tile/ops/topk/block/block_topk_stream_2d_problem.hpp"
|
||||
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
|
||||
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
|
||||
#include "ck_tile/ops/common/streamk_common.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/common/utils.hpp"
|
||||
|
||||
@@ -8,5 +8,7 @@
|
||||
#include "ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_policy.hpp"
|
||||
#include "ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_problem.hpp"
|
||||
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
|
||||
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
|
||||
#include "ck_tile/ops/common/streamk_common.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/common/utils.hpp"
|
||||
|
||||
@@ -74,58 +74,6 @@ constexpr auto is_row_major(Layout)
|
||||
return ck_tile::bool_constant<std::is_same_v<Layout, ck_tile::tensor_layout::gemm::RowMajor>>{};
|
||||
}
|
||||
|
||||
// Permutation function for pk_int4_t
|
||||
template <typename Tensor>
|
||||
void permute_vectors_i4x4_b(Tensor& tensor)
|
||||
{
|
||||
const ck_tile::index_t K = tensor.get_length(0);
|
||||
const ck_tile::index_t N = tensor.get_length(1);
|
||||
// vector pk_i4x4 permute
|
||||
for(int i = 0; i < N; i++)
|
||||
{
|
||||
for(int j = 0; j < K; j += 8)
|
||||
{
|
||||
int8_t input[8];
|
||||
|
||||
for(int k = 0; k < 4; k++)
|
||||
{
|
||||
int8_t i4x2 = tensor(j + k * 2, i).data;
|
||||
input[k * 2 + 0] = (i4x2 >> 4) & 0xf;
|
||||
input[k * 2 + 1] = (i4x2 >> 0) & 0xf;
|
||||
}
|
||||
|
||||
// permute 01234567->20643175
|
||||
{
|
||||
int8_t hi = input[2];
|
||||
int8_t lo = input[0];
|
||||
int8_t i4x2 = (hi << 4) | lo;
|
||||
tensor(j + 0, i) = i4x2;
|
||||
}
|
||||
|
||||
{
|
||||
int8_t hi = input[6];
|
||||
int8_t lo = input[4];
|
||||
int8_t i4x2 = (hi << 4) | lo;
|
||||
tensor(j + 2, i) = i4x2;
|
||||
}
|
||||
|
||||
{
|
||||
int8_t hi = input[3];
|
||||
int8_t lo = input[1];
|
||||
int8_t i4x2 = (hi << 4) | lo;
|
||||
tensor(j + 4, i) = i4x2;
|
||||
}
|
||||
|
||||
{
|
||||
int8_t hi = input[7];
|
||||
int8_t lo = input[5];
|
||||
int8_t i4x2 = (hi << 4) | lo;
|
||||
tensor(j + 6, i) = i4x2;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Structure to hold kernel traits for dispatcher
|
||||
struct KernelTraits
|
||||
{
|
||||
|
||||
@@ -96,7 +96,7 @@ class GemmProfiler
|
||||
// Permute vector pk_i4x4 data for device implementation
|
||||
ck_tile::HostTensor<BDataType> b_k_n_dev = b_k_n;
|
||||
// permute_tensor_b<decltype(b_k_n_dev)>(b_k_n_dev);
|
||||
permute_vectors_i4x4_b(b_k_n_dev);
|
||||
ck_tile::permute_vectors_i4x4_b(b_k_n_dev);
|
||||
b_k_n_dev_buf.ToDevice(b_k_n_dev.data());
|
||||
}
|
||||
else
|
||||
|
||||
Reference in New Issue
Block a user