mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 05:31:24 +00:00
Attention with output permutation (#370)
* comment on specialization for TensorSpecialization::Packed * gemm_softmax_gemm with output permutation * scaling * refactor MatrixPadder; rename to GemmPadder * remove old sanity check * restore original gemm_softmax_gemm * revise comment in gemm_softmax_gemm example * use GetElementSpaceSize() * remove extra header * typo * remove archaic DeviceOpPtr
This commit is contained in:
@@ -129,6 +129,25 @@ namespace device {
|
||||
// B[G0, G1, ..., N0, N1, N2, ..., K0, K1, K2, ...]
|
||||
// D[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2, ...]
|
||||
// E[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2, ...]
|
||||
|
||||
// FIXME: TensorSpecialization::Packed specialization does not cover all packed tensor cases, it
|
||||
// merely degenerates into TensorSpecialization::Default with NumDimG/M/N/K = 1
|
||||
//
|
||||
// Detail- Packed tensor satisfies
|
||||
// stride_0 = 1
|
||||
// stride_i = stride_{i - 1} * extent_{i - 1}
|
||||
// So tensor
|
||||
// [G0, G1, G2, M, N]
|
||||
// transposed into tensor
|
||||
// [G0, G2, G1, M, N]
|
||||
// with strides
|
||||
// [G2 * G1 * M * N, G1 * M * N, M * N, N, 1]
|
||||
// is again a packed tensor. MakeGridDescriptor() currently just merges dimensions and ignores some
|
||||
// strides from input tensor extents so finer dimension information is lost. Merging dimensions is
|
||||
// essentially a degenerated case of TensorSpecialization::Default with NumDimG/M/N/K = 1.
|
||||
//
|
||||
// Might need to expose dimension order to the interface to fully support
|
||||
// TensorSpecialization::Packed.
|
||||
template <index_t NumDimG,
|
||||
index_t NumDimM,
|
||||
index_t NumDimN,
|
||||
|
||||
@@ -54,33 +54,6 @@ struct DeviceBatchedGemmGemm : public BaseOperator
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
template <typename ALayout,
|
||||
typename B0Layout,
|
||||
typename B1Layout,
|
||||
typename CLayout,
|
||||
typename ADataType,
|
||||
typename B0DataType,
|
||||
typename B1DataType,
|
||||
typename CDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename B0ElementwiseOperation,
|
||||
typename Acc0ElementwiseOperation,
|
||||
typename B1ElementwiseOperation,
|
||||
typename CElementwiseOperation>
|
||||
using DeviceBatchedGemmGemmPtr = std::unique_ptr<DeviceBatchedGemmGemm<ALayout,
|
||||
B0Layout,
|
||||
B1Layout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
B0DataType,
|
||||
B1DataType,
|
||||
CDataType,
|
||||
AElementwiseOperation,
|
||||
B0ElementwiseOperation,
|
||||
Acc0ElementwiseOperation,
|
||||
B1ElementwiseOperation,
|
||||
CElementwiseOperation>>;
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
|
||||
@@ -54,34 +54,6 @@ struct DeviceBatchedGemmSoftmaxGemm : public BaseOperator
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
template <typename ALayout,
|
||||
typename B0Layout,
|
||||
typename B1Layout,
|
||||
typename CLayout,
|
||||
typename ADataType,
|
||||
typename B0DataType,
|
||||
typename B1DataType,
|
||||
typename CDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename B0ElementwiseOperation,
|
||||
typename Acc0ElementwiseOperation,
|
||||
typename B1ElementwiseOperation,
|
||||
typename CElementwiseOperation>
|
||||
using DeviceBatchedGemmSoftmaxGemmPtr =
|
||||
std::unique_ptr<DeviceBatchedGemmSoftmaxGemm<ALayout,
|
||||
B0Layout,
|
||||
B1Layout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
B0DataType,
|
||||
B1DataType,
|
||||
CDataType,
|
||||
AElementwiseOperation,
|
||||
B0ElementwiseOperation,
|
||||
Acc0ElementwiseOperation,
|
||||
B1ElementwiseOperation,
|
||||
CElementwiseOperation>>;
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
|
||||
@@ -0,0 +1,59 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
|
||||
#include "device_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename ALayout,
|
||||
typename B0Layout,
|
||||
typename B1Layout,
|
||||
typename CPermuteNumDims_G_M_Gemm1N, // Sequence<>
|
||||
typename ADataType,
|
||||
typename B0DataType,
|
||||
typename B1DataType,
|
||||
typename CDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename B0ElementwiseOperation,
|
||||
typename Acc0ElementwiseOperation,
|
||||
typename B1ElementwiseOperation,
|
||||
typename CElementwiseOperation>
|
||||
struct DeviceBatchedGemmSoftmaxGemmPermute : public BaseOperator
|
||||
{
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b0,
|
||||
const void* p_b1,
|
||||
void* p_c,
|
||||
ck::index_t M,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t O,
|
||||
ck::index_t Batch,
|
||||
std::vector<index_t> c_gs_ms_os_lengths,
|
||||
std::vector<index_t> c_gs_ms_os_strides,
|
||||
ck::index_t StrideA,
|
||||
ck::index_t StrideB0,
|
||||
ck::index_t StrideB1,
|
||||
ck::index_t BatchStrideA,
|
||||
ck::index_t BatchStrideB0,
|
||||
ck::index_t BatchStrideB1,
|
||||
AElementwiseOperation a_element_op,
|
||||
B0ElementwiseOperation b0_element_op,
|
||||
Acc0ElementwiseOperation acc0_element_op,
|
||||
B1ElementwiseOperation b1_element_op,
|
||||
CElementwiseOperation c_element_op) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user