mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 17:00:18 +00:00
Remove unnecessary changes
This commit is contained in:
@@ -41,10 +41,6 @@ CK_TILE_DEVICE int32x4_t make_wave_buffer_resource(const void* ptr, uint32_t siz
|
||||
{
|
||||
buffer_resource res{ptr, size, CK_TILE_BUFFER_RESOURCE_3RD_DWORD};
|
||||
int32x4_t r = __builtin_bit_cast(int32x4_t, res);
|
||||
// r.x = __builtin_amdgcn_readfirstlane(r.x);
|
||||
// r.y = __builtin_amdgcn_readfirstlane(r.y);
|
||||
// r.z = __builtin_amdgcn_readfirstlane(r.z);
|
||||
// r.w = __builtin_amdgcn_readfirstlane(r.w);
|
||||
return r;
|
||||
}
|
||||
|
||||
|
||||
@@ -32,10 +32,6 @@ CK_TILE_DEVICE int32x4_t make_wave_buffer_resource(const void* ptr, uint32_t siz
|
||||
{
|
||||
buffer_resource res{ptr, size, CK_TILE_BUFFER_RESOURCE_3RD_DWORD};
|
||||
int32x4_t r = __builtin_bit_cast(int32x4_t, res);
|
||||
// r.x = __builtin_amdgcn_readfirstlane(r.x);
|
||||
// r.y = __builtin_amdgcn_readfirstlane(r.y);
|
||||
// r.z = __builtin_amdgcn_readfirstlane(r.z);
|
||||
// r.w = __builtin_amdgcn_readfirstlane(r.w);
|
||||
return r;
|
||||
}
|
||||
|
||||
|
||||
@@ -82,8 +82,7 @@ CK_TILE_DEVICE index_t get_lane_id() { return __lane_id(); }
|
||||
|
||||
CK_TILE_DEVICE index_t get_warp_id()
|
||||
{
|
||||
// return __builtin_amdgcn_readfirstlane(threadIdx.x / get_warp_size());
|
||||
return threadIdx.x / get_warp_size();
|
||||
return __builtin_amdgcn_readfirstlane(threadIdx.x / get_warp_size());
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE index_t get_thread_id() { return threadIdx.x; }
|
||||
|
||||
@@ -191,6 +191,15 @@
|
||||
#endif
|
||||
#endif
|
||||
|
||||
// use llvm builtin bf16 data type after ROCm 6.5
|
||||
#ifndef CK_TILE_USE_LLVM_BUILTIN_BF16
|
||||
#if(HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR == 5 && HIP_VERSION_PATCH >= 50421) || \
|
||||
(HIP_VERSION_MAJOR >= 7)
|
||||
#define CK_TILE_USE_LLVM_BUILTIN_BF16 1
|
||||
#else
|
||||
#define CK_TILE_USE_LLVM_BUILTIN_BF16 0
|
||||
#endif
|
||||
|
||||
#ifndef CK_TILE_DEBUG_LOG
|
||||
#define CK_TILE_DEBUG_LOG 0
|
||||
#endif
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
#include "ck_tile/core/numeric/half.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/numeric/numeric.hpp"
|
||||
#if defined(__gfx950__)
|
||||
#if CK_TILE_USE_LLVM_BUILTIN_BF16
|
||||
#include <hip/hip_bfloat16.h>
|
||||
#endif
|
||||
#include <stdint.h>
|
||||
@@ -105,8 +105,7 @@ struct native_t<bfloat16_t>
|
||||
using bf16_t = bfloat16_t;
|
||||
using bf16_raw_t = typename bf16_t::raw_type;
|
||||
#else
|
||||
#if(HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR == 5 && HIP_VERSION_PATCH >= 50421) || \
|
||||
(HIP_VERSION_MAJOR >= 7)
|
||||
#if CK_TILE_USE_LLVM_BUILTIN_BF16
|
||||
using bfloat16_t = __bf16;
|
||||
#else
|
||||
using bfloat16_t = ushort;
|
||||
|
||||
@@ -14,7 +14,6 @@
|
||||
#include "ck_tile/core/container/statically_indexed_array.hpp"
|
||||
#include "ck_tile/core/numeric/math.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
#include "ck_tile/core/utility/debug.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -131,7 +130,6 @@ struct DefaultTranspose
|
||||
util::is_sequence_suffix_v<decltype(quad_hs[I0]), decltype(input_hs[I0])>;
|
||||
static constexpr bool suffix_valid_dim1 =
|
||||
util::is_sequence_suffix_v<decltype(quad_hs[I1]), decltype(input_hs[I1])>;
|
||||
// using bbb = decltype(CK_PRINT<decltype(quad_hs[I1]), decltype(input_hs[I1])>());
|
||||
|
||||
// 3. PS→RHS mapping constraints
|
||||
static constexpr auto input_ps_major = InDstrEncode::ps_to_rhss_major_;
|
||||
@@ -171,8 +169,6 @@ struct DefaultTranspose
|
||||
static constexpr bool ys_mapping_valid =
|
||||
(input_ys_major.back() == 2) && (input_ys_minor.back() == input_hs[I1].size() - 1);
|
||||
|
||||
// using aaa = decltype(CK_PRINT<dims_valid, suffix_valid_dim0, suffix_valid_dim1,
|
||||
// ps_mapping_valid, ys_mapping_valid>());
|
||||
static constexpr bool value = dims_valid && suffix_valid_dim0 && suffix_valid_dim1 &&
|
||||
ps_mapping_valid && ys_mapping_valid;
|
||||
};
|
||||
|
||||
@@ -288,8 +288,7 @@ struct tile_window_with_static_distribution
|
||||
sizeof(LdsDataType) -
|
||||
size_per_buf;
|
||||
|
||||
const index_t m0_init_value =
|
||||
__builtin_amdgcn_readfirstlane(size_per_buf + size_per_wave * get_warp_id());
|
||||
const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id();
|
||||
m0_set_with_memory(m0_init_value); // This should be wave independent
|
||||
|
||||
using Traits = typename Base::Traits;
|
||||
@@ -434,8 +433,6 @@ struct tile_window_with_static_distribution
|
||||
// data index [y0, y1, ...]
|
||||
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
|
||||
|
||||
// printf("Tid: %03d, tr_load_idx: %d\n",
|
||||
// get_thread_local_1d_id(),bottom_tensor_thread_coord.get_offset());
|
||||
// read from bottom tensor
|
||||
const vector_t vec_value =
|
||||
this->get_bottom_tensor_view()
|
||||
|
||||
@@ -516,8 +516,7 @@ struct tile_window_linear
|
||||
sizeof(LdsDataType) -
|
||||
size_per_buf;
|
||||
|
||||
const index_t m0_init_value =
|
||||
__builtin_amdgcn_readfirstlane(size_per_buf + size_per_wave * get_warp_id());
|
||||
const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id();
|
||||
m0_set_with_memory(m0_init_value); // This should be wave independent
|
||||
|
||||
using vector_t = typename Base::Traits::vector_t;
|
||||
|
||||
@@ -1071,20 +1071,20 @@ struct FmhaFwdSplitKVKernel
|
||||
{
|
||||
return FmhaPipeline{}(q_dram_window,
|
||||
k_dram_window_lengths,
|
||||
// k_page_block_navigator,
|
||||
k_page_block_navigator,
|
||||
v_dram_window_lengths,
|
||||
// v_page_block_navigator,
|
||||
v_page_block_navigator,
|
||||
bias_dram_window,
|
||||
lse_acc_dram_window,
|
||||
// kargs.num_splits,
|
||||
// i_split_,
|
||||
kargs.num_splits,
|
||||
i_split_,
|
||||
mask,
|
||||
position_encoding,
|
||||
kargs.scale_s,
|
||||
variant,
|
||||
variant_params,
|
||||
block_indices,
|
||||
// kv_l2p_offset,
|
||||
kv_l2p_offset,
|
||||
smem_ptr);
|
||||
}
|
||||
}();
|
||||
|
||||
@@ -45,7 +45,7 @@ template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float
|
||||
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 8, false, true> { using Type = WarpGemmMfmaF16F16F32M32N32K8SwizzleA; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, false, true> { using Type = WarpGemmMfmaF16F16F32M32N32K16SwizzleA; };
|
||||
// template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 8, true, true> { using Type = WarpGemmMfmaF16F16F32M32N32K8SwizzleBTransposedCDistribution; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 8, true, true> { using Type = WarpGemmMfmaF16F16F32M32N32K8SwizzleBTransposedCDistribution; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, true, true> { using Type = WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution; };
|
||||
|
||||
// fp16 2:4 structural sparsity
|
||||
@@ -76,7 +76,7 @@ template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float
|
||||
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 8, false, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 16, false, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA; };
|
||||
// template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 8, true, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleBTransposedCDistribution; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 8, true, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleBTransposedCDistribution; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 16, true, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution; };
|
||||
|
||||
// fp8
|
||||
|
||||
Reference in New Issue
Block a user