mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-17 19:40:04 +00:00
Input/output permutation for fused attention (#460)
* reopen masking att instance due to CI is upgraded
* re-enable instances previously failed on 9110
* enable ksize-kpadding pair validity test
* add non-masked attention+permute test; expose masking boolean to attention kernel handles
* disable bench
* fix test
* move files
* bulk rename batched_gemm_masking_scale_softmax_gemm_permute to batched_gemm_softmax_gemm_permute
* format
* amend rename
* disable bench in test
* add mask/no-mask test for non-permute attention kernels
* disable broken kernel instance
* example working
add non-permuted problem statement
evaluating whether overhead comes from permutation or the extra kernel arg
* interface for bias addition without implementing it
* test and profiler running
* tidy
* mask type determined by enum class
* unify example code
* move masking specialization to its own header
* align formats
* extract helper functions
* experiment merging dims for attn w/ permute; shows perf parity with attn wo/ permute
* add tensor specialization to template args
since tensor spec packed shows perf parity when permutation isn't needed
remove redundant template args
comment on 'packed' tensor specialization
* grouped attention with input/output permute example
* format
* clean up
* refactor acc0 tile visitor
Co-authored-by: shaojiewang <wsjmessi@163.com>
Co-authored-by: Chao Liu <chao.liu2@amd.com>
[ROCm/composable_kernel commit: de37550f72]
This commit is contained in:
@@ -29,7 +29,8 @@ template <typename ADataType,
|
||||
typename ALayout,
|
||||
typename B0Layout,
|
||||
typename B1Layout,
|
||||
typename CLayout>
|
||||
typename CLayout,
|
||||
bool MaskOutUpperTriangle>
|
||||
bool profile_batched_gemm_softmax_gemm_impl(bool do_verification,
|
||||
int init_method,
|
||||
bool do_log,
|
||||
@@ -46,16 +47,18 @@ bool profile_batched_gemm_softmax_gemm_impl(bool do_verification,
|
||||
int BatchStrideA = -1,
|
||||
int BatchStrideB0 = -1,
|
||||
int BatchStrideB1 = -1,
|
||||
int BatchStrideC = -1)
|
||||
int BatchStrideC = -1,
|
||||
float alpha = 1.f)
|
||||
|
||||
{
|
||||
|
||||
using Row = tensor_layout::gemm::RowMajor;
|
||||
using Col = tensor_layout::gemm::ColumnMajor;
|
||||
using PassThrough = tensor_operation::element_wise::PassThrough;
|
||||
using Scale = tensor_operation::element_wise::Scale;
|
||||
using AElementOp = PassThrough;
|
||||
using B0ElementOp = PassThrough;
|
||||
using Acc0ElementOp = PassThrough;
|
||||
using Acc0ElementOp = Scale;
|
||||
using B1ElementOp = PassThrough;
|
||||
using CElementOp = PassThrough;
|
||||
using AccDataType = float;
|
||||
@@ -67,7 +70,7 @@ bool profile_batched_gemm_softmax_gemm_impl(bool do_verification,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
B0ElementOp,
|
||||
CElementOp>;
|
||||
Acc0ElementOp>;
|
||||
|
||||
// Ref Softmax: fp32 in, various type out
|
||||
using ReferenceSoftmaxInstance =
|
||||
@@ -185,7 +188,7 @@ bool profile_batched_gemm_softmax_gemm_impl(bool do_verification,
|
||||
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b0_element_op = B0ElementOp{};
|
||||
auto acc0_element_op = Acc0ElementOp{};
|
||||
auto acc0_element_op = Acc0ElementOp{alpha};
|
||||
auto b1_element_op = B1ElementOp{};
|
||||
auto c_element_op = CElementOp{};
|
||||
|
||||
@@ -201,7 +204,8 @@ bool profile_batched_gemm_softmax_gemm_impl(bool do_verification,
|
||||
B0ElementOp,
|
||||
Acc0ElementOp,
|
||||
B1ElementOp,
|
||||
CElementOp>;
|
||||
CElementOp,
|
||||
MaskOutUpperTriangle>;
|
||||
|
||||
// get device op instances
|
||||
const auto op_ptrs = tensor_operation::device::instance::DeviceOperationInstanceFactory<
|
||||
@@ -214,10 +218,16 @@ bool profile_batched_gemm_softmax_gemm_impl(bool do_verification,
|
||||
auto ref_gemm0 = ReferenceGemm0Instance{};
|
||||
auto ref_gemm0_invoker = ref_gemm0.MakeInvoker();
|
||||
auto ref_gemm0_argument = ref_gemm0.MakeArgument(
|
||||
a_g_m_k, b0_g_k_n, acc0_g_m_n, a_element_op, b0_element_op, PassThrough{});
|
||||
a_g_m_k, b0_g_k_n, acc0_g_m_n, a_element_op, b0_element_op, Scale{alpha});
|
||||
|
||||
ref_gemm0_invoker.Run(ref_gemm0_argument);
|
||||
|
||||
// mask out upper triangle
|
||||
acc0_g_m_n.ForEach([&](auto& self, auto idx) {
|
||||
if(MaskOutUpperTriangle && idx[1] < idx[2])
|
||||
self(idx) = -ck::NumericLimits<float>::Infinity();
|
||||
});
|
||||
|
||||
auto ref_softmax = ReferenceSoftmaxInstance{};
|
||||
auto ref_softmax_invoker = ref_softmax.MakeInvoker();
|
||||
auto ref_softmax_argument = ref_softmax.MakeArgument(acc0_g_m_n, a1_g_m_n, 1, 0, {2});
|
||||
|
||||
@@ -7,10 +7,10 @@
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/batched_gemm_masking_scale_softmax_gemm_permute.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute.hpp"
|
||||
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
@@ -22,36 +22,32 @@
|
||||
namespace ck {
|
||||
namespace profiler {
|
||||
|
||||
template <typename ADataType,
|
||||
template <index_t NumDimG,
|
||||
index_t NumDimM,
|
||||
index_t NumDimN,
|
||||
index_t NumDimK,
|
||||
index_t NumDimO,
|
||||
typename ADataType,
|
||||
typename B0DataType,
|
||||
typename B1DataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename B0Layout,
|
||||
typename B1Layout,
|
||||
typename CPermuteNumDims_G_M_O>
|
||||
bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verification,
|
||||
int init_method,
|
||||
bool do_log,
|
||||
bool time_kernel,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
int O,
|
||||
int G0,
|
||||
int G1,
|
||||
int StrideA = -1,
|
||||
int StrideB0 = -1,
|
||||
int StrideB1 = -1,
|
||||
int BatchStrideA = -1,
|
||||
int BatchStrideB0 = -1,
|
||||
int BatchStrideB1 = -1,
|
||||
float alpha = 1.f)
|
||||
typename Acc0BiasesDataType,
|
||||
typename Acc1BiasesDataType,
|
||||
tensor_operation::device::MaskingSpecialization MaskingSpec>
|
||||
bool profile_batched_gemm_softmax_gemm_permute_impl(bool do_verification,
|
||||
int init_method,
|
||||
bool do_log,
|
||||
bool time_kernel,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
int O,
|
||||
int G0,
|
||||
int G1,
|
||||
float alpha = 1.f)
|
||||
|
||||
{
|
||||
|
||||
using Row = tensor_layout::gemm::RowMajor;
|
||||
using Col = tensor_layout::gemm::ColumnMajor;
|
||||
using PassThrough = tensor_operation::element_wise::PassThrough;
|
||||
using Scale = tensor_operation::element_wise::Scale;
|
||||
using AElementOp = PassThrough;
|
||||
@@ -60,6 +56,7 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi
|
||||
using B1ElementOp = PassThrough;
|
||||
using CElementOp = PassThrough;
|
||||
using AccDataType = float;
|
||||
using tensor_operation::device::MaskingSpecialization;
|
||||
|
||||
// Ref Gemm0: various type in, fp32 out
|
||||
using ReferenceGemm0Instance = tensor_operation::host::ReferenceBatchedGemm<ADataType,
|
||||
@@ -85,67 +82,33 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi
|
||||
|
||||
bool pass = true;
|
||||
|
||||
// A layout [G0, M, G1, K]
|
||||
std::vector<ck::index_t> a_gs_ms_ks_lengths{G0, G1, M, K};
|
||||
std::vector<ck::index_t> a_gs_ms_ks_strides{M * G1 * K, K, G1 * K, 1};
|
||||
|
||||
// B0 layout [G0, N, G1, K]
|
||||
std::vector<ck::index_t> b0_gs_ns_ks_lengths{G0, G1, N, K};
|
||||
std::vector<ck::index_t> b0_gs_ns_ks_strides{N * G1 * K, K, G1 * K, 1};
|
||||
|
||||
// B1 layout [G0, N, G1, O]
|
||||
std::vector<ck::index_t> b1_gs_os_ns_lengths{G0, G1, O, N};
|
||||
std::vector<ck::index_t> b1_gs_os_ns_strides{N * G1 * O, O, 1, G1 * O};
|
||||
|
||||
// C layout [G0, M, G1, O]
|
||||
std::vector<ck::index_t> c_gs_ms_os_lengths{G0, G1, M, O};
|
||||
std::vector<ck::index_t> c_gs_ms_os_strides{M * G1 * O, O, G1 * O, 1};
|
||||
|
||||
const int DefaultStrideA = ck::is_same_v<ALayout, Row> ? K : M;
|
||||
const int DefaultStrideB0 = ck::is_same_v<B0Layout, Row> ? N : K;
|
||||
const int DefaultStrideB1 = ck::is_same_v<B1Layout, Row> ? O : N;
|
||||
|
||||
StrideA = (StrideA < 0) ? DefaultStrideA : StrideA;
|
||||
StrideB0 = (StrideB0 < 0) ? DefaultStrideB0 : StrideB0;
|
||||
StrideB1 = (StrideB1 < 0) ? DefaultStrideB1 : StrideB1;
|
||||
|
||||
const int DefaultBatchStrideA = (ck::is_same_v<ALayout, Col> ? K : M) * StrideA;
|
||||
const int DefaultBatchStrideB0 = (ck::is_same_v<B0Layout, Col> ? N : K) * StrideB0;
|
||||
const int DefaultBatchStrideB1 = (ck::is_same_v<B1Layout, Col> ? O : N) * StrideB1;
|
||||
|
||||
BatchStrideA = BatchStrideA < 0 ? DefaultBatchStrideA : BatchStrideA;
|
||||
BatchStrideB0 = BatchStrideB0 < 0 ? DefaultBatchStrideB0 : BatchStrideB0;
|
||||
BatchStrideB1 = BatchStrideB1 < 0 ? DefaultBatchStrideB1 : BatchStrideB1;
|
||||
|
||||
const int BatchCount = G0 * G1;
|
||||
|
||||
auto f_host_tensor_descriptor = [](std::size_t batch_count,
|
||||
std::size_t row,
|
||||
std::size_t col,
|
||||
std::size_t stride,
|
||||
std::size_t batch_stride,
|
||||
auto layout) {
|
||||
if(std::is_same<decltype(layout), Row>::value)
|
||||
{
|
||||
return HostTensorDescriptor(std::vector<std::size_t>({batch_count, row, col}),
|
||||
std::vector<std::size_t>({batch_stride, stride, 1}));
|
||||
}
|
||||
else
|
||||
{
|
||||
return HostTensorDescriptor(std::vector<std::size_t>({batch_count, row, col}),
|
||||
std::vector<std::size_t>({batch_stride, 1, stride}));
|
||||
}
|
||||
};
|
||||
Tensor<ADataType> a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides);
|
||||
Tensor<B0DataType> b0_gs_ns_ks(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides);
|
||||
Tensor<B1DataType> b1_gs_os_ns(b1_gs_os_ns_lengths, b1_gs_os_ns_strides);
|
||||
Tensor<CDataType> c_gs_ms_os_host_result(c_gs_ms_os_lengths, c_gs_ms_os_strides);
|
||||
Tensor<CDataType> c_gs_ms_os_device_result(c_gs_ms_os_lengths, c_gs_ms_os_strides);
|
||||
|
||||
// C_m_o = A_m_k * B0_k_n * B1_n_o
|
||||
Tensor<ADataType> a_g_m_k(
|
||||
f_host_tensor_descriptor(BatchCount, M, K, StrideA, BatchStrideA, ALayout{}));
|
||||
Tensor<B0DataType> b0_g_k_n(
|
||||
f_host_tensor_descriptor(BatchCount, K, N, StrideB0, BatchStrideB0, B0Layout{}));
|
||||
Tensor<B1DataType> b1_g_n_o(
|
||||
f_host_tensor_descriptor(BatchCount, N, O, StrideB1, BatchStrideB1, B1Layout{}));
|
||||
Tensor<CDataType> c_gs_ms_os_host_result(
|
||||
std::vector<std::size_t>(c_gs_ms_os_lengths.begin(), c_gs_ms_os_lengths.end()),
|
||||
std::vector<std::size_t>(c_gs_ms_os_strides.begin(), c_gs_ms_os_strides.end()));
|
||||
Tensor<CDataType> c_gs_ms_os_device_result(
|
||||
std::vector<std::size_t>(c_gs_ms_os_lengths.begin(), c_gs_ms_os_lengths.end()),
|
||||
std::vector<std::size_t>(c_gs_ms_os_strides.begin(), c_gs_ms_os_strides.end()));
|
||||
// Host verification: Output of Gemm0 is input A of Gemm1
|
||||
Tensor<AccDataType> acc0_g_m_n(f_host_tensor_descriptor(BatchCount, M, N, N, M * N, Row{}));
|
||||
Tensor<ADataType> a1_g_m_n(f_host_tensor_descriptor(BatchCount, M, N, N, M * N, Row{}));
|
||||
Tensor<CDataType> c_g_m_o_host_result(std::vector<int>{BatchCount, M, O},
|
||||
std::vector<int>{M * O, O, 1});
|
||||
|
||||
std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl;
|
||||
std::cout << "b0_g_k_n: " << b0_g_k_n.mDesc << std::endl;
|
||||
std::cout << "b1_g_n_o: " << b1_g_n_o.mDesc << std::endl;
|
||||
std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl;
|
||||
std::cout << "b0_gs_ns_ks: " << b0_gs_ns_ks.mDesc << std::endl;
|
||||
std::cout << "b1_gs_os_ns: " << b1_gs_os_ns.mDesc << std::endl;
|
||||
std::cout << "c_gs_ms_os: " << c_gs_ms_os_host_result.mDesc << std::endl;
|
||||
|
||||
std::srand(1); // work around test flakiness
|
||||
@@ -157,38 +120,38 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi
|
||||
// or not. May want to try exact same approach as the GPU kernel in the host reference
|
||||
// GEMM+Softmax+GEMM function to see if the accuracy discrepancy goes away. Until then,
|
||||
// shrink the input value range as it is less likely to produce errors of around ~1e-3.
|
||||
// a_g_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
|
||||
// b0_g_k_n.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-5, 5});
|
||||
// b1_g_n_o.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-5, 5});
|
||||
a_g_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
|
||||
b0_g_k_n.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
|
||||
b1_g_n_o.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 2});
|
||||
// a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
|
||||
// b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-5, 5});
|
||||
// b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-5, 5});
|
||||
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
|
||||
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
|
||||
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 2});
|
||||
break;
|
||||
case 2:
|
||||
a_g_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
b0_g_k_n.GenerateTensorValue(GeneratorTensor_3<B0DataType>{0.0, 1.0});
|
||||
b1_g_n_o.GenerateTensorValue(GeneratorTensor_3<B1DataType>{-0.5, 0.5});
|
||||
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_3<B0DataType>{0.0, 1.0});
|
||||
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_3<B1DataType>{-0.5, 0.5});
|
||||
break;
|
||||
case 3:
|
||||
a_g_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
|
||||
b0_g_k_n.GenerateTensorValue(GeneratorTensor_Diagonal<B0DataType>{});
|
||||
b1_g_n_o.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
|
||||
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
|
||||
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<B0DataType>{});
|
||||
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
|
||||
break;
|
||||
default:
|
||||
a_g_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
|
||||
b0_g_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{});
|
||||
b1_g_n_o.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
|
||||
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
|
||||
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Sequential<1>{});
|
||||
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
|
||||
}
|
||||
|
||||
DeviceMem a_g_m_k_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSize());
|
||||
DeviceMem b0_g_k_n_device_buf(sizeof(B0DataType) * b0_g_k_n.mDesc.GetElementSize());
|
||||
DeviceMem b1_g_n_o_device_buf(sizeof(B1DataType) * b1_g_n_o.mDesc.GetElementSize());
|
||||
DeviceMem c_gs_ms_os_device_buf(sizeof(CDataType) *
|
||||
c_gs_ms_os_device_result.mDesc.GetElementSpaceSize());
|
||||
DeviceMem a_device_buf(sizeof(ADataType) * a_gs_ms_ks.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b0_device_buf(sizeof(B0DataType) * b0_gs_ns_ks.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b1_device_buf(sizeof(B1DataType) * b1_gs_os_ns.mDesc.GetElementSpaceSize());
|
||||
DeviceMem c_device_buf(sizeof(CDataType) *
|
||||
c_gs_ms_os_device_result.mDesc.GetElementSpaceSize());
|
||||
|
||||
a_g_m_k_device_buf.ToDevice(a_g_m_k.mData.data());
|
||||
b0_g_k_n_device_buf.ToDevice(b0_g_k_n.mData.data());
|
||||
b1_g_n_o_device_buf.ToDevice(b1_g_n_o.mData.data());
|
||||
a_device_buf.ToDevice(a_gs_ms_ks.mData.data());
|
||||
b0_device_buf.ToDevice(b0_gs_ns_ks.mData.data());
|
||||
b1_device_buf.ToDevice(b1_gs_os_ns.mData.data());
|
||||
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b0_element_op = B0ElementOp{};
|
||||
@@ -196,20 +159,23 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi
|
||||
auto b1_element_op = B1ElementOp{};
|
||||
auto c_element_op = CElementOp{};
|
||||
|
||||
using DeviceOp =
|
||||
tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute<ALayout,
|
||||
B0Layout,
|
||||
B1Layout,
|
||||
CPermuteNumDims_G_M_O,
|
||||
ADataType,
|
||||
B0DataType,
|
||||
B1DataType,
|
||||
CDataType,
|
||||
AElementOp,
|
||||
B0ElementOp,
|
||||
Acc0ElementOp,
|
||||
B1ElementOp,
|
||||
CElementOp>;
|
||||
using DeviceOp = tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute<2,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
ADataType,
|
||||
B0DataType,
|
||||
B1DataType,
|
||||
CDataType,
|
||||
ck::Tuple<>,
|
||||
ck::Tuple<>,
|
||||
AElementOp,
|
||||
B0ElementOp,
|
||||
Acc0ElementOp,
|
||||
B1ElementOp,
|
||||
CElementOp,
|
||||
MaskingSpec>;
|
||||
|
||||
// get device op instances
|
||||
const auto op_ptrs = tensor_operation::device::instance::DeviceOperationInstanceFactory<
|
||||
@@ -219,6 +185,26 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
c_device_buf.FromDevice(c_gs_ms_os_device_result.mData.data());
|
||||
|
||||
Tensor<ADataType> a_g_m_k({BatchCount, M, K});
|
||||
Tensor<B0DataType> b0_g_k_n({BatchCount, K, N});
|
||||
Tensor<B1DataType> b1_g_n_o({BatchCount, N, O});
|
||||
Tensor<AccDataType> acc0_g_m_n({BatchCount, M, N}); // scratch object after gemm0
|
||||
Tensor<ADataType> a1_g_m_n({BatchCount, M, N}); // scratch object after softmax
|
||||
Tensor<CDataType> c_g_m_o_host_result({BatchCount, M, O}); // scratch object after gemm1
|
||||
|
||||
// permute
|
||||
a_gs_ms_ks.ForEach([&](auto& self, auto idx) {
|
||||
a_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
|
||||
});
|
||||
b0_gs_ns_ks.ForEach([&](auto& self, auto idx) {
|
||||
b0_g_k_n(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx);
|
||||
});
|
||||
b1_gs_os_ns.ForEach([&](auto& self, auto idx) {
|
||||
b1_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx);
|
||||
});
|
||||
|
||||
auto ref_gemm0 = ReferenceGemm0Instance{};
|
||||
auto ref_gemm0_invoker = ref_gemm0.MakeInvoker();
|
||||
auto ref_gemm0_argument = ref_gemm0.MakeArgument(
|
||||
@@ -228,7 +214,7 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi
|
||||
|
||||
// mask out upper triangle
|
||||
acc0_g_m_n.ForEach([&](auto& self, auto idx) {
|
||||
if(idx[1] < idx[2])
|
||||
if(MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle && idx[1] < idx[2])
|
||||
self(idx) = -ck::NumericLimits<float>::Infinity();
|
||||
});
|
||||
|
||||
@@ -265,23 +251,24 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi
|
||||
for(auto& op_ptr : op_ptrs)
|
||||
{
|
||||
auto argument_ptr = op_ptr->MakeArgumentPointer(
|
||||
static_cast<ADataType*>(a_g_m_k_device_buf.GetDeviceBuffer()),
|
||||
static_cast<B0DataType*>(b0_g_k_n_device_buf.GetDeviceBuffer()),
|
||||
static_cast<B1DataType*>(b1_g_n_o_device_buf.GetDeviceBuffer()),
|
||||
static_cast<CDataType*>(c_gs_ms_os_device_buf.GetDeviceBuffer()),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
O,
|
||||
BatchCount,
|
||||
static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
|
||||
static_cast<B0DataType*>(b0_device_buf.GetDeviceBuffer()),
|
||||
static_cast<B1DataType*>(b1_device_buf.GetDeviceBuffer()),
|
||||
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
|
||||
{}, // std::array<void*, 1> p_acc0_biases;
|
||||
{}, // std::array<void*, 1> p_acc1_biases;
|
||||
a_gs_ms_ks_lengths,
|
||||
a_gs_ms_ks_strides,
|
||||
b0_gs_ns_ks_lengths,
|
||||
b0_gs_ns_ks_strides,
|
||||
b1_gs_os_ns_lengths,
|
||||
b1_gs_os_ns_strides,
|
||||
c_gs_ms_os_lengths,
|
||||
c_gs_ms_os_strides,
|
||||
StrideA,
|
||||
StrideB0,
|
||||
StrideB1,
|
||||
BatchStrideA,
|
||||
BatchStrideB0,
|
||||
BatchStrideB1,
|
||||
{}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_lengths},
|
||||
{}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_strides},
|
||||
{}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths},
|
||||
{}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_strides},
|
||||
a_element_op,
|
||||
b0_element_op,
|
||||
acc0_element_op,
|
||||
@@ -319,18 +306,18 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
c_gs_ms_os_device_buf.FromDevice(c_gs_ms_os_device_result.mData.data());
|
||||
c_device_buf.FromDevice(c_gs_ms_os_device_result.mData.data());
|
||||
|
||||
pass = pass & ck::utils::check_err(c_gs_ms_os_device_result.mData,
|
||||
c_gs_ms_os_host_result.mData);
|
||||
|
||||
if(do_log)
|
||||
{
|
||||
LogRangeAsType<float>(std::cout << "a_g_m_k: ", a_g_m_k.mData, ",")
|
||||
LogRangeAsType<float>(std::cout << "a_gs_ms_ks: ", a_gs_ms_ks.mData, ",")
|
||||
<< std::endl;
|
||||
LogRangeAsType<float>(std::cout << "b0_g_k_n : ", b0_g_k_n.mData, ",")
|
||||
LogRangeAsType<float>(std::cout << "b0_gs_ns_ks : ", b0_gs_ns_ks.mData, ",")
|
||||
<< std::endl;
|
||||
LogRangeAsType<float>(std::cout << "b1_g_n_o : ", b1_g_n_o.mData, ",")
|
||||
LogRangeAsType<float>(std::cout << "b1_gs_os_ns : ", b1_gs_os_ns.mData, ",")
|
||||
<< std::endl;
|
||||
LogRangeAsType<float>(
|
||||
std::cout << "c_gs_ms_os_host_result : ", c_gs_ms_os_host_result.mData, ",")
|
||||
Reference in New Issue
Block a user