mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 09:16:52 +00:00
wip fix
This commit is contained in:
@@ -48,7 +48,7 @@
|
||||
#define CK_TILE_STATICALLY_INDEXED_ARRAY_USE_ARRAY 0
|
||||
#define CK_TILE_STATICALLY_INDEXED_ARRAY_USE_TUPLE 1
|
||||
#ifndef CK_TILE_STATICALLY_INDEXED_ARRAY_DEFAULT
|
||||
#define CK_TILE_STATICALLY_INDEXED_ARRAY_DEFAULT CK_TILE_STATICALLY_INDEXED_ARRAY_USE_ARRAY
|
||||
#define CK_TILE_STATICALLY_INDEXED_ARRAY_DEFAULT CK_TILE_STATICALLY_INDEXED_ARRAY_USE_TUPLE
|
||||
#endif
|
||||
|
||||
#ifndef CK_TILE_USE_LAUNCH_BOUNDS
|
||||
|
||||
@@ -87,11 +87,29 @@ CK_TILE_HOST_DEVICE constexpr T max(T x)
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE constexpr T max(T x, T y)
|
||||
CK_TILE_HOST constexpr T max(T x, T y)
|
||||
{
|
||||
return x > y ? x : y;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE constexpr T max(T x, T y)
|
||||
{
|
||||
return x > y ? x : y;
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE constexpr float max(float x, float y)
|
||||
{
|
||||
return __builtin_fmaxf(x, y); // can resultin v_max3_f32
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE constexpr double max(double x, double y)
|
||||
{
|
||||
return __builtin_fmax(x, y); // maybe still v_max3_f32
|
||||
}
|
||||
|
||||
template <index_t X>
|
||||
CK_TILE_HOST_DEVICE constexpr index_t max(number<X>, index_t y)
|
||||
{
|
||||
@@ -118,11 +136,29 @@ CK_TILE_HOST_DEVICE constexpr T min(T x)
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE constexpr T min(T x, T y)
|
||||
CK_TILE_HOST constexpr T min(T x, T y)
|
||||
{
|
||||
return x < y ? x : y;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE constexpr T min(T x, T y)
|
||||
{
|
||||
return x < y ? x : y;
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE constexpr float min(float x, float y)
|
||||
{
|
||||
return __builtin_fminf(x, y);
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE constexpr double min(double x, double y)
|
||||
{
|
||||
return __builtin_fmin(x, y);
|
||||
}
|
||||
|
||||
template <index_t X>
|
||||
CK_TILE_HOST_DEVICE constexpr index_t min(number<X>, index_t y)
|
||||
{
|
||||
|
||||
@@ -60,7 +60,7 @@ CK_TILE_DEVICE void shuffle_tile_impl_in_thread(OutTensor& out_tensor, const InT
|
||||
}();
|
||||
|
||||
//
|
||||
constexpr index_t NDimY = InTensor::get_tile_distribution().GetNumOfDimensionY();
|
||||
constexpr index_t NDimY = InTensor::get_tile_distribution().get_num_of_dimension_y();
|
||||
|
||||
constexpr auto y_lengths = to_sequence(y_in_desc.get_lengths());
|
||||
|
||||
|
||||
@@ -78,9 +78,9 @@ struct tile_distribution
|
||||
Ys2DDescriptor ys_to_d_;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_dimension_x() { return NDimX; }
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetNumOfDimensionY() { return NDimY; }
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetNumOfDimensionP() { return NDimP; }
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetNumOfDimensionR() { return NDimR; }
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_dimension_y() { return NDimY; }
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_dimension_p() { return NDimP; }
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_dimension_r() { return NDimR; }
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_lengths()
|
||||
{
|
||||
|
||||
@@ -36,8 +36,8 @@ struct tile_window_with_static_distribution
|
||||
static constexpr index_t NDimWindowAdaptorTop = WindowAdaptor::get_num_of_top_dimension();
|
||||
static constexpr index_t NDimBottomTensor = BottomTensorDesc::get_num_of_dimension();
|
||||
|
||||
static constexpr index_t NDimP = TileDstr::GetNumOfDimensionP();
|
||||
static constexpr index_t NDimY = TileDstr::GetNumOfDimensionY();
|
||||
static constexpr index_t NDimP = TileDstr::get_num_of_dimension_p();
|
||||
static constexpr index_t NDimY = TileDstr::get_num_of_dimension_y();
|
||||
|
||||
static constexpr auto I0 = number<0>{};
|
||||
static constexpr auto I1 = number<1>{};
|
||||
@@ -265,7 +265,7 @@ struct tile_window_with_static_distribution
|
||||
window_adaptor_vector_lengths, window_adaptor_vector_strides);
|
||||
|
||||
// [y0, y1, ...]
|
||||
constexpr auto y_dims = typename arithmetic_sequence_gen<TileDstr::GetNumOfDimensionP(),
|
||||
constexpr auto y_dims = typename arithmetic_sequence_gen<TileDstr::get_num_of_dimension_p(),
|
||||
NDimWindowAdaptorTop,
|
||||
1>::type{};
|
||||
|
||||
|
||||
@@ -38,7 +38,11 @@ struct Default2DEpilogue
|
||||
// TODO: this is ugly
|
||||
if constexpr(kPadM || kPadN)
|
||||
{
|
||||
store_tile_raw(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
|
||||
// o_dram_window_tmp.foo();
|
||||
// ODataType{}.foo();
|
||||
// o_acc_tile.foo();
|
||||
auto x = cast_tile<ODataType>(o_acc_tile);
|
||||
store_tile_raw(o_dram_window_tmp, x);
|
||||
buffer_store_fence();
|
||||
}
|
||||
else
|
||||
|
||||
@@ -17,8 +17,8 @@ CK_TILE_DEVICE void block_tile_reduce_sync(AccDistributedTensor_& acc_tensor,
|
||||
using DstrEncode = typename Dstr::DstrEncode;
|
||||
using DstrEncodeDetail = typename DstrEncode::detail;
|
||||
|
||||
constexpr index_t NDimP = Dstr::GetNumOfDimensionP();
|
||||
constexpr index_t NDimR = Dstr::GetNumOfDimensionR();
|
||||
constexpr index_t NDimP = Dstr::get_num_of_dimension_p();
|
||||
constexpr index_t NDimR = Dstr::get_num_of_dimension_r();
|
||||
|
||||
constexpr index_t idim_p_lane = NDimP - 1;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user