mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
CK Tile FA Training kernels (#1286)
* FA fwd dropout * FA bwd * epilogue reuse * CMakeLists update * [CK_TILE] support alibi (#1269) * add alibi support * fix code * update code based on comment * Support more hdim * fix fp8 bias * support seqlen_k=0 case * remove unused printf * fix format --------- Co-authored-by: rocking <ChunYu.Lai@amd.com> * now fwd/bwd can build * bwd alibi * add bwd validation stream_config * update generated filenames * update bwd kernel launch * CK_TILE_HOST_DEVICE in philox * Transpose -> transpose * format * format * format * Generate the instance for FA required * format * fix error in WarpGemm --------- Co-authored-by: danyao12 <danyao12> Co-authored-by: carlushuang <carlus.huang@amd.com> Co-authored-by: rocking <ChunYu.Lai@amd.com> Co-authored-by: Po Yen Chen <PoYen.Chen@amd.com> Co-authored-by: Jing Zhang <jizhan@amd.com>
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -156,7 +156,7 @@ struct HostTensorDescriptor
|
||||
}
|
||||
|
||||
const std::vector<std::size_t>& get_lengths() const { return mLens; }
|
||||
const std::vector<std::size_t>& GetStrides() const { return mStrides; }
|
||||
const std::vector<std::size_t>& get_strides() const { return mStrides; }
|
||||
|
||||
template <typename... Is>
|
||||
std::size_t GetOffsetFromMultiIndex(Is... is) const
|
||||
@@ -188,7 +188,7 @@ CK_TILE_HOST HostTensorDescriptor transpose_host_tensor_descriptor_given_new2old
|
||||
for(std::size_t i = 0; i < a.get_num_of_dimension(); i++)
|
||||
{
|
||||
new_lengths[i] = a.get_lengths()[new2old[i]];
|
||||
new_strides[i] = a.GetStrides()[new2old[i]];
|
||||
new_strides[i] = a.get_strides()[new2old[i]];
|
||||
}
|
||||
|
||||
return HostTensorDescriptor(new_lengths, new_strides);
|
||||
@@ -327,7 +327,7 @@ struct HostTensor
|
||||
|
||||
decltype(auto) get_lengths() const { return mDesc.get_lengths(); }
|
||||
|
||||
decltype(auto) GetStrides() const { return mDesc.GetStrides(); }
|
||||
decltype(auto) get_strides() const { return mDesc.get_strides(); }
|
||||
|
||||
std::size_t get_num_of_dimension() const { return mDesc.get_num_of_dimension(); }
|
||||
|
||||
@@ -481,6 +481,34 @@ struct HostTensor
|
||||
return mData[mDesc.GetOffsetFromMultiIndex(idx)];
|
||||
}
|
||||
|
||||
HostTensor<T> transpose(std::vector<size_t> axes = {}) const
|
||||
{
|
||||
if(axes.empty())
|
||||
{
|
||||
axes.resize(this->get_num_of_dimension());
|
||||
std::iota(axes.rbegin(), axes.rend(), 0);
|
||||
}
|
||||
if(axes.size() != mDesc.get_num_of_dimension())
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"HostTensor::transpose(): size of axes must match tensor dimension");
|
||||
}
|
||||
std::vector<size_t> tlengths, tstrides;
|
||||
for(const auto& axis : axes)
|
||||
{
|
||||
tlengths.push_back(get_lengths()[axis]);
|
||||
tstrides.push_back(get_strides()[axis]);
|
||||
}
|
||||
HostTensor<T> ret(*this);
|
||||
ret.mDesc = HostTensorDescriptor(tlengths, tstrides);
|
||||
return ret;
|
||||
}
|
||||
|
||||
HostTensor<T> transpose(std::vector<size_t> axes = {})
|
||||
{
|
||||
return const_cast<HostTensor<T> const*>(this)->transpose(axes);
|
||||
}
|
||||
|
||||
typename Data::iterator begin() { return mData.begin(); }
|
||||
|
||||
typename Data::iterator end() { return mData.end(); }
|
||||
|
||||
Reference in New Issue
Block a user