mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 05:01:25 +00:00
Add grouped conv fwd bias relu instances (#2179)
* Add grouped conv fwd bias relu instances * fixes * fix
This commit is contained in:
@@ -279,9 +279,6 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
static constexpr bool isMultiD = DsDataType::Size() > 0;
|
||||
static constexpr bool isMultiABD = isMultiA || isMultiB || isMultiD;
|
||||
|
||||
// multi ABD not supported
|
||||
static_assert(!isMultiABD, "Multi A, Mutli B and Multi D are not supported");
|
||||
|
||||
static constexpr index_t NumATensor = GetNumABTensors<isMultiA, ADataType>();
|
||||
static constexpr index_t NumBTensor = GetNumABTensors<isMultiB, BDataType>();
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
@@ -1080,91 +1077,96 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
float avg_time = 0.f;
|
||||
|
||||
if constexpr(is_NGCHW_GKCYX_NGKHW<ALayout, BLayout, ELayout>() ||
|
||||
is_NGCDHW_GKCZYX_NGKDHW<ALayout, BLayout, ELayout>())
|
||||
if constexpr(!isMultiABD)
|
||||
{
|
||||
const index_t a_grid_size =
|
||||
arg.elementwise_block_2_ctile_map_transpose_a_.CalculateGridSize(
|
||||
arg.a_in_transpose_desc_);
|
||||
const index_t b_grid_size =
|
||||
arg.elementwise_block_2_ctile_map_transpose_b_.CalculateGridSize(
|
||||
arg.b_in_transpose_desc_);
|
||||
if constexpr(is_NGCHW_GKCYX_NGKHW<ALayout, BLayout, ELayout>() ||
|
||||
is_NGCDHW_GKCZYX_NGKDHW<ALayout, BLayout, ELayout>())
|
||||
{
|
||||
const index_t a_grid_size =
|
||||
arg.elementwise_block_2_ctile_map_transpose_a_.CalculateGridSize(
|
||||
arg.a_in_transpose_desc_);
|
||||
const index_t b_grid_size =
|
||||
arg.elementwise_block_2_ctile_map_transpose_b_.CalculateGridSize(
|
||||
arg.b_in_transpose_desc_);
|
||||
|
||||
ADataType* p_a_out_grid = type_convert<ADataType*>(arg.p_workspace_);
|
||||
BDataType* p_b_out_grid = type_convert<BDataType*>(arg.p_workspace_) +
|
||||
arg.GetWorkspaceATensorSizeBytes() / sizeof(BDataType);
|
||||
ADataType* p_a_out_grid = type_convert<ADataType*>(arg.p_workspace_);
|
||||
BDataType* p_b_out_grid =
|
||||
type_convert<BDataType*>(arg.p_workspace_) +
|
||||
arg.GetWorkspaceATensorSizeBytes() / sizeof(BDataType);
|
||||
|
||||
auto kernel_transpose = kernel_elementwise_dual<GridwiseElementwiseInputTranspose,
|
||||
GridwiseElementwiseWeightTranspose,
|
||||
ck::Tuple<NGCHWTransposeDescType>,
|
||||
ck::Tuple<GKCYXTransposeDescType>,
|
||||
ck::Tuple<NHWGCTransposeDescType>,
|
||||
ck::Tuple<GKYXCTransposeDescType>,
|
||||
ck::Tuple<const ADataType*>,
|
||||
ck::Tuple<const BDataType*>,
|
||||
ck::Tuple<ADataType*>,
|
||||
ck::Tuple<BDataType*>,
|
||||
Block2TileMapElementwise,
|
||||
Block2TileMapElementwise,
|
||||
element_wise::PassThrough>;
|
||||
auto kernel_transpose =
|
||||
kernel_elementwise_dual<GridwiseElementwiseInputTranspose,
|
||||
GridwiseElementwiseWeightTranspose,
|
||||
ck::Tuple<NGCHWTransposeDescType>,
|
||||
ck::Tuple<GKCYXTransposeDescType>,
|
||||
ck::Tuple<NHWGCTransposeDescType>,
|
||||
ck::Tuple<GKYXCTransposeDescType>,
|
||||
ck::Tuple<const ADataType*>,
|
||||
ck::Tuple<const BDataType*>,
|
||||
ck::Tuple<ADataType*>,
|
||||
ck::Tuple<BDataType*>,
|
||||
Block2TileMapElementwise,
|
||||
Block2TileMapElementwise,
|
||||
element_wise::PassThrough>;
|
||||
|
||||
avg_time += launch_and_time_kernel(stream_config,
|
||||
kernel_transpose,
|
||||
dim3(a_grid_size + b_grid_size),
|
||||
dim3(ElementwiseBlocksize),
|
||||
0,
|
||||
make_tuple(arg.a_in_transpose_desc_),
|
||||
make_tuple(arg.b_in_transpose_desc_),
|
||||
make_tuple(arg.a_out_transpose_desc_),
|
||||
make_tuple(arg.b_out_transpose_desc_),
|
||||
make_tuple(arg.p_a_grid_),
|
||||
make_tuple(arg.p_b_grid_),
|
||||
make_tuple(p_a_out_grid),
|
||||
make_tuple(p_b_out_grid),
|
||||
arg.elementwise_block_2_ctile_map_transpose_a_,
|
||||
arg.elementwise_block_2_ctile_map_transpose_b_,
|
||||
element_wise::PassThrough{},
|
||||
a_grid_size);
|
||||
avg_time +=
|
||||
launch_and_time_kernel(stream_config,
|
||||
kernel_transpose,
|
||||
dim3(a_grid_size + b_grid_size),
|
||||
dim3(ElementwiseBlocksize),
|
||||
0,
|
||||
make_tuple(arg.a_in_transpose_desc_),
|
||||
make_tuple(arg.b_in_transpose_desc_),
|
||||
make_tuple(arg.a_out_transpose_desc_),
|
||||
make_tuple(arg.b_out_transpose_desc_),
|
||||
make_tuple(arg.p_a_grid_),
|
||||
make_tuple(arg.p_b_grid_),
|
||||
make_tuple(p_a_out_grid),
|
||||
make_tuple(p_b_out_grid),
|
||||
arg.elementwise_block_2_ctile_map_transpose_a_,
|
||||
arg.elementwise_block_2_ctile_map_transpose_b_,
|
||||
element_wise::PassThrough{},
|
||||
a_grid_size);
|
||||
}
|
||||
|
||||
avg_time += RunGemm(arg, stream_config);
|
||||
|
||||
if constexpr(is_NGCHW_GKCYX_NGKHW<ALayout, BLayout, ELayout>() ||
|
||||
is_NGCDHW_GKCZYX_NGKDHW<ALayout, BLayout, ELayout>())
|
||||
{
|
||||
const index_t grid_size =
|
||||
arg.elementwise_block_2_ctile_map_transpose_e_.CalculateGridSize(
|
||||
arg.e_in_transpose_desc_);
|
||||
|
||||
const EDataType* p_e_in_grid =
|
||||
type_convert<EDataType*>(arg.p_workspace_) +
|
||||
(arg.GetWorkspaceATensorSizeBytes() + arg.GetWorkspaceBTensorSizeBytes()) /
|
||||
sizeof(EDataType);
|
||||
|
||||
EDataType* p_e_out_grid = arg.p_e_grid_;
|
||||
|
||||
auto kernel_transpose = kernel_elementwise<GridwiseElementwiseOutputTranspose,
|
||||
ck::Tuple<NHWGCTransposeDescType>,
|
||||
ck::Tuple<NGCHWTransposeDescType>,
|
||||
ck::Tuple<const EDataType*>,
|
||||
ck::Tuple<EDataType*>,
|
||||
Block2TileMapElementwise,
|
||||
element_wise::PassThrough>;
|
||||
|
||||
avg_time +=
|
||||
launch_and_time_kernel(stream_config,
|
||||
kernel_transpose,
|
||||
dim3(grid_size),
|
||||
dim3(ElementwiseBlocksize),
|
||||
0,
|
||||
make_tuple(arg.e_in_transpose_desc_),
|
||||
make_tuple(arg.e_out_transpose_desc_),
|
||||
make_tuple(p_e_in_grid),
|
||||
make_tuple(p_e_out_grid),
|
||||
arg.elementwise_block_2_ctile_map_transpose_e_,
|
||||
element_wise::PassThrough{});
|
||||
}
|
||||
}
|
||||
|
||||
avg_time += RunGemm(arg, stream_config);
|
||||
|
||||
if constexpr(is_NGCHW_GKCYX_NGKHW<ALayout, BLayout, ELayout>() ||
|
||||
is_NGCDHW_GKCZYX_NGKDHW<ALayout, BLayout, ELayout>())
|
||||
{
|
||||
const index_t grid_size =
|
||||
arg.elementwise_block_2_ctile_map_transpose_e_.CalculateGridSize(
|
||||
arg.e_in_transpose_desc_);
|
||||
|
||||
const EDataType* p_e_in_grid =
|
||||
type_convert<EDataType*>(arg.p_workspace_) +
|
||||
(arg.GetWorkspaceATensorSizeBytes() + arg.GetWorkspaceBTensorSizeBytes()) /
|
||||
sizeof(EDataType);
|
||||
|
||||
EDataType* p_e_out_grid = arg.p_e_grid_;
|
||||
|
||||
auto kernel_transpose = kernel_elementwise<GridwiseElementwiseOutputTranspose,
|
||||
ck::Tuple<NHWGCTransposeDescType>,
|
||||
ck::Tuple<NGCHWTransposeDescType>,
|
||||
ck::Tuple<const EDataType*>,
|
||||
ck::Tuple<EDataType*>,
|
||||
Block2TileMapElementwise,
|
||||
element_wise::PassThrough>;
|
||||
|
||||
avg_time += launch_and_time_kernel(stream_config,
|
||||
kernel_transpose,
|
||||
dim3(grid_size),
|
||||
dim3(ElementwiseBlocksize),
|
||||
0,
|
||||
make_tuple(arg.e_in_transpose_desc_),
|
||||
make_tuple(arg.e_out_transpose_desc_),
|
||||
make_tuple(p_e_in_grid),
|
||||
make_tuple(p_e_out_grid),
|
||||
arg.elementwise_block_2_ctile_map_transpose_e_,
|
||||
element_wise::PassThrough{});
|
||||
}
|
||||
|
||||
return avg_time;
|
||||
}
|
||||
|
||||
@@ -1182,6 +1184,12 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
const index_t G = arg.b_g_k_c_xs_lengths_[I0];
|
||||
const index_t K = arg.b_g_k_c_xs_lengths_[I1];
|
||||
const index_t C = arg.b_g_k_c_xs_lengths_[I2];
|
||||
// Move this to runtime check to align Conv instances
|
||||
// with Conv Multiple D instances
|
||||
if constexpr(isMultiABD)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// check device
|
||||
if(get_device_name() == "gfx908")
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -192,7 +192,6 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
|
||||
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
static constexpr index_t MaxGemmsNum = 32;
|
||||
static_assert(NumDTensor == 0, "MultiD not supported.");
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
@@ -440,89 +439,94 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
|
||||
input_left_pads_{input_left_pads},
|
||||
input_right_pads_{input_right_pads}
|
||||
{
|
||||
// Perform grouped gemm, generate array of tranformer for convolution
|
||||
Array<ConvToGemmFwdTransformerIndexT, MaxGemmsNum> conv_to_gemm_transformer_arr;
|
||||
Array<const ADataType*, MaxGemmsNum> a_grid_ptrs;
|
||||
Array<EDataType*, MaxGemmsNum> c_grid_ptrs;
|
||||
|
||||
ck::tie(conv_to_gemm_transformer_arr,
|
||||
a_grid_ptrs,
|
||||
c_grid_ptrs,
|
||||
gemms_count_,
|
||||
is_split_valid_) =
|
||||
GenerateConvToGemmTransforms(
|
||||
ConvToGemmFwdTransformerLongIndexT{a_g_n_c_wis_lengths_,
|
||||
a_g_n_c_wis_strides_,
|
||||
b_g_k_c_xs_lengths_,
|
||||
b_g_k_c_xs_strides_,
|
||||
e_g_n_k_wos_lengths_,
|
||||
e_g_n_k_wos_strides_,
|
||||
conv_filter_strides_,
|
||||
conv_filter_dilations_,
|
||||
input_left_pads_,
|
||||
input_right_pads_},
|
||||
static_cast<const ADataType*>(p_a),
|
||||
static_cast<EDataType*>(p_e));
|
||||
|
||||
grid_size_ = 0;
|
||||
valid_gemms_count_ = 0;
|
||||
|
||||
if(is_split_valid_)
|
||||
if constexpr(NumDTensor == 0)
|
||||
{
|
||||
// Create GemmArg for each gemm(conv)
|
||||
for(index_t i = 0; i < gemms_count_; i++)
|
||||
// Perform grouped gemm, generate array of tranformer for convolution
|
||||
Array<ConvToGemmFwdTransformerIndexT, MaxGemmsNum> conv_to_gemm_transformer_arr;
|
||||
Array<const ADataType*, MaxGemmsNum> a_grid_ptrs;
|
||||
Array<EDataType*, MaxGemmsNum> c_grid_ptrs;
|
||||
|
||||
ck::tie(conv_to_gemm_transformer_arr,
|
||||
a_grid_ptrs,
|
||||
c_grid_ptrs,
|
||||
gemms_count_,
|
||||
is_split_valid_) =
|
||||
GenerateConvToGemmTransforms(
|
||||
ConvToGemmFwdTransformerLongIndexT{a_g_n_c_wis_lengths_,
|
||||
a_g_n_c_wis_strides_,
|
||||
b_g_k_c_xs_lengths_,
|
||||
b_g_k_c_xs_strides_,
|
||||
e_g_n_k_wos_lengths_,
|
||||
e_g_n_k_wos_strides_,
|
||||
conv_filter_strides_,
|
||||
conv_filter_dilations_,
|
||||
input_left_pads_,
|
||||
input_right_pads_},
|
||||
static_cast<const ADataType*>(p_a),
|
||||
static_cast<EDataType*>(p_e));
|
||||
|
||||
grid_size_ = 0;
|
||||
valid_gemms_count_ = 0;
|
||||
|
||||
if(is_split_valid_)
|
||||
{
|
||||
const AGridDesc_M_K a_grid_desc_m_k{DeviceOp::MakeAGridDescriptor_M_K<ALayout>(
|
||||
conv_to_gemm_transformer_arr[i])};
|
||||
const BGridDesc_N_K b_grid_desc_n_k{DeviceOp::MakeBGridDescriptor_N_K<BLayout>(
|
||||
conv_to_gemm_transformer_arr[i])};
|
||||
const auto e_grid_desc_m_n =
|
||||
DeviceOp::MakeEGridDescriptor_M_N<ELayout>(conv_to_gemm_transformer_arr[i]);
|
||||
|
||||
const auto block_2_etile_map =
|
||||
GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n);
|
||||
|
||||
const index_t grid_size_grp =
|
||||
block_2_etile_map.CalculateGridSize(e_grid_desc_m_n);
|
||||
|
||||
const index_t BlockStart = grid_size_;
|
||||
const index_t BlockEnd = grid_size_ + grid_size_grp;
|
||||
|
||||
grid_size_ += grid_size_grp;
|
||||
|
||||
if(GridwiseGemm::CheckValidity(a_grid_desc_m_k,
|
||||
b_grid_desc_n_k,
|
||||
Tuple<>{},
|
||||
e_grid_desc_m_n,
|
||||
block_2_etile_map))
|
||||
// Create GemmArg for each gemm(conv)
|
||||
for(index_t i = 0; i < gemms_count_; i++)
|
||||
{
|
||||
const AGridDesc_M_K a_grid_desc_m_k{
|
||||
DeviceOp::MakeAGridDescriptor_M_K<ALayout>(
|
||||
conv_to_gemm_transformer_arr[i])};
|
||||
const BGridDesc_N_K b_grid_desc_n_k{
|
||||
DeviceOp::MakeBGridDescriptor_N_K<BLayout>(
|
||||
conv_to_gemm_transformer_arr[i])};
|
||||
const auto e_grid_desc_m_n = DeviceOp::MakeEGridDescriptor_M_N<ELayout>(
|
||||
conv_to_gemm_transformer_arr[i]);
|
||||
|
||||
gemm_desc_kernel_args_(valid_gemms_count_) = GemmArgs{
|
||||
a_grid_ptrs[i],
|
||||
static_cast<const BDataType*>(p_b),
|
||||
c_grid_ptrs[i],
|
||||
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k),
|
||||
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k),
|
||||
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
e_grid_desc_m_n),
|
||||
block_2_etile_map,
|
||||
BlockStart,
|
||||
BlockEnd};
|
||||
const auto block_2_etile_map =
|
||||
GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n);
|
||||
|
||||
valid_gemms_count_++;
|
||||
const index_t grid_size_grp =
|
||||
block_2_etile_map.CalculateGridSize(e_grid_desc_m_n);
|
||||
|
||||
const index_t BlockStart = grid_size_;
|
||||
const index_t BlockEnd = grid_size_ + grid_size_grp;
|
||||
|
||||
grid_size_ += grid_size_grp;
|
||||
|
||||
if(GridwiseGemm::CheckValidity(a_grid_desc_m_k,
|
||||
b_grid_desc_n_k,
|
||||
Tuple<>{},
|
||||
e_grid_desc_m_n,
|
||||
block_2_etile_map))
|
||||
{
|
||||
|
||||
gemm_desc_kernel_args_(valid_gemms_count_) = GemmArgs{
|
||||
a_grid_ptrs[i],
|
||||
static_cast<const BDataType*>(p_b),
|
||||
c_grid_ptrs[i],
|
||||
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k),
|
||||
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k),
|
||||
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
e_grid_desc_m_n),
|
||||
block_2_etile_map,
|
||||
BlockStart,
|
||||
BlockEnd};
|
||||
|
||||
valid_gemms_count_++;
|
||||
}
|
||||
}
|
||||
// N is the same for all convs
|
||||
conv_N_per_block_ = static_cast<index_t>(conv_to_gemm_transformer_arr[I0].N_);
|
||||
}
|
||||
// N is the same for all convs
|
||||
conv_N_per_block_ = static_cast<index_t>(conv_to_gemm_transformer_arr[I0].N_);
|
||||
|
||||
// Strides for G and N remain the same
|
||||
compute_ptr_offset_of_groups_.BatchStrideA_ = a_g_n_c_wis_strides[0];
|
||||
compute_ptr_offset_of_groups_.BatchStrideB_ = b_g_k_c_xs_strides[0];
|
||||
compute_ptr_offset_of_groups_.BatchStrideE_ = e_g_n_k_wos_strides[0];
|
||||
|
||||
compute_ptr_offset_of_n_.BatchStrideA_ = a_g_n_c_wis_strides[1] * conv_N_per_block_;
|
||||
compute_ptr_offset_of_n_.BatchStrideE_ = e_g_n_k_wos_strides[1] * conv_N_per_block_;
|
||||
}
|
||||
|
||||
// Strides for G and N remain the same
|
||||
compute_ptr_offset_of_groups_.BatchStrideA_ = a_g_n_c_wis_strides[0];
|
||||
compute_ptr_offset_of_groups_.BatchStrideB_ = b_g_k_c_xs_strides[0];
|
||||
compute_ptr_offset_of_groups_.BatchStrideE_ = e_g_n_k_wos_strides[0];
|
||||
|
||||
compute_ptr_offset_of_n_.BatchStrideA_ = a_g_n_c_wis_strides[1] * conv_N_per_block_;
|
||||
compute_ptr_offset_of_n_.BatchStrideE_ = e_g_n_k_wos_strides[1] * conv_N_per_block_;
|
||||
}
|
||||
|
||||
void Print() const
|
||||
@@ -578,55 +582,63 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
|
||||
{
|
||||
float Run(const DeviceOp::Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
if(stream_config.log_level_ > 0)
|
||||
if constexpr(NumDTensor == 0)
|
||||
{
|
||||
arg.Print();
|
||||
}
|
||||
if(stream_config.log_level_ > 0)
|
||||
{
|
||||
arg.Print();
|
||||
}
|
||||
|
||||
const index_t num_workgroups_per_Conv_N =
|
||||
arg.a_g_n_c_wis_lengths_[I1] / arg.conv_N_per_block_;
|
||||
const index_t num_workgroups_per_Conv_N =
|
||||
arg.a_g_n_c_wis_lengths_[I1] / arg.conv_N_per_block_;
|
||||
|
||||
const index_t gdx = arg.grid_size_;
|
||||
const index_t gdy = arg.num_group_;
|
||||
const index_t gdz = num_workgroups_per_Conv_N;
|
||||
const index_t gdx = arg.grid_size_;
|
||||
const index_t gdy = arg.num_group_;
|
||||
const index_t gdz = num_workgroups_per_Conv_N;
|
||||
|
||||
// K is constant for all gemms
|
||||
const auto K = arg.gemm_desc_kernel_args_[I0].a_grid_desc_ak0_m_ak1_.GetLength(I0) *
|
||||
arg.gemm_desc_kernel_args_[I0].a_grid_desc_ak0_m_ak1_.GetLength(I2);
|
||||
// K is constant for all gemms
|
||||
const auto K = arg.gemm_desc_kernel_args_[I0].a_grid_desc_ak0_m_ak1_.GetLength(I0) *
|
||||
arg.gemm_desc_kernel_args_[I0].a_grid_desc_ak0_m_ak1_.GetLength(I2);
|
||||
|
||||
auto launch_kernel = [&](auto has_main_k_block_loop) {
|
||||
constexpr bool has_main_loop = has_main_k_block_loop.value;
|
||||
const auto kernel = kernel_grouped_conv_fwd_multiple_d_grouped_gemm_xdl_cshuffle<
|
||||
GridwiseGemm,
|
||||
MaxGemmsNum,
|
||||
GemmArgs,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
|
||||
has_main_loop>;
|
||||
auto launch_kernel = [&](auto has_main_k_block_loop) {
|
||||
constexpr bool has_main_loop = has_main_k_block_loop.value;
|
||||
const auto kernel =
|
||||
kernel_grouped_conv_fwd_multiple_d_grouped_gemm_xdl_cshuffle<
|
||||
GridwiseGemm,
|
||||
MaxGemmsNum,
|
||||
GemmArgs,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
|
||||
has_main_loop>;
|
||||
|
||||
return launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(gdx, gdy, gdz),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.gemm_desc_kernel_args_,
|
||||
arg.gemms_count_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.cde_element_op_,
|
||||
arg.compute_ptr_offset_of_groups_,
|
||||
arg.compute_ptr_offset_of_n_);
|
||||
};
|
||||
return launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(gdx, gdy, gdz),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.gemm_desc_kernel_args_,
|
||||
arg.gemms_count_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.cde_element_op_,
|
||||
arg.compute_ptr_offset_of_groups_,
|
||||
arg.compute_ptr_offset_of_n_);
|
||||
};
|
||||
|
||||
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
|
||||
{
|
||||
return launch_kernel(integral_constant<bool, true>{});
|
||||
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
|
||||
{
|
||||
return launch_kernel(integral_constant<bool, true>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return launch_kernel(integral_constant<bool, false>{});
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return launch_kernel(integral_constant<bool, false>{});
|
||||
return 0.f;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -643,6 +655,12 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
|
||||
|
||||
const long_index_t K = arg.b_g_k_c_xs_lengths_[I1];
|
||||
const long_index_t C = arg.b_g_k_c_xs_lengths_[I2];
|
||||
// Move this to runtime check to align Conv instances
|
||||
// with Conv Multiple D instances
|
||||
if constexpr(NumDTensor != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check if all descs are valid
|
||||
if(!(arg.is_split_valid_ && arg.gemms_count_ == arg.valid_gemms_count_))
|
||||
|
||||
Reference in New Issue
Block a user