mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 05:01:25 +00:00
Support large tensors in grouped conv fwd (#1332)
* Support large tensors in grouped conv fwd * Multi ABD fixes * Fix calculate element space size
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -247,7 +247,8 @@ struct DeviceColumnToImageImpl
|
||||
independent_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads_with_offset,
|
||||
input_right_pads);
|
||||
input_right_pads,
|
||||
N);
|
||||
|
||||
const auto in_gemmm_gemmk_desc =
|
||||
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc);
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -93,12 +93,9 @@ __global__ void
|
||||
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
|
||||
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
|
||||
|
||||
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
|
||||
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
|
||||
const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)));
|
||||
const long_index_t a_batch_offset = compute_ptr_offset_of_batch.GetAPtrOffset(g_idx);
|
||||
const long_index_t b_batch_offset = compute_ptr_offset_of_batch.GetBPtrOffset(g_idx);
|
||||
const long_index_t e_batch_offset = compute_ptr_offset_of_batch.GetEPtrOffset(g_idx);
|
||||
|
||||
const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
|
||||
|
||||
|
||||
@@ -54,12 +54,9 @@ __global__ void
|
||||
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
|
||||
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
|
||||
|
||||
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
|
||||
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
|
||||
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx)));
|
||||
const long_index_t a_batch_offset = compute_ptr_offset_of_batch.GetAPtrOffset(g_idx);
|
||||
const long_index_t b_batch_offset = compute_ptr_offset_of_batch.GetBPtrOffset(g_idx);
|
||||
const long_index_t c_batch_offset = compute_ptr_offset_of_batch.GetCPtrOffset(g_idx);
|
||||
|
||||
__shared__ FloatAB p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB)];
|
||||
|
||||
|
||||
@@ -66,12 +66,9 @@ __global__ void
|
||||
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
|
||||
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
|
||||
|
||||
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
|
||||
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
|
||||
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx)));
|
||||
const long_index_t a_batch_offset = compute_ptr_offset_of_batch.GetAPtrOffset(g_idx);
|
||||
const long_index_t b_batch_offset = compute_ptr_offset_of_batch.GetBPtrOffset(g_idx);
|
||||
const long_index_t c_batch_offset = compute_ptr_offset_of_batch.GetCPtrOffset(g_idx);
|
||||
|
||||
__shared__ FloatA p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatA)];
|
||||
|
||||
|
||||
@@ -59,12 +59,9 @@ __global__ void
|
||||
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z * NumBatchToMerge);
|
||||
const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block);
|
||||
|
||||
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
|
||||
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
|
||||
const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)));
|
||||
const long_index_t a_batch_offset = compute_ptr_offset_of_batch.GetAPtrOffset(g_idx);
|
||||
const long_index_t b_batch_offset = compute_ptr_offset_of_batch.GetBPtrOffset(g_idx);
|
||||
const long_index_t e_batch_offset = compute_ptr_offset_of_batch.GetEPtrOffset(g_idx);
|
||||
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
@@ -116,12 +113,9 @@ __global__ void
|
||||
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z * NumBatchToMerge);
|
||||
const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block);
|
||||
|
||||
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
|
||||
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
|
||||
const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)));
|
||||
const long_index_t a_batch_offset = compute_ptr_offset_of_batch.GetAPtrOffset(g_idx);
|
||||
const long_index_t b_batch_offset = compute_ptr_offset_of_batch.GetBPtrOffset(g_idx);
|
||||
const long_index_t e_batch_offset = compute_ptr_offset_of_batch.GetEPtrOffset(g_idx);
|
||||
|
||||
// Pass two lds pointer is the key to tell compiler that ds_read/write
|
||||
// operate on different lds chunk at same time without order dependecy
|
||||
@@ -1268,7 +1262,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
arg.Conv_G_;
|
||||
|
||||
std::array<index_t, I1> in_out_batch_strides = {
|
||||
arg.compute_ptr_offset_of_batch_.BatchStrideC_};
|
||||
static_cast<index_t>(arg.compute_ptr_offset_of_batch_.BatchStrideC_)};
|
||||
|
||||
const auto kernel = kernel_batched_elementwise<GridwiseElementwise,
|
||||
ck::Tuple<CElementwiseGridDesc_M_N>,
|
||||
|
||||
@@ -61,12 +61,9 @@ __global__ void
|
||||
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
|
||||
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
|
||||
|
||||
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
|
||||
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
|
||||
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx)));
|
||||
const long_index_t a_batch_offset = compute_ptr_offset_of_batch.GetAPtrOffset(g_idx);
|
||||
const long_index_t b_batch_offset = compute_ptr_offset_of_batch.GetBPtrOffset(g_idx);
|
||||
const long_index_t c_batch_offset = compute_ptr_offset_of_batch.GetCPtrOffset(g_idx);
|
||||
|
||||
__shared__ FloatA p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatA)];
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -97,12 +97,9 @@ __global__ void
|
||||
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
|
||||
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
|
||||
|
||||
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
|
||||
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
|
||||
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)));
|
||||
const long_index_t a_batch_offset = compute_ptr_offset_of_batch.GetAPtrOffset(g_idx);
|
||||
const long_index_t b_batch_offset = compute_ptr_offset_of_batch.GetBPtrOffset(g_idx);
|
||||
const long_index_t c_batch_offset = compute_ptr_offset_of_batch.GetEPtrOffset(g_idx);
|
||||
|
||||
const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
|
||||
|
||||
@@ -266,7 +263,8 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads);
|
||||
input_right_pads,
|
||||
a_g_n_c_wis_lengths[I1]);
|
||||
|
||||
const auto in_gemmm_gemmk_desc =
|
||||
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc);
|
||||
@@ -312,8 +310,8 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
|
||||
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides)
|
||||
{
|
||||
const auto out_gemmmraw_gemmnraw_desc =
|
||||
conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>(e_g_n_k_wos_lengths,
|
||||
e_g_n_k_wos_strides);
|
||||
conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>(
|
||||
e_g_n_k_wos_lengths, e_g_n_k_wos_strides, e_g_n_k_wos_lengths[I1]);
|
||||
|
||||
const auto out_gemmm_gemmn_desc =
|
||||
matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -263,7 +263,8 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads);
|
||||
input_right_pads,
|
||||
a_g_n_c_wis_lengths[I1]);
|
||||
|
||||
const auto in_gemmm_gemmk_desc =
|
||||
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc);
|
||||
@@ -310,8 +311,8 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS
|
||||
const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_strides)
|
||||
{
|
||||
const auto out_gemmmraw_gemmnraw_desc =
|
||||
conv_to_gemm_transformer.template MakeCDescriptor_M_N<CLay>(c_g_n_k_wos_lengths,
|
||||
c_g_n_k_wos_strides);
|
||||
conv_to_gemm_transformer.template MakeCDescriptor_M_N<CLay>(
|
||||
c_g_n_k_wos_lengths, c_g_n_k_wos_strides, c_g_n_k_wos_lengths[I1]);
|
||||
|
||||
const auto out_gemmm_gemmn_desc =
|
||||
matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
|
||||
|
||||
@@ -69,7 +69,8 @@ template <typename GridwiseGemm,
|
||||
typename DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename Block2ETileMap,
|
||||
typename ComputePtrOffsetOfBatch,
|
||||
typename ComputePtrOffsetOfG,
|
||||
typename ComputePtrOffsetOfN,
|
||||
bool HasMainKBlockLoop,
|
||||
bool isMultiA,
|
||||
bool isMultiB>
|
||||
@@ -85,7 +86,7 @@ __global__ void
|
||||
const AElementwiseOperation a_element_op,
|
||||
const BElementwiseOperation b_element_op,
|
||||
const CDEElementwiseOperation cde_element_op,
|
||||
const index_t batch_count,
|
||||
const index_t groups_count,
|
||||
const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1,
|
||||
const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1,
|
||||
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
@@ -93,18 +94,22 @@ __global__ void
|
||||
const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
const Block2ETileMap block_2_ctile_map,
|
||||
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
|
||||
const ComputePtrOffsetOfG compute_ptr_offset_of_groups,
|
||||
const ComputePtrOffsetOfN compute_ptr_offset_of_n)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx94__))
|
||||
// offset base pointer for each work-group
|
||||
const index_t num_blocks_per_batch =
|
||||
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
|
||||
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
|
||||
|
||||
const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)));
|
||||
const auto& ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
|
||||
// offset base pointer for each work-group
|
||||
const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(gridDim.y / groups_count);
|
||||
const index_t& num_blocks_per_n = groups_count;
|
||||
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y / num_blocks_per_batch);
|
||||
const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.y / num_blocks_per_n);
|
||||
|
||||
const long_index_t e_batch_offset = compute_ptr_offset_of_groups.GetEPtrOffset(g_idx);
|
||||
const auto& ds_batch_offset = compute_ptr_offset_of_groups.GetDsPtrOffset(g_idx);
|
||||
|
||||
const long_index_t e_n_offset = compute_ptr_offset_of_n.GetEPtrOffset(n_idx);
|
||||
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
@@ -121,13 +126,28 @@ __global__ void
|
||||
AsPointer p_as_grid_grp;
|
||||
BsPointer p_bs_grid_grp;
|
||||
|
||||
const auto& as_batch_offset = compute_ptr_offset_of_batch.GetAsPtrOffset(g_idx);
|
||||
const auto& as_batch_offset = compute_ptr_offset_of_groups.GetAsPtrOffset(g_idx);
|
||||
|
||||
static constexpr index_t NumATensor = AGridDesc_AK0_M_AK1::Size();
|
||||
static_for<0, NumATensor, 1>{}(
|
||||
[&](auto i) { p_as_grid_grp(i) = p_as_grid[i] + as_batch_offset[i]; });
|
||||
// compute_ptr_offset_of_n_ not need BatchStrideB so
|
||||
// in case of MultiA is false but isMultiB is true
|
||||
// BatchStrideA_ is not tuple.
|
||||
if constexpr(isMultiA)
|
||||
{
|
||||
const auto& as_n_offset = compute_ptr_offset_of_n.GetAsPtrOffset(n_idx);
|
||||
|
||||
const auto& bs_batch_offset = compute_ptr_offset_of_batch.GetBsPtrOffset(g_idx);
|
||||
static constexpr index_t NumATensor = AGridDesc_AK0_M_AK1::Size();
|
||||
static_for<0, NumATensor, 1>{}([&](auto i) {
|
||||
p_as_grid_grp(i) = p_as_grid[i] + as_batch_offset[i] + as_n_offset[i];
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
const long_index_t a_n_offset = compute_ptr_offset_of_n.GetAPtrOffset(n_idx);
|
||||
static_for<0, 1, 1>{}(
|
||||
[&](auto i) { p_as_grid_grp(i) = p_as_grid[i] + as_batch_offset[i] + a_n_offset; });
|
||||
}
|
||||
|
||||
const auto& bs_batch_offset = compute_ptr_offset_of_groups.GetBsPtrOffset(g_idx);
|
||||
|
||||
static constexpr index_t NumBTensor = BGridDesc_BK0_N_BK1::Size();
|
||||
static_for<0, NumBTensor, 1>{}(
|
||||
@@ -137,7 +157,7 @@ __global__ void
|
||||
p_as_grid_grp,
|
||||
p_bs_grid_grp,
|
||||
p_ds_grid_grp,
|
||||
p_e_grid + e_batch_offset,
|
||||
p_e_grid + e_batch_offset + e_n_offset,
|
||||
p_shared,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
@@ -150,16 +170,16 @@ __global__ void
|
||||
}
|
||||
else
|
||||
{
|
||||
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
|
||||
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
|
||||
const long_index_t a_batch_offset = compute_ptr_offset_of_groups.GetAPtrOffset(g_idx);
|
||||
const long_index_t b_batch_offset = compute_ptr_offset_of_groups.GetBPtrOffset(g_idx);
|
||||
|
||||
const long_index_t a_n_offset = compute_ptr_offset_of_n.GetAPtrOffset(n_idx);
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop>(
|
||||
p_as_grid + a_batch_offset,
|
||||
p_as_grid + a_batch_offset + a_n_offset,
|
||||
p_bs_grid + b_batch_offset,
|
||||
p_ds_grid_grp,
|
||||
p_e_grid + e_batch_offset,
|
||||
p_e_grid + e_batch_offset + e_n_offset,
|
||||
p_shared,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
@@ -175,7 +195,7 @@ __global__ void
|
||||
ignore = p_bs_grid;
|
||||
ignore = p_ds_grid;
|
||||
ignore = p_e_grid;
|
||||
ignore = batch_count;
|
||||
ignore = groups_count;
|
||||
ignore = a_grid_desc_k0_m_k1;
|
||||
ignore = b_grid_desc_k0_n_k1;
|
||||
ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
@@ -183,7 +203,8 @@ __global__ void
|
||||
ignore = a_element_op;
|
||||
ignore = b_element_op;
|
||||
ignore = cde_element_op;
|
||||
ignore = compute_ptr_offset_of_batch;
|
||||
ignore = compute_ptr_offset_of_groups;
|
||||
ignore = compute_ptr_offset_of_n;
|
||||
ignore = block_2_ctile_map;
|
||||
#endif
|
||||
}
|
||||
@@ -309,7 +330,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<index_t, NDimSpatial>& input_right_pads)
|
||||
const std::array<index_t, NDimSpatial>& input_right_pads,
|
||||
const index_t Conv_N)
|
||||
{
|
||||
const auto in_gemmmraw_gemmkraw_desc =
|
||||
conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>(a_g_n_c_wis_lengths,
|
||||
@@ -321,7 +343,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads);
|
||||
input_right_pads,
|
||||
Conv_N);
|
||||
|
||||
const auto in_gemmm_gemmk_desc =
|
||||
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc);
|
||||
@@ -347,11 +370,12 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
template <typename ELay>
|
||||
static auto
|
||||
MakeEGridDescriptor_M_N(const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides)
|
||||
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
|
||||
const index_t Conv_N)
|
||||
{
|
||||
const auto out_gemmmraw_gemmnraw_desc =
|
||||
conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>(e_g_n_k_wos_lengths,
|
||||
e_g_n_k_wos_strides);
|
||||
conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>(
|
||||
e_g_n_k_wos_lengths, e_g_n_k_wos_strides, Conv_N);
|
||||
|
||||
const auto out_gemmm_gemmn_desc =
|
||||
matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
|
||||
@@ -363,24 +387,25 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
// Pass e_g_n_k_wos_lengths for logical broadcast.
|
||||
static auto MakeDsGridDescriptor_M_N(
|
||||
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
|
||||
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_strides)
|
||||
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_strides,
|
||||
const index_t Conv_N)
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
|
||||
|
||||
return DeviceOp::MakeEGridDescriptor_M_N<DLayout>(e_g_n_k_wos_lengths,
|
||||
ds_g_n_k_wos_strides[i]);
|
||||
return DeviceOp::MakeEGridDescriptor_M_N<DLayout>(
|
||||
e_g_n_k_wos_lengths, ds_g_n_k_wos_strides[i], Conv_N);
|
||||
},
|
||||
Number<NumDTensor>{});
|
||||
}
|
||||
|
||||
// desc for problem definition
|
||||
using AGridDesc_M_K = remove_cvref_t<decltype(MakeAGridDescriptor_M_K<ALayout>(
|
||||
{}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>;
|
||||
{}, {}, {}, {}, {}, {}, {}, {}, {}, {}, 1))>;
|
||||
using BGridDesc_N_K = remove_cvref_t<decltype(MakeBGridDescriptor_N_K<BLayout>({}, {}))>;
|
||||
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}))>;
|
||||
using EGridDesc_M_N = remove_cvref_t<decltype(MakeEGridDescriptor_M_N<ELayout>({}, {}))>;
|
||||
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}, 1))>;
|
||||
using EGridDesc_M_N = remove_cvref_t<decltype(MakeEGridDescriptor_M_N<ELayout>({}, {}, 1))>;
|
||||
|
||||
// If we are using multiAB and one of the template datatype parameters is not a tuple, convert
|
||||
// it to it
|
||||
@@ -468,6 +493,12 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
p_ds_grid_{},
|
||||
p_e_grid_{static_cast<EDataType*>(p_e)},
|
||||
num_group_{a_g_n_c_wis_lengths[0]},
|
||||
conv_N_per_block_{
|
||||
conv_to_gemm_transformer.template GetSplitedNSize<ADataType, EDataType>(
|
||||
a_g_n_c_wis_lengths,
|
||||
a_g_n_c_wis_strides,
|
||||
e_g_n_k_wos_lengths,
|
||||
e_g_n_k_wos_strides)},
|
||||
a_grid_desc_m_k_{DeviceOp::MakeAGridDescriptor_M_K<ALayout>(a_g_n_c_wis_lengths,
|
||||
a_g_n_c_wis_strides,
|
||||
b_g_k_c_xs_lengths,
|
||||
@@ -477,12 +508,13 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads)},
|
||||
input_right_pads,
|
||||
conv_N_per_block_)},
|
||||
b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K<BLayout>(b_g_k_c_xs_lengths,
|
||||
b_g_k_c_xs_strides)},
|
||||
ds_grid_desc_m_n_{},
|
||||
e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N<ELayout>(e_g_n_k_wos_lengths,
|
||||
e_g_n_k_wos_strides)},
|
||||
e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N<ELayout>(
|
||||
e_g_n_k_wos_lengths, e_g_n_k_wos_strides, conv_N_per_block_)},
|
||||
a_grid_desc_ak0_m_ak1_{
|
||||
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)},
|
||||
b_grid_desc_bk0_n_bk1_{
|
||||
@@ -490,7 +522,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock_{},
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock_{},
|
||||
block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)},
|
||||
compute_ptr_offset_of_batch_{},
|
||||
compute_ptr_offset_of_groups_{},
|
||||
compute_ptr_offset_of_n_{},
|
||||
a_element_op_{a_element_op},
|
||||
b_element_op_{b_element_op},
|
||||
cde_element_op_{cde_element_op},
|
||||
@@ -511,8 +544,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
if constexpr(isMultiA || isMultiB)
|
||||
{
|
||||
static_for<0, NumATensor, 1>{}([&](auto i) {
|
||||
// Init compute_ptr_offset_of_batch_ for multiple AB
|
||||
compute_ptr_offset_of_batch_.BatchStrideA_(i) = a_g_n_c_wis_strides[0];
|
||||
// Init compute_ptr_offset_of_groups_ for multiple AB
|
||||
compute_ptr_offset_of_groups_.BatchStrideA_(i) = a_g_n_c_wis_strides[0];
|
||||
|
||||
// Use GemmADataType/GemmBDataType to iterate over tuple (even if passed data
|
||||
// type is not tuple)
|
||||
@@ -524,16 +557,23 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
{
|
||||
// p_as is tuple
|
||||
p_as_grid_(i) = static_cast<const DataType*>(p_as[i.value]);
|
||||
// compute_ptr_offset_of_n_ not need BatchStrideB so
|
||||
// in case of MultiA is false but isMultiB is true
|
||||
// BatchStrideA_ is not tuple.
|
||||
compute_ptr_offset_of_n_.BatchStrideA_(i) =
|
||||
a_g_n_c_wis_strides[1] * conv_N_per_block_;
|
||||
}
|
||||
else
|
||||
{
|
||||
// if MultiB and not MultiA then p_as is single pointer
|
||||
p_as_grid_(i) = static_cast<const DataType*>(p_as);
|
||||
compute_ptr_offset_of_n_.BatchStrideA_ =
|
||||
a_g_n_c_wis_strides[1] * conv_N_per_block_;
|
||||
}
|
||||
});
|
||||
static_for<0, NumBTensor, 1>{}([&](auto i) {
|
||||
// Init compute_ptr_offset_of_batch_ for multiple AB
|
||||
compute_ptr_offset_of_batch_.BatchStrideB_(i) = b_g_k_c_xs_strides[0];
|
||||
// Init compute_ptr_offset_of_groups_ for multiple AB
|
||||
compute_ptr_offset_of_groups_.BatchStrideB_(i) = b_g_k_c_xs_strides[0];
|
||||
|
||||
using DataType = remove_cvref_t<tuple_element_t<i.value, GemmBDataType>>;
|
||||
// It is possible that one of the AB is a pointer and one is a tuple.
|
||||
@@ -553,8 +593,9 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
}
|
||||
else
|
||||
{
|
||||
compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_c_wis_strides[0];
|
||||
compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_k_c_xs_strides[0];
|
||||
compute_ptr_offset_of_groups_.BatchStrideA_ = a_g_n_c_wis_strides[0];
|
||||
compute_ptr_offset_of_groups_.BatchStrideB_ = b_g_k_c_xs_strides[0];
|
||||
compute_ptr_offset_of_n_.BatchStrideA_ = a_g_n_c_wis_strides[1] * conv_N_per_block_;
|
||||
|
||||
// p_as and p_bs are pointers
|
||||
p_as_grid_(I0) = static_cast<const ADataType*>(p_as);
|
||||
@@ -570,13 +611,16 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
p_ds_grid_(i) = static_cast<const DDataType*>(p_ds[i]);
|
||||
|
||||
// D batch stride
|
||||
compute_ptr_offset_of_batch_.BatchStrideDs_(i) = ds_g_n_k_wos_strides[i][0];
|
||||
compute_ptr_offset_of_groups_.BatchStrideDs_(i) = ds_g_n_k_wos_strides[i][0];
|
||||
compute_ptr_offset_of_n_.BatchStrideDs_(i) =
|
||||
ds_g_n_k_wos_strides[i][1] * conv_N_per_block_;
|
||||
|
||||
// D desc
|
||||
ds_grid_desc_m_n_(i) = DeviceOp::MakeEGridDescriptor_M_N<DLayout>(
|
||||
e_g_n_k_wos_lengths, ds_g_n_k_wos_strides[i]);
|
||||
e_g_n_k_wos_lengths, ds_g_n_k_wos_strides[i], conv_N_per_block_);
|
||||
});
|
||||
compute_ptr_offset_of_batch_.BatchStrideE_ = e_g_n_k_wos_strides[0];
|
||||
compute_ptr_offset_of_groups_.BatchStrideE_ = e_g_n_k_wos_strides[0];
|
||||
compute_ptr_offset_of_n_.BatchStrideE_ = e_g_n_k_wos_strides[1] * conv_N_per_block_;
|
||||
|
||||
// populate desc for Ds/E
|
||||
if constexpr(isMultiA || isMultiB)
|
||||
@@ -638,6 +682,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
|
||||
// tensor descriptors for problem definiton
|
||||
index_t num_group_;
|
||||
index_t conv_N_per_block_;
|
||||
|
||||
AGridDesc_M_K a_grid_desc_m_k_;
|
||||
BGridDesc_N_K b_grid_desc_n_k_;
|
||||
DsGridDesc_M_N ds_grid_desc_m_n_;
|
||||
@@ -655,7 +701,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
|
||||
// for computing batch offset
|
||||
ComputePtrOffsetOfStridedBatch<NumATensor, NumBTensor, NumDTensor>
|
||||
compute_ptr_offset_of_batch_;
|
||||
compute_ptr_offset_of_groups_;
|
||||
ComputePtrOffsetOfStridedBatch<NumATensor, I1, NumDTensor> compute_ptr_offset_of_n_;
|
||||
|
||||
// element-wise op
|
||||
AElementwiseOperation a_element_op_;
|
||||
@@ -689,8 +736,12 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
arg.Print();
|
||||
}
|
||||
|
||||
const index_t grid_size =
|
||||
arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_) * arg.num_group_;
|
||||
const index_t num_workgroups_per_Conv_N =
|
||||
arg.a_g_n_c_wis_lengths_[I1] / arg.conv_N_per_block_;
|
||||
|
||||
const index_t gdx = arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_);
|
||||
const index_t gdy = arg.num_group_ * num_workgroups_per_Conv_N;
|
||||
const index_t gdz = 1;
|
||||
|
||||
const auto K =
|
||||
arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
|
||||
@@ -721,6 +772,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
Block2ETileMap,
|
||||
ComputePtrOffsetOfStridedBatch<NumATensor, NumBTensor, NumDTensor>,
|
||||
ComputePtrOffsetOfStridedBatch<NumATensor, I1, NumDTensor>,
|
||||
has_main_loop,
|
||||
isMultiA,
|
||||
isMultiB>;
|
||||
@@ -728,7 +780,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
return launch_and_time_kernel(
|
||||
stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(gdx, gdy, gdz),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_as_grid_,
|
||||
@@ -744,7 +796,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.block_2_etile_map_,
|
||||
arg.compute_ptr_offset_of_batch_);
|
||||
arg.compute_ptr_offset_of_groups_,
|
||||
arg.compute_ptr_offset_of_n_);
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -763,6 +816,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
Block2ETileMap,
|
||||
ComputePtrOffsetOfStridedBatch<NumATensor, NumBTensor, NumDTensor>,
|
||||
ComputePtrOffsetOfStridedBatch<NumATensor, I1, NumDTensor>,
|
||||
has_main_loop,
|
||||
isMultiA,
|
||||
isMultiB>;
|
||||
@@ -770,7 +824,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
return launch_and_time_kernel(
|
||||
stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(gdx, gdy, gdz),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_as_grid_.At(I0), // Pass just A descriptor instead of tuple
|
||||
@@ -786,7 +840,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.block_2_etile_map_,
|
||||
arg.compute_ptr_offset_of_batch_);
|
||||
arg.compute_ptr_offset_of_groups_,
|
||||
arg.compute_ptr_offset_of_n_);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -60,7 +60,7 @@ template <typename GridwiseGemm,
|
||||
typename AGridDesc_AK0_M_K1,
|
||||
typename BGridDesc_BK0_N_K1,
|
||||
typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename ComputePtrOffsetOfBatch,
|
||||
typename ComputePtrOffset,
|
||||
bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
|
||||
index_t MinimumOccupancy = 1,
|
||||
@@ -69,26 +69,28 @@ __global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
#endif
|
||||
kernel_grouped_conv_fwd_xdl_cshuffle_v3(
|
||||
typename GridwiseGemm::Argument karg,
|
||||
const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1,
|
||||
const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1,
|
||||
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
|
||||
const index_t batch_count)
|
||||
kernel_grouped_conv_fwd_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg,
|
||||
const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1,
|
||||
const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1,
|
||||
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const ComputePtrOffset compute_ptr_offset_of_groups,
|
||||
const ComputePtrOffset compute_ptr_offset_of_n,
|
||||
const index_t groups_count)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
|
||||
// offset base pointer for each work-group
|
||||
const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(gridDim.y / batch_count);
|
||||
const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(gridDim.y / groups_count);
|
||||
const index_t& num_blocks_per_n = groups_count;
|
||||
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y / num_blocks_per_batch);
|
||||
const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.y / num_blocks_per_n);
|
||||
|
||||
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
|
||||
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
|
||||
const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)));
|
||||
const long_index_t a_batch_offset = compute_ptr_offset_of_groups.GetAPtrOffset(g_idx);
|
||||
const long_index_t b_batch_offset = compute_ptr_offset_of_groups.GetBPtrOffset(g_idx);
|
||||
const long_index_t e_batch_offset = compute_ptr_offset_of_groups.GetEPtrOffset(g_idx);
|
||||
|
||||
const long_index_t a_n_offset = compute_ptr_offset_of_n.GetAPtrOffset(n_idx);
|
||||
const long_index_t e_n_offset = compute_ptr_offset_of_n.GetEPtrOffset(n_idx);
|
||||
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
@@ -97,9 +99,9 @@ __global__ void
|
||||
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
HasMainKBlockLoop,
|
||||
CGlobalMemoryDataOperation,
|
||||
TailNum>(karg.p_a_grid + a_batch_offset,
|
||||
TailNum>(karg.p_a_grid + a_batch_offset + a_n_offset,
|
||||
karg.p_b_grid + b_batch_offset,
|
||||
karg.p_c_grid + e_batch_offset,
|
||||
karg.p_c_grid + e_batch_offset + e_n_offset,
|
||||
p_shared,
|
||||
karg,
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
@@ -114,7 +116,7 @@ template <typename GridwiseGemm,
|
||||
typename AGridDesc_AK0_M_K1,
|
||||
typename BGridDesc_BK0_N_K1,
|
||||
typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename ComputePtrOffsetOfBatch,
|
||||
typename ComputePtrOffset,
|
||||
bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
|
||||
index_t MinimumOccupancy = 1,
|
||||
@@ -129,20 +131,23 @@ __global__ void
|
||||
const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1,
|
||||
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
|
||||
const index_t batch_count)
|
||||
const ComputePtrOffset compute_ptr_offset_of_groups,
|
||||
const ComputePtrOffset compute_ptr_offset_of_n,
|
||||
const index_t groups_count)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
|
||||
// offset base pointer for each work-group
|
||||
const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(gridDim.y / batch_count);
|
||||
const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(gridDim.y / groups_count);
|
||||
const index_t& num_blocks_per_n = groups_count;
|
||||
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y / num_blocks_per_batch);
|
||||
const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.y / num_blocks_per_n);
|
||||
|
||||
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
|
||||
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
|
||||
const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)));
|
||||
const long_index_t a_batch_offset = compute_ptr_offset_of_groups.GetAPtrOffset(g_idx);
|
||||
const long_index_t b_batch_offset = compute_ptr_offset_of_groups.GetBPtrOffset(g_idx);
|
||||
const long_index_t e_batch_offset = compute_ptr_offset_of_groups.GetEPtrOffset(g_idx);
|
||||
|
||||
const long_index_t a_n_offset = compute_ptr_offset_of_n.GetAPtrOffset(n_idx);
|
||||
const long_index_t e_n_offset = compute_ptr_offset_of_n.GetEPtrOffset(n_idx);
|
||||
|
||||
// Pass two lds pointer is the key to tell compiler that ds_read/write
|
||||
// operate on different lds chunk at same time without order dependecy
|
||||
@@ -154,9 +159,9 @@ __global__ void
|
||||
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
HasMainKBlockLoop,
|
||||
CGlobalMemoryDataOperation,
|
||||
TailNum>(karg.p_a_grid + a_batch_offset,
|
||||
TailNum>(karg.p_a_grid + a_batch_offset + a_n_offset,
|
||||
karg.p_b_grid + b_batch_offset,
|
||||
karg.p_c_grid + e_batch_offset,
|
||||
karg.p_c_grid + e_batch_offset + e_n_offset,
|
||||
p_shared_0,
|
||||
p_shared_1,
|
||||
karg,
|
||||
@@ -294,7 +299,9 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<index_t, NDimSpatial>& input_right_pads)
|
||||
const std::array<index_t, NDimSpatial>& input_right_pads,
|
||||
const index_t Conv_N)
|
||||
|
||||
{
|
||||
const auto in_gemmmraw_gemmkraw_desc =
|
||||
conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>(a_g_n_c_wis_lengths,
|
||||
@@ -306,7 +313,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads);
|
||||
input_right_pads,
|
||||
Conv_N);
|
||||
|
||||
const auto in_gemmm_gemmk_desc =
|
||||
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc);
|
||||
@@ -350,11 +358,13 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
template <typename ELay>
|
||||
static auto
|
||||
MakeEGridDescriptor_M_N(const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides)
|
||||
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
|
||||
const index_t Conv_N)
|
||||
|
||||
{
|
||||
const auto out_gemmmraw_gemmnraw_desc =
|
||||
conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>(e_g_n_k_wos_lengths,
|
||||
e_g_n_k_wos_strides);
|
||||
conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>(
|
||||
e_g_n_k_wos_lengths, e_g_n_k_wos_strides, Conv_N);
|
||||
|
||||
const auto out_gemmm_gemmn_desc =
|
||||
matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
|
||||
@@ -363,7 +373,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
}
|
||||
|
||||
// desc for problem definition
|
||||
using EGridDesc_M_N = remove_cvref_t<decltype(MakeEGridDescriptor_M_N<ELayout>({}, {}))>;
|
||||
using EGridDesc_M_N = remove_cvref_t<decltype(MakeEGridDescriptor_M_N<ELayout>({}, {}, 1))>;
|
||||
|
||||
#define GridwiseGemmV3TemplateParams \
|
||||
tensor_layout::gemm::RowMajor, tensor_layout::gemm::ColumnMajor, \
|
||||
@@ -396,7 +406,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
|
||||
// desc for blockwise copy
|
||||
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(MakeAGridDescriptor_AK0_M_AK1<ALayout>(
|
||||
{}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>;
|
||||
{}, {}, {}, {}, {}, {}, {}, {}, {}, {}, 1))>;
|
||||
using BGridDesc_BK0_N_BK1 =
|
||||
remove_cvref_t<decltype(MakeBGridDescriptor_BK0_N_BK1<BLayout>({}, {}))>;
|
||||
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
|
||||
@@ -429,6 +439,12 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
p_b_grid_{},
|
||||
p_e_grid_{static_cast<EDataType*>(p_e)},
|
||||
num_group_{a_g_n_c_wis_lengths[0]},
|
||||
conv_N_per_block_{
|
||||
conv_to_gemm_transformer.template GetSplitedNSize<ADataType, EDataType>(
|
||||
a_g_n_c_wis_lengths,
|
||||
a_g_n_c_wis_strides,
|
||||
e_g_n_k_wos_lengths,
|
||||
e_g_n_k_wos_strides)},
|
||||
a_grid_desc_ak0_m_ak1_{MakeAGridDescriptor_AK0_M_AK1<ALayout>(a_g_n_c_wis_lengths,
|
||||
a_g_n_c_wis_strides,
|
||||
b_g_k_c_xs_lengths,
|
||||
@@ -438,13 +454,15 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads)},
|
||||
input_right_pads,
|
||||
conv_N_per_block_)},
|
||||
b_grid_desc_bk0_n_bk1_{
|
||||
MakeBGridDescriptor_BK0_N_BK1<BLayout>(b_g_k_c_xs_lengths, b_g_k_c_xs_strides)},
|
||||
e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N<ELayout>(e_g_n_k_wos_lengths,
|
||||
e_g_n_k_wos_strides)},
|
||||
e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N<ELayout>(
|
||||
e_g_n_k_wos_lengths, e_g_n_k_wos_strides, conv_N_per_block_)},
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock_{},
|
||||
compute_ptr_offset_of_batch_{},
|
||||
compute_ptr_offset_of_groups_{},
|
||||
compute_ptr_offset_of_n_{},
|
||||
a_element_op_{a_element_op},
|
||||
b_element_op_{b_element_op},
|
||||
cde_element_op_{cde_element_op},
|
||||
@@ -459,15 +477,17 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
input_left_pads_{input_left_pads},
|
||||
input_right_pads_{input_right_pads}
|
||||
{
|
||||
// A/B/E Batch Stride
|
||||
compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_c_wis_strides[0];
|
||||
compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_k_c_xs_strides[0];
|
||||
// A/B/E Batch/N Stride
|
||||
compute_ptr_offset_of_groups_.BatchStrideA_ = a_g_n_c_wis_strides[0];
|
||||
compute_ptr_offset_of_groups_.BatchStrideB_ = b_g_k_c_xs_strides[0];
|
||||
compute_ptr_offset_of_n_.BatchStrideA_ = a_g_n_c_wis_strides[1] * conv_N_per_block_;
|
||||
|
||||
// p_as and p_bs are pointers
|
||||
p_a_grid_ = static_cast<const ADataType*>(p_as);
|
||||
p_b_grid_ = static_cast<const BDataType*>(p_bs);
|
||||
|
||||
compute_ptr_offset_of_batch_.BatchStrideE_ = e_g_n_k_wos_strides[0];
|
||||
compute_ptr_offset_of_groups_.BatchStrideE_ = e_g_n_k_wos_strides[0];
|
||||
compute_ptr_offset_of_n_.BatchStrideE_ = e_g_n_k_wos_strides[1] * conv_N_per_block_;
|
||||
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock_ =
|
||||
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n_);
|
||||
@@ -488,6 +508,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
|
||||
// tensor descriptors for problem definiton
|
||||
index_t num_group_;
|
||||
index_t conv_N_per_block_;
|
||||
|
||||
// tensor descriptors for block/thread-wise copy
|
||||
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
|
||||
@@ -496,7 +517,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_;
|
||||
|
||||
// for computing batch offset
|
||||
ComputePtrOffsetOfStridedBatch<I1, I1, I0> compute_ptr_offset_of_batch_;
|
||||
ComputePtrOffsetOfStridedBatch<I1, I1, I0> compute_ptr_offset_of_groups_;
|
||||
ComputePtrOffsetOfStridedBatch<I1, I1, I0> compute_ptr_offset_of_n_;
|
||||
|
||||
// element-wise op
|
||||
AElementwiseOperation a_element_op_;
|
||||
@@ -538,11 +560,14 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
const index_t GemmK =
|
||||
arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
|
||||
|
||||
const index_t num_workgroups_per_Conv_N =
|
||||
arg.a_g_n_c_wis_lengths_[I1] / arg.conv_N_per_block_;
|
||||
|
||||
index_t gdx, gdy, gdz;
|
||||
std::tie(gdx, gdy, gdz) =
|
||||
GridwiseGemm::CalculateGridSize(GemmM, GemmN, I1 /*arg.KBatch*/);
|
||||
|
||||
gdy *= arg.num_group_;
|
||||
gdy *= arg.num_group_ * num_workgroups_per_Conv_N;
|
||||
|
||||
index_t K_split = (GemmK + KPerBlock - 1) / KPerBlock * KPerBlock;
|
||||
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
|
||||
@@ -579,7 +604,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.compute_ptr_offset_of_batch_,
|
||||
arg.compute_ptr_offset_of_groups_,
|
||||
arg.compute_ptr_offset_of_n_,
|
||||
arg.num_group_);
|
||||
}
|
||||
else
|
||||
@@ -594,7 +620,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.compute_ptr_offset_of_batch_,
|
||||
arg.compute_ptr_offset_of_groups_,
|
||||
arg.compute_ptr_offset_of_n_,
|
||||
arg.num_group_);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -338,7 +338,8 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads);
|
||||
input_right_pads,
|
||||
a_g_n_c_wis_lengths[I1]);
|
||||
|
||||
const auto in_gemmm_gemmk_desc =
|
||||
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc);
|
||||
@@ -367,8 +368,8 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
|
||||
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides)
|
||||
{
|
||||
const auto out_gemmmraw_gemmnraw_desc =
|
||||
conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>(e_g_n_k_wos_lengths,
|
||||
e_g_n_k_wos_strides);
|
||||
conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>(
|
||||
e_g_n_k_wos_lengths, e_g_n_k_wos_strides, e_g_n_k_wos_lengths[I1]);
|
||||
|
||||
const auto out_gemmm_gemmn_desc =
|
||||
matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -163,7 +163,8 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads);
|
||||
input_right_pads,
|
||||
a_g_n_c_wis_lengths[I1]);
|
||||
|
||||
const auto in_gemmm_gemmk_desc =
|
||||
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc);
|
||||
@@ -255,8 +256,8 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
|
||||
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides)
|
||||
{
|
||||
const auto out_gemmmraw_gemmnraw_desc =
|
||||
conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>(e_g_n_k_wos_lengths,
|
||||
e_g_n_k_wos_strides);
|
||||
conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>(
|
||||
e_g_n_k_wos_lengths, e_g_n_k_wos_strides, e_g_n_k_wos_lengths[I1]);
|
||||
|
||||
const auto out_gemmm_gemmn_desc =
|
||||
matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
|
||||
|
||||
@@ -68,14 +68,14 @@ template <index_t NumATensor, index_t NumBTensor, index_t NumDTensor>
|
||||
struct ComputePtrOffsetOfStridedBatch<NumATensor,
|
||||
NumBTensor,
|
||||
NumDTensor,
|
||||
ck::enable_if_t<(NumATensor > 1 || NumBTensor > 1)>>
|
||||
enable_if_t<(NumATensor > 1 || NumBTensor > 1)>>
|
||||
{
|
||||
ComputePtrOffsetOfStridedBatch() = default;
|
||||
|
||||
ComputePtrOffsetOfStridedBatch(Array<ck::index_t, NumATensor>& BatchStrideAs,
|
||||
Array<ck::index_t, NumBTensor>& BatchStrideBs,
|
||||
Array<ck::index_t, NumDTensor>& BatchStrideDs,
|
||||
index_t BatchStrideE)
|
||||
ComputePtrOffsetOfStridedBatch(Array<long_index_t, NumATensor>& BatchStrideAs,
|
||||
Array<long_index_t, NumBTensor>& BatchStrideBs,
|
||||
Array<long_index_t, NumDTensor>& BatchStrideDs,
|
||||
long_index_t BatchStrideE)
|
||||
: BatchStrideA_(BatchStrideAs),
|
||||
BatchStrideB_(BatchStrideBs),
|
||||
BatchStrideDs_(BatchStrideDs),
|
||||
@@ -87,7 +87,7 @@ struct ComputePtrOffsetOfStridedBatch<NumATensor,
|
||||
{
|
||||
Array<long_index_t, NumATensor> as_offset;
|
||||
static_for<0, NumATensor, 1>{}(
|
||||
[&](auto i) { as_offset(i) = g_idx * static_cast<long_index_t>(BatchStrideA_[i]); });
|
||||
[&](auto i) { as_offset(i) = static_cast<long_index_t>(g_idx) * BatchStrideA_[i]; });
|
||||
return as_offset;
|
||||
}
|
||||
|
||||
@@ -95,7 +95,7 @@ struct ComputePtrOffsetOfStridedBatch<NumATensor,
|
||||
{
|
||||
Array<long_index_t, NumBTensor> bs_offset;
|
||||
static_for<0, NumBTensor, 1>{}(
|
||||
[&](auto i) { bs_offset(i) = g_idx * static_cast<long_index_t>(BatchStrideB_[i]); });
|
||||
[&](auto i) { bs_offset(i) = static_cast<long_index_t>(g_idx) * BatchStrideB_[i]; });
|
||||
return bs_offset;
|
||||
}
|
||||
|
||||
@@ -103,40 +103,40 @@ struct ComputePtrOffsetOfStridedBatch<NumATensor,
|
||||
{
|
||||
Array<long_index_t, NumDTensor> ds_offset;
|
||||
static_for<0, NumDTensor, 1>{}(
|
||||
[&](auto i) { ds_offset(i) = g_idx * static_cast<long_index_t>(BatchStrideDs_[i]); });
|
||||
[&](auto i) { ds_offset(i) = static_cast<long_index_t>(g_idx) * BatchStrideDs_[i]; });
|
||||
return ds_offset;
|
||||
}
|
||||
|
||||
[[maybe_unused]] __host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideE_);
|
||||
return static_cast<long_index_t>(g_idx) * BatchStrideE_;
|
||||
}
|
||||
|
||||
// alias for kernels without multiple D
|
||||
[[maybe_unused]] __host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideE_);
|
||||
return static_cast<long_index_t>(g_idx) * BatchStrideE_;
|
||||
}
|
||||
|
||||
Array<ck::index_t, NumATensor> BatchStrideA_;
|
||||
Array<ck::index_t, NumBTensor> BatchStrideB_;
|
||||
Array<ck::index_t, NumDTensor> BatchStrideDs_;
|
||||
index_t BatchStrideE_;
|
||||
index_t& BatchStrideC_ = BatchStrideE_; // alias for kernels without multiple D
|
||||
Array<long_index_t, NumATensor> BatchStrideA_;
|
||||
Array<long_index_t, NumBTensor> BatchStrideB_;
|
||||
Array<long_index_t, NumDTensor> BatchStrideDs_;
|
||||
long_index_t BatchStrideE_;
|
||||
long_index_t& BatchStrideC_ = BatchStrideE_; // alias for kernels without multiple D
|
||||
};
|
||||
|
||||
template <index_t NumATensor, index_t NumBTensor, index_t NumDTensor>
|
||||
struct ComputePtrOffsetOfStridedBatch<NumATensor,
|
||||
NumBTensor,
|
||||
NumDTensor,
|
||||
ck::enable_if_t<(NumATensor == 1 && NumBTensor == 1)>>
|
||||
enable_if_t<(NumATensor == 1 && NumBTensor == 1)>>
|
||||
{
|
||||
ComputePtrOffsetOfStridedBatch() = default;
|
||||
|
||||
ComputePtrOffsetOfStridedBatch(index_t BatchStrideA,
|
||||
index_t BatchStrideB,
|
||||
Array<ck::index_t, NumDTensor> BatchStrideDs,
|
||||
index_t BatchStrideE)
|
||||
ComputePtrOffsetOfStridedBatch(long_index_t BatchStrideA,
|
||||
long_index_t BatchStrideB,
|
||||
Array<long_index_t, NumDTensor> BatchStrideDs,
|
||||
long_index_t BatchStrideE)
|
||||
: BatchStrideA_(BatchStrideA),
|
||||
BatchStrideB_(BatchStrideB),
|
||||
BatchStrideDs_(BatchStrideDs),
|
||||
@@ -146,38 +146,38 @@ struct ComputePtrOffsetOfStridedBatch<NumATensor,
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideA_);
|
||||
return static_cast<long_index_t>(g_idx) * BatchStrideA_;
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideB_);
|
||||
return static_cast<long_index_t>(g_idx) * BatchStrideB_;
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const
|
||||
{
|
||||
Array<long_index_t, NumDTensor> ds_offset;
|
||||
static_for<0, NumDTensor, 1>{}(
|
||||
[&](auto i) { ds_offset(i) = g_idx * static_cast<long_index_t>(BatchStrideDs_[i]); });
|
||||
[&](auto i) { ds_offset(i) = static_cast<long_index_t>(g_idx) * BatchStrideDs_[i]; });
|
||||
return ds_offset;
|
||||
}
|
||||
|
||||
[[maybe_unused]] __host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideE_);
|
||||
return static_cast<long_index_t>(g_idx) * BatchStrideE_;
|
||||
}
|
||||
|
||||
// alias for kernels without multiple D
|
||||
[[maybe_unused]] __host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideE_);
|
||||
return static_cast<long_index_t>(g_idx) * BatchStrideE_;
|
||||
}
|
||||
|
||||
ck::index_t BatchStrideA_;
|
||||
ck::index_t BatchStrideB_;
|
||||
Array<ck::index_t, NumDTensor> BatchStrideDs_;
|
||||
index_t BatchStrideE_;
|
||||
index_t& BatchStrideC_ = BatchStrideE_; // alias for kernels without multiple D
|
||||
long_index_t BatchStrideA_;
|
||||
long_index_t BatchStrideB_;
|
||||
Array<long_index_t, NumDTensor> BatchStrideDs_;
|
||||
long_index_t BatchStrideE_;
|
||||
long_index_t& BatchStrideC_ = BatchStrideE_; // alias for kernels without multiple D
|
||||
};
|
||||
|
||||
template <bool isTuple, typename Tensors>
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -108,7 +108,8 @@ struct DeviceImageToColumnImpl
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads);
|
||||
input_right_pads,
|
||||
N);
|
||||
|
||||
const auto in_gemmm_gemmk_desc =
|
||||
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc);
|
||||
|
||||
Reference in New Issue
Block a user