mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-07 00:04:37 +00:00
refactor
This commit is contained in:
@@ -726,11 +726,10 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
|
||||
const index_t grid_size =
|
||||
GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_) * arg.BatchCount_;
|
||||
|
||||
const auto K0 = arg.a_grid_desc_ak0_m_ak1_.GetLength(I0);
|
||||
const auto K =
|
||||
arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
|
||||
|
||||
const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
|
||||
|
||||
if(has_main_k0_block_loop)
|
||||
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
|
||||
{
|
||||
const auto kernel = kernel_batched_gemm_reduce_xdl_cshuffle_v1<
|
||||
GridwiseGemm,
|
||||
|
||||
@@ -430,13 +430,12 @@ struct DeviceBatchedGemmXdl
|
||||
const index_t grid_size =
|
||||
GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_) * arg.BatchCount_;
|
||||
|
||||
const auto K0 = arg.a_grid_desc_k0_m_k1_.GetLength(I0);
|
||||
|
||||
const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
|
||||
const auto K =
|
||||
arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2);
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
if(has_main_k0_block_loop)
|
||||
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
|
||||
{
|
||||
const auto kernel = kernel_batched_gemm_xdlops_v2r3<
|
||||
GridwiseGemm,
|
||||
|
||||
@@ -582,11 +582,10 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
|
||||
const index_t grid_size =
|
||||
GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_container_[i]);
|
||||
|
||||
const auto K0 = arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0);
|
||||
const auto K = arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0) *
|
||||
arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I2);
|
||||
|
||||
const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
|
||||
|
||||
if(has_main_k0_block_loop)
|
||||
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdlops_v2r3<
|
||||
GridwiseGemm,
|
||||
|
||||
@@ -698,13 +698,12 @@ struct
|
||||
|
||||
const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_);
|
||||
|
||||
const auto K0 = arg.a_grid_desc_k0_m_k1_.GetLength(I0);
|
||||
|
||||
const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
|
||||
const auto K =
|
||||
arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2);
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
if(has_main_k0_block_loop)
|
||||
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdlops_v3r3<
|
||||
GridwiseGemm,
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
#ifndef DEVICE_CONV2D_FWD_XDL_C_SHUFFLE_BIAS_ACTIVATION_NHWC_KYXC_NHWK_HPP
|
||||
#define DEVICE_CONV2D_FWD_XDL_C_SHUFFLE_BIAS_ACTIVATION_NHWC_KYXC_NHWK_HPP
|
||||
|
||||
#pragma once
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include "device.hpp"
|
||||
@@ -660,13 +658,12 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
|
||||
|
||||
const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_);
|
||||
|
||||
const auto K0 = arg.a_grid_desc_k0_m_k1_.GetLength(I0);
|
||||
|
||||
const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
|
||||
const auto K =
|
||||
arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2);
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
if(has_main_k0_block_loop)
|
||||
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdlops_v3r2<
|
||||
GridwiseGemm,
|
||||
@@ -919,4 +916,3 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
#endif
|
||||
|
||||
@@ -640,13 +640,12 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
|
||||
|
||||
const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_);
|
||||
|
||||
const auto K0 = arg.a_grid_desc_k0_m_k1_.GetLength(I0);
|
||||
|
||||
const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
|
||||
const auto K =
|
||||
arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2);
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
if(has_main_k0_block_loop)
|
||||
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdlops_v3r1<
|
||||
GridwiseGemm,
|
||||
|
||||
@@ -478,13 +478,12 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
|
||||
|
||||
const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_);
|
||||
|
||||
const auto K0 = arg.a_grid_desc_k0_m_k1_.GetLength(I0);
|
||||
|
||||
const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
|
||||
const auto K =
|
||||
arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2);
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
if(has_main_k0_block_loop)
|
||||
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdlops_v2r3<
|
||||
GridwiseGemm,
|
||||
|
||||
@@ -1296,11 +1296,10 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
|
||||
const index_t grid_size =
|
||||
GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_container_[i]);
|
||||
|
||||
const auto K0 = arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0);
|
||||
const auto K = arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0) *
|
||||
arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I2);
|
||||
|
||||
const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
|
||||
|
||||
if(has_main_k0_block_loop)
|
||||
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdlops_v2r3<
|
||||
GridwiseGemm,
|
||||
|
||||
@@ -775,13 +775,12 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
|
||||
|
||||
const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_);
|
||||
|
||||
const auto K0 = arg.a_grid_desc_k0_m_k1_.GetLength(I0);
|
||||
|
||||
const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
|
||||
const auto K =
|
||||
arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2);
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
if(has_main_k0_block_loop)
|
||||
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdlops_v2r3<
|
||||
GridwiseGemm,
|
||||
|
||||
@@ -530,11 +530,10 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
|
||||
|
||||
const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_);
|
||||
|
||||
const auto K0 = arg.a_grid_desc_ak0_m_ak1_.GetLength(I0);
|
||||
const auto K =
|
||||
arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
|
||||
|
||||
const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
|
||||
|
||||
if(has_main_k0_block_loop)
|
||||
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
|
||||
{
|
||||
const auto kernel = kernel_gemm_reduce_xdl_cshuffle_v1<
|
||||
GridwiseGemm,
|
||||
|
||||
@@ -292,6 +292,7 @@ struct DeviceGemmXdl
|
||||
|
||||
float Run(const Argument& arg, int nrepeat = 1)
|
||||
{
|
||||
#if 0
|
||||
{
|
||||
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
|
||||
<< ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
|
||||
@@ -304,6 +305,7 @@ struct DeviceGemmXdl
|
||||
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
|
||||
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
|
||||
}
|
||||
#endif
|
||||
|
||||
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
@@ -320,11 +322,9 @@ struct DeviceGemmXdl
|
||||
const auto K =
|
||||
arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2);
|
||||
|
||||
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K);
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
if(has_main_k_block_loop)
|
||||
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdlops_v2r3<
|
||||
GridwiseGemm,
|
||||
|
||||
@@ -1,478 +0,0 @@
|
||||
#ifndef DEVICE_GEMM_XDL_C_SHUFFLE_HPP
|
||||
#define DEVICE_GEMM_XDL_C_SHUFFLE_HPP
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include "device.hpp"
|
||||
#include "device_gemm.hpp"
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_layout.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "gridwise_gemm_xdlops_v3r1.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
typename AccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
ck::index_t BlockSize,
|
||||
ck::index_t MPerBlock,
|
||||
ck::index_t NPerBlock,
|
||||
ck::index_t KPerBlock,
|
||||
ck::index_t AK1,
|
||||
ck::index_t BK1,
|
||||
ck::index_t MPerXDL,
|
||||
ck::index_t NPerXDL,
|
||||
ck::index_t MXdlPerWave,
|
||||
ck::index_t NXdlPerWave,
|
||||
typename ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
ck::index_t ABlockTransferSrcVectorDim,
|
||||
ck::index_t ABlockTransferSrcScalarPerVector,
|
||||
ck::index_t ABlockTransferDstScalarPerVector_K1,
|
||||
bool ABlockLdsAddExtraM,
|
||||
typename BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
ck::index_t BBlockTransferSrcVectorDim,
|
||||
ck::index_t BBlockTransferSrcScalarPerVector,
|
||||
ck::index_t BBlockTransferDstScalarPerVector_K1,
|
||||
bool BBlockLdsAddExtraN,
|
||||
index_t CShuffleMXdlPerWavePerShuffle,
|
||||
index_t CShuffleNXdlPerWavePerShuffle,
|
||||
typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl,
|
||||
index_t CBlockTransferScalarPerVector_NWaveNPerXdl,
|
||||
index_t NumPrefetch = 1>
|
||||
struct DeviceGemmXdl_C_Shuffle
|
||||
: public DeviceGemm<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
|
||||
static auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t K, index_t StrideA)
|
||||
{
|
||||
assert(K % AK1 == 0);
|
||||
|
||||
const index_t K0 = K / AK1;
|
||||
|
||||
const auto a_grid_desc_m_k = [&]() {
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
|
||||
}
|
||||
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ALayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
|
||||
}
|
||||
}();
|
||||
|
||||
const auto a_grid_desc_k0_m_k1 = transform_tensor_descriptor(
|
||||
a_grid_desc_m_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(K0, AK1)), make_pass_through_transform(M)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
return a_grid_desc_k0_m_k1;
|
||||
}
|
||||
|
||||
static auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t StrideB)
|
||||
{
|
||||
assert(K % BK1 == 0);
|
||||
|
||||
const index_t K0 = K / BK1;
|
||||
|
||||
const auto b_grid_desc_k_n = [&]() {
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1));
|
||||
}
|
||||
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB));
|
||||
}
|
||||
}();
|
||||
|
||||
const auto b_grid_desc_k0_n_k1 = transform_tensor_descriptor(
|
||||
b_grid_desc_k_n,
|
||||
make_tuple(make_unmerge_transform(make_tuple(K0, BK1)), make_pass_through_transform(N)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
return b_grid_desc_k0_n_k1;
|
||||
}
|
||||
|
||||
static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC)
|
||||
{
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
|
||||
}
|
||||
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, CLayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
|
||||
}
|
||||
}
|
||||
|
||||
using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_K0_M_K1(1, 1, 1));
|
||||
using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1));
|
||||
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
|
||||
|
||||
// GridwiseGemm
|
||||
using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1<
|
||||
BlockSize,
|
||||
ADataType, // TODO: distinguish A/B datatype
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
CDataType,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
AGridDesc_K0_M_K1,
|
||||
BGridDesc_K0_N_K1,
|
||||
CGridDesc_M_N,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
AK1,
|
||||
BK1,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MXdlPerWave,
|
||||
NXdlPerWave,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
false,
|
||||
ABlockLdsAddExtraM,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
false,
|
||||
BBlockLdsAddExtraN,
|
||||
CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl,
|
||||
CBlockTransferScalarPerVector_NWaveNPerXdl,
|
||||
NumPrefetch>;
|
||||
|
||||
// Argument
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
Argument(const ADataType* p_a_grid,
|
||||
const BDataType* p_b_grid,
|
||||
CDataType* p_c_grid,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
index_t StrideC,
|
||||
index_t M01,
|
||||
index_t N01,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op)
|
||||
: p_a_grid_{p_a_grid},
|
||||
p_b_grid_{p_b_grid},
|
||||
p_c_grid_{p_c_grid},
|
||||
a_grid_desc_k0_m_k1_{},
|
||||
b_grid_desc_k0_n_k1_{},
|
||||
c_grid_desc_m_n_{},
|
||||
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_{},
|
||||
block_2_ctile_map_{},
|
||||
M01_{M01},
|
||||
N01_{N01},
|
||||
a_element_op_{a_element_op},
|
||||
b_element_op_{b_element_op},
|
||||
c_element_op_{c_element_op}
|
||||
{
|
||||
a_grid_desc_k0_m_k1_ =
|
||||
DeviceGemmXdl_C_Shuffle::MakeAGridDescriptor_K0_M_K1(M, K, StrideA);
|
||||
b_grid_desc_k0_n_k1_ =
|
||||
DeviceGemmXdl_C_Shuffle::MakeBGridDescriptor_K0_N_K1(K, N, StrideB);
|
||||
c_grid_desc_m_n_ = DeviceGemmXdl_C_Shuffle::MakeCGridDescriptor_M_N(M, N, StrideC);
|
||||
|
||||
if(GridwiseGemm::CheckValidity(
|
||||
a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_, M01_, N01_))
|
||||
{
|
||||
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ =
|
||||
GridwiseGemm::
|
||||
MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(
|
||||
c_grid_desc_m_n_);
|
||||
|
||||
block_2_ctile_map_ =
|
||||
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01);
|
||||
}
|
||||
}
|
||||
|
||||
// private:
|
||||
const ADataType* p_a_grid_;
|
||||
const BDataType* p_b_grid_;
|
||||
CDataType* p_c_grid_;
|
||||
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_;
|
||||
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_;
|
||||
CGridDesc_M_N c_grid_desc_m_n_;
|
||||
typename GridwiseGemm::
|
||||
CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
|
||||
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_;
|
||||
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
|
||||
index_t M01_;
|
||||
index_t N01_;
|
||||
AElementwiseOperation a_element_op_;
|
||||
BElementwiseOperation b_element_op_;
|
||||
CElementwiseOperation c_element_op_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
using Argument = DeviceGemmXdl_C_Shuffle::Argument;
|
||||
|
||||
float Run(const Argument& arg, int nrepeat = 1)
|
||||
{
|
||||
{
|
||||
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
|
||||
<< ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
|
||||
<< arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl;
|
||||
|
||||
std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0)
|
||||
<< ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", "
|
||||
<< arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl;
|
||||
|
||||
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
|
||||
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
|
||||
}
|
||||
|
||||
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
arg.c_grid_desc_m_n_,
|
||||
arg.M01_,
|
||||
arg.N01_))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting");
|
||||
}
|
||||
|
||||
const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_);
|
||||
|
||||
const auto K =
|
||||
arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2);
|
||||
|
||||
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K);
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
if(has_main_k_block_loop)
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdlops_v3r1<
|
||||
GridwiseGemm,
|
||||
ADataType, // TODO: distiguish A/B datatype
|
||||
CDataType,
|
||||
remove_reference_t<DeviceGemmXdl_C_Shuffle::AGridDesc_K0_M_K1>,
|
||||
remove_reference_t<DeviceGemmXdl_C_Shuffle::BGridDesc_K0_N_K1>,
|
||||
remove_reference_t<
|
||||
typename GridwiseGemm::
|
||||
CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
|
||||
true>;
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.a_grid_desc_k0_m_k1_,
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.block_2_ctile_map_);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdlops_v3r1<
|
||||
GridwiseGemm,
|
||||
ADataType, // TODO: distiguish A/B datatype
|
||||
CDataType,
|
||||
remove_reference_t<DeviceGemmXdl_C_Shuffle::AGridDesc_K0_M_K1>,
|
||||
remove_reference_t<DeviceGemmXdl_C_Shuffle::BGridDesc_K0_N_K1>,
|
||||
remove_reference_t<
|
||||
typename GridwiseGemm::
|
||||
CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
|
||||
false>;
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.a_grid_desc_k0_m_k1_,
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.block_2_ctile_map_);
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
float Run(const BaseArgument* p_arg, int nrepeat = 1) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat);
|
||||
}
|
||||
};
|
||||
|
||||
static constexpr bool IsValidCompilationParameter()
|
||||
{
|
||||
// TODO: properly implement this check
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
arg.c_grid_desc_m_n_,
|
||||
arg.M01_,
|
||||
arg.N01_);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
|
||||
static auto MakeArgument(const ADataType* p_a,
|
||||
const BDataType* p_b,
|
||||
CDataType* p_c,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
index_t StrideC,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op)
|
||||
{
|
||||
return Argument{p_a,
|
||||
p_b,
|
||||
p_c,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideC,
|
||||
1,
|
||||
1,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
void* p_c,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
index_t StrideC,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op,
|
||||
index_t /* KBatch */ = 1) override
|
||||
{
|
||||
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
|
||||
static_cast<const BDataType*>(p_b),
|
||||
static_cast<CDataType*>(p_c),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideC,
|
||||
1,
|
||||
1,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceGemmXdl_C_Shuffle"
|
||||
<< "<"
|
||||
<< BlockSize << ", "
|
||||
<< MPerBlock << ", "
|
||||
<< NPerBlock << ", "
|
||||
<< KPerBlock << ", "
|
||||
<< AK1 << ", "
|
||||
<< BK1
|
||||
<< ">";
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -1,6 +1,4 @@
|
||||
#ifndef DEVICE_GEMM_XDL_C_SHUFFLE_BIAS_2D_HPP
|
||||
#define DEVICE_GEMM_XDL_C_SHUFFLE_BIAS_2D_HPP
|
||||
|
||||
#pragma once
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include "device.hpp"
|
||||
@@ -291,18 +289,17 @@ struct DeviceGemmXdl_C_Shuffle_Bias_2d
|
||||
arg.N01_))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting");
|
||||
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v3r2 has invalid setting");
|
||||
}
|
||||
|
||||
const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_);
|
||||
|
||||
const auto K0 = arg.a_grid_desc_k0_m_k1_.GetLength(I0);
|
||||
|
||||
const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
|
||||
const auto K =
|
||||
arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2);
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
if(has_main_k0_block_loop)
|
||||
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdlops_v3r2<
|
||||
GridwiseGemm,
|
||||
@@ -505,4 +502,3 @@ struct DeviceGemmXdl_C_Shuffle_Bias_2d
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
#endif
|
||||
|
||||
@@ -303,13 +303,12 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation
|
||||
|
||||
const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_);
|
||||
|
||||
const auto K0 = arg.a_grid_desc_k0_m_k1_.GetLength(I0);
|
||||
|
||||
const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
|
||||
const auto K =
|
||||
arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2);
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
if(has_main_k0_block_loop)
|
||||
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdlops_v3r2<
|
||||
GridwiseGemm,
|
||||
|
||||
@@ -345,13 +345,12 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation_Add
|
||||
|
||||
const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_);
|
||||
|
||||
const auto K0 = arg.a_grid_desc_k0_m_k1_.GetLength(I0);
|
||||
|
||||
const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
|
||||
const auto K =
|
||||
arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2);
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
if(has_main_k0_block_loop)
|
||||
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdlops_v3r3<
|
||||
GridwiseGemm,
|
||||
|
||||
@@ -465,11 +465,9 @@ struct DeviceGemm_Xdl_CShuffle
|
||||
const auto K =
|
||||
arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
|
||||
|
||||
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K);
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
if(has_main_k_block_loop)
|
||||
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdl_cshuffle_v1<
|
||||
GridwiseGemm,
|
||||
|
||||
@@ -467,11 +467,9 @@ struct DeviceGemm_Xdl_CShuffle_v2
|
||||
const auto K =
|
||||
arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
|
||||
|
||||
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K);
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
if(has_main_k_block_loop)
|
||||
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdl_cshuffle_v2<
|
||||
GridwiseGemm,
|
||||
|
||||
@@ -450,59 +450,53 @@ struct DeviceGroupedGemmXdl
|
||||
|
||||
float Run(const Argument& arg, int nrepeat = 1)
|
||||
{
|
||||
StaticallyIndexedArray<GemmDescKernelArg, MaxGroupCount> gemm_desc_kernel_arg_arg;
|
||||
StaticallyIndexedArray<GemmDescKernelArg, MaxGroupCount> gemm_desc_kernel_args;
|
||||
|
||||
bool has_main_k0_block_loop = true;
|
||||
bool has_main_k_block_loop = true;
|
||||
|
||||
static_for<0, MaxGroupCount, 1>{}([&](auto i) {
|
||||
if(i < arg.gemm_desc_kernel_arg_.size())
|
||||
{
|
||||
gemm_desc_kernel_arg_arg(i) = arg.gemm_desc_kernel_arg_[i];
|
||||
gemm_desc_kernel_args(i) = arg.gemm_desc_kernel_arg_[i];
|
||||
|
||||
std::cout << "group: " << i << " arg.a_grid_desc_k0_m_k1_{"
|
||||
<< gemm_desc_kernel_arg_arg[i].a_grid_desc_k0_m_k1_.GetLength(I0)
|
||||
<< ", "
|
||||
<< gemm_desc_kernel_arg_arg[i].a_grid_desc_k0_m_k1_.GetLength(I1)
|
||||
<< ", "
|
||||
<< gemm_desc_kernel_arg_arg[i].a_grid_desc_k0_m_k1_.GetLength(I2)
|
||||
<< "}";
|
||||
<< gemm_desc_kernel_args[i].a_grid_desc_k0_m_k1_.GetLength(I0) << ", "
|
||||
<< gemm_desc_kernel_args[i].a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
|
||||
<< gemm_desc_kernel_args[i].a_grid_desc_k0_m_k1_.GetLength(I2) << "}";
|
||||
|
||||
std::cout << ", arg.b_grid_desc_k0_n_k1_{"
|
||||
<< gemm_desc_kernel_arg_arg[i].b_grid_desc_k0_n_k1_.GetLength(I0)
|
||||
<< ", "
|
||||
<< gemm_desc_kernel_arg_arg[i].b_grid_desc_k0_n_k1_.GetLength(I1)
|
||||
<< ", "
|
||||
<< gemm_desc_kernel_arg_arg[i].b_grid_desc_k0_n_k1_.GetLength(I2)
|
||||
<< "}";
|
||||
<< gemm_desc_kernel_args[i].b_grid_desc_k0_n_k1_.GetLength(I0) << ", "
|
||||
<< gemm_desc_kernel_args[i].b_grid_desc_k0_n_k1_.GetLength(I1) << ", "
|
||||
<< gemm_desc_kernel_args[i].b_grid_desc_k0_n_k1_.GetLength(I2) << "}";
|
||||
|
||||
std::cout << ", arg.c_grid_desc_m_n_{ "
|
||||
<< gemm_desc_kernel_arg_arg[i].c_grid_desc_m_n_.GetLength(I0) << ", "
|
||||
<< gemm_desc_kernel_arg_arg[i].c_grid_desc_m_n_.GetLength(I1) << "}"
|
||||
<< gemm_desc_kernel_args[i].c_grid_desc_m_n_.GetLength(I0) << ", "
|
||||
<< gemm_desc_kernel_args[i].c_grid_desc_m_n_.GetLength(I1) << "}"
|
||||
<< std::endl;
|
||||
|
||||
if(!GridwiseGemm::CheckValidity(
|
||||
gemm_desc_kernel_arg_arg[i].a_grid_desc_k0_m_k1_,
|
||||
gemm_desc_kernel_arg_arg[i].b_grid_desc_k0_n_k1_,
|
||||
gemm_desc_kernel_arg_arg[i].c_grid_desc_m_n_,
|
||||
arg.M01_,
|
||||
arg.N01_))
|
||||
if(!GridwiseGemm::CheckValidity(gemm_desc_kernel_args[i].a_grid_desc_k0_m_k1_,
|
||||
gemm_desc_kernel_args[i].b_grid_desc_k0_n_k1_,
|
||||
gemm_desc_kernel_args[i].c_grid_desc_m_n_,
|
||||
arg.M01_,
|
||||
arg.N01_))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting");
|
||||
}
|
||||
|
||||
const auto K0 = gemm_desc_kernel_arg_arg[i].a_grid_desc_k0_m_k1_.GetLength(I0);
|
||||
const auto K = gemm_desc_kernel_args[i].a_grid_desc_k0_m_k1_.GetLength(I0) *
|
||||
gemm_desc_kernel_args[i].a_grid_desc_k0_m_k1_.GetLength(I2);
|
||||
|
||||
if(GridwiseGemm::CalculateHasMainK0BlockLoop(K0) != has_main_k0_block_loop)
|
||||
if(GridwiseGemm::CalculateHasMainKBlockLoop(K) != has_main_k_block_loop)
|
||||
{
|
||||
throw std::runtime_error("wrong! not all gemm has_main_k0_block_loop");
|
||||
throw std::runtime_error("wrong! not all gemm has_main_k_block_loop");
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
if(has_main_k0_block_loop)
|
||||
if(has_main_k_block_loop)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_grouped_gemm_xdlops_v2r3<GridwiseGemm,
|
||||
@@ -520,7 +514,7 @@ struct DeviceGroupedGemmXdl
|
||||
dim3(arg.grid_size_),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
gemm_desc_kernel_arg_arg,
|
||||
gemm_desc_kernel_args,
|
||||
arg.gemm_desc_kernel_arg_.size(),
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
@@ -544,7 +538,7 @@ struct DeviceGroupedGemmXdl
|
||||
dim3(arg.grid_size_),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
gemm_desc_kernel_arg_arg,
|
||||
gemm_desc_kernel_args,
|
||||
arg.gemm_desc_kernel_arg_.size(),
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
|
||||
@@ -25,7 +25,7 @@ template <typename GridwiseGemm,
|
||||
typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename DGridDescriptor_MBlock_MPerBlock,
|
||||
typename Block2CTileMap,
|
||||
bool HasMainK0BlockLoop>
|
||||
bool HasMainKBlockLoop>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
@@ -51,22 +51,22 @@ __global__ void
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
GridwiseGemm::template Run<HasMainK0BlockLoop>(p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
p_d0_grid,
|
||||
p_d1_grid,
|
||||
p_shared,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
d0_reduce_op,
|
||||
d1_reduce_op,
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
d_grid_desc_mblock_mperblock,
|
||||
block_2_ctile_map);
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
p_d0_grid,
|
||||
p_d1_grid,
|
||||
p_shared,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
d0_reduce_op,
|
||||
d1_reduce_op,
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
d_grid_desc_mblock_mperblock,
|
||||
block_2_ctile_map);
|
||||
#else
|
||||
ignore = p_a_grid;
|
||||
ignore = p_b_grid;
|
||||
@@ -154,6 +154,10 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
|
||||
static constexpr auto AK1 = Number<AK1Value>{};
|
||||
static constexpr auto BK1 = Number<BK1Value>{};
|
||||
|
||||
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
|
||||
|
||||
using GridwiseGemmPipe = GridwiseGemmPipeline_v1<NumGemmKPrefetchStage>;
|
||||
|
||||
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
|
||||
{
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
@@ -237,21 +241,10 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
|
||||
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0))
|
||||
return false;
|
||||
|
||||
// check NumGemmKPrefetchStage
|
||||
if constexpr(NumGemmKPrefetchStage == 1)
|
||||
{
|
||||
// 1-stage prefetch always supported
|
||||
}
|
||||
else if constexpr(NumGemmKPrefetchStage == 2)
|
||||
{
|
||||
// 2-stage prefetch currently only support even number of K0 loop
|
||||
// TODO: add support for odd number of K0 loop
|
||||
if(!((K / KPerBlock) % 2 == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
// check gridwise gemm pipeline
|
||||
const auto num_k_loop = K / KPerBlock;
|
||||
|
||||
if(!GridwiseGemmPipe::IsSupported(num_k_loop))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
@@ -271,12 +264,11 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
|
||||
return grid_size;
|
||||
}
|
||||
|
||||
// TODO move this function into GEMM-pipeline class
|
||||
__host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0)
|
||||
__host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
|
||||
{
|
||||
const bool has_main_k0_block_loop = ((K0 * AK1) / (NumGemmKPrefetchStage * KPerBlock)) > 1;
|
||||
const index_t num_loop = K / KPerBlock;
|
||||
|
||||
return has_main_k0_block_loop;
|
||||
return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
@@ -362,7 +354,7 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
|
||||
using DefaultBlock2CTileMap =
|
||||
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}))>;
|
||||
|
||||
template <bool HasMainK0BlockLoop, typename Block2CTileMap>
|
||||
template <bool HasMainKBlockLoop, typename Block2CTileMap>
|
||||
__device__ static void Run(const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
@@ -485,7 +477,7 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
|
||||
math::lcm(AK1, BK1), MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
|
||||
|
||||
auto blockwise_gemm =
|
||||
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
|
||||
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<ThisThreadBlock,
|
||||
FloatAB,
|
||||
FloatGemmAcc,
|
||||
decltype(a_block_desc_ak0_m_ak1),
|
||||
@@ -513,43 +505,25 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
|
||||
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0);
|
||||
|
||||
// gridwise GEMM pipeline
|
||||
const auto gridwise_gemm_pipeline =
|
||||
GridwiseGemmPipeline_v1<remove_cvref_t<decltype(a_grid_desc_ak0_m_ak1)>,
|
||||
remove_cvref_t<decltype(a_block_desc_ak0_m_ak1)>,
|
||||
remove_cvref_t<decltype(a_blockwise_copy)>,
|
||||
remove_cvref_t<decltype(a_grid_buf)>,
|
||||
remove_cvref_t<decltype(a_block_buf)>,
|
||||
remove_cvref_t<decltype(a_block_slice_copy_step)>,
|
||||
remove_cvref_t<decltype(b_grid_desc_bk0_n_bk1)>,
|
||||
remove_cvref_t<decltype(b_block_desc_bk0_n_bk1)>,
|
||||
remove_cvref_t<decltype(b_blockwise_copy)>,
|
||||
remove_cvref_t<decltype(b_grid_buf)>,
|
||||
remove_cvref_t<decltype(b_block_buf)>,
|
||||
remove_cvref_t<decltype(b_block_slice_copy_step)>,
|
||||
remove_cvref_t<decltype(blockwise_gemm)>,
|
||||
remove_cvref_t<decltype(c_thread_buf)>,
|
||||
NumGemmKPrefetchStage,
|
||||
HasMainK0BlockLoop>{};
|
||||
|
||||
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
|
||||
(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
|
||||
KPerBlock);
|
||||
|
||||
gridwise_gemm_pipeline.Run(a_grid_desc_ak0_m_ak1,
|
||||
a_block_desc_ak0_m_ak1,
|
||||
a_blockwise_copy,
|
||||
a_grid_buf,
|
||||
a_block_buf,
|
||||
a_block_slice_copy_step,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
b_block_desc_bk0_n_bk1,
|
||||
b_blockwise_copy,
|
||||
b_grid_buf,
|
||||
b_block_buf,
|
||||
b_block_slice_copy_step,
|
||||
blockwise_gemm,
|
||||
c_thread_buf,
|
||||
num_k_block_main_loop);
|
||||
GridwiseGemmPipe::template Run<HasMainKBlockLoop>(a_grid_desc_ak0_m_ak1,
|
||||
a_block_desc_ak0_m_ak1,
|
||||
a_blockwise_copy,
|
||||
a_grid_buf,
|
||||
a_block_buf,
|
||||
a_block_slice_copy_step,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
b_block_desc_bk0_n_bk1,
|
||||
b_blockwise_copy,
|
||||
b_grid_buf,
|
||||
b_block_buf,
|
||||
b_block_slice_copy_step,
|
||||
blockwise_gemm,
|
||||
c_thread_buf,
|
||||
num_k_block_main_loop);
|
||||
|
||||
// shuffle C and write out
|
||||
{
|
||||
|
||||
@@ -120,6 +120,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
|
||||
// K1 should be Number<...>
|
||||
static constexpr auto K1 = Number<K1Value>{};
|
||||
|
||||
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
|
||||
|
||||
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
|
||||
{
|
||||
constexpr auto max_lds_align = K1;
|
||||
@@ -262,7 +264,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
|
||||
}();
|
||||
|
||||
using BlockwiseGemm =
|
||||
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
|
||||
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<ThisThreadBlock,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
decltype(a_k0_m_k1_block_desc),
|
||||
@@ -487,7 +489,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
|
||||
// sanity check
|
||||
|
||||
auto blockwise_gemm =
|
||||
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
|
||||
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<ThisThreadBlock,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
decltype(a_k0_m_k1_block_desc),
|
||||
|
||||
@@ -123,6 +123,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
|
||||
// K1 should be Number<...>
|
||||
static constexpr auto K1 = Number<K1Value>{};
|
||||
|
||||
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
|
||||
|
||||
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
|
||||
{
|
||||
constexpr auto max_lds_align = K1;
|
||||
@@ -476,7 +478,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
|
||||
// sanity check
|
||||
|
||||
auto blockwise_gemm =
|
||||
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
|
||||
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<ThisThreadBlock,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
decltype(a_k0_m_k1_block_desc),
|
||||
|
||||
@@ -24,7 +24,7 @@ template <typename GridwiseGemm,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
typename Block2CTileMap,
|
||||
bool HasMainK0BlockLoop>
|
||||
bool HasMainKBlockLoop>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
@@ -48,7 +48,7 @@ __global__ void
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
GridwiseGemm::template Run<HasMainK0BlockLoop>(
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop>(
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
@@ -119,7 +119,7 @@ template <
|
||||
index_t CShuffleNXdlPerWavePerShuffle,
|
||||
typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl,
|
||||
index_t CBlockTransferScalarPerVector_NWaveNPerXdl,
|
||||
index_t NumPrefetch = 1>
|
||||
index_t NumGemmKPrefetchStage = 1>
|
||||
struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
@@ -134,6 +134,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
|
||||
// K1 should be Number<...>
|
||||
static constexpr auto K1 = Number<K1Value>{};
|
||||
|
||||
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
|
||||
|
||||
using GridwiseGemmPipe = GridwiseGemmPipeline_v1<NumGemmKPrefetchStage>;
|
||||
|
||||
__host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1()
|
||||
{
|
||||
constexpr auto max_lds_align = K1;
|
||||
@@ -252,21 +256,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
|
||||
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0))
|
||||
return false;
|
||||
|
||||
// check NumPrefetch
|
||||
if constexpr(NumPrefetch == 1)
|
||||
{
|
||||
// 1-stage prefetch always supported
|
||||
}
|
||||
else if constexpr(NumPrefetch == 2)
|
||||
{
|
||||
// 2-stage prefetch currently only support even number of K0 loop
|
||||
// TODO: add support for odd number of K0 loop
|
||||
if(!((K0 / K0PerBlock) % 2 == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
// check gridwise gemm pipeline
|
||||
const auto num_k_loop = K0 / K0PerBlock;
|
||||
|
||||
if(!GridwiseGemmPipe::IsSupported(num_k_loop))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
@@ -296,12 +289,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
|
||||
return grid_size;
|
||||
}
|
||||
|
||||
// TODO move this function into GEMM-pipeline class
|
||||
__host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0)
|
||||
__host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
|
||||
{
|
||||
const bool has_main_k0_block_loop = (K0 / (NumPrefetch * K0PerBlock)) > 1;
|
||||
const index_t num_loop = K / (K0PerBlock * K1);
|
||||
|
||||
return has_main_k0_block_loop;
|
||||
return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
|
||||
}
|
||||
|
||||
template <typename CGridDesc_M_N_>
|
||||
@@ -379,7 +371,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
|
||||
using DefaultBlock2CTileMap =
|
||||
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1))>;
|
||||
|
||||
template <bool HasMainK0BlockLoop, typename Block2CTileMap = DefaultBlock2CTileMap>
|
||||
template <bool HasMainKBlockLoop, typename Block2CTileMap = DefaultBlock2CTileMap>
|
||||
__device__ static void
|
||||
Run(const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
@@ -455,7 +447,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
|
||||
1,
|
||||
AThreadTransferSrcResetCoordinateAfterRun,
|
||||
true,
|
||||
NumPrefetch>(
|
||||
NumGemmKPrefetchStage>(
|
||||
a_grid_desc_k0_m_k1,
|
||||
make_multi_index(0, m_block_data_idx_on_grid, 0),
|
||||
a_element_op,
|
||||
@@ -486,7 +478,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
|
||||
1,
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
true,
|
||||
NumPrefetch>(
|
||||
NumGemmKPrefetchStage>(
|
||||
b_grid_desc_k0_n_k1,
|
||||
make_multi_index(0, n_block_data_idx_on_grid, 0),
|
||||
b_element_op,
|
||||
@@ -503,7 +495,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
|
||||
// sanity check
|
||||
|
||||
auto blockwise_gemm =
|
||||
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
|
||||
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<ThisThreadBlock,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
decltype(a_block_desc_k0_m_k1),
|
||||
@@ -531,41 +523,23 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
|
||||
constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
|
||||
|
||||
// gridwise GEMM pipeline
|
||||
const auto gridwise_gemm_pipeline =
|
||||
GridwiseGemmPipeline_v1<remove_cvref_t<decltype(a_grid_desc_k0_m_k1)>,
|
||||
remove_cvref_t<decltype(a_block_desc_k0_m_k1)>,
|
||||
remove_cvref_t<decltype(a_blockwise_copy)>,
|
||||
remove_cvref_t<decltype(a_grid_buf)>,
|
||||
remove_cvref_t<decltype(a_block_buf)>,
|
||||
remove_cvref_t<decltype(a_block_slice_copy_step)>,
|
||||
remove_cvref_t<decltype(b_grid_desc_k0_n_k1)>,
|
||||
remove_cvref_t<decltype(b_block_desc_k0_n_k1)>,
|
||||
remove_cvref_t<decltype(b_blockwise_copy)>,
|
||||
remove_cvref_t<decltype(b_grid_buf)>,
|
||||
remove_cvref_t<decltype(b_block_buf)>,
|
||||
remove_cvref_t<decltype(b_block_slice_copy_step)>,
|
||||
remove_cvref_t<decltype(blockwise_gemm)>,
|
||||
remove_cvref_t<decltype(c_thread_buf)>,
|
||||
NumPrefetch,
|
||||
HasMainK0BlockLoop>{};
|
||||
|
||||
const index_t K0BlockMainLoop = __builtin_amdgcn_readfirstlane(K0 / K0PerBlock);
|
||||
|
||||
gridwise_gemm_pipeline.Run(a_grid_desc_k0_m_k1,
|
||||
a_block_desc_k0_m_k1,
|
||||
a_blockwise_copy,
|
||||
a_grid_buf,
|
||||
a_block_buf,
|
||||
a_block_slice_copy_step,
|
||||
b_grid_desc_k0_n_k1,
|
||||
b_block_desc_k0_n_k1,
|
||||
b_blockwise_copy,
|
||||
b_grid_buf,
|
||||
b_block_buf,
|
||||
b_block_slice_copy_step,
|
||||
blockwise_gemm,
|
||||
c_thread_buf,
|
||||
K0BlockMainLoop);
|
||||
GridwiseGemmPipe::template Run<HasMainKBlockLoop>(a_grid_desc_k0_m_k1,
|
||||
a_block_desc_k0_m_k1,
|
||||
a_blockwise_copy,
|
||||
a_grid_buf,
|
||||
a_block_buf,
|
||||
a_block_slice_copy_step,
|
||||
b_grid_desc_k0_n_k1,
|
||||
b_block_desc_k0_n_k1,
|
||||
b_blockwise_copy,
|
||||
b_grid_buf,
|
||||
b_block_buf,
|
||||
b_block_slice_copy_step,
|
||||
blockwise_gemm,
|
||||
c_thread_buf,
|
||||
K0BlockMainLoop);
|
||||
|
||||
// shuffle C and write out
|
||||
{
|
||||
|
||||
@@ -25,7 +25,7 @@ template <typename GridwiseGemm,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
typename Block2CTileMap,
|
||||
bool HasMainK0BlockLoop>
|
||||
bool HasMainKBlockLoop>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
@@ -52,7 +52,7 @@ __global__ void
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
GridwiseGemm::template Run<HasMainK0BlockLoop>(
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop>(
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
@@ -128,7 +128,7 @@ template <
|
||||
index_t CShuffleNXdlPerWavePerShuffle,
|
||||
typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl,
|
||||
index_t CBlockTransferScalarPerVector_NWaveNPerXdl,
|
||||
index_t NumPrefetch = 1>
|
||||
index_t NumGemmKPrefetchStage = 1>
|
||||
struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
@@ -143,6 +143,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
|
||||
// K1 should be Number<...>
|
||||
static constexpr auto K1 = Number<K1Value>{};
|
||||
|
||||
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
|
||||
|
||||
using GridwiseGemmPipe = GridwiseGemmPipeline_v1<NumGemmKPrefetchStage>;
|
||||
|
||||
__host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1()
|
||||
{
|
||||
constexpr auto max_lds_align = K1;
|
||||
@@ -261,21 +265,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
|
||||
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0))
|
||||
return false;
|
||||
|
||||
// check NumPrefetch
|
||||
if constexpr(NumPrefetch == 1)
|
||||
{
|
||||
// 1-stage prefetch always supported
|
||||
}
|
||||
else if constexpr(NumPrefetch == 2)
|
||||
{
|
||||
// 2-stage prefetch currently only support even number of K0 loop
|
||||
// TODO: add support for odd number of K0 loop
|
||||
if(!((K0 / K0PerBlock) % 2 == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
// check gridwise gemm pipeline
|
||||
const auto num_k_loop = K0 / K0PerBlock;
|
||||
|
||||
if(!GridwiseGemmPipe::IsSupported(num_k_loop))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
@@ -305,12 +298,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
|
||||
return grid_size;
|
||||
}
|
||||
|
||||
// TODO move this function into GEMM-pipeline class
|
||||
__host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0)
|
||||
__host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
|
||||
{
|
||||
const bool has_main_k0_block_loop = (K0 / (NumPrefetch * K0PerBlock)) > 1;
|
||||
const index_t num_loop = K / (K0PerBlock * K1);
|
||||
|
||||
return has_main_k0_block_loop;
|
||||
return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
|
||||
}
|
||||
|
||||
template <typename CGridDesc_M_N_>
|
||||
@@ -393,7 +385,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
|
||||
using DefaultBlock2CTileMap =
|
||||
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1))>;
|
||||
|
||||
template <bool HasMainK0BlockLoop, typename Block2CTileMap = DefaultBlock2CTileMap>
|
||||
template <bool HasMainKBlockLoop, typename Block2CTileMap = DefaultBlock2CTileMap>
|
||||
__device__ static void
|
||||
Run(const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
@@ -522,7 +514,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
|
||||
// sanity check
|
||||
|
||||
auto blockwise_gemm =
|
||||
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
|
||||
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<ThisThreadBlock,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
decltype(a_block_desc_k0_m_k1),
|
||||
@@ -550,41 +542,23 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
|
||||
constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
|
||||
|
||||
// gridwise GEMM pipeline
|
||||
const auto gridwise_gemm_pipeline =
|
||||
GridwiseGemmPipeline_v1<remove_cvref_t<decltype(a_grid_desc_k0_m_k1)>,
|
||||
remove_cvref_t<decltype(a_block_desc_k0_m_k1)>,
|
||||
remove_cvref_t<decltype(a_blockwise_copy)>,
|
||||
remove_cvref_t<decltype(a_grid_buf)>,
|
||||
remove_cvref_t<decltype(a_block_buf)>,
|
||||
remove_cvref_t<decltype(a_block_slice_copy_step)>,
|
||||
remove_cvref_t<decltype(b_grid_desc_k0_n_k1)>,
|
||||
remove_cvref_t<decltype(b_block_desc_k0_n_k1)>,
|
||||
remove_cvref_t<decltype(b_blockwise_copy)>,
|
||||
remove_cvref_t<decltype(b_grid_buf)>,
|
||||
remove_cvref_t<decltype(b_block_buf)>,
|
||||
remove_cvref_t<decltype(b_block_slice_copy_step)>,
|
||||
remove_cvref_t<decltype(blockwise_gemm)>,
|
||||
remove_cvref_t<decltype(c_thread_buf)>,
|
||||
NumPrefetch,
|
||||
HasMainK0BlockLoop>{};
|
||||
|
||||
const index_t K0BlockMainLoop = __builtin_amdgcn_readfirstlane(K0 / K0PerBlock);
|
||||
|
||||
gridwise_gemm_pipeline.Run(a_grid_desc_k0_m_k1,
|
||||
a_block_desc_k0_m_k1,
|
||||
a_blockwise_copy,
|
||||
a_grid_buf,
|
||||
a_block_buf,
|
||||
a_block_slice_copy_step,
|
||||
b_grid_desc_k0_n_k1,
|
||||
b_block_desc_k0_n_k1,
|
||||
b_blockwise_copy,
|
||||
b_grid_buf,
|
||||
b_block_buf,
|
||||
b_block_slice_copy_step,
|
||||
blockwise_gemm,
|
||||
c_thread_buf,
|
||||
K0BlockMainLoop);
|
||||
GridwiseGemmPipe::template Run<HasMainKBlockLoop>(a_grid_desc_k0_m_k1,
|
||||
a_block_desc_k0_m_k1,
|
||||
a_blockwise_copy,
|
||||
a_grid_buf,
|
||||
a_block_buf,
|
||||
a_block_slice_copy_step,
|
||||
b_grid_desc_k0_n_k1,
|
||||
b_block_desc_k0_n_k1,
|
||||
b_blockwise_copy,
|
||||
b_grid_buf,
|
||||
b_block_buf,
|
||||
b_block_slice_copy_step,
|
||||
blockwise_gemm,
|
||||
c_thread_buf,
|
||||
K0BlockMainLoop);
|
||||
|
||||
// shuffle C and write out
|
||||
{
|
||||
@@ -623,17 +597,18 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
|
||||
|
||||
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
|
||||
c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
|
||||
make_tuple(
|
||||
make_freeze_transform(I0), // freeze mblock
|
||||
make_pass_through_transform(
|
||||
Number<CShuffleMXdlPerWavePerShuffle>{}), // M0 (MXdlPerWave) per shuffle
|
||||
make_unmerge_transform(
|
||||
make_tuple(M1, M2, M3, M4)), // M1 = MWave, M2 * M3 * M4 = MPerXdl
|
||||
make_freeze_transform(I0), // freeze nblock
|
||||
make_pass_through_transform(
|
||||
Number<CShuffleNXdlPerWavePerShuffle>{}), // N0 (NXdlPerWave) per shuffle
|
||||
make_unmerge_transform(
|
||||
make_tuple(N1, N2))), // M1 = MWave, M2 * M3 * M4 = MPerXdl
|
||||
make_tuple(make_freeze_transform(I0), // freeze mblock
|
||||
make_pass_through_transform(
|
||||
Number<CShuffleMXdlPerWavePerShuffle>{}), // M0 (MXdlPerWave) per
|
||||
// shuffle
|
||||
make_unmerge_transform(
|
||||
make_tuple(M1, M2, M3, M4)), // M1 = MWave, M2 * M3 * M4 = MPerXdl
|
||||
make_freeze_transform(I0), // freeze nblock
|
||||
make_pass_through_transform(
|
||||
Number<CShuffleNXdlPerWavePerShuffle>{}), // N0 (NXdlPerWave) per
|
||||
// shuffle
|
||||
make_unmerge_transform(
|
||||
make_tuple(N1, N2))), // M1 = MWave, M2 * M3 * M4 = MPerXdl
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<1>{},
|
||||
Sequence<2>{},
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
#include <stdlib.h>
|
||||
#include "config.hpp"
|
||||
#include "device_gemm_xdl_c_shuffle.hpp"
|
||||
#include "device_gemm_xdl_cshuffle.hpp"
|
||||
#include "element_wise_operation.hpp"
|
||||
#include "device_operation_instance.hpp"
|
||||
|
||||
@@ -20,26 +20,28 @@ using S = ck::Sequence<Is...>;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
|
||||
// Compilation parameters for a[m, k] * b[n, k] = c[m, n]
|
||||
using device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances = std::tuple<
|
||||
// clang-format off
|
||||
//#####################| AData| BData| CData| AccData| CShuffle| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Num|
|
||||
//#####################| Type| Type| Type| Type| DataType| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Prefetch|
|
||||
//#####################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| |
|
||||
//#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, F16, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8, 2>,
|
||||
DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, F16, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8, 2>,
|
||||
DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, F16, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8, 2>,
|
||||
DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, F16, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8, 2>,
|
||||
DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, F16, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8, 2>,
|
||||
DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, F16, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8, 2>,
|
||||
DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, F16, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8, 2>,
|
||||
DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, F16, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8, 2>,
|
||||
DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, F16, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8, 2>,
|
||||
DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, F16, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8, 2>,
|
||||
DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, F16, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8, 2>,
|
||||
DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, F16, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8, 2>,
|
||||
DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, F16, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8, 2>
|
||||
//#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
//#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
//#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
|
||||
DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
|
||||
DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>,
|
||||
DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
|
||||
DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>,
|
||||
DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user