CK Tile FA Training kernels (#1286)

* FA fwd dropout

* FA bwd

* epilogue reuse

* CMakeLists update

* [CK_TILE] support alibi (#1269)

* add alibi support

* fix code

* update code based on comment

* Support more hdim

* fix fp8 bias

* support seqlen_k=0 case

* remove unused printf

* fix format

---------

Co-authored-by: rocking <ChunYu.Lai@amd.com>

* now fwd/bwd can build

* bwd alibi

* add bwd validation stream_config

* update generated filenames

* update bwd kernel launch

* CK_TILE_HOST_DEVICE in philox

* Transpose -> transpose

* format

* format

* format

* Generate the instance for FA required

* format

* fix error in WarpGemm

---------

Co-authored-by: danyao12 <danyao12>
Co-authored-by: carlushuang <carlus.huang@amd.com>
Co-authored-by: rocking <ChunYu.Lai@amd.com>
Co-authored-by: Po Yen Chen <PoYen.Chen@amd.com>
Co-authored-by: Jing Zhang <jizhan@amd.com>
This commit is contained in:
Dan Yao
2024-06-05 02:12:45 +08:00
committed by GitHub
parent 76827d82ca
commit 2cab8d39e3
70 changed files with 9506 additions and 482 deletions

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -156,7 +156,7 @@ struct HostTensorDescriptor
}
const std::vector<std::size_t>& get_lengths() const { return mLens; }
const std::vector<std::size_t>& GetStrides() const { return mStrides; }
const std::vector<std::size_t>& get_strides() const { return mStrides; }
template <typename... Is>
std::size_t GetOffsetFromMultiIndex(Is... is) const
@@ -188,7 +188,7 @@ CK_TILE_HOST HostTensorDescriptor transpose_host_tensor_descriptor_given_new2old
for(std::size_t i = 0; i < a.get_num_of_dimension(); i++)
{
new_lengths[i] = a.get_lengths()[new2old[i]];
new_strides[i] = a.GetStrides()[new2old[i]];
new_strides[i] = a.get_strides()[new2old[i]];
}
return HostTensorDescriptor(new_lengths, new_strides);
@@ -327,7 +327,7 @@ struct HostTensor
decltype(auto) get_lengths() const { return mDesc.get_lengths(); }
decltype(auto) GetStrides() const { return mDesc.GetStrides(); }
decltype(auto) get_strides() const { return mDesc.get_strides(); }
std::size_t get_num_of_dimension() const { return mDesc.get_num_of_dimension(); }
@@ -481,6 +481,34 @@ struct HostTensor
return mData[mDesc.GetOffsetFromMultiIndex(idx)];
}
HostTensor<T> transpose(std::vector<size_t> axes = {}) const
{
if(axes.empty())
{
axes.resize(this->get_num_of_dimension());
std::iota(axes.rbegin(), axes.rend(), 0);
}
if(axes.size() != mDesc.get_num_of_dimension())
{
throw std::runtime_error(
"HostTensor::transpose(): size of axes must match tensor dimension");
}
std::vector<size_t> tlengths, tstrides;
for(const auto& axis : axes)
{
tlengths.push_back(get_lengths()[axis]);
tstrides.push_back(get_strides()[axis]);
}
HostTensor<T> ret(*this);
ret.mDesc = HostTensorDescriptor(tlengths, tstrides);
return ret;
}
HostTensor<T> transpose(std::vector<size_t> axes = {})
{
return const_cast<HostTensor<T> const*>(this)->transpose(axes);
}
typename Data::iterator begin() { return mData.begin(); }
typename Data::iterator end() { return mData.end(); }