upgrade from clang-format-12 to clang-format-18 (#2568)

* upgrade to clang-format-18

* update to clang-format-18 in pre-commit-config
This commit is contained in:
Illia Silin
2025-07-28 11:34:07 -07:00
committed by GitHub
parent 9786087010
commit 504b101da3
373 changed files with 3351 additions and 3760 deletions

View File

@@ -205,25 +205,25 @@ template <typename GridwiseGemm,
bool isMultiB>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_grouped_conv_fwd_multiple_abd_xdl_cshuffle(
AsPointer p_as_grid,
BsPointer p_bs_grid,
DsPointer p_ds_grid,
EDataType* __restrict__ p_e_grid,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op,
const index_t batch_count,
const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1,
const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1,
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock_,
const Block2ETileMap block_2_ctile_map,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
kernel_grouped_conv_fwd_multiple_abd_xdl_cshuffle(
AsPointer p_as_grid,
BsPointer p_bs_grid,
DsPointer p_ds_grid,
EDataType* __restrict__ p_e_grid,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op,
const index_t batch_count,
const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1,
const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1,
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock_,
const Block2ETileMap block_2_ctile_map,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
{
device_grouped_conv_fwd_multiple_abd_xdl_cshuffle<

View File

@@ -36,25 +36,25 @@ template <typename GridwiseGemm,
bool HasMainKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_contraction_multiple_d_xdl_cshuffle(
const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatDsPointer p_ds_grid,
FloatE* __restrict__ p_e_grid,
const index_t batch_count,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
const Block2ETileMap block_2_etile_map)
kernel_contraction_multiple_d_xdl_cshuffle(
const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatDsPointer p_ds_grid,
FloatE* __restrict__ p_e_grid,
const index_t batch_count,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
const Block2ETileMap block_2_etile_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];

View File

@@ -58,21 +58,21 @@ template <typename GridwiseGemm,
bool HasMainKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_batched_gemm_e_permute_xdl(const ABDataType* __restrict__ p_a_grid,
const ABDataType* __restrict__ p_b_grid,
EDataType* __restrict__ p_e_grid,
const index_t batch_count,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
const Block2ETileMap block_2_etile_map)
kernel_batched_gemm_e_permute_xdl(const ABDataType* __restrict__ p_a_grid,
const ABDataType* __restrict__ p_b_grid,
EDataType* __restrict__ p_e_grid,
const index_t batch_count,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
const Block2ETileMap block_2_etile_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
const index_t num_blocks_per_batch =

View File

@@ -39,26 +39,25 @@ template <typename GridwiseGemm,
bool HasMainKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_gemm_gemm_xdl_cshuffle_v1(
const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
const FloatAB* __restrict__ p_b1_grid,
FloatC* __restrict__ p_c_grid,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const AccElementwiseOperation acc_element_op,
const B1ElementwiseOperation b1_element_op,
const CElementwiseOperation c_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2CTileMap block_2_ctile_map,
const index_t batch_count,
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch)
kernel_gemm_gemm_xdl_cshuffle_v1(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
const FloatAB* __restrict__ p_b1_grid,
FloatC* __restrict__ p_c_grid,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const AccElementwiseOperation acc_element_op,
const B1ElementwiseOperation b1_element_op,
const CElementwiseOperation c_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2CTileMap block_2_ctile_map,
const index_t batch_count,
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];

View File

@@ -63,24 +63,24 @@ template <typename GridwiseGemm,
bool HasMainKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_batched_gemm_xdl(const ABDataType* __restrict__ p_a_grid,
const ABDataType* __restrict__ p_b_grid,
DsPointer p_ds_grid,
EDataType* __restrict__ p_e_grid,
const index_t batch_count,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1,
const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1,
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock_,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
const Block2ETileMap block_2_etile_map)
kernel_batched_gemm_xdl(const ABDataType* __restrict__ p_a_grid,
const ABDataType* __restrict__ p_b_grid,
DsPointer p_ds_grid,
EDataType* __restrict__ p_e_grid,
const index_t batch_count,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1,
const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1,
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock_,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
const Block2ETileMap block_2_etile_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))

View File

@@ -52,23 +52,23 @@ template <typename GridwiseGemm,
bool HasDoubleTailKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_gemm_dl_multiple_d(
const ABDataType* __restrict__ p_a_grid,
const ABDataType* __restrict__ p_b_grid,
DsPointer p_ds_grid,
EDataType* __restrict__ p_e_grid,
const index_t batch_count,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op,
const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1,
const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1,
const DsGridDesc_M0_M10_M11_N0_N10_N11 ds_grid_desc_m0_m10_m11_n0_n10_n11,
const CGridDesc_M0_M10_M11_N0_N10_N11 e_grid_desc_m0_m10_m11_n0_n10_n11,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
const Block2CTileMap block_2_ctile_map)
kernel_gemm_dl_multiple_d(
const ABDataType* __restrict__ p_a_grid,
const ABDataType* __restrict__ p_b_grid,
DsPointer p_ds_grid,
EDataType* __restrict__ p_e_grid,
const index_t batch_count,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op,
const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1,
const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1,
const DsGridDesc_M0_M10_M11_N0_N10_N11 ds_grid_desc_m0_m10_m11_n0_n10_n11,
const CGridDesc_M0_M10_M11_N0_N10_N11 e_grid_desc_m0_m10_m11_n0_n10_n11,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
const Block2CTileMap block_2_ctile_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__) || \

View File

@@ -42,32 +42,32 @@ template <typename GridwiseGemm,
bool HasMainKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_batched_gemm_gemm_xdl_cshuffle_v1(
const A0B0B1DataType* __restrict__ p_a0_grid,
const A0B0B1DataType* __restrict__ p_b0_grid,
D0sPointer p_d0s_grid,
const A0B0B1DataType* __restrict__ p_b1_grid,
D1sPointer p_d1s_grid,
E1DataType* __restrict__ p_e1_grid,
const A0ElementwiseOperation a0_element_op,
const B0ElementwiseOperation b0_element_op,
const CDE0ElementwiseOperation cde0_element_op,
const B1ElementwiseOperation b1_element_op,
const CDE1ElementwiseOperation cde1_element_op,
const A0GridDesc_AK0_M_AK1 a0_grid_desc_ak0_m_ak1,
const B0GridDesc_BK0_N_BK1 b0_grid_desc_bk0_n_bk1,
const D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1,
const D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
d1s_grid_desc_mblock_mperblock_nblock_nperblock,
const E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e1_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2E1TileMap block_2_e1tile_map,
const index_t batch_count,
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch)
kernel_batched_gemm_gemm_xdl_cshuffle_v1(
const A0B0B1DataType* __restrict__ p_a0_grid,
const A0B0B1DataType* __restrict__ p_b0_grid,
D0sPointer p_d0s_grid,
const A0B0B1DataType* __restrict__ p_b1_grid,
D1sPointer p_d1s_grid,
E1DataType* __restrict__ p_e1_grid,
const A0ElementwiseOperation a0_element_op,
const B0ElementwiseOperation b0_element_op,
const CDE0ElementwiseOperation cde0_element_op,
const B1ElementwiseOperation b1_element_op,
const CDE1ElementwiseOperation cde1_element_op,
const A0GridDesc_AK0_M_AK1 a0_grid_desc_ak0_m_ak1,
const B0GridDesc_BK0_N_BK1 b0_grid_desc_bk0_n_bk1,
const D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1,
const D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
d1s_grid_desc_mblock_mperblock_nblock_nperblock,
const E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e1_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2E1TileMap block_2_e1tile_map,
const index_t batch_count,
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
@@ -829,10 +829,8 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
is_same_v<tensor_layout::gemm::ColumnMajor, B0Layout> &&
CheckDLayout<tensor_layout::gemm::RowMajor, D0sLayout, NumD0Tensor>() &&
(is_same_v<tensor_layout::gemm::RowMajor, B1Layout> ||
is_same_v<tensor_layout::gemm::ColumnMajor,
B1Layout>)&&CheckDLayout<tensor_layout::gemm::RowMajor,
D1sLayout,
NumD1Tensor>() &&
is_same_v<tensor_layout::gemm::ColumnMajor, B1Layout>) &&
CheckDLayout<tensor_layout::gemm::RowMajor, D1sLayout, NumD1Tensor>() &&
is_same_v<tensor_layout::gemm::RowMajor, E1Layout>))
{
return false;

View File

@@ -33,9 +33,9 @@ template <typename GridwiseGemm,
TailNumber TailNum = TailNumber::Full>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
#endif
kernel_batched_gemm_xdl_cshuffle_v3_multi_d(BatchedGemmArg karg)
kernel_batched_gemm_xdl_cshuffle_v3_multi_d(BatchedGemmArg karg)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
@@ -79,9 +79,9 @@ template <typename GridwiseGemm,
TailNumber TailNum = TailNumber::Full>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
#endif
kernel_batched_gemm_xdl_cshuffle_v3_multi_d_2lds(BatchedGemmArg karg)
kernel_batched_gemm_xdl_cshuffle_v3_multi_d_2lds(BatchedGemmArg karg)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
// Pass two lds pointer is the key to tell compiler that ds_read/write

View File

@@ -39,26 +39,26 @@ template <typename GridwiseGemm,
bool HasMainK0BlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_batched_gemm_reduce_xdl_cshuffle_v1(
const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
ReducePtrsGlobal p_reduces_grid,
const index_t batch_count,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op,
const ReduceInElementwiseOperations reduce_in_element_ops,
const ReduceAccElementwiseOperations reduce_out_element_ops,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock,
const ReduceGridDescriptor_MBlock_MPerBlock reduce_grid_desc_mblock_mperblock,
const ComputeBasePrtOfBatch compute_base_ptr_of_batch_,
const Block2CTileMap block_2_ctile_map)
kernel_batched_gemm_reduce_xdl_cshuffle_v1(
const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
ReducePtrsGlobal p_reduces_grid,
const index_t batch_count,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op,
const ReduceInElementwiseOperations reduce_in_element_ops,
const ReduceAccElementwiseOperations reduce_out_element_ops,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock,
const ReduceGridDescriptor_MBlock_MPerBlock reduce_grid_desc_mblock_mperblock,
const ComputeBasePrtOfBatch compute_base_ptr_of_batch_,
const Block2CTileMap block_2_ctile_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
const index_t num_blocks_per_batch =

View File

@@ -40,21 +40,21 @@ template <typename DeviceOp,
bool HasMainKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_batched_gemm_softmax_gemm_wmma_cshuffle(const ADataType* __restrict__ p_a_grid,
const B0DataType* __restrict__ p_b0_grid,
const B1DataType* __restrict__ p_b1_grid,
CDataType* __restrict__ p_c_grid,
index_t M,
index_t N,
index_t K,
index_t O,
index_t G0,
index_t G1,
float alpha,
bool input_permute,
bool output_permute)
kernel_batched_gemm_softmax_gemm_wmma_cshuffle(const ADataType* __restrict__ p_a_grid,
const B0DataType* __restrict__ p_b0_grid,
const B1DataType* __restrict__ p_b1_grid,
CDataType* __restrict__ p_c_grid,
index_t M,
index_t N,
index_t K,
index_t O,
index_t G0,
index_t G1,
float alpha,
bool input_permute,
bool output_permute)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__))
@@ -178,15 +178,15 @@ template <typename DeviceOp,
bool HasMainKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_wmma_self_attention_forward(const QKVDataType* __restrict__ p_qkv_grid,
ODataType* __restrict__ p_out_grid,
index_t batch_size,
index_t sequence_length,
index_t head_count,
index_t head_size,
float alpha)
kernel_wmma_self_attention_forward(const QKVDataType* __restrict__ p_qkv_grid,
ODataType* __restrict__ p_out_grid,
index_t batch_size,
index_t sequence_length,
index_t head_count,
index_t head_size,
float alpha)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__))
@@ -310,17 +310,17 @@ template <typename DeviceOp,
bool HasMainKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_wmma_cross_attention_forward(const QDataType* __restrict__ p_q_grid,
const KVDataType* __restrict__ p_kv_grid,
ODataType* __restrict__ p_out_grid,
index_t batch_size,
index_t q_sequence_length,
index_t kv_sequence_length,
index_t head_count,
index_t head_size,
float alpha)
kernel_wmma_cross_attention_forward(const QDataType* __restrict__ p_q_grid,
const KVDataType* __restrict__ p_kv_grid,
ODataType* __restrict__ p_out_grid,
index_t batch_size,
index_t q_sequence_length,
index_t kv_sequence_length,
index_t head_count,
index_t head_size,
float alpha)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__))

