mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 13:11:25 +00:00
Input/output permutation for fused attention (#460)
* reopen masking att instance due to CI is upgraded * re-enable instances previously failed on 9110 * enable ksize-kpadding pair validity test * add non-masked attention+permute test; expose masking boolean to attention kernel handles * disable bench * fix test * move files * bulk rename batched_gemm_masking_scale_softmax_gemm_permute to batched_gemm_softmax_gemm_permute * format * amend rename * disable bench in test * add mask/no-mask test for non-permute attention kernels * disable broken kernel instance * example working add non-permuted problem statement evaluating whether overhead comes from permutation or the extra kernel arg * interface for bias addition without implementing it * test and profiler running * tidy * mask type determined by enum class * unify example code * move masking specialization to its own header * align formats * extract helper functions * experiment merging dims for attn w/ permute; shows perf parity with attn wo/ permute * add tensor specialization to template args since tensor spec packed shows perf parity when permutation isn't needed remove redundant template args comment on 'packed' tensor specialization * grouped attention with input/output permute example * format * clean up * refactor acc0 tile visitor Co-authored-by: shaojiewang <wsjmessi@163.com> Co-authored-by: Chao Liu <chao.liu2@amd.com>
This commit is contained in:
@@ -24,7 +24,8 @@ template <typename ALayout,
|
||||
typename B0ElementwiseOperation,
|
||||
typename Acc0ElementwiseOperation,
|
||||
typename B1ElementwiseOperation,
|
||||
typename CElementwiseOperation>
|
||||
typename CElementwiseOperation,
|
||||
bool MaskOutUpperTriangle> // TODO: enum for mask type
|
||||
struct DeviceBatchedGemmSoftmaxGemm : public BaseOperator
|
||||
{
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
|
||||
@@ -7,49 +7,60 @@
|
||||
#include <vector>
|
||||
|
||||
#include "device_base.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/masking_specialization.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename ALayout,
|
||||
typename B0Layout,
|
||||
typename B1Layout,
|
||||
typename CPermuteNumDims_G_M_Gemm1N, // Sequence<>
|
||||
template <index_t NumDimG,
|
||||
index_t NumDimM,
|
||||
index_t NumDimN,
|
||||
index_t NumDimK,
|
||||
index_t NumDimO,
|
||||
typename ADataType,
|
||||
typename B0DataType,
|
||||
typename B1DataType,
|
||||
typename CDataType,
|
||||
typename Acc0BiasDataType,
|
||||
typename Acc1BiasDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename B0ElementwiseOperation,
|
||||
typename Acc0ElementwiseOperation,
|
||||
typename B1ElementwiseOperation,
|
||||
typename CElementwiseOperation>
|
||||
typename CElementwiseOperation,
|
||||
MaskingSpecialization MaskingSpec>
|
||||
struct DeviceBatchedGemmSoftmaxGemmPermute : public BaseOperator
|
||||
{
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b0,
|
||||
const void* p_b1,
|
||||
void* p_c,
|
||||
ck::index_t M,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t O,
|
||||
ck::index_t Batch,
|
||||
std::vector<index_t> c_gs_ms_os_lengths,
|
||||
std::vector<index_t> c_gs_ms_os_strides,
|
||||
ck::index_t StrideA,
|
||||
ck::index_t StrideB0,
|
||||
ck::index_t StrideB1,
|
||||
ck::index_t BatchStrideA,
|
||||
ck::index_t BatchStrideB0,
|
||||
ck::index_t BatchStrideB1,
|
||||
AElementwiseOperation a_element_op,
|
||||
B0ElementwiseOperation b0_element_op,
|
||||
Acc0ElementwiseOperation acc0_element_op,
|
||||
B1ElementwiseOperation b1_element_op,
|
||||
CElementwiseOperation c_element_op) = 0;
|
||||
static constexpr index_t NumAcc0Bias = Acc0BiasDataType::Size();
|
||||
static constexpr index_t NumAcc1Bias = Acc1BiasDataType::Size();
|
||||
|
||||
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(
|
||||
const void* p_a,
|
||||
const void* p_b0,
|
||||
const void* p_b1,
|
||||
void* p_c,
|
||||
const std::array<void*, NumAcc0Bias> p_acc0_biases,
|
||||
const std::array<void*, NumAcc1Bias> p_acc1_biases,
|
||||
const std::vector<index_t>& a_gs_ms_ks_lengths,
|
||||
const std::vector<index_t>& a_gs_ms_ks_strides,
|
||||
const std::vector<index_t>& b_gs_ns_ks_lengths,
|
||||
const std::vector<index_t>& b_gs_ns_ks_strides,
|
||||
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths
|
||||
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
|
||||
const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
|
||||
const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
|
||||
const std::array<std::vector<index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_lengths,
|
||||
const std::array<std::vector<index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_strides,
|
||||
const std::array<std::vector<index_t>, NumAcc1Bias>
|
||||
acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths
|
||||
const std::array<std::vector<index_t>, NumAcc1Bias>
|
||||
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides
|
||||
AElementwiseOperation a_element_op,
|
||||
B0ElementwiseOperation b0_element_op,
|
||||
Acc0ElementwiseOperation acc0_element_op,
|
||||
B1ElementwiseOperation b1_element_op,
|
||||
CElementwiseOperation c_element_op) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
@@ -7,46 +7,50 @@
|
||||
#include <vector>
|
||||
|
||||
#include "device_base.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/masking_specialization.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename ALayout,
|
||||
typename B0Layout,
|
||||
typename B1Layout,
|
||||
typename CPermuteNumDims_G_M_Gemm1N, // Sequence<>
|
||||
template <index_t NumDimG,
|
||||
index_t NumDimM,
|
||||
index_t NumDimN,
|
||||
index_t NumDimK,
|
||||
index_t NumDimO,
|
||||
typename ADataType,
|
||||
typename B0DataType,
|
||||
typename B1DataType,
|
||||
typename CDataType,
|
||||
typename Acc0BiasDataType,
|
||||
typename Acc1BiasDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename B0ElementwiseOperation,
|
||||
typename Acc0ElementwiseOperation,
|
||||
typename B1ElementwiseOperation,
|
||||
typename CElementwiseOperation>
|
||||
typename CElementwiseOperation,
|
||||
MaskingSpecialization MaskingSpec>
|
||||
struct DeviceGroupedGemmSoftmaxGemmPermute : public BaseOperator
|
||||
{
|
||||
struct ProblemDesc
|
||||
{
|
||||
// Overall problem shape
|
||||
index_t M;
|
||||
index_t N;
|
||||
index_t K;
|
||||
index_t O;
|
||||
index_t Batch;
|
||||
std::vector<index_t> a_gs_ms_ks_lengths;
|
||||
std::vector<index_t> a_gs_ms_ks_strides;
|
||||
|
||||
// Stride for A/B0/B1; layout determined by template args
|
||||
index_t StrideA;
|
||||
index_t StrideB0;
|
||||
index_t StrideB1;
|
||||
index_t BatchStrideA;
|
||||
index_t BatchStrideB0;
|
||||
index_t BatchStrideB1;
|
||||
std::vector<index_t> b0_gs_ns_ks_lengths;
|
||||
std::vector<index_t> b0_gs_ns_ks_strides;
|
||||
|
||||
std::vector<index_t> b1_gs_os_ns_lengths;
|
||||
std::vector<index_t> b1_gs_os_ns_strides;
|
||||
|
||||
// Lengths and strides for output C
|
||||
std::vector<index_t> c_gs_ms_os_lengths;
|
||||
std::vector<index_t> c_gs_ms_os_strides;
|
||||
|
||||
std::vector<std::vector<index_t>> acc0_biases_gs_ms_ns_lengths;
|
||||
std::vector<std::vector<index_t>> acc0_biases_gs_ms_ns_strides;
|
||||
|
||||
std::vector<std::vector<index_t>> acc1_biases_gs_ms_os_lengths;
|
||||
std::vector<std::vector<index_t>> acc1_biases_gs_ms_os_strides;
|
||||
};
|
||||
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
@@ -54,6 +58,8 @@ struct DeviceGroupedGemmSoftmaxGemmPermute : public BaseOperator
|
||||
std::vector<const void*> p_b0_vec,
|
||||
std::vector<const void*> p_b1_vec,
|
||||
std::vector<void*> p_c_vec,
|
||||
std::vector<std::vector<const void*>> p_acc0_biases_vec,
|
||||
std::vector<std::vector<const void*>> p_acc1_biases_vec,
|
||||
std::vector<ProblemDesc> problem_desc_vec,
|
||||
AElementwiseOperation a_element_op,
|
||||
B0ElementwiseOperation b0_element_op,
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp"
|
||||
#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
|
||||
@@ -54,9 +55,8 @@ __global__ void
|
||||
index_t right = group_count;
|
||||
index_t group_id = index_t((left + right) / 2);
|
||||
|
||||
while((!(block_id >= arg_ptr[group_id].block_start_ &&
|
||||
block_id < arg_ptr[group_id].block_end_)) &&
|
||||
left <= right)
|
||||
while(
|
||||
(!(block_id >= arg_ptr[group_id].block_start_ && block_id < arg_ptr[group_id].block_end_)))
|
||||
{
|
||||
if(block_id < arg_ptr[group_id].block_start_)
|
||||
{
|
||||
@@ -114,14 +114,17 @@ __global__ void
|
||||
// Computes C = A * B0 * B1
|
||||
// ^^^^^^ (Acc0)
|
||||
// ^^^^^^^^^^^ (Acc1)
|
||||
template <typename ALayout,
|
||||
typename BLayout, // B0Layout
|
||||
typename B1Layout,
|
||||
typename CPermuteNumDims_G_M_Gemm1N, // Sequence<NumDimG, NumDimM, NumDimGemm1N>
|
||||
template <index_t NumDimG,
|
||||
index_t NumDimM,
|
||||
index_t NumDimN,
|
||||
index_t NumDimK,
|
||||
index_t NumDimO, // NumDimGemm1N
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename B1DataType,
|
||||
typename CDataType,
|
||||
typename Acc0BiasDataType,
|
||||
typename Acc1BiasDataType,
|
||||
typename GemmAccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename AElementwiseOperation,
|
||||
@@ -130,6 +133,10 @@ template <typename ALayout,
|
||||
typename B1ElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
GemmSpecialization GemmSpec,
|
||||
TensorSpecialization ASpec,
|
||||
TensorSpecialization BSpec,
|
||||
TensorSpecialization B1Spec,
|
||||
TensorSpecialization CSpec,
|
||||
index_t NumGemmKPrefetchStage,
|
||||
index_t BlockSize,
|
||||
index_t MPerBlock,
|
||||
@@ -170,297 +177,152 @@ template <typename ALayout,
|
||||
index_t CShuffleNXdlPerWavePerShuffle,
|
||||
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
bool MaskOutUpperTriangle,
|
||||
MaskingSpecialization MaskingSpec,
|
||||
LoopScheduler LoopSched = LoopScheduler::Default>
|
||||
struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
: public DeviceGroupedGemmSoftmaxGemmPermute<ALayout,
|
||||
BLayout,
|
||||
B1Layout,
|
||||
CPermuteNumDims_G_M_Gemm1N,
|
||||
: public DeviceGroupedGemmSoftmaxGemmPermute<NumDimG,
|
||||
NumDimM,
|
||||
NumDimN,
|
||||
NumDimK,
|
||||
NumDimO,
|
||||
ADataType,
|
||||
BDataType,
|
||||
B1DataType,
|
||||
CDataType,
|
||||
Acc0BiasDataType,
|
||||
Acc1BiasDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
AccElementwiseOperation,
|
||||
B1ElementwiseOperation,
|
||||
CElementwiseOperation>
|
||||
CElementwiseOperation,
|
||||
MaskingSpec>
|
||||
{
|
||||
using DeviceOp = DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle;
|
||||
using ProblemDesc =
|
||||
typename DeviceGroupedGemmSoftmaxGemmPermute<ALayout,
|
||||
BLayout,
|
||||
B1Layout,
|
||||
CPermuteNumDims_G_M_Gemm1N,
|
||||
ADataType,
|
||||
BDataType,
|
||||
B1DataType,
|
||||
CDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
AccElementwiseOperation,
|
||||
B1ElementwiseOperation,
|
||||
CElementwiseOperation>::ProblemDesc;
|
||||
static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0,
|
||||
"Number of dimension must be greater than 0");
|
||||
|
||||
static constexpr index_t NumAcc0Bias = Acc0BiasDataType::Size();
|
||||
static constexpr index_t NumAcc1Bias = Acc1BiasDataType::Size();
|
||||
|
||||
// TODO ANT: implement bias combination
|
||||
static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented");
|
||||
|
||||
#if 0
|
||||
// TODO ANT: use alias
|
||||
static constexpr index_t NumDimGemm0M = NumDimM;
|
||||
static constexpr index_t NumDimGemm0N = NumDimN;
|
||||
static constexpr index_t NumDimGemm0K = NumDimK;
|
||||
static constexpr index_t NumDimGemm1M = NumDimM;
|
||||
static constexpr index_t NumDimGemm1N = NumDimO;
|
||||
static constexpr index_t NumDimGemm1K = NumDimN;
|
||||
#endif
|
||||
|
||||
using DeviceOp = DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle;
|
||||
using ProblemDesc = typename DeviceGroupedGemmSoftmaxGemmPermute<NumDimG,
|
||||
NumDimM,
|
||||
NumDimN,
|
||||
NumDimK,
|
||||
NumDimO,
|
||||
ADataType,
|
||||
BDataType,
|
||||
B1DataType,
|
||||
CDataType,
|
||||
Acc0BiasDataType,
|
||||
Acc1BiasDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
AccElementwiseOperation,
|
||||
B1ElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
MaskingSpec>::ProblemDesc;
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
|
||||
static constexpr auto matrix_padder =
|
||||
GemmGemmPadder<GemmSpec, index_t, index_t, index_t, index_t>{
|
||||
MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock};
|
||||
using Transform = TransformBatchedContractionContractionToBatchedGemmGemm<
|
||||
Sequence<NumDimG, NumDimM, NumDimN, NumDimK, NumDimO>,
|
||||
Sequence<MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock>,
|
||||
GemmSpec,
|
||||
ASpec,
|
||||
BSpec,
|
||||
B1Spec,
|
||||
CSpec>;
|
||||
|
||||
static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA)
|
||||
static auto MakeAGridDescriptor_AK0_M_AK1(const std::vector<index_t>& a_gs_ms_ks_lengths_vec,
|
||||
const std::vector<index_t>& a_gs_ms_ks_strides_vec)
|
||||
{
|
||||
const auto a_grid_desc_mraw_kraw = [&]() {
|
||||
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
|
||||
make_tuple(StrideA, I1));
|
||||
}
|
||||
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
|
||||
make_tuple(I1, StrideA));
|
||||
}
|
||||
}();
|
||||
|
||||
const auto a_grid_desc_m_k = matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
|
||||
|
||||
const auto M = a_grid_desc_m_k.GetLength(I0);
|
||||
const auto K = a_grid_desc_m_k.GetLength(I1);
|
||||
|
||||
const auto AK0 = K / AK1;
|
||||
|
||||
return transform_tensor_descriptor(a_grid_desc_m_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
|
||||
make_pass_through_transform(M)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
return Transform::MakeAGridDescriptor_AK0_M_AK1(
|
||||
Transform::MakeAGridDescriptor_M_K(a_gs_ms_ks_lengths_vec, a_gs_ms_ks_strides_vec),
|
||||
Number<AK1>{});
|
||||
}
|
||||
|
||||
static auto MakeBGridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB)
|
||||
static auto MakeBGridDescriptor_BK0_N_BK1(const std::vector<index_t>& b_gs_ns_ks_lengths_vec,
|
||||
const std::vector<index_t>& b_gs_ns_ks_strides_vec)
|
||||
{
|
||||
const auto b_grid_desc_nraw_kraw = [&]() {
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
|
||||
make_tuple(I1, StrideB));
|
||||
}
|
||||
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
|
||||
make_tuple(StrideB, I1));
|
||||
}
|
||||
}();
|
||||
|
||||
const auto b_grid_desc_n_k = matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
|
||||
|
||||
const auto N = b_grid_desc_n_k.GetLength(I0);
|
||||
const auto K = b_grid_desc_n_k.GetLength(I1);
|
||||
|
||||
const auto BK0 = K / BK1;
|
||||
|
||||
return transform_tensor_descriptor(b_grid_desc_n_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
|
||||
make_pass_through_transform(N)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
return Transform::MakeB0GridDescriptor_BK0_N_BK1(
|
||||
Transform::MakeB0GridDescriptor_N_K(b_gs_ns_ks_lengths_vec, b_gs_ns_ks_strides_vec),
|
||||
Number<BK1>{});
|
||||
}
|
||||
|
||||
// Args: Gemm1KRaw, Gemm1NRaw, StrideB1
|
||||
static auto MakeB1GridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB)
|
||||
static auto
|
||||
MakeB1GridDescriptor_BK0_N_BK1(const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths_vec,
|
||||
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides_vec)
|
||||
{
|
||||
const auto b1_grid_desc_nraw_kraw = [&]() {
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, B1Layout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
|
||||
make_tuple(I1, StrideB));
|
||||
}
|
||||
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, B1Layout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
|
||||
make_tuple(StrideB, I1));
|
||||
}
|
||||
}();
|
||||
|
||||
const auto b1_grid_desc_n_k = matrix_padder.PadB1Descriptor_N_K(b1_grid_desc_nraw_kraw);
|
||||
|
||||
const auto N = b1_grid_desc_n_k.GetLength(I0);
|
||||
const auto K = b1_grid_desc_n_k.GetLength(I1);
|
||||
|
||||
const auto B1K0 = K / B1K1;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
b1_grid_desc_n_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(B1K0, B1K1)),
|
||||
make_pass_through_transform(N)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
return Transform::MakeB1GridDescriptor_BK0_N_BK1(
|
||||
Transform::MakeB1GridDescriptor_N_K(b1_gs_gemm1ns_gemm1ks_lengths_vec,
|
||||
b1_gs_gemm1ns_gemm1ks_strides_vec),
|
||||
Number<B1K1>{});
|
||||
}
|
||||
|
||||
// assume C[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...]
|
||||
static auto MakeCGridDescriptor_M_N(const std::vector<index_t>& c_gs_ms_ns_lengths_vec,
|
||||
const std::vector<index_t>& c_gs_ms_ns_strides_vec)
|
||||
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {}));
|
||||
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {}));
|
||||
using B1GridDesc_BK0_N_BK1 = decltype(MakeB1GridDescriptor_BK0_N_BK1({}, {}));
|
||||
using CGridDesc_M_N = decltype(Transform::MakeCGridDescriptor_M_N({}, {}));
|
||||
using AGridDesc_G_M_K = decltype(Transform::MakeAGridDescriptor_G_M_K({}, {}));
|
||||
using BGridDesc_G_N_K = decltype(Transform::MakeB0GridDescriptor_G_N_K({}, {}));
|
||||
using B1GridDesc_G_N_K = decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {}));
|
||||
using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
|
||||
|
||||
constexpr static auto make_MaskOutPredicate()
|
||||
{
|
||||
constexpr index_t NumDimG = CPermuteNumDims_G_M_Gemm1N::At(I0);
|
||||
constexpr index_t NumDimM = CPermuteNumDims_G_M_Gemm1N::At(I1);
|
||||
constexpr index_t NumDimN = CPermuteNumDims_G_M_Gemm1N::At(I2); // NumDimGemm1N
|
||||
|
||||
assert(c_gs_ms_ns_lengths_vec.size() == NumDimG + NumDimM + NumDimN &&
|
||||
c_gs_ms_ns_strides_vec.size() == NumDimG + NumDimM + NumDimN);
|
||||
|
||||
const auto to_tuple = [&](auto& vec, auto start, auto end) {
|
||||
return generate_tuple([&](auto i) { return vec[start + i]; }, Number<end - start>{});
|
||||
};
|
||||
|
||||
const auto c_ms_ns_lengths = to_tuple(
|
||||
c_gs_ms_ns_lengths_vec, Number<NumDimG>{}, Number<NumDimG + NumDimM + NumDimN>{});
|
||||
const auto c_ms_ns_strides = to_tuple(
|
||||
c_gs_ms_ns_strides_vec, Number<NumDimG>{}, Number<NumDimG + NumDimM + NumDimN>{});
|
||||
|
||||
// dimension Ids for M0, M1, ...
|
||||
constexpr auto mDimIds = typename arithmetic_sequence_gen<0, NumDimM, 1>::type{};
|
||||
|
||||
// dimension Ids for N0, N1, ...
|
||||
constexpr auto nDimIds =
|
||||
typename arithmetic_sequence_gen<NumDimM, NumDimM + NumDimN, 1>::type{};
|
||||
|
||||
// lengths for M0, M1, ...
|
||||
const auto mLengths = get_container_subset(c_ms_ns_lengths, mDimIds);
|
||||
|
||||
// lengths for K0, K1, ...
|
||||
const auto nLengths = get_container_subset(c_ms_ns_lengths, nDimIds);
|
||||
|
||||
// naive tensor C[M0, M1, M2, ..., N0, N1, N2...]
|
||||
const auto c_grid_desc_ms_ns =
|
||||
make_naive_tensor_descriptor(c_ms_ns_lengths, c_ms_ns_strides);
|
||||
|
||||
// transformed tensor C[MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 * N2 * ...]
|
||||
const auto c_grid_desc_mraw_nraw = transform_tensor_descriptor(
|
||||
c_grid_desc_ms_ns,
|
||||
make_tuple(make_merge_transform(mLengths), make_merge_transform(nLengths)),
|
||||
make_tuple(mDimIds, nDimIds),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return matrix_padder.PadCDescriptor_M_N(c_grid_desc_mraw_nraw);
|
||||
}
|
||||
|
||||
// assume C[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...]
|
||||
static auto MakeCGridDescriptor_G_M_N(const std::vector<index_t>& c_gs_ms_ns_lengths_vec,
|
||||
const std::vector<index_t>& c_gs_ms_ns_strides_vec)
|
||||
{
|
||||
constexpr index_t NumDimG = CPermuteNumDims_G_M_Gemm1N::At(I0);
|
||||
constexpr index_t NumDimM = CPermuteNumDims_G_M_Gemm1N::At(I1);
|
||||
constexpr index_t NumDimN = CPermuteNumDims_G_M_Gemm1N::At(I2); // NumDimGemm1N
|
||||
|
||||
assert(c_gs_ms_ns_lengths_vec.size() == NumDimG + NumDimM + NumDimN &&
|
||||
c_gs_ms_ns_strides_vec.size() == NumDimG + NumDimM + NumDimN);
|
||||
|
||||
const auto to_tuple = [&](auto& vec, auto start, auto end) {
|
||||
return generate_tuple([&](auto i) { return vec[start + i]; }, Number<end - start>{});
|
||||
};
|
||||
|
||||
const auto c_gs_ms_ns_lengths =
|
||||
to_tuple(c_gs_ms_ns_lengths_vec, Number<0>{}, Number<NumDimG + NumDimM + NumDimN>{});
|
||||
const auto c_gs_ms_ns_strides =
|
||||
to_tuple(c_gs_ms_ns_strides_vec, Number<0>{}, Number<NumDimG + NumDimM + NumDimN>{});
|
||||
|
||||
// dimension Ids for G0, G1, ...
|
||||
constexpr auto gDimIds = typename arithmetic_sequence_gen<0, NumDimG, 1>::type{};
|
||||
|
||||
// dimension Ids for M0, M1, ...
|
||||
constexpr auto mDimIds =
|
||||
typename arithmetic_sequence_gen<NumDimG, NumDimG + NumDimM, 1>::type{};
|
||||
|
||||
// dimension Ids for N0, N1, ...
|
||||
constexpr auto nDimIds = typename arithmetic_sequence_gen<NumDimG + NumDimM,
|
||||
NumDimG + NumDimM + NumDimN,
|
||||
1>::type{};
|
||||
|
||||
// lengths for G0, G1, ...
|
||||
const auto gLengths = get_container_subset(c_gs_ms_ns_lengths, gDimIds);
|
||||
|
||||
// lengths for M0, M1, ...
|
||||
const auto mLengths = get_container_subset(c_gs_ms_ns_lengths, mDimIds);
|
||||
|
||||
// lengths for K0, K1, ...
|
||||
const auto nLengths = get_container_subset(c_gs_ms_ns_lengths, nDimIds);
|
||||
|
||||
// naive tensor C[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...]
|
||||
const auto c_grid_desc_gs_ms_ns =
|
||||
make_naive_tensor_descriptor(c_gs_ms_ns_lengths, c_gs_ms_ns_strides);
|
||||
|
||||
// transformed tensor C[G = G0 * G1 * ..., MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 *
|
||||
// N2 * ...]
|
||||
const auto c_grid_desc_g_mraw_nraw =
|
||||
transform_tensor_descriptor(c_grid_desc_gs_ms_ns,
|
||||
make_tuple(make_merge_transform(gLengths),
|
||||
make_merge_transform(mLengths),
|
||||
make_merge_transform(nLengths)),
|
||||
make_tuple(gDimIds, mDimIds, nDimIds),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
|
||||
// this desc is only for calculating batch offset so no padding needed
|
||||
return c_grid_desc_g_mraw_nraw;
|
||||
}
|
||||
|
||||
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1));
|
||||
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1));
|
||||
using B1GridDesc_BK0_N_BK1 = decltype(MakeB1GridDescriptor_BK0_N_BK1(1, 1, 1));
|
||||
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N({}, {}));
|
||||
using CGridDesc_G_M_N = decltype(MakeCGridDescriptor_G_M_N({}, {}));
|
||||
|
||||
// to track the points which need to be set to -inf on C0
|
||||
// Note: no need to reset M padding value, because they will not be stored out.
|
||||
struct C0MatrixMask
|
||||
{
|
||||
C0MatrixMask(index_t NRaw) : NRaw_(NRaw) {}
|
||||
|
||||
__host__ __device__ bool IsUpperTriangle(index_t m, index_t n) const { return n > m; }
|
||||
|
||||
__host__ __device__ bool IsNOutOfBound(/*index_t m, */ index_t n) const
|
||||
if constexpr(MaskingSpec == MaskingSpecialization::MaskDisabled)
|
||||
{
|
||||
return n >= NRaw_;
|
||||
return MaskDisabledPredicate{};
|
||||
}
|
||||
|
||||
__host__ __device__ bool IsMaskedElement(index_t m, index_t n) const
|
||||
else if constexpr(MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle)
|
||||
{
|
||||
return IsUpperTriangle(m, n) || IsNOutOfBound(n);
|
||||
return MaskOutUpperTrianglePredicate{};
|
||||
}
|
||||
|
||||
private:
|
||||
// index_t MRaw_;
|
||||
index_t NRaw_;
|
||||
};
|
||||
}
|
||||
using C0MatrixMask = C0MatrixMask_impl<decltype(make_MaskOutPredicate())>;
|
||||
|
||||
struct ComputeBasePtrOfStridedBatch
|
||||
{
|
||||
ComputeBasePtrOfStridedBatch(index_t BatchStrideA,
|
||||
index_t BatchStrideB,
|
||||
index_t BatchStrideB1,
|
||||
CGridDesc_G_M_N c_grid_desc_g_m_n)
|
||||
: BatchStrideA_(BatchStrideA),
|
||||
BatchStrideB_(BatchStrideB),
|
||||
BatchStrideB1_(BatchStrideB1),
|
||||
ComputeBasePtrOfStridedBatch(const AGridDesc_G_M_K& a_grid_desc_g_m_k,
|
||||
const BGridDesc_G_N_K& b_grid_desc_g_n_k,
|
||||
const B1GridDesc_G_N_K& b1_grid_desc_g_n_k,
|
||||
const CGridDesc_G_M_N& c_grid_desc_g_m_n)
|
||||
: a_grid_desc_g_m_k_(a_grid_desc_g_m_k),
|
||||
b_grid_desc_g_n_k_(b_grid_desc_g_n_k),
|
||||
b1_grid_desc_g_n_k_(b1_grid_desc_g_n_k),
|
||||
c_grid_desc_g_m_n_(c_grid_desc_g_m_n)
|
||||
{
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideA_);
|
||||
return a_grid_desc_g_m_k_.CalculateOffset(make_multi_index(g_idx, 0, 0));
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetBBasePtr(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideB_);
|
||||
return b_grid_desc_g_n_k_.CalculateOffset(make_multi_index(g_idx, 0, 0));
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideB1_);
|
||||
return b1_grid_desc_g_n_k_.CalculateOffset(make_multi_index(g_idx, 0, 0));
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetCBasePtr(index_t g_idx) const
|
||||
@@ -469,9 +331,9 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
}
|
||||
|
||||
private:
|
||||
index_t BatchStrideA_;
|
||||
index_t BatchStrideB_;
|
||||
index_t BatchStrideB1_;
|
||||
AGridDesc_G_M_K a_grid_desc_g_m_k_;
|
||||
BGridDesc_G_N_K b_grid_desc_g_n_k_;
|
||||
B1GridDesc_G_N_K b1_grid_desc_g_n_k_;
|
||||
CGridDesc_G_M_N c_grid_desc_g_m_n_;
|
||||
};
|
||||
|
||||
@@ -535,8 +397,8 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
LoopSched,
|
||||
matrix_padder.PadN,
|
||||
MaskOutUpperTriangle>;
|
||||
Transform::matrix_padder.PadN,
|
||||
MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle>;
|
||||
|
||||
using Block2CTileMap = OffsettedBlockToCTileMap<typename GridwiseGemm::DefaultBlock2CTileMap>;
|
||||
|
||||
@@ -570,16 +432,16 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
|
||||
struct GroupDeviceArg
|
||||
{
|
||||
// problem definiton
|
||||
index_t M;
|
||||
index_t N;
|
||||
index_t K;
|
||||
index_t O;
|
||||
// lengths for the last dimensions of overall problem for sanity check of vector load/store
|
||||
std::vector<index_t> raw_lengths_mz_nz_kz_gemm1nz_;
|
||||
|
||||
// Strides for the last dimensions of C for sanity check of vector load/store
|
||||
index_t c_extent_lowest_;
|
||||
index_t c_stride_lowest_;
|
||||
// strides for the last dimensions of each tensor for sanity check of vector load/store
|
||||
std::vector<index_t> a_mz_kz_strides_;
|
||||
std::vector<index_t> b_nz_kz_strides_;
|
||||
std::vector<index_t> b1_nz_kz_strides_;
|
||||
std::vector<index_t> c_mz_gemm1nz_strides_;
|
||||
|
||||
// for gridwise gemm check
|
||||
CGridDesc_M_N c_grid_desc_m_n_;
|
||||
};
|
||||
|
||||
@@ -591,6 +453,8 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
std::vector<const void*> p_b_vec,
|
||||
std::vector<const void*> p_b1_vec,
|
||||
std::vector<void*> p_c_vec,
|
||||
std::vector<std::vector<const void*>> p_acc0_biases_vec,
|
||||
std::vector<std::vector<const void*>> p_acc1_biases_vec,
|
||||
std::vector<ProblemDesc> problem_desc_vec,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
@@ -603,6 +467,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
b1_element_op_{b1_element_op},
|
||||
c_element_op_{c_element_op}
|
||||
{
|
||||
// TODO ANT: implement bias addition
|
||||
group_count_ = problem_desc_vec.size();
|
||||
|
||||
if(!(group_count_ == p_a_vec.size() && group_count_ == p_b_vec.size() &&
|
||||
@@ -611,6 +476,11 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
throw std::runtime_error("wrong! group_count_ != a/b/b1/c_vec.size");
|
||||
}
|
||||
|
||||
if(!(p_acc0_biases_vec.size() == p_acc1_biases_vec.size()))
|
||||
{
|
||||
throw std::runtime_error("wrong! acc0_bias_vec.size != acc1_bias_vec.size");
|
||||
}
|
||||
|
||||
grid_size_ = 0;
|
||||
|
||||
for(std::size_t i = 0; i < group_count_; i++)
|
||||
@@ -620,14 +490,25 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
const auto p_b1_grid = static_cast<const B1DataType*>(p_b1_vec[i]);
|
||||
const auto p_c_grid = static_cast<CDataType*>(p_c_vec[i]);
|
||||
|
||||
const auto a_grid_desc_ak0_m_ak1 = DeviceOp::MakeAGridDescriptor_AK0_M_AK1(
|
||||
problem_desc_vec[i].M, problem_desc_vec[i].K, problem_desc_vec[i].StrideA);
|
||||
const auto b_grid_desc_bk0_n_bk1 = DeviceOp::MakeBGridDescriptor_BK0_N_BK1(
|
||||
problem_desc_vec[i].K, problem_desc_vec[i].N, problem_desc_vec[i].StrideB0);
|
||||
const auto b1_grid_desc_bk0_n_bk1 = DeviceOp::MakeB1GridDescriptor_BK0_N_BK1(
|
||||
problem_desc_vec[i].N, problem_desc_vec[i].O, problem_desc_vec[i].StrideB1);
|
||||
const auto c_grid_desc_m_n = DeviceOp::MakeCGridDescriptor_M_N(
|
||||
problem_desc_vec[i].c_gs_ms_os_lengths, problem_desc_vec[i].c_gs_ms_os_strides);
|
||||
const auto& problem_desc = problem_desc_vec[i];
|
||||
|
||||
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
|
||||
problem_desc.a_gs_ms_ks_lengths, problem_desc.a_gs_ms_ks_strides);
|
||||
const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(
|
||||
problem_desc.b0_gs_ns_ks_lengths, problem_desc.b0_gs_ns_ks_strides);
|
||||
const auto b1_grid_desc_bk0_n_bk1 = MakeB1GridDescriptor_BK0_N_BK1(
|
||||
problem_desc.b1_gs_os_ns_lengths, problem_desc.b1_gs_os_ns_strides);
|
||||
const auto c_grid_desc_m_n = Transform::MakeCGridDescriptor_M_N(
|
||||
problem_desc.c_gs_ms_os_lengths, problem_desc.c_gs_ms_os_strides);
|
||||
|
||||
const auto a_grid_desc_g_m_k = Transform::MakeAGridDescriptor_G_M_K(
|
||||
problem_desc.a_gs_ms_ks_lengths, problem_desc.a_gs_ms_ks_strides);
|
||||
const auto b_grid_desc_g_n_k = Transform::MakeB0GridDescriptor_G_N_K(
|
||||
problem_desc.b0_gs_ns_ks_lengths, problem_desc.b0_gs_ns_ks_strides);
|
||||
const auto b1_grid_desc_g_n_k = Transform::MakeB1GridDescriptor_G_N_K(
|
||||
problem_desc.b1_gs_os_ns_lengths, problem_desc.b1_gs_os_ns_strides);
|
||||
const auto c_grid_desc_g_m_n = Transform::MakeCGridDescriptor_G_M_N(
|
||||
problem_desc.c_gs_ms_os_lengths, problem_desc.c_gs_ms_os_strides);
|
||||
|
||||
const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
|
||||
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
@@ -635,25 +516,32 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
|
||||
const index_t BlockStart = grid_size_;
|
||||
const auto block_2_ctile_map = Block2CTileMap(c_grid_desc_m_n, BlockStart);
|
||||
const index_t grid_size_grp = block_2_ctile_map.CalculateGridSize(c_grid_desc_m_n) *
|
||||
problem_desc_vec[i].Batch;
|
||||
const index_t batch_count = c_grid_desc_g_m_n.GetLength(I0);
|
||||
const index_t grid_size_grp =
|
||||
block_2_ctile_map.CalculateGridSize(c_grid_desc_m_n) * batch_count;
|
||||
const index_t BlockEnd = grid_size_ + grid_size_grp;
|
||||
|
||||
// batch stride
|
||||
// TODO ANT: only keep batch stride in tensor desc to reduce scalar cache pressure
|
||||
const auto c_grid_desc_g_m_n = DeviceOp::MakeCGridDescriptor_G_M_N(
|
||||
problem_desc_vec[i].c_gs_ms_os_lengths, problem_desc_vec[i].c_gs_ms_os_strides);
|
||||
const auto compute_base_ptr_of_batch =
|
||||
ComputeBasePtrOfStridedBatch(problem_desc_vec[i].BatchStrideA,
|
||||
problem_desc_vec[i].BatchStrideB0,
|
||||
problem_desc_vec[i].BatchStrideB1,
|
||||
c_grid_desc_g_m_n);
|
||||
const auto compute_base_ptr_of_batch = ComputeBasePtrOfStridedBatch(
|
||||
a_grid_desc_g_m_k, b_grid_desc_g_n_k, b1_grid_desc_g_n_k, c_grid_desc_g_m_n);
|
||||
|
||||
// C0 mask
|
||||
const auto c0_matrix_mask = C0MatrixMask(problem_desc_vec[i].N);
|
||||
const auto c0_matrix_mask = C0MatrixMask(b_grid_desc_g_n_k.GetLength(I1));
|
||||
|
||||
grid_size_ += grid_size_grp;
|
||||
|
||||
// for each group, make sure acc0_biases_gs_ms_ns_lengths.size() == NumAcc0Bias and
|
||||
// so on
|
||||
if(!(problem_desc.acc0_biases_gs_ms_ns_lengths.size() == NumAcc0Bias &&
|
||||
problem_desc.acc0_biases_gs_ms_ns_strides.size() == NumAcc0Bias &&
|
||||
problem_desc.acc1_biases_gs_ms_os_lengths.size() == NumAcc1Bias &&
|
||||
problem_desc.acc1_biases_gs_ms_os_strides.size() == NumAcc1Bias))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! number of biases in function argument does not "
|
||||
"match that in template argument");
|
||||
}
|
||||
|
||||
group_kernel_args_.push_back({p_a_grid,
|
||||
p_b_grid,
|
||||
p_b1_grid,
|
||||
@@ -669,13 +557,20 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
BlockStart,
|
||||
BlockEnd});
|
||||
|
||||
group_device_args_.push_back({problem_desc_vec[i].M,
|
||||
problem_desc_vec[i].N,
|
||||
problem_desc_vec[i].K,
|
||||
problem_desc_vec[i].O,
|
||||
problem_desc_vec[i].c_gs_ms_os_lengths.back(),
|
||||
problem_desc_vec[i].c_gs_ms_os_strides.back(),
|
||||
c_grid_desc_m_n});
|
||||
group_device_args_.push_back(
|
||||
{{problem_desc.a_gs_ms_ks_lengths[NumDimG + NumDimM - 1],
|
||||
problem_desc.b0_gs_ns_ks_lengths[NumDimG + NumDimN - 1],
|
||||
problem_desc.b0_gs_ns_ks_lengths[NumDimG + NumDimN + NumDimK - 1],
|
||||
problem_desc.b1_gs_os_ns_lengths[NumDimG + NumDimO - 1]},
|
||||
{problem_desc.a_gs_ms_ks_strides[NumDimG + NumDimM - 1],
|
||||
problem_desc.a_gs_ms_ks_strides[NumDimG + NumDimM + NumDimK - 1]},
|
||||
{problem_desc.b0_gs_ns_ks_strides[NumDimG + NumDimN - 1],
|
||||
problem_desc.b0_gs_ns_ks_strides[NumDimG + NumDimN + NumDimK - 1]},
|
||||
{problem_desc.b1_gs_os_ns_strides[NumDimG + NumDimO - 1],
|
||||
problem_desc.b1_gs_os_ns_strides[NumDimG + NumDimO + NumDimN - 1]},
|
||||
{problem_desc.c_gs_ms_os_strides[NumDimG + NumDimM - 1],
|
||||
problem_desc.c_gs_ms_os_strides[NumDimG + NumDimM + NumDimO - 1]},
|
||||
c_grid_desc_m_n});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -788,6 +683,8 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
return false;
|
||||
}
|
||||
|
||||
// TODO ANT: Check if tensor specialization & strides mismatch
|
||||
|
||||
bool all_has_main_k_block_loop = true;
|
||||
bool some_has_main_k_block_loop = false;
|
||||
|
||||
@@ -815,19 +712,16 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
|
||||
// Note: we need raw lengths since threadwise copy can not handle vector load when
|
||||
// part of vector is out of bounds
|
||||
const auto MRaw = device_arg.M;
|
||||
const auto NRaw = device_arg.N;
|
||||
const auto KRaw = device_arg.K;
|
||||
const auto Gemm1NRaw = device_arg.O;
|
||||
const auto MzRaw = device_arg.raw_lengths_mz_nz_kz_gemm1nz_[0];
|
||||
const auto NzRaw = device_arg.raw_lengths_mz_nz_kz_gemm1nz_[1];
|
||||
const auto KzRaw = device_arg.raw_lengths_mz_nz_kz_gemm1nz_[2];
|
||||
const auto Gemm1NzRaw = device_arg.raw_lengths_mz_nz_kz_gemm1nz_[3];
|
||||
|
||||
// Check scalar per vector requirement
|
||||
const auto a_extent_lowest =
|
||||
is_same_v<tensor_layout::gemm::RowMajor, ALayout> ? KRaw : MRaw;
|
||||
const auto b_extent_lowest =
|
||||
is_same_v<tensor_layout::gemm::RowMajor, BLayout> ? NRaw : KRaw;
|
||||
const auto b1_extent_lowest =
|
||||
is_same_v<tensor_layout::gemm::RowMajor, B1Layout> ? Gemm1NRaw : NRaw;
|
||||
const auto c_extent_lowest = device_arg.c_extent_lowest_;
|
||||
const auto a_extent_lowest = ABlockTransferSrcVectorDim == 2 ? KzRaw : MzRaw;
|
||||
const auto b_extent_lowest = BBlockTransferSrcVectorDim == 2 ? KzRaw : NzRaw;
|
||||
const auto b1_extent_lowest = B1BlockTransferSrcVectorDim == 2 ? NzRaw : Gemm1NzRaw;
|
||||
const auto c_extent_lowest = Gemm1NzRaw;
|
||||
|
||||
if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 &&
|
||||
b_extent_lowest % BBlockTransferSrcScalarPerVector == 0 &&
|
||||
@@ -837,8 +731,22 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check vector store requirement; assumes last dimension in N to be contiguous
|
||||
if(device_arg.c_stride_lowest_ != 1)
|
||||
// Check vector load/store requirement
|
||||
const auto a_stride_lowest = ABlockTransferSrcVectorDim == 2
|
||||
? device_arg.a_mz_kz_strides_[1]
|
||||
: device_arg.a_mz_kz_strides_[0];
|
||||
const auto b_stride_lowest = BBlockTransferSrcVectorDim == 2
|
||||
? device_arg.b_nz_kz_strides_[1]
|
||||
: device_arg.b_nz_kz_strides_[0];
|
||||
const auto b1_stride_lowest = B1BlockTransferSrcVectorDim == 2
|
||||
? device_arg.b1_nz_kz_strides_[1]
|
||||
: device_arg.b1_nz_kz_strides_[0];
|
||||
const auto c_stride_lowest =
|
||||
device_arg.c_mz_gemm1nz_strides_[1]; // cshuffle assumes lowest dim in Gemm1Ns to be
|
||||
// contiguous
|
||||
|
||||
if(!(a_stride_lowest == 1 || b_stride_lowest == 1 || b1_stride_lowest == 1 ||
|
||||
c_stride_lowest == 1))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
@@ -873,6 +781,8 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
std::vector<const void*> p_b_vec,
|
||||
std::vector<const void*> p_b1_vec,
|
||||
std::vector<void*> p_c_vec,
|
||||
std::vector<std::vector<const void*>> p_acc0_biases_vec,
|
||||
std::vector<std::vector<const void*>> p_acc1_biases_vec,
|
||||
std::vector<ProblemDesc> problem_desc_vec,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
@@ -884,6 +794,8 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
p_b_vec,
|
||||
p_b1_vec,
|
||||
p_c_vec,
|
||||
p_acc0_biases_vec,
|
||||
p_acc1_biases_vec,
|
||||
problem_desc_vec,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
@@ -895,21 +807,26 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseArgument> MakeArgumentPointer(std::vector<const void*> p_a_vec,
|
||||
std::vector<const void*> p_b_vec,
|
||||
std::vector<const void*> p_b1_vec,
|
||||
std::vector<void*> p_c_vec,
|
||||
std::vector<ProblemDesc> problem_desc_vec,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
AccElementwiseOperation acc_element_op,
|
||||
B1ElementwiseOperation b1_element_op,
|
||||
CElementwiseOperation c_element_op) override
|
||||
std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(std::vector<const void*> p_a_vec,
|
||||
std::vector<const void*> p_b_vec,
|
||||
std::vector<const void*> p_b1_vec,
|
||||
std::vector<void*> p_c_vec,
|
||||
std::vector<std::vector<const void*>> p_acc0_biases_vec,
|
||||
std::vector<std::vector<const void*>> p_acc1_biases_vec,
|
||||
std::vector<ProblemDesc> problem_desc_vec,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
AccElementwiseOperation acc_element_op,
|
||||
B1ElementwiseOperation b1_element_op,
|
||||
CElementwiseOperation c_element_op) override
|
||||
{
|
||||
return std::make_unique<Argument>(p_a_vec,
|
||||
p_b_vec,
|
||||
p_b1_vec,
|
||||
p_c_vec,
|
||||
p_acc0_biases_vec,
|
||||
p_acc1_biases_vec,
|
||||
problem_desc_vec,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
@@ -942,7 +859,12 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
<< Gemm1NPerBlock << ", "
|
||||
<< Gemm1KPerBlock << ", "
|
||||
<< B1K1 << ", "
|
||||
<< getGemmSpecializationString(GemmSpec) << ">";
|
||||
<< getGemmSpecializationString(GemmSpec) << ", "
|
||||
<< "ASpec" << getTensorSpecializationString(ASpec) << ", "
|
||||
<< "B0Spec" << getTensorSpecializationString(BSpec) << ", "
|
||||
<< "B1Spec" << getTensorSpecializationString(B1Spec) << ", "
|
||||
<< "CSpec" << getTensorSpecializationString(CSpec) << ", "
|
||||
<< getMaskingSpecializationString(MaskingSpec) << ">";
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
|
||||
@@ -130,8 +130,11 @@ namespace device {
|
||||
// D[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2, ...]
|
||||
// E[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2, ...]
|
||||
|
||||
// FIXME: TensorSpecialization::Packed specialization does not cover all packed tensor cases, it
|
||||
// merely degenerates into TensorSpecialization::Default with NumDimG/M/N/K = 1
|
||||
// NOTE: TensorSpecialization::Packed specialized tensor is "packed" in a sense that each inner
|
||||
// dimension in a dimension group (eg [G0, G1] in Gs, [M0, M1, M2] in Ms, etc.) are contiguous and
|
||||
// ordered. Not in a sense that the tensor [G0, G1, ..., M0, M1, ..., N0, N1...] can be permuted
|
||||
// while still being a contiguous, unpadded tensor. In other words, it merely degenerates into
|
||||
// TensorSpecialization::Default with NumDimG/M/N/K = 1
|
||||
//
|
||||
// Detail- Packed tensor satisfies
|
||||
// stride_0 = 1
|
||||
@@ -147,7 +150,7 @@ namespace device {
|
||||
// essentially a degenerated case of TensorSpecialization::Default with NumDimG/M/N/K = 1.
|
||||
//
|
||||
// Might need to expose dimension order to the interface to fully support
|
||||
// TensorSpecialization::Packed.
|
||||
// TensorSpecialization::Packed in a traditional sense of "packed" tensor
|
||||
template <index_t NumDimG,
|
||||
index_t NumDimM,
|
||||
index_t NumDimN,
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp"
|
||||
#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
|
||||
@@ -116,14 +117,17 @@ __global__ void
|
||||
// Computes C = A * B0 * B1
|
||||
// ^^^^^^ (Acc0)
|
||||
// ^^^^^^^^^^^ (Acc1)
|
||||
template <typename ALayout,
|
||||
typename BLayout, // B0Layout
|
||||
typename B1Layout,
|
||||
typename CPermuteNumDims_G_M_Gemm1N, // Sequence<NumDimG, NumDimM, NumDimGemm1N>
|
||||
template <index_t NumDimG,
|
||||
index_t NumDimM,
|
||||
index_t NumDimN,
|
||||
index_t NumDimK,
|
||||
index_t NumDimO, // NumDimGemm1N
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename B1DataType,
|
||||
typename CDataType,
|
||||
typename Acc0BiasDataType,
|
||||
typename Acc1BiasDataType,
|
||||
typename GemmAccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename AElementwiseOperation,
|
||||
@@ -132,6 +136,10 @@ template <typename ALayout,
|
||||
typename B1ElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
GemmSpecialization GemmSpec,
|
||||
TensorSpecialization ASpec,
|
||||
TensorSpecialization BSpec,
|
||||
TensorSpecialization B1Spec,
|
||||
TensorSpecialization CSpec,
|
||||
index_t NumGemmKPrefetchStage,
|
||||
index_t BlockSize,
|
||||
index_t MPerBlock,
|
||||
@@ -172,283 +180,135 @@ template <typename ALayout,
|
||||
index_t CShuffleNXdlPerWavePerShuffle,
|
||||
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
bool MaskOutUpperTriangle,
|
||||
MaskingSpecialization MaskingSpec,
|
||||
LoopScheduler LoopSched = LoopScheduler::Default>
|
||||
struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
: public DeviceBatchedGemmSoftmaxGemmPermute<ALayout,
|
||||
BLayout,
|
||||
B1Layout,
|
||||
CPermuteNumDims_G_M_Gemm1N,
|
||||
: public DeviceBatchedGemmSoftmaxGemmPermute<NumDimG,
|
||||
NumDimM,
|
||||
NumDimN,
|
||||
NumDimK,
|
||||
NumDimO,
|
||||
ADataType,
|
||||
BDataType,
|
||||
B1DataType,
|
||||
CDataType,
|
||||
Acc0BiasDataType,
|
||||
Acc1BiasDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
AccElementwiseOperation,
|
||||
B1ElementwiseOperation,
|
||||
CElementwiseOperation>
|
||||
CElementwiseOperation,
|
||||
MaskingSpec>
|
||||
{
|
||||
static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0,
|
||||
"Number of dimension must be greater than 0");
|
||||
|
||||
static constexpr index_t NumAcc0Bias = Acc0BiasDataType::Size();
|
||||
static constexpr index_t NumAcc1Bias = Acc1BiasDataType::Size();
|
||||
|
||||
// TODO ANT: implement bias combination
|
||||
static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented");
|
||||
|
||||
#if 0
|
||||
// TODO ANT: use alias
|
||||
static constexpr index_t NumDimGemm0M = NumDimM;
|
||||
static constexpr index_t NumDimGemm0N = NumDimN;
|
||||
static constexpr index_t NumDimGemm0K = NumDimK;
|
||||
static constexpr index_t NumDimGemm1M = NumDimM;
|
||||
static constexpr index_t NumDimGemm1N = NumDimO;
|
||||
static constexpr index_t NumDimGemm1K = NumDimN;
|
||||
#endif
|
||||
|
||||
using DeviceOp = DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle;
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
|
||||
static constexpr auto matrix_padder =
|
||||
GemmGemmPadder<GemmSpec, index_t, index_t, index_t, index_t>{
|
||||
MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock};
|
||||
using Transform = TransformBatchedContractionContractionToBatchedGemmGemm<
|
||||
Sequence<NumDimG, NumDimM, NumDimN, NumDimK, NumDimO>,
|
||||
Sequence<MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock>,
|
||||
GemmSpec,
|
||||
ASpec,
|
||||
BSpec,
|
||||
B1Spec,
|
||||
CSpec>;
|
||||
|
||||
static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA)
|
||||
static auto MakeAGridDescriptor_AK0_M_AK1(const std::vector<index_t>& a_gs_ms_ks_lengths_vec,
|
||||
const std::vector<index_t>& a_gs_ms_ks_strides_vec)
|
||||
{
|
||||
const auto a_grid_desc_mraw_kraw = [&]() {
|
||||
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
|
||||
make_tuple(StrideA, I1));
|
||||
}
|
||||
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
|
||||
make_tuple(I1, StrideA));
|
||||
}
|
||||
}();
|
||||
|
||||
const auto a_grid_desc_m_k = matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
|
||||
|
||||
const auto M = a_grid_desc_m_k.GetLength(I0);
|
||||
const auto K = a_grid_desc_m_k.GetLength(I1);
|
||||
|
||||
const auto AK0 = K / AK1;
|
||||
|
||||
return transform_tensor_descriptor(a_grid_desc_m_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
|
||||
make_pass_through_transform(M)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
return Transform::MakeAGridDescriptor_AK0_M_AK1(
|
||||
Transform::MakeAGridDescriptor_M_K(a_gs_ms_ks_lengths_vec, a_gs_ms_ks_strides_vec),
|
||||
Number<AK1>{});
|
||||
}
|
||||
|
||||
static auto MakeBGridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB)
|
||||
static auto MakeBGridDescriptor_BK0_N_BK1(const std::vector<index_t>& b_gs_ns_ks_lengths_vec,
|
||||
const std::vector<index_t>& b_gs_ns_ks_strides_vec)
|
||||
{
|
||||
const auto b_grid_desc_nraw_kraw = [&]() {
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
|
||||
make_tuple(I1, StrideB));
|
||||
}
|
||||
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
|
||||
make_tuple(StrideB, I1));
|
||||
}
|
||||
}();
|
||||
|
||||
const auto b_grid_desc_n_k = matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
|
||||
|
||||
const auto N = b_grid_desc_n_k.GetLength(I0);
|
||||
const auto K = b_grid_desc_n_k.GetLength(I1);
|
||||
|
||||
const auto BK0 = K / BK1;
|
||||
|
||||
return transform_tensor_descriptor(b_grid_desc_n_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
|
||||
make_pass_through_transform(N)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
return Transform::MakeB0GridDescriptor_BK0_N_BK1(
|
||||
Transform::MakeB0GridDescriptor_N_K(b_gs_ns_ks_lengths_vec, b_gs_ns_ks_strides_vec),
|
||||
Number<BK1>{});
|
||||
}
|
||||
|
||||
// Args: Gemm1KRaw, Gemm1NRaw, StrideB1
|
||||
static auto MakeB1GridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB)
|
||||
static auto
|
||||
MakeB1GridDescriptor_BK0_N_BK1(const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths_vec,
|
||||
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides_vec)
|
||||
{
|
||||
const auto b1_grid_desc_nraw_kraw = [&]() {
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, B1Layout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
|
||||
make_tuple(I1, StrideB));
|
||||
}
|
||||
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, B1Layout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
|
||||
make_tuple(StrideB, I1));
|
||||
}
|
||||
}();
|
||||
|
||||
const auto b1_grid_desc_n_k = matrix_padder.PadB1Descriptor_N_K(b1_grid_desc_nraw_kraw);
|
||||
|
||||
const auto N = b1_grid_desc_n_k.GetLength(I0);
|
||||
const auto K = b1_grid_desc_n_k.GetLength(I1);
|
||||
|
||||
const auto B1K0 = K / B1K1;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
b1_grid_desc_n_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(B1K0, B1K1)),
|
||||
make_pass_through_transform(N)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
return Transform::MakeB1GridDescriptor_BK0_N_BK1(
|
||||
Transform::MakeB1GridDescriptor_N_K(b1_gs_gemm1ns_gemm1ks_lengths_vec,
|
||||
b1_gs_gemm1ns_gemm1ks_strides_vec),
|
||||
Number<B1K1>{});
|
||||
}
|
||||
|
||||
// assume C[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...]
|
||||
static auto MakeCGridDescriptor_M_N(const std::vector<index_t>& c_gs_ms_ns_lengths_vec,
|
||||
const std::vector<index_t>& c_gs_ms_ns_strides_vec)
|
||||
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {}));
|
||||
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {}));
|
||||
using B1GridDesc_BK0_N_BK1 = decltype(MakeB1GridDescriptor_BK0_N_BK1({}, {}));
|
||||
using CGridDesc_M_N = decltype(Transform::MakeCGridDescriptor_M_N({}, {}));
|
||||
using AGridDesc_G_M_K = decltype(Transform::MakeAGridDescriptor_G_M_K({}, {}));
|
||||
using BGridDesc_G_N_K = decltype(Transform::MakeB0GridDescriptor_G_N_K({}, {}));
|
||||
using B1GridDesc_G_N_K = decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {}));
|
||||
using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
|
||||
|
||||
constexpr static auto make_MaskOutPredicate()
|
||||
{
|
||||
constexpr index_t NumDimG = CPermuteNumDims_G_M_Gemm1N::At(I0);
|
||||
constexpr index_t NumDimM = CPermuteNumDims_G_M_Gemm1N::At(I1);
|
||||
constexpr index_t NumDimN = CPermuteNumDims_G_M_Gemm1N::At(I2); // NumDimGemm1N
|
||||
|
||||
assert(c_gs_ms_ns_lengths_vec.size() == NumDimG + NumDimM + NumDimN &&
|
||||
c_gs_ms_ns_strides_vec.size() == NumDimG + NumDimM + NumDimN);
|
||||
|
||||
const auto to_tuple = [&](auto& vec, auto start, auto end) {
|
||||
return generate_tuple([&](auto i) { return vec[start + i]; }, Number<end - start>{});
|
||||
};
|
||||
|
||||
const auto c_ms_ns_lengths = to_tuple(
|
||||
c_gs_ms_ns_lengths_vec, Number<NumDimG>{}, Number<NumDimG + NumDimM + NumDimN>{});
|
||||
const auto c_ms_ns_strides = to_tuple(
|
||||
c_gs_ms_ns_strides_vec, Number<NumDimG>{}, Number<NumDimG + NumDimM + NumDimN>{});
|
||||
|
||||
// dimension Ids for M0, M1, ...
|
||||
constexpr auto mDimIds = typename arithmetic_sequence_gen<0, NumDimM, 1>::type{};
|
||||
|
||||
// dimension Ids for N0, N1, ...
|
||||
constexpr auto nDimIds =
|
||||
typename arithmetic_sequence_gen<NumDimM, NumDimM + NumDimN, 1>::type{};
|
||||
|
||||
// lengths for M0, M1, ...
|
||||
const auto mLengths = get_container_subset(c_ms_ns_lengths, mDimIds);
|
||||
|
||||
// lengths for K0, K1, ...
|
||||
const auto nLengths = get_container_subset(c_ms_ns_lengths, nDimIds);
|
||||
|
||||
// naive tensor C[M0, M1, M2, ..., N0, N1, N2...]
|
||||
const auto c_grid_desc_ms_ns =
|
||||
make_naive_tensor_descriptor(c_ms_ns_lengths, c_ms_ns_strides);
|
||||
|
||||
// transformed tensor C[MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 * N2 * ...]
|
||||
const auto c_grid_desc_mraw_nraw = transform_tensor_descriptor(
|
||||
c_grid_desc_ms_ns,
|
||||
make_tuple(make_merge_transform(mLengths), make_merge_transform(nLengths)),
|
||||
make_tuple(mDimIds, nDimIds),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return matrix_padder.PadCDescriptor_M_N(c_grid_desc_mraw_nraw);
|
||||
}
|
||||
|
||||
// assume C[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...]
|
||||
static auto MakeCGridDescriptor_G_M_N(const std::vector<index_t>& c_gs_ms_ns_lengths_vec,
|
||||
const std::vector<index_t>& c_gs_ms_ns_strides_vec)
|
||||
{
|
||||
constexpr index_t NumDimG = CPermuteNumDims_G_M_Gemm1N::At(I0);
|
||||
constexpr index_t NumDimM = CPermuteNumDims_G_M_Gemm1N::At(I1);
|
||||
constexpr index_t NumDimN = CPermuteNumDims_G_M_Gemm1N::At(I2); // NumDimGemm1N
|
||||
|
||||
assert(c_gs_ms_ns_lengths_vec.size() == NumDimG + NumDimM + NumDimN &&
|
||||
c_gs_ms_ns_strides_vec.size() == NumDimG + NumDimM + NumDimN);
|
||||
|
||||
const auto to_tuple = [&](auto& vec, auto start, auto end) {
|
||||
return generate_tuple([&](auto i) { return vec[start + i]; }, Number<end - start>{});
|
||||
};
|
||||
|
||||
const auto c_gs_ms_ns_lengths =
|
||||
to_tuple(c_gs_ms_ns_lengths_vec, Number<0>{}, Number<NumDimG + NumDimM + NumDimN>{});
|
||||
const auto c_gs_ms_ns_strides =
|
||||
to_tuple(c_gs_ms_ns_strides_vec, Number<0>{}, Number<NumDimG + NumDimM + NumDimN>{});
|
||||
|
||||
// dimension Ids for G0, G1, ...
|
||||
constexpr auto gDimIds = typename arithmetic_sequence_gen<0, NumDimG, 1>::type{};
|
||||
|
||||
// dimension Ids for M0, M1, ...
|
||||
constexpr auto mDimIds =
|
||||
typename arithmetic_sequence_gen<NumDimG, NumDimG + NumDimM, 1>::type{};
|
||||
|
||||
// dimension Ids for N0, N1, ...
|
||||
constexpr auto nDimIds = typename arithmetic_sequence_gen<NumDimG + NumDimM,
|
||||
NumDimG + NumDimM + NumDimN,
|
||||
1>::type{};
|
||||
|
||||
// lengths for G0, G1, ...
|
||||
const auto gLengths = get_container_subset(c_gs_ms_ns_lengths, gDimIds);
|
||||
|
||||
// lengths for M0, M1, ...
|
||||
const auto mLengths = get_container_subset(c_gs_ms_ns_lengths, mDimIds);
|
||||
|
||||
// lengths for K0, K1, ...
|
||||
const auto nLengths = get_container_subset(c_gs_ms_ns_lengths, nDimIds);
|
||||
|
||||
// naive tensor C[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...]
|
||||
const auto c_grid_desc_gs_ms_ns =
|
||||
make_naive_tensor_descriptor(c_gs_ms_ns_lengths, c_gs_ms_ns_strides);
|
||||
|
||||
// transformed tensor C[G = G0 * G1 * ..., MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 *
|
||||
// N2 * ...]
|
||||
const auto c_grid_desc_g_mraw_nraw =
|
||||
transform_tensor_descriptor(c_grid_desc_gs_ms_ns,
|
||||
make_tuple(make_merge_transform(gLengths),
|
||||
make_merge_transform(mLengths),
|
||||
make_merge_transform(nLengths)),
|
||||
make_tuple(gDimIds, mDimIds, nDimIds),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
|
||||
// this desc is only for calculating batch offset so no padding needed
|
||||
return c_grid_desc_g_mraw_nraw;
|
||||
}
|
||||
|
||||
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1));
|
||||
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1));
|
||||
using B1GridDesc_BK0_N_BK1 = decltype(MakeB1GridDescriptor_BK0_N_BK1(1, 1, 1));
|
||||
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N({}, {}));
|
||||
using CGridDesc_G_M_N = decltype(MakeCGridDescriptor_G_M_N({}, {}));
|
||||
|
||||
// to track the points which need to be set to -inf on C0
|
||||
// Note: no need to reset M padding value, because they will not be stored out.
|
||||
struct C0MatrixMask
|
||||
{
|
||||
C0MatrixMask(index_t NRaw) : NRaw_(NRaw) {}
|
||||
|
||||
__host__ __device__ bool IsUpperTriangle(index_t m, index_t n) const { return n > m; }
|
||||
|
||||
__host__ __device__ bool IsNOutOfBound(/*index_t m, */ index_t n) const
|
||||
if constexpr(MaskingSpec == MaskingSpecialization::MaskDisabled)
|
||||
{
|
||||
return n >= NRaw_;
|
||||
return MaskDisabledPredicate{};
|
||||
}
|
||||
|
||||
__host__ __device__ bool IsMaskedElement(index_t m, index_t n) const
|
||||
else if constexpr(MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle)
|
||||
{
|
||||
return IsUpperTriangle(m, n) || IsNOutOfBound(n);
|
||||
return MaskOutUpperTrianglePredicate{};
|
||||
}
|
||||
|
||||
private:
|
||||
// index_t MRaw_;
|
||||
index_t NRaw_;
|
||||
};
|
||||
}
|
||||
using C0MatrixMask = C0MatrixMask_impl<decltype(make_MaskOutPredicate())>;
|
||||
|
||||
struct ComputeBasePtrOfStridedBatch
|
||||
{
|
||||
ComputeBasePtrOfStridedBatch(index_t BatchStrideA,
|
||||
index_t BatchStrideB,
|
||||
index_t BatchStrideB1,
|
||||
CGridDesc_G_M_N c_grid_desc_g_m_n)
|
||||
: BatchStrideA_(BatchStrideA),
|
||||
BatchStrideB_(BatchStrideB),
|
||||
BatchStrideB1_(BatchStrideB1),
|
||||
ComputeBasePtrOfStridedBatch(const AGridDesc_G_M_K& a_grid_desc_g_m_k,
|
||||
const BGridDesc_G_N_K& b_grid_desc_g_n_k,
|
||||
const B1GridDesc_G_N_K& b1_grid_desc_g_n_k,
|
||||
const CGridDesc_G_M_N& c_grid_desc_g_m_n)
|
||||
: a_grid_desc_g_m_k_(a_grid_desc_g_m_k),
|
||||
b_grid_desc_g_n_k_(b_grid_desc_g_n_k),
|
||||
b1_grid_desc_g_n_k_(b1_grid_desc_g_n_k),
|
||||
c_grid_desc_g_m_n_(c_grid_desc_g_m_n)
|
||||
{
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideA_);
|
||||
return a_grid_desc_g_m_k_.CalculateOffset(make_multi_index(g_idx, 0, 0));
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetBBasePtr(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideB_);
|
||||
return b_grid_desc_g_n_k_.CalculateOffset(make_multi_index(g_idx, 0, 0));
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideB1_);
|
||||
return b1_grid_desc_g_n_k_.CalculateOffset(make_multi_index(g_idx, 0, 0));
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetCBasePtr(index_t g_idx) const
|
||||
@@ -457,9 +317,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
}
|
||||
|
||||
private:
|
||||
index_t BatchStrideA_;
|
||||
index_t BatchStrideB_;
|
||||
index_t BatchStrideB1_;
|
||||
AGridDesc_G_M_K a_grid_desc_g_m_k_;
|
||||
BGridDesc_G_N_K b_grid_desc_g_n_k_;
|
||||
B1GridDesc_G_N_K b1_grid_desc_g_n_k_;
|
||||
CGridDesc_G_M_N c_grid_desc_g_m_n_;
|
||||
};
|
||||
|
||||
@@ -523,47 +383,59 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
LoopSched,
|
||||
matrix_padder.PadN,
|
||||
MaskOutUpperTriangle>;
|
||||
Transform::matrix_padder.PadN,
|
||||
MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle>;
|
||||
|
||||
// Argument
|
||||
// FIXME: constness
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
Argument(const ADataType* p_a_grid,
|
||||
const BDataType* p_b_grid,
|
||||
const B1DataType* p_b1_grid,
|
||||
CDataType* p_c_grid,
|
||||
index_t MRaw,
|
||||
index_t NRaw,
|
||||
index_t KRaw,
|
||||
index_t Gemm1NRaw, // = ORaw
|
||||
index_t Batch,
|
||||
std::vector<index_t> c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
|
||||
std::vector<index_t> c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
index_t StrideB1,
|
||||
index_t BatchStrideA,
|
||||
index_t BatchStrideB,
|
||||
index_t BatchStrideB1,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
AccElementwiseOperation acc_element_op,
|
||||
B1ElementwiseOperation b1_element_op,
|
||||
CElementwiseOperation c_element_op)
|
||||
Argument(
|
||||
const ADataType* p_a_grid,
|
||||
const BDataType* p_b_grid,
|
||||
const B1DataType* p_b1_grid,
|
||||
CDataType* p_c_grid,
|
||||
const std::array<void*, NumAcc0Bias> p_acc0_biases,
|
||||
const std::array<void*, NumAcc1Bias> p_acc1_biases,
|
||||
const std::vector<index_t>& a_gs_ms_ks_lengths,
|
||||
const std::vector<index_t>& a_gs_ms_ks_strides,
|
||||
const std::vector<index_t>& b_gs_ns_ks_lengths,
|
||||
const std::vector<index_t>& b_gs_ns_ks_strides,
|
||||
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths
|
||||
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
|
||||
const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
|
||||
const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
|
||||
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_lengths,
|
||||
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_strides,
|
||||
const std::array<std::vector<ck::index_t>, NumAcc1Bias>
|
||||
acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths
|
||||
const std::array<std::vector<ck::index_t>, NumAcc1Bias>
|
||||
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
AccElementwiseOperation acc_element_op,
|
||||
B1ElementwiseOperation b1_element_op,
|
||||
CElementwiseOperation c_element_op)
|
||||
: p_a_grid_{p_a_grid},
|
||||
p_b_grid_{p_b_grid},
|
||||
p_b1_grid_{p_b1_grid},
|
||||
p_c_grid_{p_c_grid},
|
||||
a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KRaw, StrideA)},
|
||||
b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)},
|
||||
b1_grid_desc_bk0_n_bk1_{
|
||||
DeviceOp::MakeB1GridDescriptor_BK0_N_BK1(NRaw, Gemm1NRaw, StrideB1)},
|
||||
c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(c_gs_ms_gemm1ns_lengths,
|
||||
c_gs_ms_gemm1ns_strides)},
|
||||
c_grid_desc_g_m_n_{DeviceOp::MakeCGridDescriptor_G_M_N(c_gs_ms_gemm1ns_lengths,
|
||||
c_gs_ms_gemm1ns_strides)},
|
||||
a_grid_desc_ak0_m_ak1_{
|
||||
DeviceOp::MakeAGridDescriptor_AK0_M_AK1(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)},
|
||||
b_grid_desc_bk0_n_bk1_{
|
||||
DeviceOp::MakeBGridDescriptor_BK0_N_BK1(b_gs_ns_ks_lengths, b_gs_ns_ks_strides)},
|
||||
b1_grid_desc_bk0_n_bk1_{DeviceOp::MakeB1GridDescriptor_BK0_N_BK1(
|
||||
b1_gs_gemm1ns_gemm1ks_lengths, b1_gs_gemm1ns_gemm1ks_strides)},
|
||||
c_grid_desc_m_n_{Transform::MakeCGridDescriptor_M_N(c_gs_ms_gemm1ns_lengths,
|
||||
c_gs_ms_gemm1ns_strides)},
|
||||
a_grid_desc_g_m_k_{
|
||||
Transform::MakeAGridDescriptor_G_M_K(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)},
|
||||
b_grid_desc_g_n_k_{
|
||||
Transform::MakeB0GridDescriptor_G_N_K(b_gs_ns_ks_lengths, b_gs_ns_ks_strides)},
|
||||
b1_grid_desc_g_n_k_{Transform::MakeB1GridDescriptor_G_N_K(
|
||||
b1_gs_gemm1ns_gemm1ks_lengths, b1_gs_gemm1ns_gemm1ks_strides)},
|
||||
c_grid_desc_g_m_n_{Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_gemm1ns_lengths,
|
||||
c_gs_ms_gemm1ns_strides)},
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock_{},
|
||||
block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)},
|
||||
a_element_op_{a_element_op},
|
||||
@@ -571,14 +443,31 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
acc_element_op_{acc_element_op},
|
||||
b1_element_op_{b1_element_op},
|
||||
c_element_op_{c_element_op},
|
||||
batch_count_(Batch),
|
||||
c0_matrix_mask_{b_grid_desc_g_n_k_.GetLength(I1)},
|
||||
raw_lengths_mz_nz_kz_gemm1nz_{a_gs_ms_ks_lengths[NumDimG + NumDimM - 1],
|
||||
b_gs_ns_ks_lengths[NumDimG + NumDimN - 1],
|
||||
b_gs_ns_ks_lengths[NumDimG + NumDimN + NumDimK - 1],
|
||||
b1_gs_gemm1ns_gemm1ks_lengths[NumDimG + NumDimO - 1]},
|
||||
a_mz_kz_strides_{a_gs_ms_ks_strides[NumDimG + NumDimM - 1],
|
||||
a_gs_ms_ks_strides[NumDimG + NumDimM + NumDimK - 1]},
|
||||
b_nz_kz_strides_{b_gs_ns_ks_strides[NumDimG + NumDimN - 1],
|
||||
b_gs_ns_ks_strides[NumDimG + NumDimN + NumDimK - 1]},
|
||||
b1_nz_kz_strides_{b1_gs_gemm1ns_gemm1ks_strides[NumDimG + NumDimO - 1],
|
||||
b1_gs_gemm1ns_gemm1ks_strides[NumDimG + NumDimO + NumDimN - 1]},
|
||||
c_mz_gemm1nz_strides_{c_gs_ms_gemm1ns_strides[NumDimG + NumDimM - 1],
|
||||
c_gs_ms_gemm1ns_strides[NumDimG + NumDimM + NumDimO - 1]},
|
||||
batch_count_{c_grid_desc_g_m_n_.GetLength(I0)},
|
||||
compute_base_ptr_of_batch_{
|
||||
BatchStrideA, BatchStrideB, BatchStrideB1, c_grid_desc_g_m_n_},
|
||||
c0_matrix_mask_{NRaw},
|
||||
raw_lengths_m_n_k_o_{MRaw, NRaw, KRaw, Gemm1NRaw},
|
||||
c_extent_lowest_{c_gs_ms_gemm1ns_lengths.back()},
|
||||
c_stride_lowest_{c_gs_ms_gemm1ns_strides.back()}
|
||||
a_grid_desc_g_m_k_, b_grid_desc_g_n_k_, b1_grid_desc_g_n_k_, c_grid_desc_g_m_n_}
|
||||
{
|
||||
// TODO ANT: implement bias addition
|
||||
ignore = p_acc0_biases;
|
||||
ignore = p_acc1_biases;
|
||||
ignore = acc0_biases_gs_ms_ns_lengths;
|
||||
ignore = acc0_biases_gs_ms_ns_strides;
|
||||
ignore = acc1_biases_gs_ms_gemm1ns_lengths;
|
||||
ignore = acc1_biases_gs_ms_gemm1ns_strides;
|
||||
|
||||
if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_,
|
||||
b_grid_desc_bk0_n_bk1_,
|
||||
b1_grid_desc_bk0_n_bk1_,
|
||||
@@ -591,34 +480,66 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
}
|
||||
}
|
||||
|
||||
// private:
|
||||
void Print() const
|
||||
{
|
||||
std::cout << "a_grid_desc_g_m_k_: " << a_grid_desc_g_m_k_.GetLength(I0) << ", "
|
||||
<< a_grid_desc_g_m_k_.GetLength(I1) << ", "
|
||||
<< a_grid_desc_g_m_k_.GetLength(I2) << '\n';
|
||||
// a_grid_desc_g_m_k_.Print();
|
||||
std::cout << "b_grid_desc_g_n_k_: " << b_grid_desc_g_n_k_.GetLength(I0) << ", "
|
||||
<< b_grid_desc_g_n_k_.GetLength(I1) << ", "
|
||||
<< b_grid_desc_g_n_k_.GetLength(I2) << '\n';
|
||||
// b_grid_desc_g_n_k_.Print();
|
||||
std::cout << "b1_grid_desc_g_n_k_: " << b1_grid_desc_g_n_k_.GetLength(I0) << ", "
|
||||
<< b1_grid_desc_g_n_k_.GetLength(I1) << ", "
|
||||
<< b1_grid_desc_g_n_k_.GetLength(I2) << '\n';
|
||||
// b1_grid_desc_g_n_k_.Print();
|
||||
std::cout << "c_grid_desc_g_m_n_: " << c_grid_desc_g_m_n_.GetLength(I0) << ", "
|
||||
<< c_grid_desc_g_m_n_.GetLength(I1) << ", "
|
||||
<< c_grid_desc_g_m_n_.GetLength(I2) << '\n';
|
||||
// c_grid_desc_g_m_n_.Print();
|
||||
}
|
||||
|
||||
// pointers
|
||||
const ADataType* p_a_grid_;
|
||||
const BDataType* p_b_grid_;
|
||||
const B1DataType* p_b1_grid_;
|
||||
CDataType* p_c_grid_;
|
||||
|
||||
// tensor descriptor
|
||||
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
|
||||
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
|
||||
B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_;
|
||||
CGridDesc_M_N c_grid_desc_m_n_;
|
||||
AGridDesc_G_M_K a_grid_desc_g_m_k_;
|
||||
BGridDesc_G_N_K b_grid_desc_g_n_k_;
|
||||
B1GridDesc_G_N_K b1_grid_desc_g_n_k_;
|
||||
CGridDesc_G_M_N c_grid_desc_g_m_n_;
|
||||
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock_;
|
||||
|
||||
// block-to-c-tile map
|
||||
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
|
||||
|
||||
// element-wise op
|
||||
AElementwiseOperation a_element_op_;
|
||||
BElementwiseOperation b_element_op_;
|
||||
AccElementwiseOperation acc_element_op_;
|
||||
B1ElementwiseOperation b1_element_op_;
|
||||
CElementwiseOperation c_element_op_;
|
||||
index_t batch_count_;
|
||||
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_;
|
||||
|
||||
// check C0 masking and padding
|
||||
C0MatrixMask c0_matrix_mask_;
|
||||
|
||||
// For robust IsSupportedArgument() check
|
||||
std::vector<index_t> raw_lengths_m_n_k_o_;
|
||||
index_t c_extent_lowest_;
|
||||
index_t c_stride_lowest_;
|
||||
std::vector<index_t> raw_lengths_mz_nz_kz_gemm1nz_;
|
||||
std::vector<index_t> a_mz_kz_strides_;
|
||||
std::vector<index_t> b_nz_kz_strides_;
|
||||
std::vector<index_t> b1_nz_kz_strides_;
|
||||
std::vector<index_t> c_mz_gemm1nz_strides_;
|
||||
|
||||
index_t batch_count_;
|
||||
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
@@ -628,13 +549,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.b1_grid_desc_bk0_n_bk1_,
|
||||
arg.c_grid_desc_m_n_,
|
||||
arg.block_2_ctile_map_))
|
||||
if(!DeviceOp::IsSupportedArgument(arg))
|
||||
{
|
||||
throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
|
||||
throw std::runtime_error("wrong! unsupported argument");
|
||||
}
|
||||
|
||||
const index_t grid_size =
|
||||
@@ -719,17 +636,24 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
#if 0
|
||||
arg.Print();
|
||||
#endif
|
||||
|
||||
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a"))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// TODO ANT: Check if tensor specialization & strides mismatch
|
||||
|
||||
// Check if C permute dimension matches GEMM + GEMM shape
|
||||
const index_t c_g = arg.c_grid_desc_g_m_n_.GetLength(I0); // unpadded
|
||||
const index_t c_m = arg.c_grid_desc_m_n_.GetLength(I0);
|
||||
const index_t c_gemm1n = arg.c_grid_desc_m_n_.GetLength(I1);
|
||||
const index_t a_m = arg.a_grid_desc_ak0_m_ak1_.GetLength(I1);
|
||||
const index_t b1_gemm1n = arg.b1_grid_desc_bk0_n_bk1_.GetLength(I1);
|
||||
|
||||
if(!(c_g == arg.batch_count_ && c_m == a_m && c_gemm1n == b1_gemm1n))
|
||||
{
|
||||
return false;
|
||||
@@ -737,19 +661,17 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
|
||||
// Note: we need raw lengths since threadwise copy can not handle vector load when part of
|
||||
// vector is out of bounds
|
||||
const auto MRaw = arg.raw_lengths_m_n_k_o_[0];
|
||||
const auto NRaw = arg.raw_lengths_m_n_k_o_[1];
|
||||
const auto KRaw = arg.raw_lengths_m_n_k_o_[2];
|
||||
const auto Gemm1NRaw = arg.raw_lengths_m_n_k_o_[3];
|
||||
// Note: need lowest dim in Ms/Ns/Ks/Os, not merged M/N/K/O
|
||||
const auto MzRaw = arg.raw_lengths_mz_nz_kz_gemm1nz_[0];
|
||||
const auto NzRaw = arg.raw_lengths_mz_nz_kz_gemm1nz_[1];
|
||||
const auto KzRaw = arg.raw_lengths_mz_nz_kz_gemm1nz_[2];
|
||||
const auto Gemm1NzRaw = arg.raw_lengths_mz_nz_kz_gemm1nz_[3];
|
||||
|
||||
// Check scalar per vector requirement
|
||||
const auto a_extent_lowest =
|
||||
is_same_v<tensor_layout::gemm::RowMajor, ALayout> ? KRaw : MRaw;
|
||||
const auto b_extent_lowest =
|
||||
is_same_v<tensor_layout::gemm::RowMajor, BLayout> ? NRaw : KRaw;
|
||||
const auto b1_extent_lowest =
|
||||
is_same_v<tensor_layout::gemm::RowMajor, B1Layout> ? Gemm1NRaw : NRaw;
|
||||
const auto c_extent_lowest = arg.c_extent_lowest_;
|
||||
const auto a_extent_lowest = ABlockTransferSrcVectorDim == 2 ? KzRaw : MzRaw;
|
||||
const auto b_extent_lowest = BBlockTransferSrcVectorDim == 2 ? KzRaw : NzRaw;
|
||||
const auto b1_extent_lowest = B1BlockTransferSrcVectorDim == 2 ? NzRaw : Gemm1NzRaw;
|
||||
const auto c_extent_lowest = Gemm1NzRaw;
|
||||
|
||||
if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 &&
|
||||
b_extent_lowest % BBlockTransferSrcScalarPerVector == 0 &&
|
||||
@@ -759,8 +681,18 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check vector store requirement; assumes last dimension in N to be contiguous
|
||||
if(arg.c_stride_lowest_ != 1)
|
||||
// Check vector load/store requirement
|
||||
const auto a_stride_lowest =
|
||||
ABlockTransferSrcVectorDim == 2 ? arg.a_mz_kz_strides_[1] : arg.a_mz_kz_strides_[0];
|
||||
const auto b_stride_lowest =
|
||||
BBlockTransferSrcVectorDim == 2 ? arg.b_nz_kz_strides_[1] : arg.b_nz_kz_strides_[0];
|
||||
const auto b1_stride_lowest =
|
||||
B1BlockTransferSrcVectorDim == 2 ? arg.b1_nz_kz_strides_[1] : arg.b1_nz_kz_strides_[0];
|
||||
const auto c_stride_lowest =
|
||||
arg.c_mz_gemm1nz_strides_[1]; // cshuffle assumes lowest dim in Gemm1Ns to be contiguous
|
||||
|
||||
if(!(a_stride_lowest == 1 || b_stride_lowest == 1 || b1_stride_lowest == 1 ||
|
||||
c_stride_lowest == 1))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
@@ -778,46 +710,51 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
|
||||
static auto MakeArgument(const ADataType* p_a,
|
||||
const BDataType* p_b,
|
||||
const B1DataType* p_b1,
|
||||
CDataType* p_c,
|
||||
index_t MRaw,
|
||||
index_t NRaw,
|
||||
index_t KRaw,
|
||||
index_t Gemm1NRaw,
|
||||
index_t Batch,
|
||||
std::vector<index_t> c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
|
||||
std::vector<index_t> c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
index_t StrideB1,
|
||||
index_t BatchStrideA,
|
||||
index_t BatchStrideB,
|
||||
index_t BatchStrideB1,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
AccElementwiseOperation acc_element_op,
|
||||
B1ElementwiseOperation b1_element_op,
|
||||
CElementwiseOperation c_element_op)
|
||||
static auto MakeArgument(
|
||||
const ADataType* p_a,
|
||||
const BDataType* p_b,
|
||||
const B1DataType* p_b1,
|
||||
CDataType* p_c,
|
||||
const std::array<void*, NumAcc0Bias> p_acc0_biases,
|
||||
const std::array<void*, NumAcc1Bias> p_acc1_biases,
|
||||
const std::vector<index_t>& a_gs_ms_ks_lengths,
|
||||
const std::vector<index_t>& a_gs_ms_ks_strides,
|
||||
const std::vector<index_t>& b_gs_ns_ks_lengths,
|
||||
const std::vector<index_t>& b_gs_ns_ks_strides,
|
||||
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths
|
||||
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
|
||||
const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
|
||||
const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
|
||||
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_lengths,
|
||||
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_strides,
|
||||
const std::array<std::vector<ck::index_t>, NumAcc1Bias>
|
||||
acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths
|
||||
const std::array<std::vector<ck::index_t>, NumAcc1Bias>
|
||||
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
AccElementwiseOperation acc_element_op,
|
||||
B1ElementwiseOperation b1_element_op,
|
||||
CElementwiseOperation c_element_op)
|
||||
{
|
||||
return Argument{p_a,
|
||||
p_b,
|
||||
p_b1,
|
||||
p_c,
|
||||
MRaw,
|
||||
NRaw,
|
||||
KRaw,
|
||||
Gemm1NRaw,
|
||||
Batch,
|
||||
c_gs_ms_gemm1ns_lengths,
|
||||
c_gs_ms_gemm1ns_strides,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideB1,
|
||||
BatchStrideA,
|
||||
BatchStrideB,
|
||||
BatchStrideB1,
|
||||
p_acc0_biases,
|
||||
p_acc1_biases,
|
||||
a_gs_ms_ks_lengths,
|
||||
a_gs_ms_ks_strides,
|
||||
b_gs_ns_ks_lengths,
|
||||
b_gs_ns_ks_strides,
|
||||
b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths
|
||||
b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
|
||||
c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
|
||||
c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
|
||||
acc0_biases_gs_ms_ns_lengths,
|
||||
acc0_biases_gs_ms_ns_strides,
|
||||
acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths
|
||||
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
acc_element_op,
|
||||
@@ -829,47 +766,51 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
|
||||
// polymorphic
|
||||
// FIXME: constness
|
||||
std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
const void* p_b1,
|
||||
void* p_c,
|
||||
index_t MRaw,
|
||||
index_t NRaw,
|
||||
index_t KRaw,
|
||||
index_t Gemm1NRaw,
|
||||
index_t Batch,
|
||||
std::vector<index_t> c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
|
||||
std::vector<index_t> c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
index_t StrideB1,
|
||||
index_t BatchStrideA,
|
||||
index_t BatchStrideB,
|
||||
index_t BatchStrideB1,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
AccElementwiseOperation acc_element_op,
|
||||
B1ElementwiseOperation b1_element_op,
|
||||
CElementwiseOperation c_element_op) override
|
||||
std::unique_ptr<BaseArgument> MakeArgumentPointer(
|
||||
const void* p_a,
|
||||
const void* p_b,
|
||||
const void* p_b1,
|
||||
void* p_c,
|
||||
const std::array<void*, NumAcc0Bias> p_acc0_biases,
|
||||
const std::array<void*, NumAcc1Bias> p_acc1_biases,
|
||||
const std::vector<index_t>& a_gs_ms_ks_lengths,
|
||||
const std::vector<index_t>& a_gs_ms_ks_strides,
|
||||
const std::vector<index_t>& b_gs_ns_ks_lengths,
|
||||
const std::vector<index_t>& b_gs_ns_ks_strides,
|
||||
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths
|
||||
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
|
||||
const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
|
||||
const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
|
||||
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_lengths,
|
||||
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_strides,
|
||||
const std::array<std::vector<ck::index_t>, NumAcc1Bias>
|
||||
acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths
|
||||
const std::array<std::vector<ck::index_t>, NumAcc1Bias>
|
||||
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
AccElementwiseOperation acc_element_op,
|
||||
B1ElementwiseOperation b1_element_op,
|
||||
CElementwiseOperation c_element_op) override
|
||||
{
|
||||
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
|
||||
static_cast<const BDataType*>(p_b),
|
||||
static_cast<const B1DataType*>(p_b1),
|
||||
static_cast<CDataType*>(p_c),
|
||||
MRaw,
|
||||
NRaw,
|
||||
KRaw,
|
||||
Gemm1NRaw,
|
||||
Batch,
|
||||
c_gs_ms_gemm1ns_lengths,
|
||||
c_gs_ms_gemm1ns_strides,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideB1,
|
||||
BatchStrideA,
|
||||
BatchStrideB,
|
||||
BatchStrideB1,
|
||||
p_acc0_biases, // cast in struct Argument
|
||||
p_acc1_biases, // cast in struct Argument
|
||||
a_gs_ms_ks_lengths,
|
||||
a_gs_ms_ks_strides,
|
||||
b_gs_ns_ks_lengths,
|
||||
b_gs_ns_ks_strides,
|
||||
b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths
|
||||
b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
|
||||
c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
|
||||
c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
|
||||
acc0_biases_gs_ms_ns_lengths,
|
||||
acc0_biases_gs_ms_ns_strides,
|
||||
acc1_biases_gs_ms_gemm1ns_lengths,
|
||||
acc1_biases_gs_ms_gemm1ns_strides,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
acc_element_op,
|
||||
@@ -901,7 +842,12 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
<< Gemm1NPerBlock << ", "
|
||||
<< Gemm1KPerBlock << ", "
|
||||
<< B1K1 << ", "
|
||||
<< getGemmSpecializationString(GemmSpec) << ">";
|
||||
<< getGemmSpecializationString(GemmSpec) << ", "
|
||||
<< "ASpec" << getTensorSpecializationString(ASpec) << ", "
|
||||
<< "B0Spec" << getTensorSpecializationString(BSpec) << ", "
|
||||
<< "B1Spec" << getTensorSpecializationString(B1Spec) << ", "
|
||||
<< "CSpec" << getTensorSpecializationString(CSpec) << ", "
|
||||
<< getMaskingSpecializationString(MaskingSpec) << ">";
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/masking_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
@@ -196,7 +197,8 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
|
||||
BElementwiseOperation,
|
||||
AccElementwiseOperation,
|
||||
B1ElementwiseOperation,
|
||||
CElementwiseOperation>
|
||||
CElementwiseOperation,
|
||||
MaskOutUpperTriangle>
|
||||
{
|
||||
using DeviceOp = DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle;
|
||||
|
||||
@@ -315,29 +317,6 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
|
||||
return matrix_padder.PadCDescriptor_M_N(c_grid_desc_mraw_nraw);
|
||||
}
|
||||
|
||||
// to track the points which need to be set to -inf on C0
|
||||
// Note: no need to reset M padding value, because they will not be stored out.
|
||||
struct C0MatrixMask
|
||||
{
|
||||
C0MatrixMask(index_t NRaw) : NRaw_(NRaw) {}
|
||||
|
||||
__host__ __device__ bool IsUpperTriangle(index_t m, index_t n) const { return n > m; }
|
||||
|
||||
__host__ __device__ bool IsNOutOfBound(/*index_t m, */ index_t n) const
|
||||
{
|
||||
return n >= NRaw_;
|
||||
}
|
||||
|
||||
__host__ __device__ bool IsMaskedElement(index_t m, index_t n) const
|
||||
{
|
||||
return IsUpperTriangle(m, n) || IsNOutOfBound(n);
|
||||
}
|
||||
|
||||
private:
|
||||
// index_t MRaw_;
|
||||
index_t NRaw_;
|
||||
};
|
||||
|
||||
struct ComputeBasePtrOfStridedBatch
|
||||
{
|
||||
ComputeBasePtrOfStridedBatch(index_t BatchStrideA,
|
||||
@@ -383,6 +362,10 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
|
||||
using B1GridDesc_BK0_N_BK1 = decltype(MakeB1GridDescriptor_BK0_N_BK1(1, 1, 1));
|
||||
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
|
||||
|
||||
using C0MatrixMask = conditional_t<MaskOutUpperTriangle,
|
||||
C0MatrixMask_impl<MaskOutUpperTrianglePredicate>,
|
||||
C0MatrixMask_impl<MaskDisabledPredicate>>;
|
||||
|
||||
// GridwiseGemm
|
||||
using GridwiseGemm = GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle<
|
||||
ADataType, // TODO: distinguish A/B datatype
|
||||
|
||||
@@ -0,0 +1,82 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
enum struct MaskingSpecialization
|
||||
{
|
||||
MaskDisabled,
|
||||
MaskOutUpperTriangle
|
||||
};
|
||||
|
||||
inline std::string getMaskingSpecializationString(const MaskingSpecialization& s)
|
||||
{
|
||||
switch(s)
|
||||
{
|
||||
case MaskingSpecialization::MaskDisabled: return "MaskDisabled";
|
||||
case MaskingSpecialization::MaskOutUpperTriangle: return "MaskOutUpperTriangle";
|
||||
default: return "Unrecognized specialization!";
|
||||
}
|
||||
}
|
||||
|
||||
struct MaskDisabledPredicate
|
||||
{
|
||||
__host__ __device__ constexpr bool operator()(index_t /*m*/, index_t /*n*/) const
|
||||
{
|
||||
return false;
|
||||
};
|
||||
|
||||
__host__ __device__ constexpr bool
|
||||
IsTileSkippable(index_t /*m*/, index_t /*n*/, index_t /*m_tile*/, index_t /*n_tile*/) const
|
||||
{
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
struct MaskOutUpperTrianglePredicate
|
||||
{
|
||||
__host__ __device__ constexpr bool operator()(index_t m, index_t n) const { return n > m; }
|
||||
|
||||
__host__ __device__ constexpr bool
|
||||
IsTileSkippable(index_t m, index_t n, index_t m_tile, index_t /*n_tile*/) const
|
||||
{
|
||||
return operator()(m + m_tile - 1, n);
|
||||
}
|
||||
};
|
||||
|
||||
// to track the points which need to be set to -inf on C0
|
||||
// Note: no need to reset M padding value, because they will not be stored out.
|
||||
template <typename MaskOutPredicate>
|
||||
struct C0MatrixMask_impl
|
||||
{
|
||||
C0MatrixMask_impl(index_t NRaw) : NRaw_(NRaw), predicate_(MaskOutPredicate{}) {}
|
||||
|
||||
__host__ __device__ constexpr bool IsNOutOfBound(/*index_t m, */ index_t n) const
|
||||
{
|
||||
return n >= NRaw_;
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr bool IsMaskedElement(index_t m, index_t n) const
|
||||
{
|
||||
return predicate_(m, n) || IsNOutOfBound(n);
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr bool
|
||||
IsTileSkippable(index_t m, index_t n, index_t m_tile, index_t n_tile) const
|
||||
{
|
||||
return predicate_.IsTileSkippable(m, n, m_tile, n_tile);
|
||||
}
|
||||
|
||||
private:
|
||||
// index_t MRaw_;
|
||||
index_t NRaw_;
|
||||
MaskOutPredicate predicate_;
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
Reference in New Issue
Block a user