mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 17:00:18 +00:00
fix build wip
This commit is contained in:
@@ -26,6 +26,8 @@ set(version 1.1.0)
|
||||
project(composable_kernel VERSION ${version} LANGUAGES CXX)
|
||||
include(CTest)
|
||||
|
||||
find_package(Python3 3.7 COMPONENTS Interpreter REQUIRED)
|
||||
|
||||
list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake")
|
||||
|
||||
if (DTYPES)
|
||||
|
||||
@@ -14,7 +14,7 @@ add_custom_command(
|
||||
--output_dir ${CMAKE_CURRENT_BINARY_DIR}
|
||||
)
|
||||
|
||||
set(EXAMPLE_FMHA_FWD "example_fmha_fwd")
|
||||
set(EXAMPLE_FMHA_FWD "ck_tile_example_fmha_fwd")
|
||||
add_example_executable(${EXAMPLE_FMHA_FWD} fmha_fwd.cpp)
|
||||
target_include_directories(${EXAMPLE_FMHA_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
|
||||
target_sources(${EXAMPLE_FMHA_FWD} PRIVATE ${FMHA_FWD_GEN_BLOBS})
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
import argparse
|
||||
import itertools
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, tuple
|
||||
from typing import List, Optional
|
||||
from dataclasses import dataclass
|
||||
import copy
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
#include <string>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/fmha.hpp"
|
||||
#include "ck_tile/ops/fmha.hpp"
|
||||
|
||||
enum class mask_enum
|
||||
{
|
||||
|
||||
5
example/ck_tile/CMakeLists.txt
Normal file
5
example/ck_tile/CMakeLists.txt
Normal file
@@ -0,0 +1,5 @@
|
||||
include_directories(AFTER
|
||||
${CMAKE_CURRENT_LIST_DIR}
|
||||
)
|
||||
|
||||
add_subdirectory(01_fmha)
|
||||
@@ -224,7 +224,7 @@ struct pad : public base_transform<1, 1>
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time()
|
||||
{
|
||||
return ck_tile::ck_tile::is_known_at_compile_time<UpLengths>::value &&
|
||||
return ck_tile::is_known_at_compile_time<UpLengths>::value &&
|
||||
ck_tile::is_known_at_compile_time<LeftPadLength>::value &&
|
||||
ck_tile::is_known_at_compile_time<RightPadLength>::value;
|
||||
}
|
||||
@@ -577,7 +577,7 @@ struct merge_v2_magic_division : public base_transform<LowLengths::size(), 1>
|
||||
using UpperIndex = multi_index<1>;
|
||||
|
||||
using UpLengths =
|
||||
decltype(make_tuple(container_reduce(LowLengths{}, math::multiplies{}, number<1>{})));
|
||||
decltype(make_tuple(container_reduce(LowLengths{}, multiplies{}, number<1>{})));
|
||||
|
||||
using LowLengthsMagicDivisor = decltype(generate_tuple(
|
||||
lambda_merge_generate_MagicDivision_calculate_magic_divisor<LowLengths>{},
|
||||
@@ -597,7 +597,7 @@ struct merge_v2_magic_division : public base_transform<LowLengths::size(), 1>
|
||||
low_lengths_magic_divisor_{generate_tuple(
|
||||
[&](auto i) { return magic_division::calculate_magic_numbers(low_lengths[i]); },
|
||||
number<NDimLow>{})},
|
||||
up_lengths_{make_tuple(container_reduce(low_lengths, math::multiplies{}, I1))}
|
||||
up_lengths_{make_tuple(container_reduce(low_lengths, multiplies{}, I1))}
|
||||
{
|
||||
static_assert(LowerIndex::size() == NDimLow, "wrong!");
|
||||
}
|
||||
@@ -722,10 +722,10 @@ struct merge_v3_division_mod : public base_transform<LowLengths::size(), 1>
|
||||
using UpperIndex = multi_index<1>;
|
||||
|
||||
using LowLengthsScan =
|
||||
decltype(container_reverse_exclusive_scan(LowLengths{}, math::multiplies{}, number<1>{}));
|
||||
decltype(container_reverse_exclusive_scan(LowLengths{}, multiplies{}, number<1>{}));
|
||||
|
||||
using UpLengths =
|
||||
decltype(make_tuple(container_reduce(LowLengths{}, math::multiplies{}, number<1>{})));
|
||||
decltype(make_tuple(container_reduce(LowLengths{}, multiplies{}, number<1>{})));
|
||||
|
||||
LowLengths low_lengths_;
|
||||
LowLengthsScan low_lengths_scan_;
|
||||
@@ -736,8 +736,8 @@ struct merge_v3_division_mod : public base_transform<LowLengths::size(), 1>
|
||||
CK_TILE_HOST_DEVICE constexpr merge_v3_division_mod(const LowLengths& low_lengths)
|
||||
: low_lengths_{low_lengths},
|
||||
low_lengths_scan_{
|
||||
container_reverse_exclusive_scan(low_lengths, math::multiplies{}, number<1>{})},
|
||||
up_lengths_{make_tuple(container_reduce(low_lengths, math::multiplies{}, number<1>{}))}
|
||||
container_reverse_exclusive_scan(low_lengths, multiplies{}, number<1>{})},
|
||||
up_lengths_{make_tuple(container_reduce(low_lengths, multiplies{}, number<1>{}))}
|
||||
{
|
||||
static_assert(LowerIndex::size() == NDimLow, "wrong!");
|
||||
}
|
||||
@@ -855,7 +855,7 @@ struct unmerge : public base_transform<1, UpLengths::size()>
|
||||
using UpperIndex = multi_index<NDimUp>;
|
||||
|
||||
using UpLengthsScan =
|
||||
decltype(container_reverse_exclusive_scan(UpLengths{}, math::multiplies{}, number<1>{}));
|
||||
decltype(container_reverse_exclusive_scan(UpLengths{}, multiplies{}, number<1>{}));
|
||||
|
||||
UpLengths up_lengths_;
|
||||
UpLengthsScan up_lengths_scan_;
|
||||
@@ -864,8 +864,7 @@ struct unmerge : public base_transform<1, UpLengths::size()>
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr unmerge(const UpLengths& up_lengths)
|
||||
: up_lengths_{up_lengths},
|
||||
up_lengths_scan_{
|
||||
container_reverse_exclusive_scan(up_lengths, math::multiplies{}, number<1>{})}
|
||||
up_lengths_scan_{container_reverse_exclusive_scan(up_lengths, multiplies{}, number<1>{})}
|
||||
{
|
||||
}
|
||||
|
||||
@@ -944,7 +943,7 @@ struct unmerge : public base_transform<1, UpLengths::size()>
|
||||
{
|
||||
if(low_vector_lengths[0] != -1)
|
||||
{
|
||||
up_vector_lengths(NDimUp - 1) = math::gcd(low_vector_lengths[0], up_length_last);
|
||||
up_vector_lengths(NDimUp - 1) = gcd(low_vector_lengths[0], up_length_last);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -979,7 +978,7 @@ struct freeze : public base_transform<1, 0>
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr freeze(const LowerIndex& low_idx) : low_idx_{low_idx} {}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_upper_lengths() { return Tuple<>{}; }
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_upper_lengths() { return tuple<>{}; }
|
||||
|
||||
template <typename LowIdx, typename UpIdx>
|
||||
CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx& idx_low,
|
||||
@@ -1428,7 +1427,7 @@ struct xor_t : public base_transform<2, 2>
|
||||
{
|
||||
if(low_vector_lengths[1] != -1)
|
||||
{
|
||||
up_vector_lengths(1) = math::gcd(low_vector_lengths[1], math::abs(right_shift_));
|
||||
up_vector_lengths(1) = gcd(low_vector_lengths[1], abs(right_shift_));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1546,35 +1545,35 @@ struct offset : public base_transform<1, 1>
|
||||
template <typename LowLength>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_pass_through_transform(const LowLength& low_length)
|
||||
{
|
||||
return PassThrough<LowLength>{low_length};
|
||||
return pass_through<LowLength>{low_length};
|
||||
}
|
||||
|
||||
template <typename LowLength, typename left_pad, typename right_pad, bool SkipIsValidCheck = false>
|
||||
template <typename LowLength, typename LeftPad, typename RightPad, bool SkipIsValidCheck = false>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
make_pad_transform(const LowLength& low_length,
|
||||
const left_pad& left_pad,
|
||||
const right_pad& right_pad,
|
||||
const LeftPad& left_pad,
|
||||
const RightPad& right_pad,
|
||||
integral_constant<bool, SkipIsValidCheck> = integral_constant<bool, false>{})
|
||||
{
|
||||
return pad<LowLength, left_pad, right_pad, SkipIsValidCheck>{low_length, left_pad, right_pad};
|
||||
return pad<LowLength, LeftPad, RightPad, SkipIsValidCheck>{low_length, left_pad, right_pad};
|
||||
}
|
||||
|
||||
template <typename LowLength, typename LeftPadLength, bool SkipIsValidCheck = false>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_left_pad_transform(
|
||||
const LowLength& low_length,
|
||||
const LeftPadLength& left_pad,
|
||||
const LeftPadLength& left_pad_,
|
||||
integral_constant<bool, SkipIsValidCheck> = integral_constant<bool, false>{})
|
||||
{
|
||||
return left_pad<LowLength, LeftPadLength, SkipIsValidCheck>{low_length, left_pad};
|
||||
return left_pad<LowLength, LeftPadLength, SkipIsValidCheck>{low_length, left_pad_};
|
||||
}
|
||||
|
||||
template <typename LowLength, typename RightPadLength, bool SkipIsValidCheck = false>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_right_pad_transform(
|
||||
const LowLength& low_length,
|
||||
const RightPadLength& right_pad,
|
||||
const RightPadLength& right_pad_,
|
||||
integral_constant<bool, SkipIsValidCheck> = integral_constant<bool, false>{})
|
||||
{
|
||||
return right_pad<LowLength, RightPadLength, SkipIsValidCheck>{low_length, right_pad};
|
||||
return right_pad<LowLength, RightPadLength, SkipIsValidCheck>{low_length, right_pad_};
|
||||
}
|
||||
|
||||
template <typename UpLengths,
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/container/multi_index.hpp"
|
||||
#include "ck_tile/core/container/container_helper.hpp"
|
||||
#include "ck_tile/core/container/statically_indexed_array.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
@@ -119,7 +120,7 @@ struct space_filling_curve
|
||||
constexpr auto ordered_access_idx = generate_tuple(compute_index, number<nDim>{});
|
||||
#endif
|
||||
constexpr auto forward_sweep = [&]() {
|
||||
StaticallyIndexedArray<bool, nDim> forward_sweep_;
|
||||
statically_indexed_array<bool, nDim> forward_sweep_;
|
||||
|
||||
forward_sweep_(I0) = true;
|
||||
|
||||
|
||||
@@ -1754,7 +1754,7 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave,
|
||||
|
||||
constexpr index_t vector_size = scalar_type<vector_t>::vector_size;
|
||||
|
||||
#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
|
||||
#if CK_TILE_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
|
||||
uint32_t src_addr_shift = [&]() {
|
||||
if constexpr(oob_conditional_check)
|
||||
return src_thread_element_valid ? 0 : 0x80000000;
|
||||
@@ -1876,7 +1876,7 @@ amd_buffer_store(const typename vector_type_maker<T, N>::type::type src_thread_d
|
||||
using scalar_t = typename scalar_type<vector_t>::type;
|
||||
constexpr index_t vector_size = scalar_type<vector_t>::vector_size;
|
||||
|
||||
#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
|
||||
#if CK_TILE_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
|
||||
uint32_t dst_addr_shift = [&]() {
|
||||
if constexpr(oob_conditional_check)
|
||||
return dst_thread_element_valid ? 0 : 0x80000000;
|
||||
@@ -1951,7 +1951,7 @@ amd_buffer_atomic_add(const typename vector_type_maker<T, N>::type::type src_thr
|
||||
using scalar_t = typename scalar_type<vector_t>::type;
|
||||
constexpr index_t vector_size = scalar_type<vector_t>::vector_size;
|
||||
|
||||
#if CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK
|
||||
#if CK_TILE_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK
|
||||
uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000;
|
||||
|
||||
amd_buffer_atomic_add_impl<scalar_t, vector_size>(
|
||||
@@ -1986,7 +1986,7 @@ amd_buffer_atomic_max(const typename vector_type_maker<T, N>::type::type src_thr
|
||||
using scalar_t = typename scalar_type<vector_t>::type;
|
||||
constexpr index_t vector_size = scalar_type<vector_t>::vector_size;
|
||||
|
||||
#if CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_MAX_OOB_CHECK_OFFSET_TRICK
|
||||
#if CK_TILE_EXPERIMENTAL_USE_BUFFER_ATOMIC_MAX_OOB_CHECK_OFFSET_TRICK
|
||||
uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000;
|
||||
|
||||
amd_buffer_atomic_max_impl<scalar_t, vector_size>(
|
||||
@@ -2028,7 +2028,7 @@ CK_TILE_DEVICE void amd_direct_load_global_to_lds(const T* global_base_ptr,
|
||||
const int32x4_t src_resource = make_wave_buffer_resource(global_ptr, src_element_space_size);
|
||||
const index_t global_offset_bytes = is_valid ? global_offset * sizeof(T) : 0x80000000;
|
||||
|
||||
#if CK_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM
|
||||
#if CK_TILE_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM
|
||||
T* lds_ptr = lds_base_ptr + lds_offset;
|
||||
auto const lds_ptr_sgpr =
|
||||
__builtin_amdgcn_readfirstlane((reinterpret_cast<uintptr_t>(lds_ptr)));
|
||||
|
||||
@@ -3,6 +3,11 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#ifndef CK_TILE_DONT_USE_HIP_RUNTIME_HEADERS
|
||||
#include "hip/hip_runtime.h"
|
||||
#include "hip/hip_fp16.h"
|
||||
#endif
|
||||
|
||||
#ifdef __HIPCC__
|
||||
#define CK_TILE_HOST __host__
|
||||
#define CK_TILE_DEVICE __device__
|
||||
@@ -54,3 +59,64 @@
|
||||
|
||||
#define CK_TILE_MAX_THREAD_PER_BLOCK 256
|
||||
#define CK_TILE_MIN_BLOCK_PER_CU 2
|
||||
|
||||
#ifndef CK_TILE_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
|
||||
#define CK_TILE_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 0
|
||||
#endif
|
||||
|
||||
#ifndef CK_TILE_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
|
||||
#define CK_TILE_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK 1
|
||||
#endif
|
||||
|
||||
#ifndef CK_TILE_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK
|
||||
#define CK_TILE_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK 1
|
||||
#endif
|
||||
|
||||
#ifndef CK_TILE_EXPERIMENTAL_USE_BUFFER_ATOMIC_MAX_OOB_CHECK_OFFSET_TRICK
|
||||
#define CK_TILE_EXPERIMENTAL_USE_BUFFER_ATOMIC_MAX_OOB_CHECK_OFFSET_TRICK 1
|
||||
#endif
|
||||
|
||||
#ifndef CK_TILE_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM
|
||||
#define CK_TILE_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM 1
|
||||
#endif
|
||||
|
||||
#ifndef CK_TILE_USE_AMD_BUFFER_LOAD
|
||||
#define CK_TILE_USE_AMD_BUFFER_LOAD 1
|
||||
#endif
|
||||
|
||||
#ifndef CK_TILE_USE_AMD_BUFFER_STORE
|
||||
#define CK_TILE_USE_AMD_BUFFER_STORE 1
|
||||
#endif
|
||||
|
||||
#ifndef CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER
|
||||
#define CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER 1
|
||||
#endif
|
||||
|
||||
// buffer atomic add: floating point
|
||||
#ifndef __HIP_DEVICE_COMPILE__ // for host code
|
||||
#define CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1
|
||||
#elif defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \
|
||||
defined(__gfx942__) // for GPU code
|
||||
#define CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1
|
||||
#else // for GPU code
|
||||
#define CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 0
|
||||
#endif
|
||||
|
||||
#if(defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \
|
||||
defined(__gfx942__)) // for GPU code
|
||||
#define CK_TILE_USE_AMD_BUFFER_ATOMIC_MAX_FLOAT64 1
|
||||
#else
|
||||
#define CK_TILE_USE_AMD_BUFFER_ATOMIC_MAX_FLOAT64 0
|
||||
#endif
|
||||
|
||||
#ifndef CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
|
||||
#define CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS 0
|
||||
#endif
|
||||
|
||||
#ifndef CK_TILE_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE
|
||||
#define CK_TILE_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE 1
|
||||
#endif
|
||||
|
||||
#ifndef CK_TILE_DEBUG_LOG
|
||||
#define CK_TILE_DEBUG_LOG 0
|
||||
#endif
|
||||
|
||||
@@ -6,6 +6,8 @@
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
|
||||
@@ -75,11 +75,11 @@ struct map
|
||||
|
||||
CK_TILE_HOST_DEVICE void clear() { size_ = 0; }
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr index_t find_position(const key& key) const
|
||||
CK_TILE_HOST_DEVICE constexpr index_t find_position(const key& k) const
|
||||
{
|
||||
for(index_t i = 0; i < size(); i++)
|
||||
{
|
||||
if(impl_[i].template at<0>() == key)
|
||||
if(impl_[i].template at<0>() == k)
|
||||
{
|
||||
return i;
|
||||
}
|
||||
@@ -88,39 +88,39 @@ struct map
|
||||
return size_;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr const_iterator find(const key& key) const
|
||||
CK_TILE_HOST_DEVICE constexpr const_iterator find(const key& k) const
|
||||
{
|
||||
return const_iterator{impl_, find_position(key)};
|
||||
return const_iterator{impl_, find_position(k)};
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr iterator find(const key& key)
|
||||
CK_TILE_HOST_DEVICE constexpr iterator find(const key& k)
|
||||
{
|
||||
return iterator{impl_, find_position(key)};
|
||||
return iterator{impl_, find_position(k)};
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr const data& operator[](const key& key) const
|
||||
CK_TILE_HOST_DEVICE constexpr const data& operator[](const key& k) const
|
||||
{
|
||||
const auto it = find(key);
|
||||
const auto it = find(k);
|
||||
|
||||
// FIXME
|
||||
assert(it.pos_ < size());
|
||||
// assert(it.pos_ < size());
|
||||
|
||||
return impl_[it.pos_].template at<1>();
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr data& operator()(const key& key)
|
||||
CK_TILE_HOST_DEVICE constexpr data& operator()(const key& k)
|
||||
{
|
||||
auto it = find(key);
|
||||
auto it = find(k);
|
||||
|
||||
// if entry not found
|
||||
if(it.pos_ == size())
|
||||
{
|
||||
impl_(it.pos_).template at<0>() = key;
|
||||
impl_(it.pos_).template at<0>() = k;
|
||||
size_++;
|
||||
}
|
||||
|
||||
// FIXME
|
||||
assert(size_ <= max_size);
|
||||
// assert(size_ <= max_size);
|
||||
|
||||
return impl_(it.pos_).template at<1>();
|
||||
}
|
||||
@@ -146,12 +146,12 @@ struct map
|
||||
//
|
||||
printf("impl_: [");
|
||||
//
|
||||
for(const auto& [key, data] : *this)
|
||||
for(const auto& [k, d] : *this)
|
||||
{
|
||||
printf("{key: ");
|
||||
print(key);
|
||||
print(k);
|
||||
printf(", data: ");
|
||||
print(data);
|
||||
print(d);
|
||||
printf("}, ");
|
||||
}
|
||||
//
|
||||
|
||||
@@ -4,11 +4,12 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/container/array.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/numeric/math.hpp"
|
||||
#include "ck_tile/core/utility/to_sequence.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -55,14 +56,6 @@ struct sequence
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t size() { return sizeof...(Is); }
|
||||
CK_TILE_HOST_DEVICE static constexpr bool is_static() { return true; };
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t at(index_t I)
|
||||
{
|
||||
// the last dummy element is to prevent compiler complain about empty array, when mSize = 0
|
||||
static_assert(I < size(), "wrong! I too large");
|
||||
const index_t mData[mSize + 1] = {Is..., 0};
|
||||
return mData[I];
|
||||
}
|
||||
|
||||
template <index_t I>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get()
|
||||
{
|
||||
@@ -81,7 +74,7 @@ struct sequence
|
||||
{
|
||||
// the last dummy element is to prevent compiler complain about empty array, when mSize = 0
|
||||
static_assert(I < size(), "wrong! I too large");
|
||||
const index_t mData[mSize + 1] = {Is..., 0};
|
||||
const index_t mData[size() + 1] = {Is..., 0};
|
||||
return mData[I];
|
||||
}
|
||||
|
||||
@@ -298,7 +291,7 @@ struct arithmetic_sequence_gen
|
||||
static constexpr bool kHasContent =
|
||||
(Increment > 0 && IBegin < IEnd) || (Increment < 0 && IBegin > IEnd);
|
||||
|
||||
using type = typename conditional<kHasContent, type0, type1>::type;
|
||||
using type = typename std::conditional<kHasContent, type0, type1>::type;
|
||||
};
|
||||
|
||||
template <index_t IEnd>
|
||||
@@ -403,8 +396,8 @@ struct sequence_reverse<sequence<Ns...>>
|
||||
{
|
||||
};
|
||||
|
||||
template <index_t... Ns>
|
||||
using sequence_reverse_t = typename sequence_reverse<Ns...>::type;
|
||||
// template <index_t... Ns>
|
||||
// using sequence_reverse_t = typename sequence_reverse<Ns...>::type;
|
||||
|
||||
#if 1
|
||||
template <typename Reduce, typename Seq, typename... Seqs>
|
||||
@@ -449,16 +442,15 @@ struct sequence_sort_impl
|
||||
using new_merged_values = decltype(MergedValues::push_back(number<chosen_value>{}));
|
||||
using new_merged_ids = decltype(MergedIds::push_back(number<chosen_id>{}));
|
||||
|
||||
using new_left_values =
|
||||
typename conditional<choose_left, decltype(LeftValues::pop_front()), LeftValues>::type;
|
||||
using new_left_values = typename std::
|
||||
conditional<choose_left, decltype(LeftValues::pop_front()), LeftValues>::type;
|
||||
using new_left_ids =
|
||||
typename conditional<choose_left, decltype(LeftIds::pop_front()), LeftIds>::type;
|
||||
typename std::conditional<choose_left, decltype(LeftIds::pop_front()), LeftIds>::type;
|
||||
|
||||
using new_right_values = typename conditional<choose_left,
|
||||
RightValues,
|
||||
decltype(RightValues::pop_front())>::type;
|
||||
using new_right_values = typename std::
|
||||
conditional<choose_left, RightValues, decltype(RightValues::pop_front())>::type;
|
||||
using new_right_ids =
|
||||
typename conditional<choose_left, RightIds, decltype(RightIds::pop_front())>::type;
|
||||
typename std::conditional<choose_left, RightIds, decltype(RightIds::pop_front())>::type;
|
||||
|
||||
using merge = sorted_sequence_merge_impl<new_left_values,
|
||||
new_left_ids,
|
||||
@@ -557,9 +549,10 @@ struct sequence_sort_impl<sequence<ValueX, ValueY>, sequence<IdX, IdY>, Compare>
|
||||
{
|
||||
static constexpr bool choose_x = Compare{}(ValueX, ValueY);
|
||||
|
||||
using sorted_values =
|
||||
typename conditional<choose_x, sequence<ValueX, ValueY>, sequence<ValueY, ValueX>>::type;
|
||||
using sorted_ids = typename conditional<choose_x, sequence<IdX, IdY>, sequence<IdY, IdX>>::type;
|
||||
using sorted_values = typename std::
|
||||
conditional<choose_x, sequence<ValueX, ValueY>, sequence<ValueY, ValueX>>::type;
|
||||
using sorted_ids =
|
||||
typename std::conditional<choose_x, sequence<IdX, IdY>, sequence<IdY, IdX>>::type;
|
||||
};
|
||||
|
||||
template <index_t Value, index_t Id, typename Compare>
|
||||
@@ -606,14 +599,15 @@ struct sequence_unique_sort
|
||||
using new_remain_ids = decltype(RemainIds::pop_front());
|
||||
|
||||
using new_uniquified_values =
|
||||
typename conditional<is_unique_value,
|
||||
decltype(UniquifiedValues::push_back(number<current_value>{})),
|
||||
UniquifiedValues>::type;
|
||||
typename std::conditional<is_unique_value,
|
||||
decltype(UniquifiedValues::push_back(
|
||||
number<current_value>{})),
|
||||
UniquifiedValues>::type;
|
||||
|
||||
using new_uniquified_ids =
|
||||
typename conditional<is_unique_value,
|
||||
decltype(UniquifiedIds::push_back(number<current_id>{})),
|
||||
UniquifiedIds>::type;
|
||||
typename std::conditional<is_unique_value,
|
||||
decltype(UniquifiedIds::push_back(number<current_id>{})),
|
||||
UniquifiedIds>::type;
|
||||
|
||||
using uniquify = sorted_sequence_uniquify_impl<new_remain_values,
|
||||
new_remain_ids,
|
||||
@@ -662,8 +656,9 @@ struct sequence_unique_sort
|
||||
};
|
||||
|
||||
template <typename SeqMap>
|
||||
struct is_valid_sequence_map : is_same<typename arithmetic_sequence_gen<0, SeqMap::size(), 1>::type,
|
||||
typename sequence_sort<SeqMap, math::less<index_t>>::type>
|
||||
struct is_valid_sequence_map
|
||||
: std::is_same<typename arithmetic_sequence_gen<0, SeqMap::size(), 1>::type,
|
||||
typename sequence_sort<SeqMap, less<index_t>>::type>
|
||||
{
|
||||
};
|
||||
|
||||
@@ -906,7 +901,7 @@ constexpr auto prefix_sum_sequence(Seq)
|
||||
{
|
||||
return typename sequence_exclusive_scan<sequence<0>,
|
||||
typename sequence_merge<Seq, sequence<0>>::type,
|
||||
math::plus<index_t>>::type{};
|
||||
plus<index_t>>::type{};
|
||||
}
|
||||
|
||||
template <typename Seq, index_t... Is>
|
||||
@@ -920,9 +915,9 @@ namespace detail {
|
||||
template <typename WorkSeq, typename RemainSeq, typename RemainMask>
|
||||
struct pick_sequence_elements_by_mask_impl
|
||||
{
|
||||
using new_work_seq = typename conditional<RemainMask::front(),
|
||||
decltype(WorkSeq::push_back(RemainSeq::front())),
|
||||
WorkSeq>::type;
|
||||
using new_work_seq = typename std::conditional<RemainMask::front(),
|
||||
decltype(WorkSeq::push_back(RemainSeq::front())),
|
||||
WorkSeq>::type;
|
||||
|
||||
using type =
|
||||
typename pick_sequence_elements_by_mask_impl<new_work_seq,
|
||||
@@ -1088,9 +1083,12 @@ struct sorted_sequence_histogram<h_idx, sequence<x>, sequence<r, rs...>>
|
||||
};
|
||||
} // namespace detail
|
||||
|
||||
template <typename, index_t>
|
||||
struct array; // declare for later use (array->seq utility)
|
||||
|
||||
// SeqSortedSamples: <0, 2, 3, 5, 7>, SeqRange: <0, 3, 6, 9> -> SeqHistogram : <2, 2, 1>
|
||||
template <typename SeqSortedSamples, index_t r, index_t... rs>
|
||||
constexpr auto histogram_sorted_sequence(SeqSortedSamples, sequence<r, rs...>)
|
||||
CK_TILE_HOST_DEVICE constexpr auto histogram_sorted_sequence(SeqSortedSamples, sequence<r, rs...>)
|
||||
{
|
||||
constexpr auto bins = sizeof...(rs); // or categories
|
||||
constexpr auto histogram = [&]() {
|
||||
|
||||
@@ -100,8 +100,8 @@ struct tuple : impl::tuple_base<make_index_sequence<sizeof...(T)>, T...>
|
||||
{
|
||||
bool flag = true;
|
||||
|
||||
static_for<0, sizeof...(Xs), 1>{}([&flag](auto i) {
|
||||
flag &= is_static_v<remove_cvref_t<__type_pack_element<i.value, Xs...>>>;
|
||||
static_for<0, sizeof...(T), 1>{}([&flag](auto i) {
|
||||
flag &= is_static_v<remove_cvref_t<__type_pack_element<i.value, T...>>>;
|
||||
});
|
||||
|
||||
return flag;
|
||||
@@ -262,11 +262,11 @@ CK_TILE_HOST_DEVICE constexpr auto unroll_nested_tuple(const T& element)
|
||||
}
|
||||
|
||||
template <index_t Depth = 0, index_t MaxDepth = -1, typename... Ts>
|
||||
CK_TILE_HOST_DEVICE constexpr auto unroll_nested_tuple(const tuple<Ts...>& tuple)
|
||||
CK_TILE_HOST_DEVICE constexpr auto unroll_nested_tuple(const tuple<Ts...>& t)
|
||||
{
|
||||
if constexpr(Depth == MaxDepth)
|
||||
{
|
||||
return tuple;
|
||||
return t;
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -274,33 +274,33 @@ CK_TILE_HOST_DEVICE constexpr auto unroll_nested_tuple(const tuple<Ts...>& tuple
|
||||
[&](auto&&... ts) {
|
||||
return concat_tuple(unroll_nested_tuple<Depth + 1, MaxDepth>(ts)...);
|
||||
},
|
||||
tuple);
|
||||
t);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename... Ts>
|
||||
CK_TILE_HOST_DEVICE constexpr auto tuple_reverse(const tuple<Ts...>& tuple)
|
||||
CK_TILE_HOST_DEVICE constexpr auto tuple_reverse(const tuple<Ts...>& t)
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
using Idx = number<tuple<Ts...>::size()() - i - 1>;
|
||||
return tuple.at(Idx{});
|
||||
using Idx = number<tuple<Ts...>::size() - i - 1>;
|
||||
return t.at(Idx{});
|
||||
},
|
||||
number<tuple<Ts...>::size()()>{});
|
||||
}
|
||||
|
||||
// Reduce tuple values in specific range using Function
|
||||
template <index_t Idx, index_t End, typename F, typename... Ts>
|
||||
CK_TILE_HOST_DEVICE constexpr auto tuple_reduce(F&& f, const tuple<Ts...>& tuple)
|
||||
CK_TILE_HOST_DEVICE constexpr auto tuple_reduce(F&& f, const tuple<Ts...>& t)
|
||||
{
|
||||
static_assert(Idx < End, "Wrong parameters for tuple_reduce");
|
||||
if constexpr(Idx + 1 == End)
|
||||
{
|
||||
return tuple.at(number<Idx>{});
|
||||
return t.at(number<Idx>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return f(tuple.at(number<Idx>{}), tuple_reduce<Idx + 1, End>(f, tuple));
|
||||
return f(t.at(number<Idx>{}), tuple_reduce<Idx + 1, End>(f, t));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -322,7 +322,7 @@ CK_TILE_HOST_DEVICE constexpr auto tuple_depth(const T&)
|
||||
template <index_t depth = 0, typename... Ts>
|
||||
CK_TILE_HOST_DEVICE constexpr auto tuple_depth(const tuple<Ts...>&)
|
||||
{
|
||||
return math::max(tuple_depth<depth + 1>(Ts{})...);
|
||||
return max(tuple_depth<depth + 1>(Ts{})...);
|
||||
}
|
||||
|
||||
template <typename... Seqs>
|
||||
@@ -456,6 +456,7 @@ CK_TILE_HOST_DEVICE constexpr auto operator/(const tuple<Xs...>& x, const tuple<
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
#include <tuple>
|
||||
// WARNING: needed by compiler for C++ structured binding support only, don't use this
|
||||
namespace std {
|
||||
|
||||
@@ -465,7 +466,7 @@ struct tuple_size<ck_tile::tuple<Ts...>> : std::integral_constant<std::size_t, s
|
||||
};
|
||||
|
||||
template <std::size_t I, typename... Ts>
|
||||
struct tuple_element<I, ck_tile::tuple<Ts...>> : ck_tile::tuple_element<I, ck_tile::tuple<Ts...>>
|
||||
struct tuple_element<I, ck_tile::tuple<Ts...>> : std::tuple_element<I, std::tuple<Ts...>>
|
||||
{
|
||||
};
|
||||
|
||||
@@ -476,7 +477,7 @@ struct tuple_size<const ck_tile::tuple<Ts...>> : std::integral_constant<std::siz
|
||||
|
||||
template <std::size_t I, typename... Ts>
|
||||
struct tuple_element<I, const ck_tile::tuple<Ts...>>
|
||||
: ck_tile::tuple_element<I, const ck_tile::tuple<Ts...>>
|
||||
: std::tuple_element<I, const std::tuple<Ts...>>
|
||||
{
|
||||
};
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
#include "ck_tile/core/numeric/arithmetic.hpp"
|
||||
#include "ck_tile/core/numeric/half.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/utility/limits.hpp"
|
||||
#include <stdint.h>
|
||||
|
||||
#pragma once
|
||||
@@ -32,7 +33,7 @@ struct alignas(2) bfloat16_t
|
||||
raw_type data;
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
static bfloat16_t bit_cast(raw_type x)
|
||||
static constexpr bfloat16_t bit_cast(raw_type x)
|
||||
{
|
||||
bfloat16_t y;
|
||||
y.data = x;
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
#include "ck_tile/core/numeric/arithmetic.hpp"
|
||||
#include "ck_tile/core/numeric/half.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/utility/limits.hpp"
|
||||
#include <stdint.h>
|
||||
#include <type_traits>
|
||||
|
||||
@@ -62,7 +63,7 @@ struct alignas(1) float8_e4m3_t
|
||||
raw_type data;
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
static float8_e4m3_t bit_cast(raw_type x)
|
||||
static constexpr float8_e4m3_t bit_cast(raw_type x)
|
||||
{
|
||||
float8_e4m3_t y;
|
||||
y.data = x;
|
||||
@@ -70,37 +71,40 @@ struct alignas(1) float8_e4m3_t
|
||||
}
|
||||
|
||||
// constructor
|
||||
float8_e4m3_t() = default;
|
||||
constexpr float8_e4m3_t() : data() {}
|
||||
|
||||
// construct from float
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit float8_e4m3_t(const float& x) { data = float_to_fp8_raw(x); }
|
||||
explicit constexpr float8_e4m3_t(const float& x) { data = float_to_fp8_raw(x); }
|
||||
|
||||
// construct from int
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit float8_e4m3_t(const int& x) { data = float_to_fp8_raw(static_cast<float>(x)); }
|
||||
explicit constexpr float8_e4m3_t(const int& x)
|
||||
{
|
||||
data = float_to_fp8_raw(static_cast<float>(x));
|
||||
}
|
||||
|
||||
// construct from unsigned int
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit float8_e4m3_t(const unsigned int& x)
|
||||
explicit constexpr float8_e4m3_t(const unsigned int& x)
|
||||
{
|
||||
data = float_to_fp8_raw(static_cast<float>(x));
|
||||
}
|
||||
|
||||
// cast to float
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit operator float() const { return fp8_to_float_raw(data); }
|
||||
explicit constexpr operator float() const { return fp8_to_float_raw(data); }
|
||||
|
||||
// cast to int
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit operator int() const { return static_cast<int>(fp8_to_float_raw(data)); }
|
||||
explicit constexpr operator int() const { return static_cast<int>(fp8_to_float_raw(data)); }
|
||||
|
||||
// internal access
|
||||
CK_TILE_HOST_DEVICE
|
||||
raw_type& get() { return data; }
|
||||
constexpr raw_type& get() { return data; }
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
raw_type get() const { return data; }
|
||||
constexpr raw_type get() const { return data; }
|
||||
};
|
||||
|
||||
struct alignas(1) float8_e5m2_t
|
||||
@@ -116,7 +120,7 @@ struct alignas(1) float8_e5m2_t
|
||||
raw_type data;
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
static float8_e5m2_t bit_cast(raw_type x)
|
||||
static constexpr float8_e5m2_t bit_cast(raw_type x)
|
||||
{
|
||||
float8_e5m2_t y;
|
||||
y.data = x;
|
||||
@@ -124,37 +128,40 @@ struct alignas(1) float8_e5m2_t
|
||||
}
|
||||
|
||||
// constructor
|
||||
float8_e5m2_t() = default;
|
||||
constexpr float8_e5m2_t() : data() {}
|
||||
|
||||
// construct from float
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit float8_e5m2_t(const float& x) { data = float_to_bf8_raw(x); }
|
||||
explicit constexpr float8_e5m2_t(const float& x) { data = float_to_bf8_raw(x); }
|
||||
|
||||
// construct from int
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit float8_e5m2_t(const int& x) { data = float_to_bf8_raw(static_cast<float>(x)); }
|
||||
explicit constexpr float8_e5m2_t(const int& x)
|
||||
{
|
||||
data = float_to_bf8_raw(static_cast<float>(x));
|
||||
}
|
||||
|
||||
// construct from unsigned int
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit float8_e5m2_t(const unsigned int& x)
|
||||
explicit constexpr float8_e5m2_t(const unsigned int& x)
|
||||
{
|
||||
data = float_to_bf8_raw(static_cast<float>(x));
|
||||
}
|
||||
|
||||
// cast to float
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit operator float() const { return bf8_to_float_raw(data); }
|
||||
explicit constexpr constexpr operator float() const { return bf8_to_float_raw(data); }
|
||||
|
||||
// cast to int
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit operator int() const { return static_cast<int>(bf8_to_float_raw(data)); }
|
||||
explicit constexpr operator int() const { return static_cast<int>(bf8_to_float_raw(data)); }
|
||||
|
||||
// internal access
|
||||
CK_TILE_HOST_DEVICE
|
||||
raw_type& get() { return data; }
|
||||
constexpr raw_type& get() { return data; }
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
raw_type get() const { return data; }
|
||||
constexpr raw_type get() const { return data; }
|
||||
};
|
||||
|
||||
// below is sw fp8 conversion, not utilizing hw instruction
|
||||
|
||||
@@ -3,26 +3,29 @@
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/utility/bit_cast.hpp"
|
||||
#include "ck_tile/core/utility/limits.hpp"
|
||||
#include <hip/hip_fp16.h>
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
float fp16_to_float_hip(const _Float16& x);
|
||||
using fp16_hip_t = __half; // most of hip internal function use this type
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
_Float16 float_to_fp16_hip(const float& x);
|
||||
float fp16_to_float_hip(const fp16_hip_t& x);
|
||||
|
||||
// HIP use _Float16 as interchangable data type for float16
|
||||
CK_TILE_HOST_DEVICE
|
||||
fp16_hip_t float_to_fp16_hip(const float& x);
|
||||
|
||||
// HIP use fp16_hip_t as interchangable data type for float16
|
||||
struct alignas(2) half_t
|
||||
{
|
||||
using raw_type = uint16_t;
|
||||
raw_type data;
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
static half_t bit_cast(raw_type x)
|
||||
static constexpr half_t bit_cast(raw_type x)
|
||||
{
|
||||
half_t y;
|
||||
y.data = x;
|
||||
@@ -30,56 +33,62 @@ struct alignas(2) half_t
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
_Float16 to_fp16() const { return reinterpret_cast<const raw_type&>(data); }
|
||||
constexpr fp16_hip_t to_fp16() const { return ck_tile::bit_cast<fp16_hip_t>(data); }
|
||||
|
||||
// constructor
|
||||
half_t() = default;
|
||||
constexpr half_t() : data() {}
|
||||
|
||||
// construct from HIP half
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit half_t(const _Float16& x) : data(reinterpret_cast<const raw_type&>(x)) {}
|
||||
explicit constexpr half_t(const fp16_hip_t& x) : data(ck_tile::bit_cast<raw_type>(x)) {}
|
||||
|
||||
// construct from float
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit half_t(const float& x) : half_t(float_to_fp16_hip(x)) {}
|
||||
explicit constexpr half_t(const float& x) : half_t(float_to_fp16_hip(x)) {}
|
||||
|
||||
// construct from int
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit half_t(const int& x) : half_t(__int2half_rn(x)) {}
|
||||
explicit constexpr half_t(const int& x) : half_t(static_cast<fp16_hip_t>(__int2half_rn(x))) {}
|
||||
|
||||
// construct from unsigned int
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit half_t(const unsigned int& x) : half_t(__uint2half_rn(x)) {}
|
||||
explicit constexpr half_t(const unsigned int& x)
|
||||
: half_t(static_cast<fp16_hip_t>(__uint2half_rn(x)))
|
||||
{
|
||||
}
|
||||
|
||||
// cast to float
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit operator float() const { return fp16_to_float_hip(to_fp16()); }
|
||||
explicit constexpr operator float() const { return fp16_to_float_hip(to_fp16()); }
|
||||
|
||||
// cast to int
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit operator int() const { return static_cast<int>(fp16_to_float_hip(to_fp16())); }
|
||||
explicit constexpr operator int() const
|
||||
{
|
||||
return static_cast<int>(fp16_to_float_hip(to_fp16()));
|
||||
}
|
||||
|
||||
// internal access
|
||||
CK_TILE_HOST_DEVICE
|
||||
raw_type& get() { return data; }
|
||||
constexpr raw_type& get() { return data; }
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
raw_type get() const { return data; }
|
||||
constexpr raw_type get() const { return data; }
|
||||
};
|
||||
|
||||
// conversions
|
||||
CK_TILE_HOST_DEVICE
|
||||
float fp16_to_float_hip(const _Float16& x)
|
||||
float fp16_to_float_hip(const fp16_hip_t& x)
|
||||
{
|
||||
// return __half2float(x);
|
||||
return static_cast<float>(x);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
_Float16 float_to_fp16_hip(const float& x)
|
||||
fp16_hip_t float_to_fp16_hip(const float& x)
|
||||
{
|
||||
// return __float2half(x);
|
||||
return static_cast<_Float16>(x);
|
||||
return static_cast<fp16_hip_t>(x);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/utility/bit_cast.hpp"
|
||||
#include <type_traits>
|
||||
#include <stdint.h>
|
||||
|
||||
@@ -290,7 +291,7 @@ float abs(const float& x)
|
||||
CK_TILE_HOST_DEVICE
|
||||
bool isnan(const float& x)
|
||||
{
|
||||
uint32_t xx = reinterpret_cast<const uint32_t&>(x);
|
||||
uint32_t xx = bit_cast<uint32_t>(x);
|
||||
return (xx & 0x7fffffff) > 0x7F800000;
|
||||
}
|
||||
|
||||
|
||||
@@ -82,10 +82,10 @@ struct buffer_view<address_space_enum::generic,
|
||||
|
||||
// i is offset of T, not X. i should be aligned to X
|
||||
template <typename X,
|
||||
bool oob_conditional_check = true,
|
||||
typename enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
|
||||
typename scalar_type<remove_cvref_t<T>>::type>::value,
|
||||
bool>::type = false>
|
||||
bool oob_conditional_check = true,
|
||||
typename std::enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
|
||||
typename scalar_type<remove_cvref_t<T>>::type>::value,
|
||||
bool>::type = false>
|
||||
CK_TILE_DEVICE constexpr auto
|
||||
get(index_t i, bool is_valid_element, bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
@@ -99,7 +99,7 @@ struct buffer_view<address_space_enum::generic,
|
||||
|
||||
if(is_valid_element)
|
||||
{
|
||||
#if CK_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
|
||||
#if CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
|
||||
X tmp;
|
||||
|
||||
__builtin_memcpy(&tmp, &(p_data_[i]), sizeof(X));
|
||||
@@ -125,9 +125,9 @@ struct buffer_view<address_space_enum::generic,
|
||||
// i is offset of T, not X. i should be aligned to X
|
||||
template <InMemoryDataOperationEnum Op,
|
||||
typename X,
|
||||
typename enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
|
||||
typename scalar_type<remove_cvref_t<T>>::type>::value,
|
||||
bool>::type = false>
|
||||
typename std::enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
|
||||
typename scalar_type<remove_cvref_t<T>>::type>::value,
|
||||
bool>::type = false>
|
||||
CK_TILE_DEVICE void update(index_t i, bool is_valid_element, const X& x)
|
||||
{
|
||||
if constexpr(Op == InMemoryDataOperationEnum::set)
|
||||
@@ -144,9 +144,9 @@ struct buffer_view<address_space_enum::generic,
|
||||
|
||||
// i is offset of T, not X. i should be aligned to X
|
||||
template <typename X,
|
||||
typename enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
|
||||
typename scalar_type<remove_cvref_t<T>>::type>::value,
|
||||
bool>::type = false>
|
||||
typename std::enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
|
||||
typename scalar_type<remove_cvref_t<T>>::type>::value,
|
||||
bool>::type = false>
|
||||
CK_TILE_DEVICE void set(index_t i, bool is_valid_element, const X& x)
|
||||
{
|
||||
// X contains multiple T
|
||||
@@ -159,7 +159,7 @@ struct buffer_view<address_space_enum::generic,
|
||||
|
||||
if(is_valid_element)
|
||||
{
|
||||
#if CK_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
|
||||
#if CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
|
||||
X tmp = x;
|
||||
|
||||
__builtin_memcpy(&(p_data_[i]), &tmp, sizeof(X));
|
||||
@@ -253,10 +253,10 @@ struct buffer_view<address_space_enum::global,
|
||||
|
||||
// i is offset of T, not X. i should be aligned to X
|
||||
template <typename X,
|
||||
bool oob_conditional_check = true,
|
||||
typename enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
|
||||
typename scalar_type<remove_cvref_t<T>>::type>::value,
|
||||
bool>::type = false>
|
||||
bool oob_conditional_check = true,
|
||||
typename std::enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
|
||||
typename scalar_type<remove_cvref_t<T>>::type>::value,
|
||||
bool>::type = false>
|
||||
CK_TILE_DEVICE constexpr auto
|
||||
get(index_t i, bool is_valid_element, bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
@@ -268,7 +268,7 @@ struct buffer_view<address_space_enum::global,
|
||||
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
|
||||
"wrong! X should contain multiple T");
|
||||
|
||||
#if CK_USE_AMD_BUFFER_LOAD
|
||||
#if CK_TILE_USE_AMD_BUFFER_LOAD
|
||||
bool constexpr use_amd_buffer_addressing = true;
|
||||
#else
|
||||
bool constexpr use_amd_buffer_addressing = false;
|
||||
@@ -300,7 +300,7 @@ struct buffer_view<address_space_enum::global,
|
||||
{
|
||||
if(is_valid_element)
|
||||
{
|
||||
#if CK_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
|
||||
#if CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
|
||||
X tmp;
|
||||
|
||||
__builtin_memcpy(&tmp, &(p_data_[i]), sizeof(X));
|
||||
@@ -326,10 +326,10 @@ struct buffer_view<address_space_enum::global,
|
||||
|
||||
// i is offset of T, not X. i should be aligned to X
|
||||
template <typename X,
|
||||
bool oob_conditional_check = true,
|
||||
typename enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
|
||||
typename scalar_type<remove_cvref_t<T>>::type>::value,
|
||||
bool>::type = false>
|
||||
bool oob_conditional_check = true,
|
||||
typename std::enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
|
||||
typename scalar_type<remove_cvref_t<T>>::type>::value,
|
||||
bool>::type = false>
|
||||
CK_TILE_DEVICE constexpr auto
|
||||
get_raw(remove_cvref_t<X>& dst, index_t i, bool is_valid_element) const
|
||||
{
|
||||
@@ -348,9 +348,9 @@ struct buffer_view<address_space_enum::global,
|
||||
|
||||
// i is offset of T, not X. i should be aligned to X
|
||||
template <typename X,
|
||||
typename enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
|
||||
typename scalar_type<remove_cvref_t<T>>::type>::value,
|
||||
bool>::type = false>
|
||||
typename std::enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
|
||||
typename scalar_type<remove_cvref_t<T>>::type>::value,
|
||||
bool>::type = false>
|
||||
CK_TILE_DEVICE constexpr auto
|
||||
async_get(remove_cvref_t<T>* smem, index_t i, bool /*is_valid_element*/) const
|
||||
{
|
||||
@@ -370,9 +370,9 @@ struct buffer_view<address_space_enum::global,
|
||||
// i is offset of T, not X. i should be aligned to X
|
||||
template <InMemoryDataOperationEnum Op,
|
||||
typename X,
|
||||
typename enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
|
||||
typename scalar_type<remove_cvref_t<T>>::type>::value,
|
||||
bool>::type = false>
|
||||
typename std::enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
|
||||
typename scalar_type<remove_cvref_t<T>>::type>::value,
|
||||
bool>::type = false>
|
||||
CK_TILE_DEVICE void update(index_t i, bool is_valid_element, const X& x)
|
||||
{
|
||||
if constexpr(Op == InMemoryDataOperationEnum::set)
|
||||
@@ -399,10 +399,10 @@ struct buffer_view<address_space_enum::global,
|
||||
|
||||
// i is offset of T, not X. i should be aligned to X
|
||||
template <typename X,
|
||||
bool oob_conditional_check = true,
|
||||
typename enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
|
||||
typename scalar_type<remove_cvref_t<T>>::type>::value,
|
||||
bool>::type = false>
|
||||
bool oob_conditional_check = true,
|
||||
typename std::enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
|
||||
typename scalar_type<remove_cvref_t<T>>::type>::value,
|
||||
bool>::type = false>
|
||||
CK_TILE_DEVICE void set(index_t i, bool is_valid_element, const X& x)
|
||||
{
|
||||
// X contains multiple T
|
||||
@@ -413,7 +413,7 @@ struct buffer_view<address_space_enum::global,
|
||||
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
|
||||
"wrong! X should contain multiple T");
|
||||
|
||||
#if CK_USE_AMD_BUFFER_STORE
|
||||
#if CK_TILE_USE_AMD_BUFFER_STORE
|
||||
bool constexpr use_amd_buffer_addressing = true;
|
||||
#else
|
||||
bool constexpr use_amd_buffer_addressing = false;
|
||||
@@ -430,7 +430,7 @@ struct buffer_view<address_space_enum::global,
|
||||
{
|
||||
if(is_valid_element)
|
||||
{
|
||||
#if CK_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
|
||||
#if CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
|
||||
X tmp = x;
|
||||
|
||||
__builtin_memcpy(&(p_data_[i]), &tmp, sizeof(X));
|
||||
@@ -443,10 +443,10 @@ struct buffer_view<address_space_enum::global,
|
||||
|
||||
// i is offset of T, not X. i should be aligned to X
|
||||
template <typename X,
|
||||
bool oob_conditional_check = true,
|
||||
typename enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
|
||||
typename scalar_type<remove_cvref_t<T>>::type>::value,
|
||||
bool>::type = false>
|
||||
bool oob_conditional_check = true,
|
||||
typename std::enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
|
||||
typename scalar_type<remove_cvref_t<T>>::type>::value,
|
||||
bool>::type = false>
|
||||
CK_TILE_DEVICE void set_raw(index_t i, bool is_valid_element, const X& x)
|
||||
{
|
||||
// X contains multiple T
|
||||
@@ -463,9 +463,9 @@ struct buffer_view<address_space_enum::global,
|
||||
}
|
||||
|
||||
template <typename X,
|
||||
typename enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
|
||||
typename scalar_type<remove_cvref_t<T>>::type>::value,
|
||||
bool>::type = false>
|
||||
typename std::enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
|
||||
typename scalar_type<remove_cvref_t<T>>::type>::value,
|
||||
bool>::type = false>
|
||||
CK_TILE_DEVICE void atomic_add(index_t i, bool is_valid_element, const X& x)
|
||||
{
|
||||
using scalar_t = typename scalar_type<remove_cvref_t<T>>::type;
|
||||
@@ -480,14 +480,14 @@ struct buffer_view<address_space_enum::global,
|
||||
|
||||
static_assert(get_address_space() == address_space_enum::global, "only support global mem");
|
||||
|
||||
#if CK_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER && CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT
|
||||
#if CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER && CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT
|
||||
bool constexpr use_amd_buffer_addressing =
|
||||
is_same_v<remove_cvref_t<scalar_t>, int32_t> ||
|
||||
is_same_v<remove_cvref_t<scalar_t>, float> ||
|
||||
(is_same_v<remove_cvref_t<scalar_t>, half_t> && scalar_per_x_vector % 2 == 0);
|
||||
#elif CK_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER && (!CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT)
|
||||
#elif CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER && (!CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT)
|
||||
bool constexpr use_amd_buffer_addressing = is_same_v<remove_cvref_t<scalar_t>, int32_t>;
|
||||
#elif(!CK_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER) && CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT
|
||||
#elif(!CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER) && CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT
|
||||
bool constexpr use_amd_buffer_addressing =
|
||||
is_same_v<remove_cvref_t<scalar_t>, float> ||
|
||||
(is_same_v<remove_cvref_t<scalar_t>, half_t> && scalar_per_x_vector % 2 == 0);
|
||||
@@ -512,9 +512,9 @@ struct buffer_view<address_space_enum::global,
|
||||
}
|
||||
|
||||
template <typename X,
|
||||
typename enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
|
||||
typename scalar_type<remove_cvref_t<T>>::type>::value,
|
||||
bool>::type = false>
|
||||
typename std::enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
|
||||
typename scalar_type<remove_cvref_t<T>>::type>::value,
|
||||
bool>::type = false>
|
||||
CK_TILE_DEVICE void atomic_max(index_t i, bool is_valid_element, const X& x)
|
||||
{
|
||||
// X contains multiple T
|
||||
@@ -527,7 +527,7 @@ struct buffer_view<address_space_enum::global,
|
||||
|
||||
static_assert(get_address_space() == address_space_enum::global, "only support global mem");
|
||||
|
||||
#if CK_USE_AMD_BUFFER_ATOMIC_MAX_FLOAT64
|
||||
#if CK_TILE_USE_AMD_BUFFER_ATOMIC_MAX_FLOAT64
|
||||
using scalar_t = typename scalar_type<remove_cvref_t<T>>::type;
|
||||
bool constexpr use_amd_buffer_addressing = is_same_v<remove_cvref_t<scalar_t>, double>;
|
||||
#else
|
||||
@@ -628,10 +628,10 @@ struct buffer_view<address_space_enum::lds,
|
||||
|
||||
// i is offset of T, not X. i should be aligned to X
|
||||
template <typename X,
|
||||
bool oob_conditional_check = true,
|
||||
typename enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
|
||||
typename scalar_type<remove_cvref_t<T>>::type>::value,
|
||||
bool>::type = false>
|
||||
bool oob_conditional_check = true,
|
||||
typename std::enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
|
||||
typename scalar_type<remove_cvref_t<T>>::type>::value,
|
||||
bool>::type = false>
|
||||
CK_TILE_DEVICE constexpr auto
|
||||
get(index_t i, bool is_valid_element, bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
@@ -645,7 +645,7 @@ struct buffer_view<address_space_enum::lds,
|
||||
|
||||
if(is_valid_element)
|
||||
{
|
||||
#if CK_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
|
||||
#if CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
|
||||
X tmp;
|
||||
|
||||
__builtin_memcpy(&tmp, &(p_data_[i]), sizeof(X));
|
||||
@@ -671,9 +671,9 @@ struct buffer_view<address_space_enum::lds,
|
||||
// i is offset of T, not X. i should be aligned to X
|
||||
template <InMemoryDataOperationEnum Op,
|
||||
typename X,
|
||||
typename enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
|
||||
typename scalar_type<remove_cvref_t<T>>::type>::value,
|
||||
bool>::type = false>
|
||||
typename std::enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
|
||||
typename scalar_type<remove_cvref_t<T>>::type>::value,
|
||||
bool>::type = false>
|
||||
CK_TILE_DEVICE void update(index_t i, bool is_valid_element, const X& x)
|
||||
{
|
||||
if constexpr(Op == InMemoryDataOperationEnum::set)
|
||||
@@ -690,9 +690,9 @@ struct buffer_view<address_space_enum::lds,
|
||||
|
||||
// i is offset of T, not X. i should be aligned to X
|
||||
template <typename X,
|
||||
typename enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
|
||||
typename scalar_type<remove_cvref_t<T>>::type>::value,
|
||||
bool>::type = false>
|
||||
typename std::enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
|
||||
typename scalar_type<remove_cvref_t<T>>::type>::value,
|
||||
bool>::type = false>
|
||||
CK_TILE_DEVICE void set(index_t i, bool is_valid_element, const X& x)
|
||||
{
|
||||
// X contains multiple T
|
||||
@@ -703,7 +703,7 @@ struct buffer_view<address_space_enum::lds,
|
||||
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
|
||||
"wrong! X should contain multiple T");
|
||||
|
||||
#if CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE
|
||||
#if CK_TILE_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE
|
||||
bool constexpr workaround_int8_ds_write_issue = true;
|
||||
#else
|
||||
bool constexpr workaround_int8_ds_write_issue = false;
|
||||
@@ -807,7 +807,7 @@ struct buffer_view<address_space_enum::lds,
|
||||
{
|
||||
if(is_valid_element)
|
||||
{
|
||||
#if CK_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
|
||||
#if CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
|
||||
X tmp = x;
|
||||
|
||||
__builtin_memcpy(&(p_data_[i]), &tmp, sizeof(X));
|
||||
@@ -899,10 +899,10 @@ struct buffer_view<address_space_enum::vgpr,
|
||||
|
||||
// i is offset of T, not X. i should be aligned to X
|
||||
template <typename X,
|
||||
bool oob_conditional_check = true,
|
||||
typename enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
|
||||
typename scalar_type<remove_cvref_t<T>>::type>::value,
|
||||
bool>::type = false>
|
||||
bool oob_conditional_check = true,
|
||||
typename std::enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
|
||||
typename scalar_type<remove_cvref_t<T>>::type>::value,
|
||||
bool>::type = false>
|
||||
CK_TILE_DEVICE constexpr auto
|
||||
get(index_t i, bool is_valid_element, bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
@@ -916,7 +916,7 @@ struct buffer_view<address_space_enum::vgpr,
|
||||
|
||||
if(is_valid_element)
|
||||
{
|
||||
#if CK_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
|
||||
#if CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
|
||||
X tmp;
|
||||
|
||||
__builtin_memcpy(&tmp, &(p_data_[i]), sizeof(X));
|
||||
@@ -942,9 +942,9 @@ struct buffer_view<address_space_enum::vgpr,
|
||||
// i is offset of T, not X. i should be aligned to X
|
||||
template <InMemoryDataOperationEnum Op,
|
||||
typename X,
|
||||
typename enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
|
||||
typename scalar_type<remove_cvref_t<T>>::type>::value,
|
||||
bool>::type = false>
|
||||
typename std::enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
|
||||
typename scalar_type<remove_cvref_t<T>>::type>::value,
|
||||
bool>::type = false>
|
||||
CK_TILE_DEVICE void update(index_t i, bool is_valid_element, const X& x)
|
||||
{
|
||||
if constexpr(Op == InMemoryDataOperationEnum::set)
|
||||
@@ -961,9 +961,9 @@ struct buffer_view<address_space_enum::vgpr,
|
||||
|
||||
// i is offset of T, not X. i should be aligned to X
|
||||
template <typename X,
|
||||
typename enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
|
||||
typename scalar_type<remove_cvref_t<T>>::type>::value,
|
||||
bool>::type = false>
|
||||
typename std::enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
|
||||
typename scalar_type<remove_cvref_t<T>>::type>::value,
|
||||
bool>::type = false>
|
||||
CK_TILE_DEVICE void set(index_t i, bool is_valid_element, const X& x)
|
||||
{
|
||||
// X contains multiple T
|
||||
@@ -976,7 +976,7 @@ struct buffer_view<address_space_enum::vgpr,
|
||||
|
||||
if(is_valid_element)
|
||||
{
|
||||
#if CK_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
|
||||
#if CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
|
||||
X tmp = x;
|
||||
|
||||
__builtin_memcpy(&(p_data_[i]), &tmp, sizeof(X));
|
||||
@@ -1024,13 +1024,13 @@ CK_TILE_HOST_DEVICE constexpr auto make_buffer_view(T* p, BufferSizeType buffer_
|
||||
return buffer_view<BufferAddressSpace, T, BufferSizeType, true, Coherence>{p, buffer_size};
|
||||
}
|
||||
|
||||
template <
|
||||
address_space_enum BufferAddressSpace,
|
||||
amd_buffer_coherence_enum Coherence = amd_buffer_coherence_enum::coherence_default,
|
||||
typename T,
|
||||
typename BufferSizeType,
|
||||
typename X,
|
||||
typename enable_if<is_same<remove_cvref_t<T>, remove_cvref_t<X>>::value, bool>::type = false>
|
||||
template <address_space_enum BufferAddressSpace,
|
||||
amd_buffer_coherence_enum Coherence = amd_buffer_coherence_enum::coherence_default,
|
||||
typename T,
|
||||
typename BufferSizeType,
|
||||
typename X,
|
||||
typename std::enable_if<is_same<remove_cvref_t<T>, remove_cvref_t<X>>::value,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
make_buffer_view(T* p, BufferSizeType buffer_size, X invalid_element_value)
|
||||
{
|
||||
@@ -1038,4 +1038,4 @@ make_buffer_view(T* p, BufferSizeType buffer_size, X invalid_element_value)
|
||||
p, buffer_size, invalid_element_value};
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -14,7 +14,6 @@
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
namespace tile_program {
|
||||
|
||||
template <typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
@@ -90,5 +89,4 @@ set_slice_tile(static_distributed_tensor<DstDataType_, DstStaticTileDistribution
|
||||
dst_tile.SetSlicedThreadData(sliced_y_origins, sliced_y_lengths, src_tile.get_thread_buffer());
|
||||
}
|
||||
|
||||
} // namespace tile_program
|
||||
} // namespace ck_tile
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -10,6 +10,7 @@
|
||||
#include "ck_tile/core/container/container_helper.hpp"
|
||||
#include "ck_tile/core/numeric/math.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
#include "ck_tile/core/utility/limits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -457,7 +458,7 @@ transform_tensor_adaptor(const OldTensorAdaptor& old_tensor_adaptor,
|
||||
generate_sequence(lambda_get_up_dim_num<NewTransforms>{}, number<num_new_transform>{});
|
||||
|
||||
constexpr auto up_dim_numbers_scan = merge_sequences(
|
||||
Sequence<0>{}, inclusive_scan_sequence(up_dim_numbers, plus<index_t>{}, number<0>{}));
|
||||
sequence<0>{}, inclusive_scan_sequence(up_dim_numbers, plus<index_t>{}, number<0>{}));
|
||||
|
||||
constexpr auto up_dim_hidden_idss = generate_tuple(
|
||||
[ old_hidden_dim_number, up_dim_numbers_scan ](auto i) constexpr {
|
||||
@@ -510,7 +511,7 @@ CK_TILE_HOST_DEVICE constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a
|
||||
|
||||
// shift
|
||||
constexpr index_t adaptor0_max_hidden_id = [&]() {
|
||||
index_t adaptor0_max_hidden_id_ = NumericLimits<index_t>::Min();
|
||||
index_t adaptor0_max_hidden_id_ = numeric_limits<index_t>::min();
|
||||
|
||||
static_for<0, TensorAdaptor0::get_num_of_transform(), 1>{}([&](auto itran) {
|
||||
constexpr index_t ndim_low =
|
||||
@@ -536,7 +537,7 @@ CK_TILE_HOST_DEVICE constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a
|
||||
}();
|
||||
|
||||
constexpr index_t adaptor1_min_hidden_id = [&]() {
|
||||
index_t adaptor1_min_hidden_id_ = NumericLimits<index_t>::Max();
|
||||
index_t adaptor1_min_hidden_id_ = numeric_limits<index_t>::max();
|
||||
|
||||
static_for<0, TensorAdaptor1::get_num_of_transform(), 1>{}([&](auto itran) {
|
||||
constexpr index_t ndim_low =
|
||||
@@ -680,7 +681,9 @@ CK_TILE_HOST_DEVICE constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a
|
||||
remove_cvref_t<decltype(top_dim_hidden_ids)>>{all_transforms};
|
||||
}
|
||||
|
||||
template <typename X, typename... Xs, typename enable_if<sizeof...(Xs) >= 2, bool>::type = false>
|
||||
template <typename X,
|
||||
typename... Xs,
|
||||
typename std::enable_if<sizeof...(Xs) >= 2, bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr auto chain_tensor_adaptors(const X& x, const Xs&... xs)
|
||||
{
|
||||
return chain_tensor_adaptors(x, chain_tensor_adaptors(xs...));
|
||||
|
||||
@@ -363,7 +363,7 @@ make_naive_tensor_descriptor_packed(const tuple<Lengths...>& lengths,
|
||||
|
||||
constexpr auto visible_dim_hidden_ids = typename arithmetic_sequence_gen<1, N + 1, 1>::type{};
|
||||
|
||||
const auto element_space_size = container_reduce(lengths, math::multiplies{}, long_number<1>{});
|
||||
const auto element_space_size = container_reduce(lengths, multiplies{}, long_number<1>{});
|
||||
|
||||
using GuaranteedVectorLengths =
|
||||
typename sequence_merge<typename uniform_sequence_gen<N, -1>::type,
|
||||
@@ -392,8 +392,7 @@ CK_TILE_HOST_DEVICE constexpr auto make_naive_tensor_descriptor_packed_with_offs
|
||||
number<GuaranteedLastDimensionVectorLength> = number<-1>{})
|
||||
{
|
||||
const auto desc_0 = [&]() {
|
||||
const auto element_space_size =
|
||||
container_reduce(lengths, math::multiplies{}, long_number<1>{});
|
||||
const auto element_space_size = container_reduce(lengths, multiplies{}, long_number<1>{});
|
||||
|
||||
const auto transforms = make_tuple(make_offset_transform(element_space_size, offset));
|
||||
|
||||
@@ -442,7 +441,7 @@ make_naive_tensor_descriptor_aligned(const tuple<Lengths...>& lengths, Align ali
|
||||
|
||||
constexpr index_t N = sizeof...(Lengths);
|
||||
|
||||
const auto stride_n_minus_2 = math::integer_least_multiple(lengths[number<N - 1>{}], align);
|
||||
const auto stride_n_minus_2 = integer_least_multiple(lengths[number<N - 1>{}], align);
|
||||
|
||||
auto strides = generate_tuple(
|
||||
[&](auto i) {
|
||||
@@ -456,12 +455,8 @@ make_naive_tensor_descriptor_aligned(const tuple<Lengths...>& lengths, Align ali
|
||||
}
|
||||
else
|
||||
{
|
||||
return container_reduce(lengths,
|
||||
math::multiplies{},
|
||||
number<stride_n_minus_2>{},
|
||||
i + I1,
|
||||
number<N - 1>{},
|
||||
I1);
|
||||
return container_reduce(
|
||||
lengths, multiplies{}, number<stride_n_minus_2>{}, i + I1, number<N - 1>{}, I1);
|
||||
}
|
||||
},
|
||||
number<N>{});
|
||||
|
||||
@@ -58,11 +58,12 @@ struct tensor_view
|
||||
#endif
|
||||
// X is vector of DataType.
|
||||
// "coord" is coordinate of DataType, not X. "coord" should be aligned to X
|
||||
template <typename X,
|
||||
bool oob_conditional_check = true,
|
||||
typename enable_if<is_same_v<typename scalar_type<remove_cvref_t<X>>::type,
|
||||
typename scalar_type<remove_cvref_t<DataType>>::type>,
|
||||
bool>::type = false>
|
||||
template <
|
||||
typename X,
|
||||
bool oob_conditional_check = true,
|
||||
typename std::enable_if<is_same_v<typename scalar_type<remove_cvref_t<X>>::type,
|
||||
typename scalar_type<remove_cvref_t<DataType>>::type>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr remove_cvref_t<X>
|
||||
get_vectorized_elements(const TensorCoord& coord,
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
@@ -75,11 +76,12 @@ struct tensor_view
|
||||
|
||||
// X is vector of DataType.
|
||||
// "coord" is coordinate of DataType, not X. "coord" should be aligned to X
|
||||
template <typename X,
|
||||
bool oob_conditional_check = true,
|
||||
typename enable_if<is_same_v<typename scalar_type<remove_cvref_t<X>>::type,
|
||||
typename scalar_type<remove_cvref_t<DataType>>::type>,
|
||||
bool>::type = false>
|
||||
template <
|
||||
typename X,
|
||||
bool oob_conditional_check = true,
|
||||
typename std::enable_if<is_same_v<typename scalar_type<remove_cvref_t<X>>::type,
|
||||
typename scalar_type<remove_cvref_t<DataType>>::type>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE void
|
||||
get_vectorized_elements_raw(remove_cvref_t<X>& dst,
|
||||
const TensorCoord& coord,
|
||||
@@ -91,10 +93,11 @@ struct tensor_view
|
||||
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord));
|
||||
}
|
||||
|
||||
template <typename X,
|
||||
typename enable_if<is_same_v<typename scalar_type<remove_cvref_t<X>>::type,
|
||||
typename scalar_type<remove_cvref_t<DataType>>::type>,
|
||||
bool>::type = false>
|
||||
template <
|
||||
typename X,
|
||||
typename std::enable_if<is_same_v<typename scalar_type<remove_cvref_t<X>>::type,
|
||||
typename scalar_type<remove_cvref_t<DataType>>::type>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr void async_get_vectorized_elements(remove_cvref_t<DataType>* smem,
|
||||
const TensorCoord& coord) const
|
||||
{
|
||||
@@ -103,11 +106,12 @@ struct tensor_view
|
||||
|
||||
// X is vector of DataType.
|
||||
// "coord" is coordinate of DataType, not X. "coord" should be aligned to X
|
||||
template <typename X,
|
||||
bool oob_conditional_check = true,
|
||||
typename enable_if<is_same_v<typename scalar_type<remove_cvref_t<X>>::type,
|
||||
typename scalar_type<remove_cvref_t<DataType>>::type>,
|
||||
bool>::type = false>
|
||||
template <
|
||||
typename X,
|
||||
bool oob_conditional_check = true,
|
||||
typename std::enable_if<is_same_v<typename scalar_type<remove_cvref_t<X>>::type,
|
||||
typename scalar_type<remove_cvref_t<DataType>>::type>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr void set_vectorized_elements(
|
||||
const TensorCoord& coord, const X& x, bool_constant<oob_conditional_check> = {})
|
||||
{
|
||||
@@ -117,11 +121,12 @@ struct tensor_view
|
||||
x);
|
||||
}
|
||||
|
||||
template <typename X,
|
||||
bool oob_conditional_check = true,
|
||||
typename enable_if<is_same_v<typename scalar_type<remove_cvref_t<X>>::type,
|
||||
typename scalar_type<remove_cvref_t<DataType>>::type>,
|
||||
bool>::type = false>
|
||||
template <
|
||||
typename X,
|
||||
bool oob_conditional_check = true,
|
||||
typename std::enable_if<is_same_v<typename scalar_type<remove_cvref_t<X>>::type,
|
||||
typename scalar_type<remove_cvref_t<DataType>>::type>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr void set_vectorized_elements_raw(
|
||||
const TensorCoord& coord, const X& x, bool_constant<oob_conditional_check> = {})
|
||||
{
|
||||
@@ -172,9 +177,9 @@ template <AddressSpaceEnum BufferAddressSpace = AddressSpaceEnum::Generic,
|
||||
typename DataType,
|
||||
typename... Lengths,
|
||||
typename... Strides,
|
||||
index_t GuaranteedLastDimensionVectorLength = -1,
|
||||
index_t GuaranteedLastDimensionVectorStride = -1,
|
||||
typename enable_if<sizeof...(Lengths) == sizeof...(Strides), bool>::type = false>
|
||||
index_t GuaranteedLastDimensionVectorLength = -1,
|
||||
index_t GuaranteedLastDimensionVectorStride = -1,
|
||||
typename std::enable_if<sizeof...(Lengths) == sizeof...(Strides), bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
make_naive_tensor_view(DataType* p,
|
||||
const tuple<Lengths...>& lengths,
|
||||
@@ -245,8 +250,7 @@ pad_tensor_view(const tensor_view& tensor_view, const TileLengths& tile_lengths,
|
||||
|
||||
const auto tile_length = tile_lengths[idim];
|
||||
|
||||
const auto new_length =
|
||||
math::integer_divide_ceil(old_length, tile_length) * tile_length;
|
||||
const auto new_length = integer_divide_ceil(old_length, tile_length) * tile_length;
|
||||
|
||||
const auto pad_length = new_length - old_length;
|
||||
|
||||
|
||||
@@ -89,7 +89,7 @@ struct tile_distribution
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
constexpr index_t x_length =
|
||||
container_reduce(typename DstrEncode::HsLengthss{}[i], math::multiplies{}, 1);
|
||||
container_reduce(typename DstrEncode::HsLengthss{}[i], multiplies{}, 1);
|
||||
|
||||
return number<x_length>{};
|
||||
},
|
||||
@@ -530,7 +530,7 @@ struct reverse_slice_sequence_impl<sequence<x, xs...>,
|
||||
|
||||
static constexpr auto slice_size = old_scan::remaining_slice_sizes::Front().value;
|
||||
static constexpr auto slice_length =
|
||||
std::conditional_t<m, number<math::gcd(x, slice_size)>, number<x>>::value;
|
||||
std::conditional_t<m, number<gcd(x, slice_size)>, number<x>>::value;
|
||||
|
||||
using dim_lengths =
|
||||
typename sequence_merge<sequence<slice_length>, typename old_scan::dim_lengths>::type;
|
||||
@@ -557,7 +557,7 @@ struct reverse_slice_sequence_impl<sequence<x>, sequence<m>, sequence<id>, Slice
|
||||
{
|
||||
static constexpr auto slice_size = SliceSize;
|
||||
static constexpr auto slice_length =
|
||||
std::conditional_t<m, number<math::gcd(x, slice_size)>, number<x>>::value;
|
||||
std::conditional_t<m, number<gcd(x, slice_size)>, number<x>>::value;
|
||||
|
||||
using dim_lengths = sequence<slice_length>;
|
||||
using dim_slices = sequence<x / slice_length>;
|
||||
|
||||
@@ -71,7 +71,7 @@ struct tile_distribution_encoding
|
||||
|
||||
// max_ndim_rh_minor_
|
||||
static constexpr index_t max_ndim_rh_minor_ =
|
||||
container_reduce(ndims_rhs_minor_, math::maximize<index_t>{}, 0);
|
||||
container_reduce(ndims_rhs_minor_, maximize<index_t>{}, 0);
|
||||
|
||||
// rhs_lengthss_[ndim_rh_major_][max_ndim_rh_minor_]
|
||||
static constexpr auto rhs_lengthss_ =
|
||||
@@ -122,7 +122,7 @@ struct tile_distribution_encoding
|
||||
|
||||
// max_ndim_span_minor_
|
||||
static constexpr index_t max_ndim_span_minor_ =
|
||||
container_reduce(ndims_span_minor_, math::maximize<index_t>{}, 0);
|
||||
container_reduce(ndims_span_minor_, maximize<index_t>{}, 0);
|
||||
|
||||
// rhs_major_minor_to_span_minor_ [ndim_rh_major_][max_ndim_rh_minor_]
|
||||
static constexpr auto rhs_major_minor_to_span_minor_ = [] {
|
||||
@@ -293,8 +293,7 @@ struct tile_distribution_encoding
|
||||
template <typename IdxSeq, typename PrefixSumSeq>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_sorted_info(IdxSeq, PrefixSumSeq)
|
||||
{
|
||||
using sorted_idx =
|
||||
sequence_unique_sort<IdxSeq, math::less<index_t>, math::equal<index_t>>;
|
||||
using sorted_idx = sequence_unique_sort<IdxSeq, less<index_t>, equal<index_t>>;
|
||||
|
||||
constexpr auto sorted_dims = typename sorted_idx::type{};
|
||||
constexpr auto sorted_maps = typename sorted_idx::sorted2unsorted_map{};
|
||||
|
||||
@@ -4,9 +4,9 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/container/sequence.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/container/sequence.hpp"
|
||||
#include <stdint.h>
|
||||
#include <utility>
|
||||
|
||||
|
||||
@@ -43,4 +43,30 @@ using is_known_at_compile_time = is_static<T>;
|
||||
// , this helper will also return false, which is not good(?)
|
||||
// do we need something like is_constexpr()?
|
||||
|
||||
namespace detail {
|
||||
template <class Default, class AlwaysVoid, template <class...> class Op, class... Args>
|
||||
struct detector
|
||||
{
|
||||
using value_t = std::false_type;
|
||||
using type = Default;
|
||||
};
|
||||
|
||||
template <class Default, template <class...> class Op, class... Args>
|
||||
struct detector<Default, std::void_t<Op<Args...>>, Op, Args...>
|
||||
{
|
||||
using value_t = std::true_type;
|
||||
using type = Op<Args...>;
|
||||
};
|
||||
} // namespace detail
|
||||
|
||||
struct nonesuch
|
||||
{
|
||||
~nonesuch() = delete;
|
||||
nonesuch(nonesuch const&) = delete;
|
||||
void operator=(nonesuch const&) = delete;
|
||||
};
|
||||
|
||||
template <template <class...> class Op, class... Args>
|
||||
using is_detected = typename detail::detector<nonesuch, void, Op, Args...>::value_t;
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -86,7 +86,7 @@ float launch_and_time_kernel_with_preprocess(const stream_config& s,
|
||||
#if CK_TILE_TIME_KERNEL
|
||||
if(s.time_kernel_)
|
||||
{
|
||||
#if DEBUG_LOG
|
||||
#if CK_TILE_DEBUG_LOG
|
||||
printf("%s: grid_dim {%d, %d, %d}, block_dim {%d, %d, %d} \n",
|
||||
__func__,
|
||||
grid_dim.x,
|
||||
@@ -104,7 +104,7 @@ float launch_and_time_kernel_with_preprocess(const stream_config& s,
|
||||
hip_check_error(hipGetLastError());
|
||||
|
||||
const int nrepeat = 10;
|
||||
#if DEBUG_LOG
|
||||
#if CK_TILE_DEBUG_LOG
|
||||
printf("Start running %d times...\n", nrepeat);
|
||||
#endif
|
||||
hipEvent_t start, stop;
|
||||
|
||||
@@ -107,7 +107,7 @@ struct GenericAttentionMask
|
||||
index_t x_start = [&]() {
|
||||
if constexpr(IsLocal)
|
||||
{
|
||||
index_t tmp = math::max(-y + i_y + 1, 0);
|
||||
index_t tmp = max(-y + i_y + 1, 0);
|
||||
return (tmp / XTile) * XTile; // round to tile aligned
|
||||
}
|
||||
else
|
||||
@@ -119,7 +119,7 @@ struct GenericAttentionMask
|
||||
// TODO: end could be negative, we ignore clamp here, and let caller to check
|
||||
// ... in which case end-start is negative
|
||||
index_t x_end = [&]() {
|
||||
index_t tmp = math::min(i_y + YTile - 1 + x, x_total);
|
||||
index_t tmp = min(i_y + YTile - 1 + x, x_total);
|
||||
return ((tmp + XTile - 1) / XTile) * XTile;
|
||||
}();
|
||||
|
||||
@@ -138,7 +138,7 @@ struct GenericAttentionMask
|
||||
{
|
||||
// no need to do min/max here, since i_x will never be < 0 or >= x_total
|
||||
index_t x_start = -y + i_y + 1;
|
||||
index_t x_end = math::min(i_y + x, x_total);
|
||||
index_t x_end = min(i_y + x, x_total);
|
||||
|
||||
if constexpr(IsLocal)
|
||||
{
|
||||
@@ -164,7 +164,7 @@ struct GenericAttentionMask
|
||||
// check top-right corner > x or left-borrom corner < x
|
||||
index_t i_tile_right = i_tile_left + TileWidth;
|
||||
index_t i_tile_bottom = i_tile_top + TileHeight;
|
||||
index_t x_end = math::min(i_tile_top + x, x_total);
|
||||
index_t x_end = min(i_tile_top + x, x_total);
|
||||
|
||||
bool top_right_edge = i_tile_right > (i_tile_top + x);
|
||||
bool bottom_left_edge = i_tile_bottom > (i_tile_left + y);
|
||||
@@ -176,7 +176,7 @@ struct GenericAttentionMask
|
||||
{
|
||||
// only need to check top-right corner > x
|
||||
index_t i_tile_right = i_tile_left + TileWidth;
|
||||
index_t x_end = math::min(i_tile_top + x, x_total);
|
||||
index_t x_end = min(i_tile_top + x, x_total);
|
||||
|
||||
bool top_right_edge = i_tile_right > x_end;
|
||||
return top_right_edge;
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -195,7 +196,7 @@ struct BlockFmhaPipelineQRKSVS
|
||||
const auto [seqlen_k_start, seqlen_k_end] =
|
||||
mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
|
||||
|
||||
const auto num_total_loop = math::integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0);
|
||||
const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0);
|
||||
|
||||
// check early exit if masked and no work to do.
|
||||
if constexpr(FmhaMask::IsMasking)
|
||||
@@ -322,7 +323,7 @@ struct BlockFmhaPipelineQRKSVS
|
||||
#if !CK_FMHA_FWD_FAST_EXP2
|
||||
x = scale * x + type_convert<SaccDataType>(bias_element_func(y));
|
||||
#else
|
||||
x = scale * x + math::log2e_v<SaccDataType> *
|
||||
x = scale * x + log2e_v<SaccDataType> *
|
||||
type_convert<SaccDataType>(bias_element_func(y));
|
||||
#endif
|
||||
},
|
||||
@@ -395,14 +396,14 @@ struct BlockFmhaPipelineQRKSVS
|
||||
#if CK_FMHA_FWD_FAST_EXP2
|
||||
if constexpr(kHasBias)
|
||||
{
|
||||
p_compute(i_j_idx) = math::exp2(s[i_j_idx] - get_validated_m(m[i_idx]));
|
||||
p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx]));
|
||||
}
|
||||
else
|
||||
{
|
||||
p_compute(i_j_idx) = math::exp2(scale * s[i_j_idx] - row_max);
|
||||
p_compute(i_j_idx) = exp2(scale * s[i_j_idx] - row_max);
|
||||
}
|
||||
#else
|
||||
p_compute(i_j_idx) = math::exp(s[i_j_idx] - get_validated_m(m[i_idx]));
|
||||
p_compute(i_j_idx) = exp(s[i_j_idx] - get_validated_m(m[i_idx]));
|
||||
#endif
|
||||
});
|
||||
});
|
||||
@@ -419,16 +420,16 @@ struct BlockFmhaPipelineQRKSVS
|
||||
const auto tmp = [&]() {
|
||||
if constexpr(kHasBias)
|
||||
{
|
||||
return math::exp2(m_old[i_idx] - get_validated_m(m[i_idx]));
|
||||
return exp2(m_old[i_idx] - get_validated_m(m[i_idx]));
|
||||
}
|
||||
else
|
||||
{
|
||||
auto row_max = scale * get_validated_m(m[i_idx]);
|
||||
return math::exp2(scale * m_old[i_idx] - row_max);
|
||||
return exp2(scale * m_old[i_idx] - row_max);
|
||||
}
|
||||
}();
|
||||
#else
|
||||
const auto tmp = math::exp(m_old[i_idx] - get_validated_m(m[i_idx]));
|
||||
const auto tmp = exp(m_old[i_idx] - get_validated_m(m[i_idx]));
|
||||
#endif
|
||||
l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx];
|
||||
sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
|
||||
@@ -511,14 +512,14 @@ struct BlockFmhaPipelineQRKSVS
|
||||
#if CK_FMHA_FWD_FAST_EXP2
|
||||
if constexpr(kHasBias)
|
||||
{
|
||||
lse(i_idx) = m_[i_idx] / C_LOG2E + math::log(l_[i_idx]);
|
||||
lse(i_idx) = m_[i_idx] / C_LOG2E + log(l_[i_idx]);
|
||||
}
|
||||
else
|
||||
{
|
||||
lse(i_idx) = m_[i_idx] * scale / C_LOG2E + math::log(l_[i_idx]);
|
||||
lse(i_idx) = m_[i_idx] * scale / C_LOG2E + log(l_[i_idx]);
|
||||
}
|
||||
#else
|
||||
lse(i_idx) = m_[i_idx] + math::log(l_[i_idx]);
|
||||
lse(i_idx) = m_[i_idx] + log(l_[i_idx]);
|
||||
#endif
|
||||
});
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -68,7 +69,7 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
kPadSeqLenK ? 1 : Policy::template GetAlignmentBias<Problem>();
|
||||
|
||||
#if CK_FMHA_FWD_FAST_EXP2
|
||||
static constexpr auto R_LOG2E = 1.0 / math::log2e_v<SaccDataType>;
|
||||
static constexpr auto R_LOG2E = 1.0 / log2e_v<SaccDataType>;
|
||||
#endif
|
||||
|
||||
static constexpr index_t kBlockPerCu = []() {
|
||||
@@ -238,7 +239,7 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
const auto [seqlen_k_start, seqlen_k_end] =
|
||||
mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
|
||||
|
||||
const auto num_total_loop = math::integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0);
|
||||
const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0);
|
||||
|
||||
// check early exit if masked and no work to do.
|
||||
if constexpr(FmhaMask::IsMasking)
|
||||
@@ -365,7 +366,7 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
#if !CK_FMHA_FWD_FAST_EXP2
|
||||
x = scale * x + type_convert<SaccDataType>(bias_element_func(y));
|
||||
#else
|
||||
x = scale * x + math::log2e_v<SaccDataType> *
|
||||
x = scale * x + log2e_v<SaccDataType> *
|
||||
type_convert<SaccDataType>(bias_element_func(y));
|
||||
#endif
|
||||
},
|
||||
@@ -474,14 +475,14 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
#if CK_FMHA_FWD_FAST_EXP2
|
||||
if constexpr(kHasBias)
|
||||
{
|
||||
p_compute(i_j_idx) = math::exp2(s[i_j_idx] - get_validated_m(m[i_idx]));
|
||||
p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx]));
|
||||
}
|
||||
else
|
||||
{
|
||||
p_compute(i_j_idx) = math::exp2(scale * s[i_j_idx] - row_max);
|
||||
p_compute(i_j_idx) = exp2(scale * s[i_j_idx] - row_max);
|
||||
}
|
||||
#else
|
||||
p_compute(i_j_idx) = math::exp(s[i_j_idx] - get_validated_m(m[i_idx]));
|
||||
p_compute(i_j_idx) = exp(s[i_j_idx] - get_validated_m(m[i_idx]));
|
||||
#endif
|
||||
});
|
||||
});
|
||||
@@ -498,16 +499,16 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
const auto tmp = [&]() {
|
||||
if constexpr(kHasBias)
|
||||
{
|
||||
return math::exp2(m_old[i_idx] - get_validated_m(m[i_idx]));
|
||||
return exp2(m_old[i_idx] - get_validated_m(m[i_idx]));
|
||||
}
|
||||
else
|
||||
{
|
||||
auto row_max = scale * get_validated_m(m[i_idx]);
|
||||
return math::exp2(scale * m_old[i_idx] - row_max);
|
||||
return exp2(scale * m_old[i_idx] - row_max);
|
||||
}
|
||||
}();
|
||||
#else
|
||||
const auto tmp = math::exp(m_old[i_idx] - get_validated_m(m[i_idx]));
|
||||
const auto tmp = exp(m_old[i_idx] - get_validated_m(m[i_idx]));
|
||||
#endif
|
||||
l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx];
|
||||
sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
|
||||
@@ -606,14 +607,14 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
#if CK_FMHA_FWD_FAST_EXP2
|
||||
if constexpr(kHasBias)
|
||||
{
|
||||
lse(i_idx) = m_[i_idx] * R_LOG2E + math::log(l_[i_idx]);
|
||||
lse(i_idx) = m_[i_idx] * R_LOG2E + log(l_[i_idx]);
|
||||
}
|
||||
else
|
||||
{
|
||||
lse(i_idx) = m_[i_idx] * scale * R_LOG2E + math::log(l_[i_idx]);
|
||||
lse(i_idx) = m_[i_idx] * scale * R_LOG2E + log(l_[i_idx]);
|
||||
}
|
||||
#else
|
||||
lse(i_idx) = m_[i_idx] + math::log(l_[i_idx]);
|
||||
lse(i_idx) = m_[i_idx] + log(l_[i_idx]);
|
||||
#endif
|
||||
});
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -187,7 +188,7 @@ struct BlockFmhaPipelineQRKSVSFp8
|
||||
const auto [seqlen_k_start, seqlen_k_end] =
|
||||
mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
|
||||
|
||||
const auto num_total_loop = math::integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0);
|
||||
const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0);
|
||||
|
||||
// check early exit if masked and no work to do.
|
||||
if constexpr(FmhaMask::IsMasking)
|
||||
@@ -306,8 +307,7 @@ struct BlockFmhaPipelineQRKSVSFp8
|
||||
#if !CK_FMHA_FWD_FAST_EXP2
|
||||
x = scale * x + type_convert<SaccDataType>((y));
|
||||
#else
|
||||
x = scale * x +
|
||||
math::log2e_v<SaccDataType> * type_convert<SaccDataType>((y));
|
||||
x = scale * x + log2e_v<SaccDataType> * type_convert<SaccDataType>((y));
|
||||
#endif
|
||||
},
|
||||
s_acc,
|
||||
@@ -379,14 +379,14 @@ struct BlockFmhaPipelineQRKSVSFp8
|
||||
#if CK_FMHA_FWD_FAST_EXP2
|
||||
if constexpr(kHasBias)
|
||||
{
|
||||
p_compute(i_j_idx) = math::exp2(s[i_j_idx] - get_validated_m(m[i_idx]));
|
||||
p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx]));
|
||||
}
|
||||
else
|
||||
{
|
||||
p_compute(i_j_idx) = math::exp2(scale * s[i_j_idx] - row_max);
|
||||
p_compute(i_j_idx) = exp2(scale * s[i_j_idx] - row_max);
|
||||
}
|
||||
#else
|
||||
p_compute(i_j_idx) = math::exp(s[i_j_idx] - get_validated_m(m[i_idx]));
|
||||
p_compute(i_j_idx) = exp(s[i_j_idx] - get_validated_m(m[i_idx]));
|
||||
#endif
|
||||
});
|
||||
});
|
||||
@@ -403,16 +403,16 @@ struct BlockFmhaPipelineQRKSVSFp8
|
||||
const auto tmp = [&]() {
|
||||
if constexpr(kHasBias)
|
||||
{
|
||||
return math::exp2(m_old[i_idx] - get_validated_m(m[i_idx]));
|
||||
return exp2(m_old[i_idx] - get_validated_m(m[i_idx]));
|
||||
}
|
||||
else
|
||||
{
|
||||
auto row_max = scale * get_validated_m(m[i_idx]);
|
||||
return math::exp2(scale * m_old[i_idx] - row_max);
|
||||
return exp2(scale * m_old[i_idx] - row_max);
|
||||
}
|
||||
}();
|
||||
#else
|
||||
const auto tmp = math::exp(m_old[i_idx] - get_validated_m(m[i_idx]));
|
||||
const auto tmp = exp(m_old[i_idx] - get_validated_m(m[i_idx]));
|
||||
#endif
|
||||
l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx];
|
||||
sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs_default_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -181,7 +182,7 @@ struct BlockFmhaPipelineQSKSVS
|
||||
const auto [seqlen_k_start, seqlen_k_end] =
|
||||
mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
|
||||
|
||||
const auto num_total_loop = math::integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0);
|
||||
const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0);
|
||||
|
||||
// check early exit if masked and no work to do.
|
||||
if constexpr(FmhaMask::IsMasking)
|
||||
@@ -314,7 +315,7 @@ struct BlockFmhaPipelineQSKSVS
|
||||
#if !CK_FMHA_FWD_FAST_EXP2
|
||||
x = scale * x + type_convert<SaccDataType>(bias_element_func(y));
|
||||
#else
|
||||
x = scale * x + math::log2e_v<SaccDataType> *
|
||||
x = scale * x + log2e_v<SaccDataType> *
|
||||
type_convert<SaccDataType>(bias_element_func(y));
|
||||
#endif
|
||||
},
|
||||
@@ -387,14 +388,14 @@ struct BlockFmhaPipelineQSKSVS
|
||||
#if CK_FMHA_FWD_FAST_EXP2
|
||||
if constexpr(kHasBias)
|
||||
{
|
||||
p_compute(i_j_idx) = math::exp2(s[i_j_idx] - get_validated_m(m[i_idx]));
|
||||
p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx]));
|
||||
}
|
||||
else
|
||||
{
|
||||
p_compute(i_j_idx) = math::exp2(scale * s[i_j_idx] - row_max);
|
||||
p_compute(i_j_idx) = exp2(scale * s[i_j_idx] - row_max);
|
||||
}
|
||||
#else
|
||||
p_compute(i_j_idx) = math::exp(s[i_j_idx] - get_validated_m(m[i_idx]));
|
||||
p_compute(i_j_idx) = exp(s[i_j_idx] - get_validated_m(m[i_idx]));
|
||||
#endif
|
||||
});
|
||||
});
|
||||
@@ -411,16 +412,16 @@ struct BlockFmhaPipelineQSKSVS
|
||||
const auto tmp = [&]() {
|
||||
if constexpr(kHasBias)
|
||||
{
|
||||
return math::exp2(m_old[i_idx] - get_validated_m(m[i_idx]));
|
||||
return exp2(m_old[i_idx] - get_validated_m(m[i_idx]));
|
||||
}
|
||||
else
|
||||
{
|
||||
auto row_max = scale * get_validated_m(m[i_idx]);
|
||||
return math::exp2(scale * m_old[i_idx] - row_max);
|
||||
return exp2(scale * m_old[i_idx] - row_max);
|
||||
}
|
||||
}();
|
||||
#else
|
||||
const auto tmp = math::exp(m_old[i_idx] - get_validated_m(m[i_idx]));
|
||||
const auto tmp = exp(m_old[i_idx] - get_validated_m(m[i_idx]));
|
||||
#endif
|
||||
l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx];
|
||||
sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
|
||||
@@ -503,14 +504,14 @@ struct BlockFmhaPipelineQSKSVS
|
||||
#if CK_FMHA_FWD_FAST_EXP2
|
||||
if constexpr(kHasBias)
|
||||
{
|
||||
lse(i_idx) = m_[i_idx] / C_LOG2E + math::log(l_[i_idx]);
|
||||
lse(i_idx) = m_[i_idx] / C_LOG2E + log(l_[i_idx]);
|
||||
}
|
||||
else
|
||||
{
|
||||
lse(i_idx) = m_[i_idx] * scale / C_LOG2E + math::log(l_[i_idx]);
|
||||
lse(i_idx) = m_[i_idx] * scale / C_LOG2E + log(l_[i_idx]);
|
||||
}
|
||||
#else
|
||||
lse(i_idx) = m_[i_idx] + math::log(l_[i_idx]);
|
||||
lse(i_idx) = m_[i_idx] + log(l_[i_idx]);
|
||||
#endif
|
||||
});
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
|
||||
@@ -425,7 +425,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
|
||||
return (kKPerBlock / kKPack) * (kNPerBlock / NPerRow) * (PixelsPerRow + kKPack);
|
||||
}();
|
||||
|
||||
return math::max(SingleKSize, SingleVSize);
|
||||
return max(SingleKSize, SingleVSize);
|
||||
}
|
||||
|
||||
template <typename Problem, typename BlockGemm>
|
||||
@@ -610,7 +610,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
|
||||
// constexpr index_t SingleVSize =
|
||||
// MakeVLdsBlockDescriptor<Problem>().get_element_space_size();
|
||||
constexpr index_t BufferSize =
|
||||
GetSingleSmemElementSpaceSize<Problem>(); // math::max(SingleKSize, SingleVSize);
|
||||
GetSingleSmemElementSpaceSize<Problem>(); // max(SingleKSize, SingleVSize);
|
||||
|
||||
constexpr auto k_lds_block_desc_0 =
|
||||
make_naive_tensor_descriptor(make_tuple(number<NumPrefetchK>{}, // num_buffers
|
||||
@@ -693,7 +693,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
|
||||
GetSingleSmemElementSpaceSize<Problem>() * sizeof(typename Problem::KDataType);
|
||||
|
||||
return QXPolicy::template GetSmemSizeQ<Problem>() +
|
||||
single_smem_size * math::max(NumPrefetchK, NumPrefetchV);
|
||||
single_smem_size * max(NumPrefetchK, NumPrefetchV);
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
|
||||
@@ -22,10 +22,9 @@ struct TileFmhaShape
|
||||
using Gemm1WarpTile = remove_cvref_t<Gemm1WarpTile_>;
|
||||
|
||||
static constexpr index_t NumWarps =
|
||||
reduce_on_sequence(Gemm0BlockWarps{}, math::multiplies{}, number<1>{});
|
||||
reduce_on_sequence(Gemm0BlockWarps{}, multiplies{}, number<1>{});
|
||||
|
||||
static_assert(NumWarps ==
|
||||
reduce_on_sequence(Gemm1BlockWarps{}, math::multiplies{}, number<1>{}));
|
||||
static_assert(NumWarps == reduce_on_sequence(Gemm1BlockWarps{}, multiplies{}, number<1>{}));
|
||||
|
||||
static constexpr index_t kM0 = BlockTile::at(number<0>{}); // tile size along q seqlen
|
||||
static constexpr index_t kN0 = BlockTile::at(number<1>{}); // tile size along k seqlen
|
||||
|
||||
@@ -64,8 +64,7 @@ struct BlockGemmPipelineAGmemBGmemCRegV1
|
||||
auto a_lds_block = make_tensor_view<address_space_enum::lds>(p_a_lds, a_lds_block_desc);
|
||||
|
||||
constexpr index_t a_lds_block_space_size_aligned =
|
||||
math::integer_divide_ceil(sizeof(ADataType) * a_lds_block_desc.get_element_space_size(),
|
||||
16) *
|
||||
integer_divide_ceil(sizeof(ADataType) * a_lds_block_desc.get_element_space_size(), 16) *
|
||||
16;
|
||||
|
||||
// B tile in LDS
|
||||
|
||||
@@ -64,8 +64,7 @@ struct BlockGemmPipelineAGmemBGmemCRegV2
|
||||
auto a_lds_block = make_tensor_view<address_space_enum::lds>(p_a_lds, a_lds_block_desc);
|
||||
|
||||
constexpr index_t a_lds_block_space_size_aligned =
|
||||
math::integer_divide_ceil(sizeof(ADataType) * a_lds_block_desc.get_element_space_size(),
|
||||
16) *
|
||||
integer_divide_ceil(sizeof(ADataType) * a_lds_block_desc.get_element_space_size(), 16) *
|
||||
16;
|
||||
|
||||
// B tile in LDS
|
||||
|
||||
@@ -43,10 +43,10 @@ CK_TILE_DEVICE void block_tile_reduce_sync(AccDistributedTensor_& acc_tensor,
|
||||
constexpr index_t lid_over_rid_derivative =
|
||||
DstrEncodeDetail::ps_over_rs_derivative_[idim_p_lane][idim_r];
|
||||
|
||||
static_assert(math::is_power_of_two_integer(r_length),
|
||||
static_assert(is_power_of_two_integer(r_length),
|
||||
"wrong! only support power of 2 reduction");
|
||||
|
||||
constexpr index_t nstage = math::integer_log2_floor(r_length);
|
||||
constexpr index_t nstage = integer_log2_floor(r_length);
|
||||
|
||||
// reduction sweep forward
|
||||
static_for<0, nstage, 1>{}([&](auto istage) {
|
||||
@@ -78,10 +78,10 @@ CK_TILE_DEVICE void block_tile_reduce_sync(AccDistributedTensor_& acc_tensor,
|
||||
constexpr index_t lid_over_rid_derivative =
|
||||
DstrEncodeDetail::ps_over_rs_derivative_[NDimP - 1][idim_r];
|
||||
|
||||
static_assert(math::is_power_of_two_integer(r_length),
|
||||
static_assert(is_power_of_two_integer(r_length),
|
||||
"wrong! only support power of 2 reduction");
|
||||
|
||||
constexpr index_t nstage = math::integer_log2_floor(r_length);
|
||||
constexpr index_t nstage = integer_log2_floor(r_length);
|
||||
|
||||
// broadcast sweep backward
|
||||
static_for<0, nstage, 1>{}([&](auto istage) {
|
||||
|
||||
@@ -65,6 +65,7 @@ class submodule_t:
|
||||
submodule = submodule_t()
|
||||
# formatting
|
||||
for x in all_files:
|
||||
subprocess.Popen(f'dos2unix {str(x)}', shell=True)
|
||||
cmd = f'clang-format-12 -style=file -i {str(x)}'
|
||||
#for xp in x.parents:
|
||||
#print(get_file_base(x))
|
||||
|
||||
@@ -5,13 +5,19 @@ rm -rf CMakeFiles
|
||||
|
||||
MY_PROJECT_SOURCE=$1
|
||||
|
||||
if [ $# -ge 2 ] ; then
|
||||
GPU_TARGETS=$2
|
||||
else
|
||||
GPU_TARGETS="gfx908;gfx90a;gfx940"
|
||||
fi
|
||||
|
||||
cmake \
|
||||
-D CMAKE_PREFIX_PATH=/opt/rocm \
|
||||
-D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
|
||||
-D CMAKE_CXX_FLAGS="-std=c++17 -O3 -ftemplate-backtrace-limit=0 -fPIE -Wno-gnu-line-marker" \
|
||||
-D CMAKE_BUILD_TYPE=Release \
|
||||
-D BUILD_DEV=ON \
|
||||
-D GPU_TARGETS="gfx908;gfx90a;gfx940" \
|
||||
-D GPU_TARGETS=$GPU_TARGETS \
|
||||
-D CMAKE_VERBOSE_MAKEFILE:BOOL=ON \
|
||||
-D USE_BITINT_EXTENSION_INT4=OFF \
|
||||
${MY_PROJECT_SOURCE}
|
||||
|
||||
@@ -5,13 +5,19 @@ rm -rf CMakeFiles
|
||||
|
||||
MY_PROJECT_SOURCE=$1
|
||||
|
||||
if [ $# -ge 2 ] ; then
|
||||
GPU_TARGETS=$2
|
||||
else
|
||||
GPU_TARGETS="gfx908;gfx90a;gfx940"
|
||||
fi
|
||||
|
||||
cmake \
|
||||
-D CMAKE_PREFIX_PATH=/opt/rocm \
|
||||
-D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
|
||||
-D CMAKE_CXX_FLAGS="-O3" \
|
||||
-D CMAKE_BUILD_TYPE=Release \
|
||||
-D BUILD_DEV=OFF \
|
||||
-D GPU_TARGETS="gfx908;gfx90a;gfx940" \
|
||||
-D GPU_TARGETS=$GPU_TARGETS \
|
||||
-D CMAKE_VERBOSE_MAKEFILE:BOOL=ON \
|
||||
-D USE_BITINT_EXTENSION_INT4=OFF \
|
||||
${MY_PROJECT_SOURCE}
|
||||
|
||||
Reference in New Issue
Block a user