View File

@@ -43,30 +43,30 @@ template <typename GridwiseGemm,
bool HasMainKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v1(
const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
const FloatAB* __restrict__ p_b1_grid,
FloatC* __restrict__ p_c_grid,
D0sPointer p_d0s_grid,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const C0DEElementwiseOperation c0de_element_op,
const B1ElementwiseOperation b1_element_op,
const C1DEElementwiseOperation c1de_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1,
const C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c1_grid_desc_mblock_mperblock_nblock_nperblock,
const D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
const Block2CTileMap block_2_ctile_map,
const index_t batch_count,
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch,
const C0MatrixMask c0_matrix_mask)
kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v1(
const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
const FloatAB* __restrict__ p_b1_grid,
FloatC* __restrict__ p_c_grid,
D0sPointer p_d0s_grid,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const C0DEElementwiseOperation c0de_element_op,
const B1ElementwiseOperation b1_element_op,
const C1DEElementwiseOperation c1de_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1,
const C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c1_grid_desc_mblock_mperblock_nblock_nperblock,
const D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
const Block2CTileMap block_2_ctile_map,
const index_t batch_count,
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch,
const C0MatrixMask c0_matrix_mask)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];

View File

@@ -42,27 +42,27 @@ template <typename GridwiseGemm,
bool HasMainKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v1(
const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
const FloatAB* __restrict__ p_b1_grid,
FloatC* __restrict__ p_c_grid,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const AccElementwiseOperation acc_element_op,
const B1ElementwiseOperation b1_element_op,
const CElementwiseOperation c_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2CTileMap block_2_ctile_map,
const index_t batch_count,
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch,
const C0MatrixMask c0_matrix_mask)
kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v1(
const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
const FloatAB* __restrict__ p_b1_grid,
FloatC* __restrict__ p_c_grid,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const AccElementwiseOperation acc_element_op,
const B1ElementwiseOperation b1_element_op,
const CElementwiseOperation c_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2CTileMap block_2_ctile_map,
const index_t batch_count,
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch,
const C0MatrixMask c0_matrix_mask)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];

View File

@@ -29,14 +29,13 @@ template <typename GridwiseGemm,
TailNumber TailNum = TailNumber::Full>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
#endif
kernel_batched_gemm_wmma_cshuffle_v3(
typename GridwiseGemm::Argument
karg, // This works for now but it actually receives a
// DeviceBatchedGemm_Wmma_CShuffleV3::Argument
// argument through implicit conversion to base class!
const ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch)
kernel_batched_gemm_wmma_cshuffle_v3(
typename GridwiseGemm::Argument karg, // This works for now but it actually receives a
// DeviceBatchedGemm_Wmma_CShuffleV3::Argument
// argument through implicit conversion to base class!
const ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__))
#if defined(__gfx11__)

View File

