mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 21:51:28 +00:00
CK Tile FA Training kernels (#1286)
* FA fwd dropout * FA bwd * epilogue reuse * CMakeLists update * [CK_TILE] support alibi (#1269) * add alibi support * fix code * update code based on comment * Support more hdim * fix fp8 bias * support seqlen_k=0 case * remove unused printf * fix format --------- Co-authored-by: rocking <ChunYu.Lai@amd.com> * now fwd/bwd can build * bwd alibi * add bwd validation stream_config * update generated filenames * update bwd kernel launch * CK_TILE_HOST_DEVICE in philox * Transpose -> transpose * format * format * format * Generate the instance for FA required * format * fix error in WarpGemm --------- Co-authored-by: danyao12 <danyao12> Co-authored-by: carlushuang <carlus.huang@amd.com> Co-authored-by: rocking <ChunYu.Lai@amd.com> Co-authored-by: Po Yen Chen <PoYen.Chen@amd.com> Co-authored-by: Jing Zhang <jizhan@amd.com>
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -783,6 +783,28 @@ llvm_amdgcn_raw_buffer_store_i32(int32_t vdata,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i32");
|
||||
|
||||
// buffer store ui16
|
||||
CK_TILE_DEVICE_EXTERN void
|
||||
llvm_amdgcn_raw_buffer_store_ui16(uint16_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i16");
|
||||
|
||||
CK_TILE_DEVICE_EXTERN void
|
||||
llvm_amdgcn_raw_buffer_store_ui16x2(uint16x2_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i16");
|
||||
|
||||
CK_TILE_DEVICE_EXTERN void
|
||||
llvm_amdgcn_raw_buffer_store_ui16x4(uint16x4_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i16");
|
||||
|
||||
CK_TILE_DEVICE_EXTERN void
|
||||
llvm_amdgcn_raw_buffer_store_i32x2(int32x2_t vdata,
|
||||
int32x4_t rsrc,
|
||||
@@ -1353,7 +1375,10 @@ CK_TILE_DEVICE void amd_buffer_store_impl(const thread_buffer<T, N> src_thread_d
|
||||
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
(std::is_same<T, fp8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
(std::is_same<T, bf8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
(std::is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)),
|
||||
(std::is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
(std::is_same<T, uint16_t>::value &&
|
||||
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
(std::is_same<T, uint8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)),
|
||||
"wrong! not implemented");
|
||||
|
||||
if constexpr(std::is_same<T, float>::value) // fp32
|
||||
@@ -1492,6 +1517,49 @@ CK_TILE_DEVICE void amd_buffer_store_impl(const thread_buffer<T, N> src_thread_d
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
}
|
||||
else if constexpr(std::is_same<T, uint16_t>::value)
|
||||
{
|
||||
if constexpr(N == 1)
|
||||
{
|
||||
llvm_amdgcn_raw_buffer_store_ui16(bit_cast<uint16_t>(src_thread_data),
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
else if constexpr(N == 2)
|
||||
{
|
||||
llvm_amdgcn_raw_buffer_store_ui16x2(bit_cast<uint16x2_t>(src_thread_data),
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
else if constexpr(N == 4)
|
||||
{
|
||||
llvm_amdgcn_raw_buffer_store_ui16x4(bit_cast<uint16x4_t>(src_thread_data),
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
else if constexpr(N == 8)
|
||||
{
|
||||
llvm_amdgcn_raw_buffer_store_ui16x4(
|
||||
src_thread_data.template get_as<uint16x4_t>()[number<0>{}],
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
|
||||
llvm_amdgcn_raw_buffer_store_ui16x4(
|
||||
src_thread_data.template get_as<uint16x4_t>()[number<1>{}],
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset + 4 * sizeof(uint16_t),
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
using r_t = thread_buffer<int8_t, sizeof(T) * N>;
|
||||
@@ -1609,7 +1677,7 @@ CK_TILE_DEVICE void amd_buffer_atomic_add_impl(const thread_buffer<T, N>& src_th
|
||||
{
|
||||
if constexpr(N == 2)
|
||||
{
|
||||
llvm_amdgcn_raw_buffer_atomic_add_fp16x2(bit_cast<fp16_t>(src_thread_data),
|
||||
llvm_amdgcn_raw_buffer_atomic_add_fp16x2(bit_cast<fp16x2_t>(src_thread_data),
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
|
||||
175
include/ck_tile/core/arch/generic_memory_space_atomic.hpp
Normal file
175
include/ck_tile/core/arch/generic_memory_space_atomic.hpp
Normal file
@@ -0,0 +1,175 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
#include "ck_tile/core/numeric/vector_type.hpp"
|
||||
#include "ck_tile/core/numeric/type_convert.hpp"
|
||||
#include "ck_tile/core/container/thread_buffer.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
CK_TILE_HOST_DEVICE bf16_t add_bf16_t(const bf16_t& a, const bf16_t& b)
|
||||
{
|
||||
return type_convert<bf16_t>(type_convert<float>(a) + type_convert<float>(b));
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE bf16x2_t add_bf16x2_t(const bf16x2_t& a, const bf16x2_t& b)
|
||||
{
|
||||
bf16x2_t rtn;
|
||||
rtn[0] = add_bf16_t(a[0], b[0]);
|
||||
rtn[1] = add_bf16_t(a[1], b[1]);
|
||||
return rtn;
|
||||
}
|
||||
|
||||
// Caution: DO NOT REMOVE
|
||||
// intentionally have only declaration but no definition to cause compilation failure when trying to
|
||||
// instantiate this template. The purpose is to make the implementation of atomic_add explicit for
|
||||
// each datatype.
|
||||
template <typename X>
|
||||
CK_TILE_DEVICE void atomic_add(X* p_dst, const X& x);
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE void atomic_add<bf16x2_t>(bf16x2_t* p_dst, const bf16x2_t& x)
|
||||
{
|
||||
union U32BF162_ADDR
|
||||
{
|
||||
uint32_t* u32_a;
|
||||
bf16x2_t* bf162_a;
|
||||
};
|
||||
|
||||
union U32BF162
|
||||
{
|
||||
uint32_t u32;
|
||||
bf16x2_t bf162;
|
||||
};
|
||||
|
||||
U32BF162_ADDR dword_addr;
|
||||
U32BF162 cur_v;
|
||||
U32BF162 new_;
|
||||
uint32_t old_v, new_v;
|
||||
dword_addr.bf162_a = p_dst;
|
||||
cur_v.u32 = *dword_addr.u32_a;
|
||||
|
||||
do
|
||||
{
|
||||
old_v = cur_v.u32;
|
||||
new_.bf162 = add_bf16x2_t(cur_v.bf162, x);
|
||||
new_v = new_.u32;
|
||||
cur_v.u32 = atomicCAS(dword_addr.u32_a, old_v, new_v);
|
||||
} while(cur_v.u32 != old_v);
|
||||
}
|
||||
|
||||
template <typename T, index_t N>
|
||||
CK_TILE_DEVICE void atomic_add_g(T* p_dst, const thread_buffer<T, N>& x)
|
||||
{
|
||||
static_assert((std::is_same<T, int32_t>::value && (N == 1)) ||
|
||||
(std::is_same<T, uint32_t>::value && (N == 1)) ||
|
||||
(std::is_same<T, float>::value && (N == 1 || N == 2)) ||
|
||||
(std::is_same<T, double>::value && (N == 1 || N == 2)) ||
|
||||
(std::is_same<T, bf16_t>::value && (N == 2 || N == 4)),
|
||||
"wrong! not implemented");
|
||||
|
||||
constexpr auto I0 = number<0>{};
|
||||
constexpr auto I1 = number<1>{};
|
||||
|
||||
if constexpr(std::is_same<T, float>::value)
|
||||
{
|
||||
if constexpr(N == 1)
|
||||
{
|
||||
atomicAdd(p_dst, bit_cast<float>(x));
|
||||
}
|
||||
else if constexpr(N == 2)
|
||||
{
|
||||
atomicAdd(c_style_pointer_cast<float*>(p_dst), x.template get_as<float>()[I0]);
|
||||
atomicAdd(c_style_pointer_cast<float*>(p_dst) + 1, x.template get_as<float>()[I1]);
|
||||
}
|
||||
}
|
||||
else if constexpr(std::is_same<T, double>::value)
|
||||
{
|
||||
if constexpr(N == 1)
|
||||
{
|
||||
return atomicAdd(p_dst, bit_cast<double>(x));
|
||||
}
|
||||
else if constexpr(N == 2)
|
||||
{
|
||||
atomicAdd(c_style_pointer_cast<double*>(p_dst), x.template get_as<double>()[I0]);
|
||||
atomicAdd(c_style_pointer_cast<double*>(p_dst) + 1, x.template get_as<double>()[I1]);
|
||||
}
|
||||
}
|
||||
else if constexpr(std::is_same<T, int32_t>::value)
|
||||
{
|
||||
if constexpr(N == 1)
|
||||
{
|
||||
atomicAdd(p_dst, bit_cast<int32_t>(x));
|
||||
}
|
||||
}
|
||||
else if constexpr(std::is_same<T, uint32_t>::value)
|
||||
{
|
||||
if constexpr(N == 1)
|
||||
{
|
||||
atomicAdd(p_dst, bit_cast<uint32_t>(x));
|
||||
}
|
||||
}
|
||||
else if constexpr(std::is_same<T, bf16_t>::value)
|
||||
{
|
||||
if constexpr(N == 2)
|
||||
{
|
||||
atomic_add(c_style_pointer_cast<bf16x2_t*>(p_dst), bit_cast<bf16x2_t>(x));
|
||||
}
|
||||
else if constexpr(N == 4)
|
||||
{
|
||||
atomic_add(c_style_pointer_cast<bf16x2_t*>(p_dst), x.template get_as<bf16x2_t>()[I0]);
|
||||
atomic_add(c_style_pointer_cast<bf16x2_t*>(p_dst) + 1,
|
||||
x.template get_as<bf16x2_t>()[I1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, index_t N>
|
||||
CK_TILE_DEVICE void atomic_max_g(T* p_dst, const thread_buffer<T, N>& x)
|
||||
{
|
||||
static_assert((std::is_same<T, int32_t>::value && (N == 1)) ||
|
||||
(std::is_same<T, uint32_t>::value && (N == 1)) ||
|
||||
(std::is_same<T, float>::value && (N == 1 || N == 2)) ||
|
||||
(std::is_same<T, double>::value && (N == 1)),
|
||||
"wrong! not implemented");
|
||||
|
||||
constexpr auto I0 = number<0>{};
|
||||
constexpr auto I1 = number<1>{};
|
||||
|
||||
if constexpr(std::is_same<T, float>::value)
|
||||
{
|
||||
if constexpr(N == 1)
|
||||
{
|
||||
atomicMax(p_dst, bit_cast<float>(x));
|
||||
}
|
||||
else if constexpr(N == 2)
|
||||
{
|
||||
atomicMax(c_style_pointer_cast<float*>(p_dst), x.template get_as<float>()[I0]);
|
||||
atomicMax(c_style_pointer_cast<float*>(p_dst) + 1, x.template get_as<float>()[I1]);
|
||||
}
|
||||
}
|
||||
else if constexpr(std::is_same<T, double>::value)
|
||||
{
|
||||
if constexpr(N == 1)
|
||||
{
|
||||
atomicMax(p_dst, bit_cast<double>(x));
|
||||
}
|
||||
}
|
||||
else if constexpr(std::is_same<T, int32_t>::value)
|
||||
{
|
||||
if constexpr(N == 1)
|
||||
{
|
||||
atomicMax(p_dst, bit_cast<int32_t>(x));
|
||||
}
|
||||
}
|
||||
else if constexpr(std::is_same<T, uint32_t>::value)
|
||||
{
|
||||
if constexpr(N == 1)
|
||||
{
|
||||
atomicMax(p_dst, bit_cast<uint32_t>(x));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -144,6 +144,15 @@ using int8x16_t = int8_t __attribute((ext_vector_type(16)));
|
||||
using int8x32_t = int8_t __attribute((ext_vector_type(32)));
|
||||
using int8x64_t = int8_t __attribute((ext_vector_type(64)));
|
||||
|
||||
// ui8
|
||||
// using uint8_t
|
||||
using uint8x2_t = uint8_t __attribute((ext_vector_type(2)));
|
||||
using uint8x4_t = uint8_t __attribute((ext_vector_type(4)));
|
||||
using uint8x8_t = uint8_t __attribute((ext_vector_type(8)));
|
||||
using uint8x16_t = uint8_t __attribute((ext_vector_type(16)));
|
||||
using uint8x32_t = uint8_t __attribute((ext_vector_type(32)));
|
||||
using uint8x64_t = uint8_t __attribute((ext_vector_type(64)));
|
||||
|
||||
#if CK_TILE_USE_CUSTOM_DATA_TYPE
|
||||
// f8
|
||||
// using fp8_t
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/core/arch/amd_buffer_addressing.hpp"
|
||||
#include "ck_tile/core/arch/generic_memory_space_atomic.hpp"
|
||||
#include "ck_tile/core/container/array.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
@@ -507,10 +508,10 @@ struct buffer_view<address_space_enum::global,
|
||||
bool constexpr use_amd_buffer_addressing = false;
|
||||
#endif
|
||||
|
||||
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
|
||||
|
||||
if constexpr(use_amd_buffer_addressing)
|
||||
{
|
||||
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
|
||||
|
||||
amd_buffer_atomic_add<remove_cvref_t<T>, t_per_x>(
|
||||
x, p_data_, i, is_valid_element, buffer_size_);
|
||||
}
|
||||
@@ -518,7 +519,7 @@ struct buffer_view<address_space_enum::global,
|
||||
{
|
||||
if(is_valid_element)
|
||||
{
|
||||
atomic_add<X>(c_style_pointer_cast<X*>(&p_data_[i]), x);
|
||||
atomic_add_g<remove_cvref_t<T>, t_per_x>(&p_data_[i], x);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -547,16 +548,16 @@ struct buffer_view<address_space_enum::global,
|
||||
bool constexpr use_amd_buffer_addressing = false;
|
||||
#endif
|
||||
|
||||
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
|
||||
|
||||
if constexpr(use_amd_buffer_addressing)
|
||||
{
|
||||
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
|
||||
|
||||
amd_buffer_atomic_max<remove_cvref_t<T>, t_per_x>(
|
||||
x, p_data_, i, is_valid_element, buffer_size_);
|
||||
}
|
||||
else if(is_valid_element)
|
||||
{
|
||||
atomic_max<X>(c_style_pointer_cast<X*>(&p_data_[i]), x);
|
||||
atomic_max_g<remove_cvref_t<T>, t_per_x>(&p_data_[i], x);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
|
||||
@@ -16,7 +16,9 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename BufferView_, typename TensorDesc_>
|
||||
template <typename BufferView_,
|
||||
typename TensorDesc_,
|
||||
memory_operation_enum DstInMemOp_ = memory_operation_enum::set>
|
||||
struct tensor_view
|
||||
{
|
||||
using buffer_view = remove_reference_t<BufferView_>;
|
||||
@@ -24,6 +26,7 @@ struct tensor_view
|
||||
using TensorDesc = remove_cvref_t<TensorDesc_>;
|
||||
using TensorIndex = array<index_t, TensorDesc::get_num_of_top_dimension()>;
|
||||
using TensorCoord = decltype(make_tensor_coordinate(TensorDesc{}, TensorIndex{}));
|
||||
static constexpr auto DstInMemOp = DstInMemOp_;
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr tensor_view() = default;
|
||||
|
||||
@@ -140,6 +143,23 @@ struct tensor_view
|
||||
x);
|
||||
}
|
||||
|
||||
// X is vector of DataType.
|
||||
// "coord" is coordinate of DataType, not X. "coord" should be aligned to X
|
||||
template <typename X,
|
||||
bool oob_conditional_check = true,
|
||||
typename std::enable_if<
|
||||
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr void update_vectorized_elements(
|
||||
const TensorCoord& coord, const X& x, bool_constant<oob_conditional_check> = {})
|
||||
{
|
||||
buf_.template update<DstInMemOp, X, oob_conditional_check>(
|
||||
coord.get_offset(),
|
||||
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
|
||||
x);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void print() const
|
||||
{
|
||||
printf("tensor_view{");
|
||||
@@ -178,6 +198,7 @@ CK_TILE_HOST_DEVICE constexpr auto make_tensor_view(DataType* p,
|
||||
}
|
||||
|
||||
template <address_space_enum BufferAddressSpace = address_space_enum::generic,
|
||||
memory_operation_enum DstInMemOp = memory_operation_enum::set,
|
||||
typename DataType,
|
||||
typename... Lengths,
|
||||
typename... Strides,
|
||||
@@ -198,7 +219,7 @@ make_naive_tensor_view(DataType* p,
|
||||
|
||||
auto buffer_view = make_buffer_view<BufferAddressSpace>(p, desc.get_element_space_size());
|
||||
|
||||
return tensor_view<decltype(buffer_view), decltype(desc)>{buffer_view, desc};
|
||||
return tensor_view<decltype(buffer_view), decltype(desc), DstInMemOp>{buffer_view, desc};
|
||||
}
|
||||
|
||||
template <address_space_enum BufferAddressSpace = address_space_enum::generic,
|
||||
@@ -232,8 +253,9 @@ CK_TILE_HOST_DEVICE constexpr auto transform_tensor_view(const OldTensorView& ol
|
||||
NewLowerDimensionOldVisibleIdss{},
|
||||
NewUpperDimensionNewVisibleIdss{});
|
||||
|
||||
return tensor_view<typename OldTensorView::buffer_view, remove_cvref_t<decltype(new_desc)>>{
|
||||
old_tensor_view.buf_, new_desc};
|
||||
return tensor_view<typename OldTensorView::buffer_view,
|
||||
remove_cvref_t<decltype(new_desc)>,
|
||||
remove_cvref_t<OldTensorView>::DstInMemOp>{old_tensor_view.buf_, new_desc};
|
||||
}
|
||||
|
||||
template <typename TensorView,
|
||||
|
||||
@@ -9,6 +9,7 @@
|
||||
#include "ck_tile/core/container/sequence.hpp"
|
||||
#include "ck_tile/core/container/tuple.hpp"
|
||||
#include "ck_tile/core/container/container_helper.hpp"
|
||||
#include "ck_tile/core/container/meta_data_buffer.hpp"
|
||||
#include "ck_tile/core/tensor/tensor_adaptor.hpp"
|
||||
#include "ck_tile/core/tensor/tile_distribution_encoding.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
|
||||
@@ -594,6 +594,66 @@ struct tile_window_with_static_distribution
|
||||
});
|
||||
}
|
||||
|
||||
template <bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE void update(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor,
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
using Traits = load_store_traits;
|
||||
|
||||
using vector_t = typename Traits::vector_t;
|
||||
using SFC_Ys = typename Traits::SFC_Ys;
|
||||
|
||||
constexpr auto tile_dstr = TileDstr{};
|
||||
|
||||
// loop over thread tensor space [y0, y1, ...]
|
||||
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
|
||||
/// TODO: use structure binding (to be captured later) if compiled in C++20
|
||||
auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
|
||||
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
|
||||
|
||||
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
|
||||
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
|
||||
|
||||
// data index [y0, y1, ...]
|
||||
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
|
||||
|
||||
// read from distributed tensor
|
||||
vector_t vec_value;
|
||||
|
||||
static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) {
|
||||
constexpr auto idx_ys = generate_array(
|
||||
[&](auto jj) {
|
||||
return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
|
||||
: idx_ys_start[jj];
|
||||
},
|
||||
number<NDimY>{});
|
||||
|
||||
constexpr index_t d =
|
||||
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
|
||||
|
||||
vec_value.template get_as<DataType>()(j) =
|
||||
dstr_tensor.get_thread_buffer().template at<d>();
|
||||
});
|
||||
|
||||
// write into bottom tensor
|
||||
get_bottom_tensor_view().template update_vectorized_elements<vector_t>(
|
||||
bottom_tensor_thread_coord, vec_value, bool_constant<oob_conditional_check>{});
|
||||
|
||||
// move thread coordinate
|
||||
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
|
||||
{
|
||||
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
|
||||
|
||||
constexpr auto idx_diff_ps_ys =
|
||||
container_concat(array<index_t, NDimP>{0}, idx_diff_ys);
|
||||
|
||||
move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
// move thread's botom tensor coordiante
|
||||
// [x0', x1', ... ] ==> [offset]
|
||||
// also move window-origin
|
||||
|
||||
55
include/ck_tile/core/tensor/update_tile.hpp
Normal file
55
include/ck_tile/core/tensor/update_tile.hpp
Normal file
@@ -0,0 +1,55 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
|
||||
#include "ck_tile/core/container/container_helper.hpp"
|
||||
#include "ck_tile/core/numeric/math.hpp"
|
||||
#include "ck_tile/core/tensor/tile_window.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
typename TileDistribution_,
|
||||
typename DataType_>
|
||||
CK_TILE_DEVICE void
|
||||
update_tile(tile_window_with_static_lengths<BottomTensorView_, WindowLengths_>& tile_window_tmp,
|
||||
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor)
|
||||
{
|
||||
using DataType = remove_cvref_t<typename BottomTensorView_::DataType>;
|
||||
using TileDstr = remove_cvref_t<TileDistribution_>;
|
||||
|
||||
static_assert(std::is_same_v<remove_cvref_t<DataType_>, DataType>, "wrong!");
|
||||
|
||||
constexpr auto tile_dstr = TileDstr{};
|
||||
|
||||
auto tile_window = make_tile_window(tile_window_tmp.get_bottom_tensor_view(),
|
||||
tile_window_tmp.get_window_lengths(),
|
||||
tile_window_tmp.get_window_origin(),
|
||||
tile_dstr);
|
||||
|
||||
tile_window.update(dstr_tensor);
|
||||
}
|
||||
|
||||
template <typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
typename TileDistribution_,
|
||||
index_t NumCoord,
|
||||
typename DataType_>
|
||||
CK_TILE_DEVICE void
|
||||
update_tile(tile_window_with_static_distribution<BottomTensorView_,
|
||||
WindowLengths_,
|
||||
TileDistribution_,
|
||||
NumCoord>& tile_window,
|
||||
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor)
|
||||
{
|
||||
tile_window.update(dstr_tensor);
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
89
include/ck_tile/core/utility/philox_rand.hpp
Normal file
89
include/ck_tile/core/utility/philox_rand.hpp
Normal file
@@ -0,0 +1,89 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// Reference: https://github.com/Dao-AILab/flash-attention/blob/main/csrc/flash_attn/src/philox.cuh
|
||||
class philox
|
||||
{
|
||||
public:
|
||||
CK_TILE_HOST_DEVICE philox(unsigned long long seed_, unsigned long long offset_)
|
||||
: seed(reinterpret_cast<const uint2&>(seed_))
|
||||
{
|
||||
|
||||
ull2* tmp = reinterpret_cast<ull2*>(&counter);
|
||||
tmp->x = offset_;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE uint4 get_philox_4x32(const unsigned long long subsequence) const
|
||||
{
|
||||
|
||||
uint4 counter_ = counter;
|
||||
ull2* tmp = reinterpret_cast<ull2*>(&counter_);
|
||||
tmp->y = subsequence;
|
||||
|
||||
uint2 key_ = seed;
|
||||
// 7-round philox
|
||||
#pragma unroll
|
||||
for(int i = 0; i < 6; i++)
|
||||
{
|
||||
counter_ = philox_single_round(counter_, key_);
|
||||
key_.x += kPhilox10A;
|
||||
key_.y += kPhilox10B;
|
||||
}
|
||||
uint4 output = philox_single_round(counter_, key_);
|
||||
return output;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void get_random_16x8(uint8_t* out,
|
||||
const unsigned long long subsequence) const
|
||||
{
|
||||
uint4 tmp_ph;
|
||||
tmp_ph = get_philox_4x32(subsequence);
|
||||
|
||||
uint32_t* out_tmp = reinterpret_cast<uint32_t*>(&out[0]);
|
||||
|
||||
out_tmp[0] = tmp_ph.x;
|
||||
out_tmp[1] = tmp_ph.y;
|
||||
out_tmp[2] = tmp_ph.z;
|
||||
out_tmp[3] = tmp_ph.w;
|
||||
}
|
||||
|
||||
private:
|
||||
struct ull2
|
||||
{
|
||||
uint64_t x;
|
||||
uint64_t y;
|
||||
};
|
||||
uint4 counter;
|
||||
const uint2 seed;
|
||||
|
||||
CK_TILE_HOST_DEVICE uint2 mulhilo32(const unsigned int a, const unsigned int b) const
|
||||
{
|
||||
uint2* res;
|
||||
unsigned long long tmp;
|
||||
tmp = static_cast<unsigned long long>(a) * b;
|
||||
res = reinterpret_cast<uint2*>(&tmp);
|
||||
return *res;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE uint4 philox_single_round(const uint4 ctr, const uint2 key) const
|
||||
{
|
||||
|
||||
uint2 res0 = mulhilo32(kPhiloxSA, ctr.x);
|
||||
uint2 res1 = mulhilo32(kPhiloxSB, ctr.z);
|
||||
uint4 ret = {res1.y ^ ctr.y ^ key.x, res1.x, res0.y ^ ctr.w ^ key.y, res0.x};
|
||||
return ret;
|
||||
}
|
||||
|
||||
static const unsigned long kPhilox10A = 0x9E3779B9;
|
||||
static const unsigned long kPhilox10B = 0xBB67AE85;
|
||||
static const unsigned long kPhiloxSA = 0xD2511F53;
|
||||
static const unsigned long kPhiloxSB = 0xCD9E8D57;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
Reference in New Issue
Block a user