mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-07 08:15:04 +00:00
Introduce gemm_softmax_gemm to codegen.
This commit is contained in:
@@ -3,15 +3,17 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#ifndef __HIPCC_RTC__
|
||||
#include <string>
|
||||
#include <sstream>
|
||||
|
||||
#include "ck/stream_config.hpp"
|
||||
#endif
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
#ifndef __HIPCC_RTC__
|
||||
struct BaseArgument
|
||||
{
|
||||
BaseArgument() = default;
|
||||
@@ -36,6 +38,7 @@ struct BaseInvoker
|
||||
|
||||
virtual ~BaseInvoker() {}
|
||||
};
|
||||
#endif
|
||||
|
||||
struct BaseOperator
|
||||
{
|
||||
@@ -43,6 +46,7 @@ struct BaseOperator
|
||||
BaseOperator(const BaseOperator&) = default;
|
||||
BaseOperator& operator=(const BaseOperator&) = default;
|
||||
|
||||
#ifndef __HIPCC_RTC__
|
||||
virtual bool IsSupportedArgument(const BaseArgument*) { return false; }
|
||||
virtual std::string GetTypeString() const { return ""; }
|
||||
|
||||
@@ -66,7 +70,7 @@ struct BaseOperator
|
||||
assert(p_arg);
|
||||
p_arg->p_workspace_ = p_workspace;
|
||||
}
|
||||
|
||||
#endif
|
||||
virtual ~BaseOperator() {}
|
||||
};
|
||||
|
||||
|
||||
@@ -2,9 +2,10 @@
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#ifndef __HIPCC_RTC__
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
#endif
|
||||
|
||||
#include "device_base.hpp"
|
||||
|
||||
@@ -28,6 +29,7 @@ template <typename ALayout,
|
||||
bool MaskOutUpperTriangle> // TODO: enum for mask type
|
||||
struct DeviceBatchedGemmSoftmaxGemm : public BaseOperator
|
||||
{
|
||||
#ifndef __HIPCC_RTC__
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b0,
|
||||
@@ -53,6 +55,7 @@ struct DeviceBatchedGemmSoftmaxGemm : public BaseOperator
|
||||
CElementwiseOperation c_element_op) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
#endif
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
|
||||
@@ -2,9 +2,11 @@
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#ifndef __HIPCC_RTC__
|
||||
#include <array>
|
||||
#endif
|
||||
|
||||
#include "ck/utility/array.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
@@ -34,6 +36,7 @@ struct DeviceGemmMultipleD : public BaseOperator
|
||||
{
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
|
||||
#ifndef __HIPCC_RTC__
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
@@ -51,6 +54,7 @@ struct DeviceGemmMultipleD : public BaseOperator
|
||||
CDEElementwiseOperation cde_element_op) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
#endif
|
||||
};
|
||||
|
||||
// GEMM:
|
||||
@@ -76,6 +80,7 @@ struct DeviceGemmMultipleDSplitK : public BaseOperator
|
||||
{
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
|
||||
#ifndef __HIPCC_RTC__
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
@@ -94,6 +99,7 @@ struct DeviceGemmMultipleDSplitK : public BaseOperator
|
||||
CDEElementwiseOperation cde_element_op) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
#endif
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
|
||||
@@ -28,7 +28,7 @@ enum struct GemmSpecialization
|
||||
NKOPadding,
|
||||
MNKOPadding,
|
||||
};
|
||||
|
||||
#ifndef __HIPCC_RTC__
|
||||
inline std::string getGemmSpecializationString(const GemmSpecialization& s)
|
||||
{
|
||||
switch(s)
|
||||
@@ -52,6 +52,7 @@ inline std::string getGemmSpecializationString(const GemmSpecialization& s)
|
||||
default: return "Unrecognized specialization!";
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
@@ -3,8 +3,12 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#ifndef __HIPCC_RTC__
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
#endif
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
@@ -15,8 +19,6 @@
|
||||
#include "ck/tensor_operation/gpu/device/masking_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -40,27 +42,27 @@ template <typename GridwiseGemm,
|
||||
bool HasMainKBlockLoop>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v1(
|
||||
const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
const FloatAB* __restrict__ p_b1_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
const AElementwiseOperation a_element_op,
|
||||
const BElementwiseOperation b_element_op,
|
||||
const AccElementwiseOperation acc_element_op,
|
||||
const B1ElementwiseOperation b1_element_op,
|
||||
const CElementwiseOperation c_element_op,
|
||||
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
|
||||
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
|
||||
const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1,
|
||||
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const Block2CTileMap block_2_ctile_map,
|
||||
const index_t batch_count,
|
||||
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch,
|
||||
const C0MatrixMask c0_matrix_mask)
|
||||
kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v1(
|
||||
const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
const FloatAB* __restrict__ p_b1_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
const AElementwiseOperation a_element_op,
|
||||
const BElementwiseOperation b_element_op,
|
||||
const AccElementwiseOperation acc_element_op,
|
||||
const B1ElementwiseOperation b1_element_op,
|
||||
const CElementwiseOperation c_element_op,
|
||||
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
|
||||
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
|
||||
const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1,
|
||||
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const Block2CTileMap block_2_ctile_map,
|
||||
const index_t batch_count,
|
||||
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch,
|
||||
const C0MatrixMask c0_matrix_mask)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx94__))
|
||||
@@ -430,6 +432,7 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
|
||||
matrix_padder.PadN,
|
||||
MaskOutUpperTriangle>;
|
||||
|
||||
#ifndef __HIPCC_RTC__
|
||||
// Argument
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
@@ -604,6 +607,7 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
static constexpr bool IsValidCompilationParameter()
|
||||
{
|
||||
@@ -611,6 +615,97 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
|
||||
return true;
|
||||
}
|
||||
|
||||
static constexpr bool
|
||||
IsSupported(index_t MRaw_, index_t NRaw_, index_t KRaw_, index_t Gemm1NRaw_)
|
||||
{
|
||||
// check vector load/store
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
// check vector load of A
|
||||
if constexpr(is_same_v<ALayout, Row>)
|
||||
{
|
||||
if(KRaw_ % ABlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Col>)
|
||||
{
|
||||
if(MRaw_ % ABlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// check vector load of B
|
||||
if constexpr(is_same_v<BLayout, Row>)
|
||||
{
|
||||
if(NRaw_ % BBlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same_v<BLayout, Col>)
|
||||
{
|
||||
if(KRaw_ % BBlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// check vector load of B1
|
||||
if constexpr(is_same_v<B1Layout, Row>)
|
||||
{
|
||||
if(Gemm1NRaw_ % B1BlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same_v<B1Layout, Col>)
|
||||
{
|
||||
if(NRaw_ % B1BlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// check vector load of C
|
||||
if constexpr(is_same_v<CLayout, Row>)
|
||||
{
|
||||
if(Gemm1NRaw_ % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same_v<CLayout, Col>)
|
||||
{
|
||||
if(MRaw_ % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
#ifndef __HIPCC_RTC__
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(!ck::is_xdl_supported())
|
||||
@@ -765,8 +860,271 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
|
||||
|
||||
return str.str();
|
||||
}
|
||||
#endif
|
||||
|
||||
template <class ADesc, class BDesc, class B1Desc, class CDesc>
|
||||
struct Descriptor
|
||||
{
|
||||
template <class AGridDescriptor>
|
||||
static constexpr auto MakeAGridDescriptor_AK0_M_AK1(const AGridDescriptor& a_grid_desc)
|
||||
{
|
||||
const auto a_grid_desc_m_k = DeviceOp::matrix_padder.PadADescriptor_M_K(a_grid_desc);
|
||||
|
||||
const auto M = a_grid_desc_m_k.GetLength(I0);
|
||||
const auto K = a_grid_desc_m_k.GetLength(I1);
|
||||
|
||||
const auto AK0 = K / AK1;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
a_grid_desc_m_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
|
||||
make_pass_through_transform(M)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
}
|
||||
|
||||
template <class BGridDescriptor>
|
||||
static constexpr auto MakeBGridDescriptor_BK0_N_BK1(const BGridDescriptor& b_grid_desc)
|
||||
{
|
||||
const auto b_grid_desc_n_k = DeviceOp::matrix_padder.PadBDescriptor_N_K(b_grid_desc);
|
||||
|
||||
const auto N = b_grid_desc_n_k.GetLength(I0);
|
||||
const auto K = b_grid_desc_n_k.GetLength(I1);
|
||||
|
||||
const auto BK0 = K / BK1;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
b_grid_desc_n_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
|
||||
make_pass_through_transform(N)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
}
|
||||
|
||||
template <class B1GridDescriptor>
|
||||
static constexpr auto MakeB1GridDescriptor_BK0_N_BK1(const B1GridDescriptor& b1_grid_desc)
|
||||
{
|
||||
const auto b1_grid_desc_n_k = DeviceOp::matrix_padder.PadB1Descriptor_N_K(b1_grid_desc);
|
||||
|
||||
const auto N = b1_grid_desc_n_k.GetLength(I0);
|
||||
const auto K = b1_grid_desc_n_k.GetLength(I1);
|
||||
|
||||
const auto B1K0 = K / B1K1;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
b1_grid_desc_n_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(B1K0, B1K1)),
|
||||
make_pass_through_transform(N)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
}
|
||||
|
||||
template <class CGridDescriptor>
|
||||
static constexpr auto MakeCGridDescriptor_M_N(const CGridDescriptor& c_grid_desc)
|
||||
{
|
||||
return DeviceOp::matrix_padder.PadCDescriptor_M_N(c_grid_desc);
|
||||
}
|
||||
|
||||
using AGridDesc_AK0_M_AK1 =
|
||||
remove_cvref_t<decltype(MakeAGridDescriptor_AK0_M_AK1(ADesc{}))>;
|
||||
using BGridDesc_BK0_N_BK1 =
|
||||
remove_cvref_t<decltype(MakeBGridDescriptor_BK0_N_BK1(BDesc{}))>;
|
||||
using B1GridDesc_BK0_N_BK1 =
|
||||
remove_cvref_t<decltype(MakeB1GridDescriptor_BK0_N_BK1(B1Desc{}))>;
|
||||
using CGridDesc_M_N = remove_cvref_t<decltype(MakeCGridDescriptor_M_N(CDesc{}))>;
|
||||
|
||||
// GridwiseGemm
|
||||
using GridwiseGemm = GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle<
|
||||
ADataType, // TODO: distinguish A/B datatype
|
||||
GemmAccDataType,
|
||||
CShuffleDataType,
|
||||
CDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
AccElementwiseOperation,
|
||||
B1ElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
AGridDesc_AK0_M_AK1,
|
||||
BGridDesc_BK0_N_BK1,
|
||||
B1GridDesc_BK0_N_BK1,
|
||||
CGridDesc_M_N,
|
||||
NumGemmKPrefetchStage,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
Gemm1NPerBlock,
|
||||
Gemm1KPerBlock,
|
||||
AK1,
|
||||
BK1,
|
||||
B1K1,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MXdlPerWave,
|
||||
NXdlPerWave,
|
||||
Gemm1NXdlPerWave,
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_AK1,
|
||||
true,
|
||||
ABlockLdsExtraM,
|
||||
BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_BK1,
|
||||
true,
|
||||
BBlockLdsExtraN,
|
||||
B1BlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
B1BlockTransferThreadClusterArrangeOrder,
|
||||
B1BlockTransferSrcAccessOrder,
|
||||
B1BlockTransferSrcVectorDim,
|
||||
B1BlockTransferSrcScalarPerVector,
|
||||
B1BlockTransferDstScalarPerVector_BK1,
|
||||
false,
|
||||
B1BlockLdsExtraN,
|
||||
CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
LoopSched,
|
||||
matrix_padder.PadN,
|
||||
MaskOutUpperTriangle>;
|
||||
|
||||
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1;
|
||||
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1;
|
||||
B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1;
|
||||
CGridDesc_M_N c_grid_desc_m_n;
|
||||
C0MatrixMask c0_matrix_mask;
|
||||
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map;
|
||||
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
c_grid_descriptor_mblock_mperblock_nblock_nperblock;
|
||||
|
||||
// element-wise op
|
||||
AElementwiseOperation a_element_op;
|
||||
BElementwiseOperation b_element_op;
|
||||
B1ElementwiseOperation b1_element_op;
|
||||
CElementwiseOperation c_element_op;
|
||||
|
||||
bool has_main_k_block_loop = true;
|
||||
bool is_valid = false;
|
||||
|
||||
constexpr Descriptor(ADesc a,
|
||||
BDesc b,
|
||||
B1Desc b1,
|
||||
CDesc c,
|
||||
AElementwiseOperation a_element_op_,
|
||||
BElementwiseOperation b_element_op_,
|
||||
B1ElementwiseOperation b1_element_op_,
|
||||
CElementwiseOperation c_element_op_)
|
||||
: a_grid_desc_ak0_m_ak1{MakeAGridDescriptor_AK0_M_AK1(a)},
|
||||
b_grid_desc_bk0_n_bk1{MakeBGridDescriptor_BK0_N_BK1(b)},
|
||||
b1_grid_desc_bk0_n_bk1{MakeB1GridDescriptor_BK0_N_BK1(b1)},
|
||||
c_grid_desc_m_n{MakeCGridDescriptor_M_N(c)},
|
||||
block_2_ctile_map{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n)},
|
||||
c_grid_descriptor_mblock_mperblock_nblock_nperblock{
|
||||
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
c_grid_desc_m_n)},
|
||||
has_main_k_block_loop{GridwiseGemm::CalculateHasMainKBlockLoop(
|
||||
a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2))},
|
||||
c0_matrix_mask{c.GetLength(I1)},
|
||||
a_element_op{a_element_op_},
|
||||
b_element_op{b_element_op_},
|
||||
b1_element_op{b1_element_op_},
|
||||
c_element_op{c_element_op_},
|
||||
is_valid{GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
b1_grid_desc_bk0_n_bk1,
|
||||
c_grid_desc_m_n,
|
||||
block_2_ctile_map) and
|
||||
IsSupported(a_grid_desc_ak0_m_ak1.GetLength(I1),
|
||||
b_grid_desc_bk0_n_bk1.GetLength(I1),
|
||||
a_grid_desc_ak0_m_ak1.GetLength(I0) *
|
||||
a_grid_desc_ak0_m_ak1.GetLength(I2),
|
||||
b1_grid_desc_bk0_n_bk1.GetLength(I1))}
|
||||
{
|
||||
}
|
||||
|
||||
constexpr bool IsValid() const { return is_valid; }
|
||||
};
|
||||
|
||||
template <class ADesc, class BDesc, class B1Desc, class CDesc>
|
||||
static constexpr auto
|
||||
make_descriptor(ADesc a,
|
||||
BDesc b,
|
||||
B1Desc b1,
|
||||
CDesc c,
|
||||
AElementwiseOperation a_element_op = AElementwiseOperation{},
|
||||
BElementwiseOperation b_element_op = BElementwiseOperation{},
|
||||
B1ElementwiseOperation b1_element_op = B1ElementwiseOperation{},
|
||||
CElementwiseOperation c_element_op = CElementwiseOperation{})
|
||||
{
|
||||
return Descriptor<ADesc, BDesc, B1Desc, CDesc>(
|
||||
a, b, b1, c, a_element_op, b_element_op, b1_element_op, c_element_op);
|
||||
}
|
||||
|
||||
template <class Desc>
|
||||
__device__ static void Run(const Desc& desc,
|
||||
const float scale,
|
||||
const ADataType* __restrict__ p_a_grid,
|
||||
const ADataType* __restrict__ p_b_grid,
|
||||
const ADataType* __restrict__ p_b1_grid,
|
||||
CDataType* __restrict__ p_c_grid)
|
||||
{
|
||||
#ifndef __HIPCC_RTC__
|
||||
assert(desc.is_valid);
|
||||
#endif
|
||||
__shared__ char p_shared_block[Desc::GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
AccElementwiseOperation acc_element_op{scale};
|
||||
|
||||
if(desc.has_main_k_block_loop)
|
||||
{
|
||||
Desc::GridwiseGemm::template Run<true>(
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_b1_grid,
|
||||
p_c_grid,
|
||||
p_shared_block,
|
||||
desc.a_element_op,
|
||||
desc.b_element_op,
|
||||
acc_element_op,
|
||||
desc.b1_element_op,
|
||||
desc.c_element_op,
|
||||
desc.a_grid_desc_ak0_m_ak1,
|
||||
desc.b_grid_desc_bk0_n_bk1,
|
||||
desc.b1_grid_desc_bk0_n_bk1,
|
||||
desc.c_grid_descriptor_mblock_mperblock_nblock_nperblock,
|
||||
desc.block_2_ctile_map,
|
||||
desc.c0_matrix_mask);
|
||||
}
|
||||
else
|
||||
{
|
||||
Desc::GridwiseGemm::template Run<false>(
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_b1_grid,
|
||||
p_c_grid,
|
||||
p_shared_block,
|
||||
desc.a_element_op,
|
||||
desc.b_element_op,
|
||||
acc_element_op,
|
||||
desc.b1_element_op,
|
||||
desc.c_element_op,
|
||||
desc.a_grid_desc_ak0_m_ak1,
|
||||
desc.b_grid_desc_bk0_n_bk1,
|
||||
desc.b1_grid_desc_bk0_n_bk1,
|
||||
desc.c_grid_descriptor_mblock_mperblock_nblock_nperblock,
|
||||
desc.block_2_ctile_map,
|
||||
desc.c0_matrix_mask);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
} // namespace ck
|
||||
@@ -3,8 +3,12 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#ifndef __HIPCC_RTC__
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
#endif
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
@@ -14,8 +18,6 @@
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
@@ -35,22 +37,22 @@ template <typename GridwiseGemm,
|
||||
bool HasMainKBlockLoop>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_gemm_multiple_d_xdl_cshuffle(const ADataType* __restrict__ p_a_grid,
|
||||
const BDataType* __restrict__ p_b_grid,
|
||||
DsPointer p_ds_grid,
|
||||
EDataType* __restrict__ p_e_grid,
|
||||
const AElementwiseOperation a_element_op,
|
||||
const BElementwiseOperation b_element_op,
|
||||
const CDEElementwiseOperation cde_element_op,
|
||||
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
|
||||
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
|
||||
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const Block2ETileMap block_2_etile_map)
|
||||
kernel_gemm_multiple_d_xdl_cshuffle(const ADataType* __restrict__ p_a_grid,
|
||||
const BDataType* __restrict__ p_b_grid,
|
||||
DsPointer p_ds_grid,
|
||||
EDataType* __restrict__ p_e_grid,
|
||||
const AElementwiseOperation a_element_op,
|
||||
const BElementwiseOperation b_element_op,
|
||||
const CDEElementwiseOperation cde_element_op,
|
||||
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
|
||||
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
|
||||
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const Block2ETileMap block_2_etile_map)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx94__))
|
||||
@@ -225,9 +227,9 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
|
||||
return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw);
|
||||
}
|
||||
|
||||
static auto MakeDsGridDescriptor_M_N(const std::array<index_t, NumDTensor>& MRaws,
|
||||
const std::array<index_t, NumDTensor>& NRaws,
|
||||
const std::array<index_t, NumDTensor>& DsStride)
|
||||
static auto MakeDsGridDescriptor_M_N(const Array<index_t, NumDTensor>& MRaws,
|
||||
const Array<index_t, NumDTensor>& NRaws,
|
||||
const Array<index_t, NumDTensor>& DsStride)
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
@@ -309,6 +311,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
|
||||
using Block2ETileMap =
|
||||
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))>;
|
||||
|
||||
#ifndef __HIPCC_RTC__
|
||||
// Argument
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
@@ -498,6 +501,8 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
|
||||
}
|
||||
};
|
||||
|
||||
#endif
|
||||
|
||||
static constexpr bool IsSupported(index_t MRaw_, index_t NRaw_, index_t KRaw_)
|
||||
{
|
||||
// check vector load/store
|
||||
@@ -578,6 +583,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
|
||||
return true;
|
||||
}
|
||||
|
||||
#ifndef __HIPCC_RTC__
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(!ck::is_xdl_supported())
|
||||
@@ -676,11 +682,13 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
std::map<LoopScheduler, std::string> LoopSchedToString{
|
||||
{LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}};
|
||||
std::map<LoopScheduler, std::string> LoopSchedToString{{LoopScheduler::Default, "Default"},
|
||||
{ LoopScheduler::Interwave,
|
||||
"Interwave" }};
|
||||
|
||||
std::map<PipelineVersion, std::string> PipelineVersionToString{{PipelineVersion::v1, "v1"},
|
||||
{PipelineVersion::v2, "v2"}};
|
||||
{ PipelineVersion::v2,
|
||||
"v2" }};
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceGemmMultipleD_Xdl_CShuffle"
|
||||
@@ -709,6 +717,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
|
||||
|
||||
return str.str();
|
||||
}
|
||||
#endif
|
||||
|
||||
template <class ADesc, class BDesc, class DsDesc, class EDesc>
|
||||
struct Descriptor
|
||||
@@ -847,7 +856,9 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
|
||||
EDataType* __restrict__ p_e_grid)
|
||||
{
|
||||
__shared__ char p_shared_block[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
#ifndef __HIPCC_RTC__
|
||||
assert(desc.IsValid());
|
||||
#endif
|
||||
if(desc.has_main_k_block_loop)
|
||||
{
|
||||
GridwiseGemm::template Run<true>(p_a_grid,
|
||||
|
||||
@@ -13,6 +13,7 @@ enum struct MaskingSpecialization
|
||||
MaskOutUpperTriangle
|
||||
};
|
||||
|
||||
#ifndef __HIPCC_RTC__
|
||||
inline std::string getMaskingSpecializationString(const MaskingSpecialization& s)
|
||||
{
|
||||
switch(s)
|
||||
@@ -22,6 +23,7 @@ inline std::string getMaskingSpecializationString(const MaskingSpecialization& s
|
||||
default: return "Unrecognized specialization!";
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
struct MaskDisabledPredicate
|
||||
{
|
||||
@@ -53,7 +55,7 @@ struct MaskOutUpperTrianglePredicate
|
||||
template <typename MaskOutPredicate>
|
||||
struct C0MatrixMask_impl
|
||||
{
|
||||
__host__ __device__ C0MatrixMask_impl(index_t NRaw)
|
||||
__host__ __device__ constexpr C0MatrixMask_impl(index_t NRaw)
|
||||
: NRaw_(NRaw), predicate_(MaskOutPredicate{})
|
||||
{
|
||||
}
|
||||
|
||||
@@ -430,6 +430,7 @@ struct G_NDHW : public BaseTensorLayout
|
||||
|
||||
} // namespace convolution
|
||||
|
||||
#ifndef __HIPCC_RTC__
|
||||
template <
|
||||
typename Layout,
|
||||
typename std::enable_if<std::is_base_of<BaseTensorLayout, Layout>::value, bool>::type = false>
|
||||
@@ -438,6 +439,7 @@ std::ostream& operator<<(std::ostream& os, const Layout&)
|
||||
os << Layout::name;
|
||||
return os;
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace tensor_layout
|
||||
} // namespace ck
|
||||
|
||||
@@ -340,8 +340,8 @@ struct Bilinear
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void operator()<std::int8_t, std::int32_t, std::int8_t>(
|
||||
std::int8_t& y, const std::int32_t& x0, const std::int8_t& x1) const
|
||||
__host__ __device__ constexpr void operator()<int8_t, int32_t, int8_t>(
|
||||
int8_t& y, const int32_t& x0, const int8_t& x1) const
|
||||
{
|
||||
y = type_convert<int8_t>(alpha_ * type_convert<float>(x0) +
|
||||
beta_ * type_convert<float>(x1));
|
||||
|
||||
@@ -466,7 +466,7 @@ struct FastGelu
|
||||
|
||||
template <typename Y, typename X>
|
||||
__device__ void operator()(Y& y, const X& x) const;
|
||||
|
||||
#ifndef __HIPCC_RTC__
|
||||
template <>
|
||||
__host__ void operator()<float, float>(float& y, const float& x) const
|
||||
{
|
||||
@@ -477,7 +477,7 @@ struct FastGelu
|
||||
const float emu = exp(u);
|
||||
y = x / (1.f + emu);
|
||||
}
|
||||
|
||||
#endif
|
||||
// device code, use lower precision "__ocml_exp_f32" and "rcp"
|
||||
template <>
|
||||
__device__ void operator()<float, float>(float& y, const float& x) const
|
||||
|
||||
@@ -7,8 +7,10 @@
|
||||
#include "ck/utility/number.hpp"
|
||||
#include "ck/tensor_description/tensor_adaptor.hpp"
|
||||
#include "ck/tensor_description/multi_index_transform_helper.hpp"
|
||||
#ifndef __HIPCC_RTC__
|
||||
#include <limits>
|
||||
#include <stdlib.h>
|
||||
#endif
|
||||
|
||||
namespace ck {
|
||||
|
||||
@@ -979,7 +981,7 @@ struct BlockToCTileMap_3DGrid_KSplit
|
||||
const auto M0 = math::integer_divide_ceil(M, MPerBlock);
|
||||
const auto N0 = math::integer_divide_ceil(N, NPerBlock);
|
||||
|
||||
return std::make_tuple(N0, M0, k_split);
|
||||
return ck::make_tuple(N0, M0, k_split);
|
||||
}
|
||||
|
||||
template <typename TopIdx>
|
||||
@@ -1103,7 +1105,7 @@ struct BlockToCTileMap_GemmStreamK
|
||||
uint32_t dp_for_sk_iters = k_iters_per_tile.get();
|
||||
|
||||
uint32_t best_sk_score =
|
||||
std::numeric_limits<int>::max(); // we need to find the smallest sk iters
|
||||
ck::NumericLimits<int>::Max(); // we need to find the smallest sk iters
|
||||
for(uint32_t tentative_sk_blocks = min_sk_tiles; tentative_sk_blocks < max_sk_tiles;
|
||||
tentative_sk_blocks++)
|
||||
{
|
||||
|
||||
@@ -475,9 +475,9 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
|
||||
|
||||
template <typename DsLayout, GemmSpecialization GemmSpec>
|
||||
__host__ __device__ static auto
|
||||
MakeDsGridDescriptor_M_N(const std::array<index_t, NumDTensor>& MRaws,
|
||||
const std::array<index_t, NumDTensor>& NRaws,
|
||||
const std::array<index_t, NumDTensor>& DsStride)
|
||||
MakeDsGridDescriptor_M_N(const Array<index_t, NumDTensor>& MRaws,
|
||||
const Array<index_t, NumDTensor>& NRaws,
|
||||
const Array<index_t, NumDTensor>& DsStride)
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
@@ -941,7 +941,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
|
||||
const index_t K,
|
||||
const index_t StrideA,
|
||||
const index_t StrideB,
|
||||
const std::array<index_t, NumDTensor> StrideDs,
|
||||
const Array<index_t, NumDTensor> StrideDs,
|
||||
const index_t StrideE,
|
||||
const Block2ETileMap& block_2_etile_map)
|
||||
{
|
||||
|
||||
@@ -3,8 +3,10 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#ifndef __HIPCC_RTC__
|
||||
#include <iostream>
|
||||
#include <ostream>
|
||||
#endif
|
||||
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp"
|
||||
@@ -53,12 +55,15 @@ constexpr auto GridwiseGemmPipeline_Selector()
|
||||
}
|
||||
else
|
||||
{
|
||||
#ifndef __HIPCC_RTC__
|
||||
std::cerr << "GridwiseGemmPipeline configuration is not available" << std::endl;
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
|
||||
#ifndef __HIPCC_RTC__
|
||||
inline std::ostream& operator<<(std::ostream& os, const ck::PipelineVersion& p)
|
||||
{
|
||||
switch(p)
|
||||
@@ -71,3 +76,4 @@ inline std::ostream& operator<<(std::ostream& os, const ck::PipelineVersion& p)
|
||||
}
|
||||
return os;
|
||||
}
|
||||
#endif
|
||||
|
||||
Reference in New Issue
Block a user