@@ -48,9 +48,9 @@ namespace device {
template <typename DeviceOp, typename GridwiseGemm, bool HasMainKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_batched_gemm_xdlops_v2r3(const typename DeviceOp::Argument karg)
kernel_batched_gemm_xdlops_v2r3(const typename DeviceOp::Argument karg)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
const index_t num_blocks_per_batch =

View File

@@ -33,9 +33,9 @@ template <typename GridwiseGemm,
TailNumber TailNum = TailNumber::Full>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
#endif
kernel_batched_gemm_b_scale_xdl_cshuffle_v3(BatchedGemmArg karg)
kernel_batched_gemm_b_scale_xdl_cshuffle_v3(BatchedGemmArg karg)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
@@ -71,9 +71,9 @@ template <typename GridwiseGemm,
TailNumber TailNum = TailNumber::Full>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
#endif
kernel_batched_gemm_b_scale_xdl_cshuffle_v3_2lds(BatchedGemmArg karg)
kernel_batched_gemm_b_scale_xdl_cshuffle_v3_2lds(BatchedGemmArg karg)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
// Pass two lds pointer is the key to tell compiler that ds_read/write

View File

@@ -610,8 +610,8 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
if(!parg)
{
std::ostringstream err;
err << "Provided argument pointer is not of an Argument class!"
<< " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
err << "Provided argument pointer is not of an Argument class!" << " In " << __FILE__
<< ":" << __LINE__ << ", in function: " << __func__;
throw std::runtime_error(err.str());
}

View File

@@ -467,12 +467,12 @@ struct DeviceColumnToImageImpl
float elapsed_time = 0.f;
const auto kernel = kernel_tensor_rearrange<InputGridDesc,
InputDataType,
OutputGridDesc,
OutputDataType,
Block2ETileMap,
ComputePtrOffsetOfStridedBatch<>,
GridwiseTensorRearrangeKernel>;
InputDataType,
OutputGridDesc,
OutputDataType,
Block2ETileMap,
ComputePtrOffsetOfStridedBatch<>,
GridwiseTensorRearrangeKernel>;
// Execute each set of independent filters
for(std::size_t i = 0; i < arg.in_grid_desc_m_k_container_.size(); i++)

View File

@@ -37,23 +37,23 @@ template <typename GridwiseGemm,
bool HasMainKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_contraction_multiple_abd_xdl_cshuffle(
AsPointer p_as_grid,
BsPointer p_bs_grid,
DsPointer p_ds_grid,
EDataType* __restrict__ p_e_grid,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op,
const AsGridDesc_AK0_M_AK1 as_grid_desc_ak0_m_ak1,
const BsGridDesc_BK0_N_BK1 bs_grid_desc_bk0_n_bk1,
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2ETileMap block_2_etile_map)
kernel_contraction_multiple_abd_xdl_cshuffle(
AsPointer p_as_grid,
BsPointer p_bs_grid,
DsPointer p_ds_grid,
EDataType* __restrict__ p_e_grid,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op,
const AsGridDesc_AK0_M_AK1 as_grid_desc_ak0_m_ak1,
const BsGridDesc_BK0_N_BK1 bs_grid_desc_bk0_n_bk1,
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2ETileMap block_2_etile_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];

View File

@@ -35,23 +35,23 @@ template <typename GridwiseGemm,
bool HasMainKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_contraction_multiple_d_xdl_cshuffle(
const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatDsPointer p_ds_grid,
FloatE* __restrict__ p_e_grid,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2ETileMap block_2_etile_map)
kernel_contraction_multiple_d_xdl_cshuffle(
const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatDsPointer p_ds_grid,
FloatE* __restrict__ p_e_grid,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2ETileMap block_2_etile_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];

View File

@@ -35,17 +35,15 @@ auto CalculateMaxRead(const std::vector<index_t>& lengths, const std::vector<ind
if(lengths.size() != NumDim1 + NumDim2)
{
std::ostringstream err;
err << "Incorrect number of lengths in "
<< "device_contraction_utils.hpp"
<< ":" << __LINE__ << ", in function: " << __func__;
err << "Incorrect number of lengths in " << "device_contraction_utils.hpp" << ":"
<< __LINE__ << ", in function: " << __func__;
throw std::runtime_error(err.str());
}
if(strides.size() != NumDim1 + NumDim2)
{
std::ostringstream err;
err << "Incorrect number of strides in "
<< "device_contraction_utils.hpp"
<< ":" << __LINE__ << ", in function: " << __func__;
err << "Incorrect number of strides in " << "device_contraction_utils.hpp" << ":"
<< __LINE__ << ", in function: " << __func__;
throw std::runtime_error(err.str());
}

View File

@@ -648,9 +648,8 @@ struct
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << DeviceOp{}.GetTypeString() << std::endl;
std::cout << "N " << arg.Conv_N_ << ", "
<< "K " << arg.Conv_K_ << ", "
<< "C " << arg.Conv_C_ << ", " << std::endl;
std::cout << "N " << arg.Conv_N_ << ", " << "K " << arg.Conv_K_ << ", " << "C "
<< arg.Conv_C_ << ", " << std::endl;
std::cout << "Y X " << arg.filter_spatial_lengths_[0] << ", "
<< arg.filter_spatial_lengths_[1] << ", " << std::endl;
std::cout << "Hi Wi " << arg.input_spatial_lengths_[0] << ", "

View File

@@ -618,9 +618,8 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << DeviceOp{}.GetTypeString() << std::endl;
std::cout << "N " << arg.Conv_N_ << ", "
<< "K " << arg.Conv_K_ << ", "
<< "C " << arg.Conv_C_ << ", " << std::endl;
std::cout << "N " << arg.Conv_N_ << ", " << "K " << arg.Conv_K_ << ", " << "C "
<< arg.Conv_C_ << ", " << std::endl;
std::cout << "Y X " << arg.filter_spatial_lengths_[0] << ", "
<< arg.filter_spatial_lengths_[1] << ", " << std::endl;
std::cout << "Hi Wi " << arg.input_spatial_lengths_[0] << ", "

View File

@@ -583,9 +583,8 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << DeviceOp{}.GetTypeString() << std::endl;
std::cout << "N " << arg.Conv_N_ << ", "
<< "K " << arg.Conv_K_ << ", "
<< "C " << arg.Conv_C_ << ", " << std::endl;
std::cout << "N " << arg.Conv_N_ << ", " << "K " << arg.Conv_K_ << ", " << "C "
<< arg.Conv_C_ << ", " << std::endl;
std::cout << "Y X " << arg.filter_spatial_lengths_[0] << ", "
<< arg.filter_spatial_lengths_[1] << ", " << std::endl;
std::cout << "Hi Wi " << arg.input_spatial_lengths_[0] << ", "

View File

@@ -38,23 +38,23 @@ template <typename GridwiseGemm,
bool HasMainKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_gemm_xdlops_v2r3_for_conv3d(
const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const index_t num_batches,
const index_t a_batch_stride,
const index_t b_batch_stride,
const index_t c_batch_stride,
const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1,
const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1,
const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op,
const Block2CTileMap block_2_ctile_map)
kernel_gemm_xdlops_v2r3_for_conv3d(
const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const index_t num_batches,
const index_t a_batch_stride,
const index_t b_batch_stride,
const index_t c_batch_stride,
const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1,
const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1,
const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op,
const Block2CTileMap block_2_ctile_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
const index_t num_blocks_per_batch =

View File

@@ -34,21 +34,21 @@ template <typename GridwiseGemm,
bool HasDoubleTailKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_gemm_dl_multiple_d(
const ABDataType* __restrict__ p_a_grid,
const ABDataType* __restrict__ p_b_grid,
DsPointer p_ds_grid,
EDataType* __restrict__ p_e_grid,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op,
const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1,
const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1,
const DsGridDesc_M0_M10_M11_N0_N10_N11 ds_grid_desc_m0_m10_m11_n0_n10_n11,
const CGridDesc_M0_M10_M11_N0_N10_N11 e_grid_desc_m0_m10_m11_n0_n10_n11,
const Block2CTileMap block_2_ctile_map)
kernel_gemm_dl_multiple_d(
const ABDataType* __restrict__ p_a_grid,
const ABDataType* __restrict__ p_b_grid,
DsPointer p_ds_grid,
EDataType* __restrict__ p_e_grid,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op,
const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1,
const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1,
const DsGridDesc_M0_M10_M11_N0_N10_N11 ds_grid_desc_m0_m10_m11_n0_n10_n11,
const CGridDesc_M0_M10_M11_N0_N10_N11 e_grid_desc_m0_m10_m11_n0_n10_n11,
const Block2CTileMap block_2_ctile_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx9__) || \
defined(__gfx103__) || defined(__gfx11__) || defined(__gfx12__))

View File

@@ -37,31 +37,30 @@ template <typename GridwiseGemmWelford,
bool HasMainKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_gemm_multiple_d_welford_first_half_xdl_cshuffle(
const ABDataType* __restrict__ p_a_grid,
const ABDataType* __restrict__ p_b_grid,
DsPointer p_ds_grid,
EMeanVarDataType* __restrict__ p_e_grid,
EMeanVarDataType* __restrict__ p_welford_mean_grid,
EMeanVarDataType* __restrict__ p_welford_var_grid,
int32_t* __restrict__ p_welford_count_grid,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock,
const MeanVarGridDescriptor_MBlock_MPerBlock_NBlock
mean_var_grid_desc_mblock_mperblock_nblock,
const CountGridDescriptor_MBlock_MPerBlock_NBlock
count_grid_desc_mblock_mperblock_nblock,
const Block2ETileMap block_2_etile_map,
index_t NRaw)
kernel_gemm_multiple_d_welford_first_half_xdl_cshuffle(
const ABDataType* __restrict__ p_a_grid,
const ABDataType* __restrict__ p_b_grid,
DsPointer p_ds_grid,
EMeanVarDataType* __restrict__ p_e_grid,
EMeanVarDataType* __restrict__ p_welford_mean_grid,
EMeanVarDataType* __restrict__ p_welford_var_grid,
int32_t* __restrict__ p_welford_count_grid,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock,
const MeanVarGridDescriptor_MBlock_MPerBlock_NBlock
mean_var_grid_desc_mblock_mperblock_nblock,
const CountGridDescriptor_MBlock_MPerBlock_NBlock count_grid_desc_mblock_mperblock_nblock,
const Block2ETileMap block_2_etile_map,
index_t NRaw)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
__shared__ char p_shared[GridwiseGemmWelford::GetSharedMemoryNumberOfByte()];
@@ -121,26 +120,26 @@ template <typename GridwiseWelfordLayernorm,
typename HElementwiseOperation>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_welford_layernorm2d_second_half(
const EMeanVarDataType* __restrict__ p_e_grid,
const EMeanVarDataType* __restrict__ p_in_welford_mean_grid,
const EMeanVarDataType* __restrict__ p_in_welford_var_grid,
const int32_t* __restrict__ p_in_welford_count_grid,
const GammaDataType* __restrict__ p_gamma_grid,
const BetaDataType* __restrict__ p_beta_grid,
HDataType* __restrict__ p_h_grid,
const EHGridDesc_M_N e_grid_desc_m_n,
const EHGridDesc_M_N h_grid_desc_m_n,
const LayernormMeanVarGridDesc_M_NBlock mean_var_grid_desc_m_nblock,
const LayernormCountGridDesc_M_NBlock count_grid_desc_m_nblock,
const GammaBetaGridDesc_N gamma_grid_desc_n,
const GammaBetaGridDesc_N beta_grid_desc_n,
index_t numMeanVarCountBlockTileIteration_N,
index_t NBlockClusterLength,
ComputeDataType epsilon,
HElementwiseOperation h_element_op)
kernel_welford_layernorm2d_second_half(
const EMeanVarDataType* __restrict__ p_e_grid,
const EMeanVarDataType* __restrict__ p_in_welford_mean_grid,
const EMeanVarDataType* __restrict__ p_in_welford_var_grid,
const int32_t* __restrict__ p_in_welford_count_grid,
const GammaDataType* __restrict__ p_gamma_grid,
const BetaDataType* __restrict__ p_beta_grid,
HDataType* __restrict__ p_h_grid,
const EHGridDesc_M_N e_grid_desc_m_n,
const EHGridDesc_M_N h_grid_desc_m_n,
const LayernormMeanVarGridDesc_M_NBlock mean_var_grid_desc_m_nblock,
const LayernormCountGridDesc_M_NBlock count_grid_desc_m_nblock,
const GammaBetaGridDesc_N gamma_grid_desc_n,
const GammaBetaGridDesc_N beta_grid_desc_n,
index_t numMeanVarCountBlockTileIteration_N,
index_t NBlockClusterLength,
ComputeDataType epsilon,
HElementwiseOperation h_element_op)
{
GridwiseWelfordLayernorm::Run(p_e_grid,
p_in_welford_mean_grid,

View File

@@ -38,27 +38,27 @@ template <typename GridwiseGemm,
bool HasMainKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_gemm_multiple_d_multiple_r_xdl_cshuffle(
const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatDsPointer p_ds_grid,
FloatE* __restrict__ p_e_grid,
FloatRsPointer p_rs_grid,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op,
const QsElementwiseOperation qs_element_op,
const RsElementwiseOperation rs_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock,
const RsGridDescriptor_MBlock_MPerBlock rs_grid_desc_mblock_mperblock,
const Block2ETileMap block_2_etile_map)
kernel_gemm_multiple_d_multiple_r_xdl_cshuffle(
const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatDsPointer p_ds_grid,
FloatE* __restrict__ p_e_grid,
FloatRsPointer p_rs_grid,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op,
const QsElementwiseOperation qs_element_op,
const RsElementwiseOperation rs_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock,
const RsGridDescriptor_MBlock_MPerBlock rs_grid_desc_mblock_mperblock,
const Block2ETileMap block_2_etile_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];

View File

@@ -37,22 +37,22 @@ template <typename GridwiseGemm,
bool HasMainKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_gemm_multiple_d_xdl_cshuffle(const ADataType* __restrict__ p_a_grid,
const BDataType* __restrict__ p_b_grid,
DsPointer p_ds_grid,
EDataType* __restrict__ p_e_grid,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2ETileMap block_2_etile_map)
kernel_gemm_multiple_d_xdl_cshuffle(const ADataType* __restrict__ p_a_grid,
const BDataType* __restrict__ p_b_grid,
DsPointer p_ds_grid,
EDataType* __restrict__ p_e_grid,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2ETileMap block_2_etile_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];

View File

@@ -32,20 +32,19 @@ template <typename GridwiseGemm,
bool HasMainKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_WAVELET_MAX_THREAD_PER_BLOCK, CK_WAVELET_MIN_BLOCK_PER_CU)
__launch_bounds__(CK_WAVELET_MAX_THREAD_PER_BLOCK, CK_WAVELET_MIN_BLOCK_PER_CU)
#endif
kernel_gemm_xdl_waveletmodel_cshuffle(
const ABDataType* __restrict__ p_a_grid,
const ABDataType* __restrict__ p_b_grid,
EDataType* __restrict__ p_e_grid,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const EElementwiseOperation e_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2ETileMap block_2_etile_map)
kernel_gemm_xdl_waveletmodel_cshuffle(const ABDataType* __restrict__ p_a_grid,
const ABDataType* __restrict__ p_b_grid,
EDataType* __restrict__ p_e_grid,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const EElementwiseOperation e_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2ETileMap block_2_etile_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];

View File

@@ -28,14 +28,14 @@ template <typename GridwiseGemm,
bool HasMainKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_grouped_contraction_multiple_d_xdl_cshuffle(
const void CK_CONSTANT_ADDRESS_SPACE* contraction_args,
const index_t group_count,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op)
kernel_grouped_contraction_multiple_d_xdl_cshuffle(
const void CK_CONSTANT_ADDRESS_SPACE* contraction_args,
const index_t group_count,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];

View File

@@ -80,21 +80,21 @@ template <typename GridwiseGemm,
bool CTranspose>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_grouped_conv_bwd_data_multiple_d_xdl_cshuffle(
const ABDataType* __restrict__ p_a_grid,
const ABDataType* __restrict__ p_b_grid,
DsPointer p_ds_grid,
EDataType* __restrict__ p_e_grid,
const std::array<GemmArgs, MaxGroupedGemmGroupsNum> gemm_kernel_args,
const index_t gemms_count,
const AElementwiseOp a_element_op,
const BElementwiseOp b_element_op,
const CDEElementwiseOp cde_element_op,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
const ComputePtrOffsetOfN compute_ptr_offset_of_n,
const index_t KBatch)
kernel_grouped_conv_bwd_data_multiple_d_xdl_cshuffle(
const ABDataType* __restrict__ p_a_grid,
const ABDataType* __restrict__ p_b_grid,
DsPointer p_ds_grid,
EDataType* __restrict__ p_e_grid,
const std::array<GemmArgs, MaxGroupedGemmGroupsNum> gemm_kernel_args,
const index_t gemms_count,
const AElementwiseOp a_element_op,
const BElementwiseOp b_element_op,
const CDEElementwiseOp cde_element_op,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
const ComputePtrOffsetOfN compute_ptr_offset_of_n,
const index_t KBatch)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
// offset base pointer for each work-group

View File

@@ -35,18 +35,18 @@ template <typename GridwiseGemm,
bool HasDoubleTailKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_batched_gemm_dlops_bwd_weight(
const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const index_t batch_count,
const AGridDesc_B_K0_M0_M1_K1 a_grid_desc_kbatch_k0_m0_m1_k1,
const BGridDesc_B_K0_N0_N1_K1 b_grid_desc_kbatch_k0_n0_n1_k1,
const CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11,
const Block2CTileMap block_2_ctile_map,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
kernel_batched_gemm_dlops_bwd_weight(
const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const index_t batch_count,
const AGridDesc_B_K0_M0_M1_K1 a_grid_desc_kbatch_k0_m0_m1_k1,
const BGridDesc_B_K0_N0_N1_K1 b_grid_desc_kbatch_k0_n0_n1_k1,
const CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11,
const Block2CTileMap block_2_ctile_map,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx103__) || \
defined(__gfx90a__) || defined(__gfx908__) || defined(__gfx94__) || defined(__gfx11__) || \

View File

@@ -77,21 +77,21 @@ struct DeviceGroupedConvBwdWeight_Explicit_Xdl
using CElementwiseGridDesc = remove_cvref_t<decltype(GetElementwiseCGridDesc(I1))>;
using Block2TileMapElementwise = BlockToCTileMap_M00_N0_M01Adapt<1, ElemsPerBlock>;
using GridwiseElementwiseCast = GridwiseElementwise<Tuple<CElementwiseGridDesc>,
Tuple<CElementwiseGridDesc>,
Tuple<const float*>,
Tuple<WeiDataType*>,
Block2TileMapElementwise,
WeiElementwiseOperation,
ElementwiseBlockSize,
I1,
ElemsPerBlock,
I1,
ElemsPerBlock / ElementwiseBlockSize,
Sequence<0, 1>,
Sequence<1>,
Sequence<1>,
I1,
I1>;
Tuple<CElementwiseGridDesc>,
Tuple<const float*>,
Tuple<WeiDataType*>,
Block2TileMapElementwise,
WeiElementwiseOperation,
ElementwiseBlockSize,
I1,
ElemsPerBlock,
I1,
ElemsPerBlock / ElementwiseBlockSize,
Sequence<0, 1>,
Sequence<1>,
Sequence<1>,
I1,
I1>;
struct Argument : public BaseArgument
{

View File

@@ -43,22 +43,21 @@ template <typename GridwiseGemm,
bool HasMainKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_batched_gemm_xdlops_bwd_weight(
const FloatA* __restrict__ p_a_grid,
const FloatB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op,
const index_t batch_count,
const AGridDesc_B_K0_M_K1 a_b_k0_m_k1_grid_desc,
const BGridDesc_B_K0_N_K1 b_b_k0_n_k1_grid_desc,
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2CTileMap block_2_ctile_map,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
kernel_batched_gemm_xdlops_bwd_weight(const FloatA* __restrict__ p_a_grid,
const FloatB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op,
const index_t batch_count,
const AGridDesc_B_K0_M_K1 a_b_k0_m_k1_grid_desc,
const BGridDesc_B_K0_N_K1 b_b_k0_n_k1_grid_desc,
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2CTileMap block_2_ctile_map,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
const index_t num_blocks_per_batch =

View File

@@ -44,16 +44,16 @@ template <typename GridwiseGemm,
TailNumber TailNum = TailNumber::Full>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
#endif
kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3(
typename GridwiseGemm::Argument karg,
[[maybe_unused]] const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1,
[[maybe_unused]] const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1,
[[maybe_unused]] const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock,
[[maybe_unused]] const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
[[maybe_unused]] const index_t num_k_per_block)
kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3(
typename GridwiseGemm::Argument karg,
[[maybe_unused]] const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1,
[[maybe_unused]] const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1,
[[maybe_unused]] const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock,
[[maybe_unused]] const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
[[maybe_unused]] const index_t num_k_per_block)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z * NumGroupsToMerge);
@@ -99,16 +99,16 @@ template <typename GridwiseGemm,
TailNumber TailNum = TailNumber::Full>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
#endif
kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3_2lds(
typename GridwiseGemm::Argument karg,
[[maybe_unused]] const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1,
[[maybe_unused]] const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1,
[[maybe_unused]] const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock,
[[maybe_unused]] const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
[[maybe_unused]] const index_t num_k_per_block)
kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3_2lds(
typename GridwiseGemm::Argument karg,
[[maybe_unused]] const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1,
[[maybe_unused]] const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1,
[[maybe_unused]] const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock,
[[maybe_unused]] const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
[[maybe_unused]] const index_t num_k_per_block)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
// offset base pointer for each work-group

View File

@@ -41,22 +41,21 @@ template <typename GridwiseGemm,
bool HasMainKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_batched_gemm_xdlops_bwd_weight(
const FloatA* __restrict__ p_a_grid,
const FloatB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op,
const index_t batch_count,
const AGridDesc_B_K0_M_K1 a_b_k0_m_k1_grid_desc,
const BGridDesc_B_K0_N_K1 b_b_k0_n_k1_grid_desc,
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2CTileMap block_2_ctile_map,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
kernel_batched_gemm_xdlops_bwd_weight(const FloatA* __restrict__ p_a_grid,
const FloatB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op,
const index_t batch_count,
const AGridDesc_B_K0_M_K1 a_b_k0_m_k1_grid_desc,
const BGridDesc_B_K0_N_K1 b_b_k0_n_k1_grid_desc,
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2CTileMap block_2_ctile_map,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx94__))

View File

@@ -42,16 +42,16 @@ template <typename GridwiseGemm,
TailNumber TailNum = TailNumber::Full>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
#endif
kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3(
typename GridwiseGemm::Argument karg,
const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1,
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
const index_t num_k_per_block)
kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3(
typename GridwiseGemm::Argument karg,
const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1,
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
const index_t num_k_per_block)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z);
@@ -100,16 +100,16 @@ template <typename GridwiseGemm,
TailNumber TailNum = TailNumber::Full>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
#endif
kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3_2lds(
typename GridwiseGemm::Argument karg,
const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1,
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
const index_t num_k_per_block)
kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3_2lds(
typename GridwiseGemm::Argument karg,
const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1,
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
const index_t num_k_per_block)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
// offset base pointer for each work-group

View File

@@ -72,23 +72,23 @@ template <typename GridwiseGemm,
bool HasDoubleTailKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_grouped_conv_fwd_dl_multiple_d(
const ABDataType* __restrict__ p_a_grid,
const ABDataType* __restrict__ p_b_grid,
DsPointer p_ds_grid,
EDataType* __restrict__ p_e_grid,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op,
const index_t batch_count,
const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1,
const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1,
const DsGridDesc_M0_M10_M11_N0_N10_N11 ds_grid_desc_m0_m10_m11_n0_n10_n11,
const CGridDesc_M0_M10_M11_N0_N10_N11 e_grid_desc_m0_m10_m11_n0_n10_n11,
const Block2CTileMap block_2_ctile_map,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
kernel_grouped_conv_fwd_dl_multiple_d(
const ABDataType* __restrict__ p_a_grid,
const ABDataType* __restrict__ p_b_grid,
DsPointer p_ds_grid,
EDataType* __restrict__ p_e_grid,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op,
const index_t batch_count,
const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1,
const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1,
const DsGridDesc_M0_M10_M11_N0_N10_N11 ds_grid_desc_m0_m10_m11_n0_n10_n11,
const CGridDesc_M0_M10_M11_N0_N10_N11 e_grid_desc_m0_m10_m11_n0_n10_n11,
const Block2CTileMap block_2_ctile_map,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx103__) || \
defined(__gfx90a__) || defined(__gfx908__) || defined(__gfx94__) || defined(__gfx11__) || \

View File

@@ -93,18 +93,18 @@ template <typename GridwiseGemm,
bool HasDoubleTailKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_grouped_conv_fwd_dl(
const ABDataType* __restrict__ p_a_grid,
const ABDataType* __restrict__ p_b_grid,
CDataType* __restrict__ p_c_grid,
const index_t batch_count,
const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1,
const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1,
const CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11,
const Block2CTileMap block_2_ctile_map,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
kernel_grouped_conv_fwd_dl(
const ABDataType* __restrict__ p_a_grid,
const ABDataType* __restrict__ p_b_grid,
CDataType* __restrict__ p_c_grid,
const index_t batch_count,
const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1,
const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1,
const CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11,
const Block2CTileMap block_2_ctile_map,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx103__) || \
defined(__gfx11__) || defined(__gfx12__))

View File

@@ -81,25 +81,25 @@ template <typename GridwiseGemm,
bool CTranspose>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_grouped_conv_fwd_multiple_abd_xdl_cshuffle(
AsPointer p_as_grid,
BsPointer p_bs_grid,
DsPointer p_ds_grid,
EDataType* __restrict__ p_e_grid,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1,
const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1,
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock_,
const Block2ETileMap block_2_ctile_map,
const ComputePtrOffsetOfG compute_ptr_offset_of_groups,
const ComputePtrOffsetOfN compute_ptr_offset_of_n)
kernel_grouped_conv_fwd_multiple_abd_xdl_cshuffle(
AsPointer p_as_grid,
BsPointer p_bs_grid,
DsPointer p_ds_grid,
EDataType* __restrict__ p_e_grid,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1,
const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1,
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock_,
const Block2ETileMap block_2_ctile_map,
const ComputePtrOffsetOfG compute_ptr_offset_of_groups,
const ComputePtrOffsetOfN compute_ptr_offset_of_n)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
@@ -383,11 +383,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
{
namespace ctc = tensor_layout::convolution;
using Layout = std::conditional_t<
is_NGCHW_NGKHW<ALayout, BLayout, ELayout>() && NeedTransposeKernel,
ctc::NHWGC,
std::conditional_t<is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>() && NeedTransposeKernel,
ctc::NDHWGC,
ALay>>;
is_NGCHW_NGKHW<ALayout, BLayout, ELayout>() && NeedTransposeKernel,
ctc::NHWGC,
std::conditional_t<is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>() && NeedTransposeKernel,
ctc::NDHWGC,
ALay>>;
const auto in_gemmmraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeADescriptor_M_K<Layout>();
@@ -403,11 +403,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
{
namespace ctc = tensor_layout::convolution;
using Layout = std::conditional_t<
is_NGCHW_NGKHW<ALayout, BLayout, ELayout>() && NeedTransposeKernel,
ctc::GKYXC,
std::conditional_t<is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>() && NeedTransposeKernel,
ctc::GKZYXC,
BLay>>;
is_NGCHW_NGKHW<ALayout, BLayout, ELayout>() && NeedTransposeKernel,
ctc::GKYXC,
std::conditional_t<is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>() && NeedTransposeKernel,
ctc::GKZYXC,
BLay>>;
const auto wei_gemmnraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeBDescriptor_N_K<Layout>();
@@ -423,11 +423,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
{
namespace ctc = tensor_layout::convolution;
using Layout = std::conditional_t<
is_NGCHW_NGKHW<ALayout, BLayout, ELayout>() && NeedTransposeKernel,
ctc::NHWGK,
std::conditional_t<is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>() && NeedTransposeKernel,
ctc::NDHWGK,
ELay>>;
is_NGCHW_NGKHW<ALayout, BLayout, ELayout>() && NeedTransposeKernel,
ctc::NHWGK,
std::conditional_t<is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>() && NeedTransposeKernel,
ctc::NDHWGK,
ELay>>;
const auto out_gemmmraw_gemmnraw_desc =
conv_to_gemm_transformer.template MakeCDescriptor_M_N<Layout>();

View File

@@ -72,15 +72,15 @@ template <typename GridwiseGemm,
TailNumber TailNum = TailNumber::Full>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
#endif
kernel_grouped_conv_fwd_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg,
const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1,
const DsGridDesc_M_N ds_grid_desc_m_n,
const EGridDesc_M_N c_grid_desc_m_n,
const ComputePtrOffset compute_ptr_offset_of_groups,
const ComputePtrOffset compute_ptr_offset_of_n)
kernel_grouped_conv_fwd_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg,
const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1,
const DsGridDesc_M_N ds_grid_desc_m_n,
const EGridDesc_M_N c_grid_desc_m_n,
const ComputePtrOffset compute_ptr_offset_of_groups,
const ComputePtrOffset compute_ptr_offset_of_n)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
// offset base pointer for each work-group
@@ -151,16 +151,16 @@ template <typename GridwiseGemm,
TailNumber TailNum = TailNumber::Full>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
#endif
kernel_grouped_conv_fwd_xdl_cshuffle_v3_2lds(
typename GridwiseGemm::Argument karg,
const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1,
const DsGridDesc_M_N ds_grid_desc_m_n,
const EGridDesc_M_N c_grid_desc_m_n,
const ComputePtrOffset compute_ptr_offset_of_groups,
const ComputePtrOffset compute_ptr_offset_of_n)
kernel_grouped_conv_fwd_xdl_cshuffle_v3_2lds(
typename GridwiseGemm::Argument karg,
const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1,
const DsGridDesc_M_N ds_grid_desc_m_n,
const EGridDesc_M_N c_grid_desc_m_n,
const ComputePtrOffset compute_ptr_offset_of_groups,
const ComputePtrOffset compute_ptr_offset_of_n)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
// offset base pointer for each work-group
@@ -369,11 +369,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
{
namespace ctc = tensor_layout::convolution;
using Layout = std::conditional_t<
is_NGCHW_GKCYX_NGKHW<ALayout, BLayout, ELayout>(),
ctc::NHWGC,
std::conditional_t<is_NGCDHW_GKCZYX_NGKDHW<ALayout, BLayout, ELayout>(),
ctc::NDHWGC,
ALay>>;
is_NGCHW_GKCYX_NGKHW<ALayout, BLayout, ELayout>(),
ctc::NHWGC,
std::conditional_t<is_NGCDHW_GKCZYX_NGKDHW<ALayout, BLayout, ELayout>(),
ctc::NDHWGC,
ALay>>;
const auto in_gemmmraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeADescriptor_M_K<Layout>();
@@ -399,11 +399,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
{
namespace ctc = tensor_layout::convolution;
using Layout = std::conditional_t<
is_NGCHW_GKCYX_NGKHW<ALayout, BLayout, ELayout>(),
ctc::GKYXC,
std::conditional_t<is_NGCDHW_GKCZYX_NGKDHW<ALayout, BLayout, ELayout>(),
ctc::GKZYXC,
BLay>>;
is_NGCHW_GKCYX_NGKHW<ALayout, BLayout, ELayout>(),
ctc::GKYXC,
std::conditional_t<is_NGCDHW_GKCZYX_NGKDHW<ALayout, BLayout, ELayout>(),
ctc::GKZYXC,
BLay>>;
const auto wei_gemmnraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeBDescriptor_N_K<Layout>();
@@ -429,11 +429,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
{
namespace ctc = tensor_layout::convolution;
using Layout = std::conditional_t<
is_NGCHW_GKCYX_NGKHW<ALayout, BLayout, ELayout>(),
ctc::NHWGK,
std::conditional_t<is_NGCDHW_GKCZYX_NGKDHW<ALayout, BLayout, ELayout>(),
ctc::NDHWGK,
ELay>>;
is_NGCHW_GKCYX_NGKHW<ALayout, BLayout, ELayout>(),
ctc::NHWGK,
std::conditional_t<is_NGCDHW_GKCZYX_NGKDHW<ALayout, BLayout, ELayout>(),
ctc::NDHWGK,
ELay>>;
const auto out_gemmmraw_gemmnraw_desc =
conv_to_gemm_transformer.template MakeCDescriptor_M_N<Layout>();
@@ -1347,9 +1347,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
return false;
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "The MultiABD is not supported!"
<< " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
<< std::endl;
std::cout << "The MultiABD is not supported!" << " In " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
}
}
@@ -1374,8 +1373,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "Current device does not support xdl instructions!"
<< " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
std::cout << "Current device does not support xdl instructions!" << " In "
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
<< std::endl;
}
return false;
@@ -1455,9 +1454,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "Unsupported A Layout!"
<< " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
<< std::endl;
std::cout << "Unsupported A Layout!" << " In " << __FILE__ << ":" << __LINE__
<< ", in function: " << __func__ << std::endl;
}
return false;
}
@@ -1488,9 +1486,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "Unsupported A Layout!"
<< " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
<< std::endl;
std::cout << "Unsupported A Layout!" << " In " << __FILE__ << ":" << __LINE__
<< ", in function: " << __func__ << std::endl;
}
return false;
}
@@ -1602,9 +1599,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "Unsupported E Layout!"
<< " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
<< std::endl;
std::cout << "Unsupported E Layout!" << " In " << __FILE__ << ":" << __LINE__
<< ", in function: " << __func__ << std::endl;
}
return false;
}

View File

@@ -131,29 +131,29 @@ template <typename GridwiseGemm,
bool HasMainKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_batch_gemm_multiple_d_xdl_cshuffle(
const ABDataType* __restrict__ p_a_grid,
const ABDataType* __restrict__ p_b_grid,
DsPointer p_ds_grid,
EDataType* __restrict__ p_e_grid,
RsPointer p_rs_grid,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op,
const QsElementwiseOperation qs_element_op,
const RsElementwiseOperation rs_element_op,
const index_t batch_count,
const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1,
const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1,
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock_,
const RsGridDescriptor_MBlock_MPerBlock rs_grid_desc_mblock_mperblock,
const Block2ETileMap block_2_ctile_map,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
kernel_batch_gemm_multiple_d_xdl_cshuffle(
const ABDataType* __restrict__ p_a_grid,
const ABDataType* __restrict__ p_b_grid,
DsPointer p_ds_grid,
EDataType* __restrict__ p_e_grid,
RsPointer p_rs_grid,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op,
const QsElementwiseOperation qs_element_op,
const RsElementwiseOperation rs_element_op,
const index_t batch_count,
const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1,
const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1,
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock_,
const RsGridDescriptor_MBlock_MPerBlock rs_grid_desc_mblock_mperblock,
const Block2ETileMap block_2_ctile_map,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
const index_t num_blocks_per_batch =

View File

@@ -41,16 +41,16 @@ template <typename GridwiseGemm,
bool HasMainKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_grouped_conv_fwd_multiple_d_grouped_gemm_xdl_cshuffle(
Array<GemmArgs, MaxGemmsNum> gemm_desc_kernel_args,
const index_t gemms_count,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CDEElementwiseOperation c_element_op,
const ComputePtrOffset compute_ptr_offset_of_groups,
const ComputePtrOffset compute_ptr_offset_of_n)
kernel_grouped_conv_fwd_multiple_d_grouped_gemm_xdl_cshuffle(
Array<GemmArgs, MaxGemmsNum> gemm_desc_kernel_args,
const index_t gemms_count,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CDEElementwiseOperation c_element_op,
const ComputePtrOffset compute_ptr_offset_of_groups,
const ComputePtrOffset compute_ptr_offset_of_n)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];

View File

@@ -36,14 +36,14 @@ template <typename GridwiseGemm,
bool HasMainKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_grouped_gemm_xdl_fixed_nk(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
const index_t group_count,
const index_t grid_size_grp,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op)
kernel_grouped_gemm_xdl_fixed_nk(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
const index_t group_count,
const index_t grid_size_grp,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];

View File

@@ -32,13 +32,13 @@ template <typename GridwiseGemm,
bool HasDoubleTailKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_grouped_gemm_multiple_d_dl(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
const index_t group_count,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op)
kernel_grouped_gemm_multiple_d_dl(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
const index_t group_count,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
defined(__gfx90a__) || defined(__gfx103__) || defined(__gfx11__) || defined(__gfx94__) || \

View File

@@ -576,16 +576,16 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
if(dev_gemm_args == nullptr)
{
std::ostringstream err;
err << "The gemm arguments device buffer is not allocated!"
<< " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
err << "The gemm arguments device buffer is not allocated!" << " In " << __FILE__
<< ":" << __LINE__ << ", in function: " << __func__;
throw std::runtime_error(err.str());
}
if(dev_gemm_workspace == nullptr)
{
std::ostringstream err;
err << "The gemm workspace buffer is not allocated!"
<< " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
err << "The gemm workspace buffer is not allocated!" << " In " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__;
throw std::runtime_error(err.str());
}
@@ -624,16 +624,16 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
if(arg.p_dev_gemm_kargs_ == nullptr)
{
std::ostringstream err;
err << "The gemm arguments device buffer is not allocated!"
<< " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
err << "The gemm arguments device buffer is not allocated!" << " In " << __FILE__
<< ":" << __LINE__ << ", in function: " << __func__;
throw std::runtime_error(err.str());
}
if(arg.p_workspace_ == nullptr)
{
std::ostringstream err;
err << "The gemm workspace buffer is not allocated!"
<< " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
err << "The gemm workspace buffer is not allocated!" << " In " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__;
throw std::runtime_error(err.str());
}
@@ -711,8 +711,8 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
if(not_all_have_kbatch_value_same)
{
std::ostringstream err;
err << "Not all gemms have same kbatch value (=1 or >1)! "
<< "group [" << i << "], kbatch: " << gemm_arg.k_batch
err << "Not all gemms have same kbatch value (=1 or >1)! " << "group [" << i
<< "], kbatch: " << gemm_arg.k_batch
<< ", group [0], kbatch: " << gemm_arg.k_batch << " in " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__;
throw std::runtime_error(err.str());

View File

@@ -60,13 +60,13 @@ template <typename GridwiseGemm,
BlockGemmPipelineVersion BlkGemmPipelineVer>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_grouped_gemm_multiple_d_xdl(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
const index_t group_count,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op)
kernel_grouped_gemm_multiple_d_xdl(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
const index_t group_count,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
@@ -600,8 +600,8 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
if(dev_gemm_args == nullptr)
{
std::ostringstream err;
err << "The gemm arguments device buffer is not allocated!"
<< " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
err << "The gemm arguments device buffer is not allocated!" << " In " << __FILE__
<< ":" << __LINE__ << ", in function: " << __func__;
throw std::runtime_error(err.str());
}
@@ -629,8 +629,8 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
if(arg.p_dev_gemm_args_ == nullptr)
{
std::ostringstream err;
err << "The gemm arguments device buffer is not allocated!"
<< " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
err << "The gemm arguments device buffer is not allocated!" << " In " << __FILE__
<< ":" << __LINE__ << ", in function: " << __func__;
throw std::runtime_error(err.str());
}

View File

@@ -32,16 +32,16 @@ template <typename GridwiseGemm,
bool HasMainKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_grouped_gemm_softmax_gemm_xdl_cshuffle_v1(
const void CK_CONSTANT_ADDRESS_SPACE* group_kernel_args,
const index_t group_count,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const AccElementwiseOperation acc_element_op,
const B1ElementwiseOperation b1_element_op,
const CElementwiseOperation c_element_op)
kernel_grouped_gemm_softmax_gemm_xdl_cshuffle_v1(
const void CK_CONSTANT_ADDRESS_SPACE* group_kernel_args,
const index_t group_count,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const AccElementwiseOperation acc_element_op,
const B1ElementwiseOperation b1_element_op,
const CElementwiseOperation c_element_op)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];

View File

@@ -31,13 +31,13 @@ template <typename GridwiseGemm,
bool HasMainKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_grouped_gemm_xdl(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
const index_t group_count,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CDEElementwiseOperation c_element_op)
kernel_grouped_gemm_xdl(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
const index_t group_count,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CDEElementwiseOperation c_element_op)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];

View File

@@ -38,17 +38,17 @@ template <typename GridwiseGemm,
bool HasMainKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_grouped_gemm_xdl_fixed_nk(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
uint32_t* barrier_count,
const index_t barrier_size_grp,
const index_t group_count,
const index_t grid_size_grp,
const index_t KBatch,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CDEElementwiseOperation c_element_op)
kernel_grouped_gemm_xdl_fixed_nk(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
uint32_t* barrier_count,
const index_t barrier_size_grp,
const index_t group_count,
const index_t grid_size_grp,
const index_t KBatch,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CDEElementwiseOperation c_element_op)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];

View File

@@ -33,13 +33,13 @@ template <typename GridwiseGemm,
typename CElementwiseOperation = ck::tensor_operation::element_wise::PassThrough>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_grouped_gemm_xdl_splitk(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
const index_t group_count,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op)
kernel_grouped_gemm_xdl_splitk(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
const index_t group_count,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte();
@@ -416,8 +416,8 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
if(not_all_have_kbatch_value_same)
{
std::ostringstream err;
err << "Not all gemms have same kbatch value (=1 or >1)! "
<< "group [" << i << "], kbatch: " << kbatch
err << "Not all gemms have same kbatch value (=1 or >1)! " << "group [" << i
<< "], kbatch: " << kbatch
<< ", group [0], kbatch: " << arg.gemm_kernel_args_[0].karg_.k_batch
<< " in " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
throw std::runtime_error(err.str());

View File

@@ -45,21 +45,21 @@ template <typename DeviceOp,
bool HasMainKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_grouped_query_attention_wmma(const ADataType* __restrict__ p_a_grid,
const B0DataType* __restrict__ p_b0_grid,
const B1DataType* __restrict__ p_b1_grid,
CDataType* __restrict__ p_c_grid,
index_t M, // SequenceQ
index_t N, // SequenceK
index_t K, // HeadDim
index_t O, // SequenceK
index_t G0, // Batch
index_t G1, // HeadNum
float alpha,
bool input_permute,
bool output_permute)
kernel_grouped_query_attention_wmma(const ADataType* __restrict__ p_a_grid,
const B0DataType* __restrict__ p_b0_grid,
const B1DataType* __restrict__ p_b1_grid,
CDataType* __restrict__ p_c_grid,
index_t M, // SequenceQ
index_t N, // SequenceK
index_t K, // HeadDim
index_t O, // SequenceK
index_t G0, // Batch
index_t G1, // HeadNum
float alpha,
bool input_permute,
bool output_permute)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__))

View File

@@ -100,64 +100,64 @@ struct DeviceMoeGemmBlockScale
{
static constexpr index_t NumDTensor = DsDataType::Size();
using GridwiseGemm = GridwiseMoeGemmBlockScale<
ALayout,
BLayout,
DsLayout,
CLayout,
ADataType,
BDataType,
GemmAccDataType,
CShuffleDataType,
DsDataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
GemmSpec,
BlockSize,
ScaleBlockM,
ScaleBlockN,
ScaleBlockK,
MPerBlock,
NPerBlock,
KPerBlock,
AK1,
BK1,
MPerXDL,
NPerXDL,
MXdlPerWave,
NXdlPerWave,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
false,
ABlockLdsExtraM,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
false,
BBlockLdsExtraN,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CDEShuffleBlockTransferScalarPerVectors,
BlkGemmPipeSched,
BlkGemmPipelineVer,
ActivationOP,
NSwizzle,
IsInputGemm,
MulRoutedWeight,
IndexType,
ComputeTypeA,
ComputeTypeB,
LDSTypeA,
LDSTypeB>;
ALayout,
BLayout,
DsLayout,
CLayout,
ADataType,
BDataType,
GemmAccDataType,
CShuffleDataType,
DsDataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
GemmSpec,
BlockSize,
ScaleBlockM,
ScaleBlockN,
ScaleBlockK,
MPerBlock,
NPerBlock,
KPerBlock,
AK1,
BK1,
MPerXDL,
NPerXDL,
MXdlPerWave,
NXdlPerWave,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
false,
ABlockLdsExtraM,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
false,
BBlockLdsExtraN,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CDEShuffleBlockTransferScalarPerVectors,
BlkGemmPipeSched,
BlkGemmPipelineVer,
ActivationOP,
NSwizzle,
IsInputGemm,
MulRoutedWeight,
IndexType,
ComputeTypeA,
ComputeTypeB,
LDSTypeA,
LDSTypeB>;
using Argument = typename GridwiseGemm::Argument;

View File

@@ -92,62 +92,62 @@ struct DeviceMoeGemmMXBPreShuffle : public DeviceMoEGemmMXBPreShuffle<ALayout,
{
static constexpr index_t NumDTensor = DsDataType::Size();
using GridwiseGemm = GridwiseMoeGemmMX_BPreshuffle<
ALayout,
BLayout,
DsLayout,
CLayout,
ADataType,
AScaleDataType,
BDataType,
BScaleDataType,
GemmAccDataType,
CShuffleDataType,
DsDataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
GemmSpec,
ScaleBlockSize,
BlockSize,
MPerBlock,
NPerBlock,
KPerBlock,
AK1,
BK1,
MPerXDL,
NPerXDL,
MXdlPerWave,
NXdlPerWave,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
false,
ABlockLdsExtraM,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
false,
BBlockLdsExtraN,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CDEShuffleBlockTransferScalarPerVectors,
BlkGemmPipeSched,
BlkGemmPipelineVer,
ActivationOP,
NSwizzle,
IsInputGemm,
MulRoutedWeight,
IndexType,
ComputeTypeA,
ComputeTypeB>;
ALayout,
BLayout,
DsLayout,
CLayout,
ADataType,
AScaleDataType,
BDataType,
BScaleDataType,
GemmAccDataType,
CShuffleDataType,
DsDataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
GemmSpec,
ScaleBlockSize,
BlockSize,
MPerBlock,
NPerBlock,
KPerBlock,
AK1,
BK1,
MPerXDL,
NPerXDL,
MXdlPerWave,
NXdlPerWave,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
false,
ABlockLdsExtraM,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
false,
BBlockLdsExtraN,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CDEShuffleBlockTransferScalarPerVectors,
BlkGemmPipeSched,
BlkGemmPipelineVer,
ActivationOP,
NSwizzle,
IsInputGemm,
MulRoutedWeight,
IndexType,
ComputeTypeA,
ComputeTypeB>;
using Argument = typename GridwiseGemm::Argument;

View File

@@ -44,21 +44,21 @@ template <typename DeviceOp,
bool HasMainKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_multi_query_attention_wmma(const ADataType* __restrict__ p_a_grid,
const B0DataType* __restrict__ p_b0_grid,
const B1DataType* __restrict__ p_b1_grid,
CDataType* __restrict__ p_c_grid,
index_t M, // SequenceQ
index_t N, // SequenceK
index_t K, // HeadDim
index_t O, // SequenceK
index_t G0, // Batch
index_t G1, // HeadNum
float alpha,
bool input_permute,
bool output_permute)
kernel_multi_query_attention_wmma(const ADataType* __restrict__ p_a_grid,
const B0DataType* __restrict__ p_b0_grid,
const B1DataType* __restrict__ p_b1_grid,
CDataType* __restrict__ p_c_grid,
index_t M, // SequenceQ
index_t N, // SequenceK
index_t K, // HeadDim
index_t O, // SequenceK
index_t G0, // Batch
index_t G1, // HeadNum
float alpha,
bool input_permute,
bool output_permute)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__))

View File

@@ -36,25 +36,25 @@ template <typename GridwiseGemm,
bool HasMainKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_contraction_multiple_d_xdl_cshuffle(
const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatDsPointer p_ds_grid,
FloatE* __restrict__ p_e_grid,
const index_t batch_count,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op,
const AGridDesc_AKB_AK0_M_AK1 a_grid_desc_akb_ak0_m_ak1,
const BGridDesc_BKB_BK0_N_BK1 b_grid_desc_bkb_bk0_n_bk1,
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
const Block2ETileMap block_2_etile_map)
kernel_contraction_multiple_d_xdl_cshuffle(
const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatDsPointer p_ds_grid,
FloatE* __restrict__ p_e_grid,
const index_t batch_count,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op,
const AGridDesc_AKB_AK0_M_AK1 a_grid_desc_akb_ak0_m_ak1,
const BGridDesc_BKB_BK0_N_BK1 b_grid_desc_bkb_bk0_n_bk1,
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
const Block2ETileMap block_2_etile_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];