mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-06 15:54:31 +00:00
Fix tensor descriptors and stride calculations
This commit is contained in:
@@ -814,16 +814,23 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
(Conv_N_ * output_spatial_acum) % (KPerBlock * k_batch_) == 0;
|
||||
// Check if there is KPading and we can divide N * OutSpatialDims by k_batch
|
||||
split_k_offset_a_hack_ =
|
||||
(Conv_N_ * output_spatial_acum) % k_batch_ == 0 && is_k_not_paded &&
|
||||
k_batch_ > 1 && (Conv_N_ * output_spatial_acum) % k_batch_ == 0 && is_k_not_paded &&
|
||||
is_NSpatialGC_GKSpatial_NSpatialGK<InLayout, WeiLayout, OutLayout>();
|
||||
// Check if there is KPading and we can divide N by k_batch
|
||||
split_k_offset_b_hack_ =
|
||||
Conv_N_ % k_batch_ == 0 && is_k_not_paded &&
|
||||
k_batch_ > 1 && Conv_N_ % k_batch_ == 0 && is_k_not_paded &&
|
||||
is_NSpatialGC_GKSpatial_NSpatialGK<InLayout, WeiLayout, OutLayout>();
|
||||
|
||||
split_k_stride_a_ =
|
||||
a_g_n_k_wos_strides[NDimSpatial + I2] * (Conv_N_ * output_spatial_acum) / k_batch_;
|
||||
split_k_stride_b_ = b_g_n_c_wis_strides[I1] * Conv_N_ / k_batch_;
|
||||
// Calculate stride from descriptor size
|
||||
// NOTE: GetElementSpaceSize() returns the full size even when KBatchIndex=1,
|
||||
// so we need to divide by k_batch_ to get the per-batch stride when the hack is enabled
|
||||
split_k_stride_a_ = a_grid_desc_k0_m_k1_.GetElementSpaceSize();
|
||||
if(split_k_offset_a_hack_)
|
||||
split_k_stride_a_ /= k_batch_;
|
||||
|
||||
split_k_stride_b_ = b_grid_desc_k0_n_k1_.GetElementSpaceSize();
|
||||
if(split_k_offset_b_hack_)
|
||||
split_k_stride_b_ /= k_batch_;
|
||||
}
|
||||
|
||||
std::size_t GetWorkspaceATensorSizeBytes() const
|
||||
|
||||
@@ -65,7 +65,8 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
const long_index_t split_k_stride_a,
|
||||
const long_index_t split_k_stride_b,
|
||||
bool split_k_offset_a_hack,
|
||||
bool split_k_offset_b_hack)
|
||||
bool split_k_offset_b_hack,
|
||||
index_t k_batch)
|
||||
{
|
||||
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx11__) || \
|
||||
defined(__gfx12__)
|
||||
@@ -96,7 +97,8 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
split_k_stride_a,
|
||||
split_k_stride_b,
|
||||
split_k_offset_a_hack,
|
||||
split_k_offset_b_hack);
|
||||
split_k_offset_b_hack,
|
||||
k_batch);
|
||||
}
|
||||
#else
|
||||
ignore = p_a_grid;
|
||||
@@ -115,6 +117,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
ignore = split_k_stride_b;
|
||||
ignore = split_k_offset_a_hack;
|
||||
ignore = split_k_offset_b_hack;
|
||||
ignore = k_batch;
|
||||
|
||||
compute_ptr_offset_of_batch.GetAPtrOffset(0);
|
||||
compute_ptr_offset_of_batch.GetBPtrOffset(0);
|
||||
@@ -587,6 +590,19 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
k_batch_ = split_k;
|
||||
}
|
||||
|
||||
const index_t output_spatial_acum = ck::accumulate_n<index_t>(
|
||||
output_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>());
|
||||
const bool is_k_not_paded =
|
||||
(Conv_N_ * output_spatial_acum) % (K0PerBlock * K1 * k_batch_) == 0;
|
||||
// Check if there is KPading and we can divide N * OutSpatialDims by k_batch
|
||||
split_k_offset_a_hack_ =
|
||||
k_batch_ > 1 && (Conv_N_ * output_spatial_acum) % k_batch_ == 0 && is_k_not_paded &&
|
||||
is_NSpatialGC_GKSpatial_NSpatialGK<InLayout, WeiLayout, OutLayout>();
|
||||
// Check if there is KPading and we can divide N by k_batch
|
||||
split_k_offset_b_hack_ =
|
||||
k_batch_ > 1 && Conv_N_ % k_batch_ == 0 && is_k_not_paded &&
|
||||
is_NSpatialGC_GKSpatial_NSpatialGK<InLayout, WeiLayout, OutLayout>();
|
||||
|
||||
const auto descs =
|
||||
conv_to_gemm_transformer
|
||||
.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
|
||||
@@ -603,12 +619,25 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
k_batch_);
|
||||
k_batch_,
|
||||
split_k_offset_a_hack_,
|
||||
split_k_offset_b_hack_);
|
||||
|
||||
a_grid_desc_kbatch_k0_m_k1_ = descs[I0];
|
||||
b_grid_desc_kbatch_k0_n_k1_ = descs[I1];
|
||||
c_grid_desc_m_n_ = descs[I2];
|
||||
|
||||
// Calculate stride from descriptor size
|
||||
// NOTE: GetElementSpaceSize() returns the full size even when KBatchIndex=1,
|
||||
// so we need to divide by k_batch_ to get the per-batch stride when the hack is enabled
|
||||
split_k_stride_a_ = a_grid_desc_kbatch_k0_m_k1_.GetElementSpaceSize();
|
||||
if(split_k_offset_a_hack_)
|
||||
split_k_stride_a_ /= k_batch_;
|
||||
|
||||
split_k_stride_b_ = b_grid_desc_kbatch_k0_n_k1_.GetElementSpaceSize();
|
||||
if(split_k_offset_b_hack_)
|
||||
split_k_stride_b_ /= k_batch_;
|
||||
|
||||
block_2_ctile_map_ =
|
||||
GridwiseGemm64::MakeCBlockClusterAdaptor(c_grid_desc_m_n_, M01, N01, k_batch_);
|
||||
|
||||
@@ -650,23 +679,6 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
elementwise_block_2_ctile_map_transpose_e_ = Block2TileMapTranspose{
|
||||
e_in_transpose_desc_.GetLength(I0), e_in_transpose_desc_.GetLength(I1)};
|
||||
}
|
||||
|
||||
const index_t output_spatial_acum = ck::accumulate_n<index_t>(
|
||||
output_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>());
|
||||
const bool is_k_not_paded =
|
||||
(Conv_N_ * output_spatial_acum) % (K0PerBlock * K1 * k_batch_) == 0;
|
||||
// Check if there is KPading and we can divide N * OutSpatialDims by k_batch
|
||||
split_k_offset_a_hack_ =
|
||||
(Conv_N_ * output_spatial_acum) % k_batch_ == 0 && is_k_not_paded &&
|
||||
is_NSpatialGC_GKSpatial_NSpatialGK<InLayout, WeiLayout, OutLayout>();
|
||||
// Check if there is KPading and we can divide N by k_batch
|
||||
split_k_offset_b_hack_ =
|
||||
Conv_N_ % k_batch_ == 0 && is_k_not_paded &&
|
||||
is_NSpatialGC_GKSpatial_NSpatialGK<InLayout, WeiLayout, OutLayout>();
|
||||
|
||||
split_k_stride_a_ =
|
||||
a_g_n_k_wos_strides[NDimSpatial + I2] * (Conv_N_ * output_spatial_acum) / k_batch_;
|
||||
split_k_stride_b_ = b_g_n_c_wis_strides[I1] * Conv_N_ / k_batch_;
|
||||
}
|
||||
|
||||
std::size_t GetWorkspaceATensorSizeBytes() const
|
||||
@@ -913,7 +925,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
arg.split_k_stride_a_,
|
||||
arg.split_k_stride_b_,
|
||||
arg.split_k_offset_a_hack_,
|
||||
arg.split_k_offset_b_hack_);
|
||||
arg.split_k_offset_b_hack_,
|
||||
arg.k_batch_);
|
||||
};
|
||||
|
||||
if(has_main_k0_block_loop)
|
||||
|
||||
@@ -514,8 +514,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
|
||||
: p_a_grid_{p_out_grid},
|
||||
p_b_grid_{p_in_grid},
|
||||
p_c_grid_{p_wei_grid},
|
||||
a_grid_desc_kbatch_k0_m_k1_{},
|
||||
b_grid_desc_kbatch_k0_n_k1_{},
|
||||
a_grid_desc_k0_m_k1_{},
|
||||
b_grid_desc_k0_n_k1_{},
|
||||
c_grid_desc_m_n_{},
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock_{},
|
||||
compute_ptr_offset_of_batch_{},
|
||||
@@ -584,6 +584,19 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
|
||||
k_batch_ = split_k;
|
||||
}
|
||||
|
||||
const index_t output_spatial_acum = ck::accumulate_n<index_t>(
|
||||
output_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>());
|
||||
const bool is_k_not_paded =
|
||||
(Conv_N_ * output_spatial_acum) % (K0PerBlock * k_batch_) == 0;
|
||||
// Check if there is KPading and we can divide N * OutSpatialDims by k_batch
|
||||
split_k_offset_a_hack_ =
|
||||
k_batch_ > 1 && (Conv_N_ * output_spatial_acum) % k_batch_ == 0 && is_k_not_paded &&
|
||||
is_NSpatialGC_GKSpatial_NSpatialGK<InLayout, WeiLayout, OutLayout>();
|
||||
// Check if there is KPading and we can divide N by k_batch
|
||||
split_k_offset_b_hack_ =
|
||||
k_batch_ > 1 && Conv_N_ % k_batch_ == 0 && is_k_not_paded &&
|
||||
is_NSpatialGC_GKSpatial_NSpatialGK<InLayout, WeiLayout, OutLayout>();
|
||||
|
||||
const auto descs =
|
||||
conv_to_gemm_transformer
|
||||
.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
|
||||
@@ -600,11 +613,24 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
k_batch_);
|
||||
k_batch_,
|
||||
split_k_offset_a_hack_,
|
||||
split_k_offset_b_hack_);
|
||||
|
||||
a_grid_desc_kbatch_k0_m_k1_ = descs[I0];
|
||||
b_grid_desc_kbatch_k0_n_k1_ = descs[I1];
|
||||
c_grid_desc_m_n_ = descs[I2];
|
||||
a_grid_desc_k0_m_k1_ = descs[I0];
|
||||
b_grid_desc_k0_n_k1_ = descs[I1];
|
||||
c_grid_desc_m_n_ = descs[I2];
|
||||
|
||||
// Calculate stride from descriptor size
|
||||
// NOTE: GetElementSpaceSize() returns the full size even when KBatchIndex=1,
|
||||
// so we need to divide by k_batch_ to get the per-batch stride when the hack is enabled
|
||||
split_k_stride_a_ = a_grid_desc_k0_m_k1_.GetElementSpaceSize();
|
||||
if(split_k_offset_a_hack_)
|
||||
split_k_stride_a_ /= k_batch_;
|
||||
|
||||
split_k_stride_b_ = b_grid_desc_k0_n_k1_.GetElementSpaceSize();
|
||||
if(split_k_offset_b_hack_)
|
||||
split_k_stride_b_ /= k_batch_;
|
||||
|
||||
// A/B/C Batch Stride
|
||||
compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_k_wos_strides[0];
|
||||
@@ -615,38 +641,21 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
|
||||
end(filter_spatial_lengths_),
|
||||
index_t{1},
|
||||
std::multiplies<>{});
|
||||
const index_t GemmM = a_grid_desc_kbatch_k0_m_k1_.GetLength(I1);
|
||||
const index_t GemmN = b_grid_desc_kbatch_k0_n_k1_.GetLength(I1);
|
||||
const index_t GemmM = a_grid_desc_k0_m_k1_.GetLength(I1);
|
||||
const index_t GemmN = b_grid_desc_k0_n_k1_.GetLength(I1);
|
||||
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock_ =
|
||||
GridwiseGemm64::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
c_grid_desc_m_n_,
|
||||
GridwiseGemm64::CalculateMBlock(GemmM),
|
||||
GridwiseGemm64::CalculateNBlock(GemmN));
|
||||
|
||||
const index_t output_spatial_acum = ck::accumulate_n<index_t>(
|
||||
output_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>());
|
||||
const bool is_k_not_paded =
|
||||
(Conv_N_ * output_spatial_acum) % (K0PerBlock * k_batch_) == 0;
|
||||
// Check if there is KPading and we can divide N * OutSpatialDims by k_batch
|
||||
split_k_offset_a_hack_ =
|
||||
(Conv_N_ * output_spatial_acum) % k_batch_ == 0 && is_k_not_paded &&
|
||||
is_NSpatialGC_GKSpatial_NSpatialGK<InLayout, WeiLayout, OutLayout>();
|
||||
// Check if there is KPading and we can divide N by k_batch
|
||||
split_k_offset_b_hack_ =
|
||||
Conv_N_ % k_batch_ == 0 && is_k_not_paded &&
|
||||
is_NSpatialGC_GKSpatial_NSpatialGK<InLayout, WeiLayout, OutLayout>();
|
||||
|
||||
split_k_stride_a_ =
|
||||
a_g_n_k_wos_strides[NDimSpatial + I2] * (Conv_N_ * output_spatial_acum) / k_batch_;
|
||||
split_k_stride_b_ = b_g_n_c_wis_strides[I1] * Conv_N_ / k_batch_;
|
||||
}
|
||||
|
||||
const ADataType* p_a_grid_;
|
||||
const BDataType* p_b_grid_;
|
||||
CDataType* p_c_grid_;
|
||||
AGridDesc_K0_M_K1 a_grid_desc_kbatch_k0_m_k1_;
|
||||
BGridDesc_K0_N_K1 b_grid_desc_kbatch_k0_n_k1_;
|
||||
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_;
|
||||
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_;
|
||||
CGridDesc_M_N c_grid_desc_m_n_;
|
||||
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock_;
|
||||
|
||||
@@ -685,16 +694,16 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
|
||||
void ShowInfo(const Argument& arg)
|
||||
{
|
||||
std::cout << "arg.a_grid_desc_kbatch_k0_m_k1_{"
|
||||
<< arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) << ", "
|
||||
<< arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1) << ", "
|
||||
<< arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2) << ", "
|
||||
<< arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I3) << "}" << std::endl;
|
||||
<< arg.a_grid_desc_k0_m_k1_.GetLength(I0) << ", "
|
||||
<< arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
|
||||
<< arg.a_grid_desc_k0_m_k1_.GetLength(I2) << ", "
|
||||
<< arg.a_grid_desc_k0_m_k1_.GetLength(I3) << "}" << std::endl;
|
||||
|
||||
std::cout << "arg.b_grid_desc_kbatch_k0_n_k1_{"
|
||||
<< arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I0) << ", "
|
||||
<< arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1) << ", "
|
||||
<< arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I2) << ", "
|
||||
<< arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I3) << "}" << std::endl;
|
||||
<< arg.b_grid_desc_k0_n_k1_.GetLength(I0) << ", "
|
||||
<< arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", "
|
||||
<< arg.b_grid_desc_k0_n_k1_.GetLength(I2) << ", "
|
||||
<< arg.b_grid_desc_k0_n_k1_.GetLength(I3) << "}" << std::endl;
|
||||
|
||||
std::cout << "arg.c_grid_desc_m_n_{" << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
|
||||
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
|
||||
@@ -703,10 +712,10 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
|
||||
template <typename GridwiseGemm>
|
||||
float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
const index_t GemmM = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1);
|
||||
const index_t GemmN = arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1);
|
||||
const index_t GemmK = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) *
|
||||
arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2);
|
||||
const index_t GemmM = arg.a_grid_desc_k0_m_k1_.GetLength(I1);
|
||||
const index_t GemmN = arg.b_grid_desc_k0_n_k1_.GetLength(I1);
|
||||
const index_t GemmK =
|
||||
arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2);
|
||||
|
||||
const ADataType* p_a_grid = arg.p_a_grid_;
|
||||
const BDataType* p_b_grid = arg.p_b_grid_;
|
||||
@@ -724,7 +733,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
|
||||
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
|
||||
|
||||
const auto num_k_per_block =
|
||||
arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(Number<0>{}) / gemm_arg.KBatch;
|
||||
arg.a_grid_desc_k0_m_k1_.GetLength(Number<0>{}) / gemm_arg.KBatch;
|
||||
|
||||
const auto clear_workspace = [&]() {
|
||||
if(arg.k_batch_ > 1)
|
||||
@@ -760,8 +769,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
gemm_arg_,
|
||||
arg.a_grid_desc_kbatch_k0_m_k1_,
|
||||
arg.b_grid_desc_kbatch_k0_n_k1_,
|
||||
arg.a_grid_desc_k0_m_k1_,
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.compute_ptr_offset_of_batch_,
|
||||
num_k_per_block,
|
||||
@@ -780,8 +789,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
gemm_arg,
|
||||
arg.a_grid_desc_kbatch_k0_m_k1_,
|
||||
arg.b_grid_desc_kbatch_k0_n_k1_,
|
||||
arg.a_grid_desc_k0_m_k1_,
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.compute_ptr_offset_of_batch_,
|
||||
num_k_per_block,
|
||||
@@ -1341,10 +1350,10 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
|
||||
}
|
||||
#endif
|
||||
|
||||
const index_t GemmM = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1);
|
||||
const index_t GemmN = arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1);
|
||||
const index_t GemmK = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) *
|
||||
arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2);
|
||||
const index_t GemmM = arg.a_grid_desc_k0_m_k1_.GetLength(I1);
|
||||
const index_t GemmN = arg.b_grid_desc_k0_n_k1_.GetLength(I1);
|
||||
const index_t GemmK =
|
||||
arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2);
|
||||
|
||||
if constexpr(is_same_v<ComputeTypeA, ck::tf32_t> || is_same_v<ComputeTypeB, ck::tf32_t>)
|
||||
{
|
||||
@@ -1475,8 +1484,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
|
||||
}
|
||||
|
||||
constexpr long_index_t TwoGB = (long_index_t{1} << 31);
|
||||
if(!(arg.a_grid_desc_kbatch_k0_m_k1_.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB &&
|
||||
arg.b_grid_desc_kbatch_k0_n_k1_.GetElementSpaceSize() * sizeof(BDataType) <= TwoGB &&
|
||||
if(!(arg.a_grid_desc_k0_m_k1_.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB &&
|
||||
arg.b_grid_desc_k0_n_k1_.GetElementSpaceSize() * sizeof(BDataType) <= TwoGB &&
|
||||
arg.c_grid_desc_m_n_.GetElementSpaceSize() * sizeof(CDataType) <= TwoGB))
|
||||
{
|
||||
return false;
|
||||
|
||||
@@ -164,7 +164,12 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
const AElementwiseOperation a_element_op,
|
||||
const BElementwiseOperation b_element_op,
|
||||
const CElementwiseOperation c_element_op,
|
||||
const CBlockClusterAdaptor c_block_cluster_adaptor)
|
||||
const CBlockClusterAdaptor c_block_cluster_adaptor,
|
||||
const long_index_t split_k_stride_a,
|
||||
const long_index_t split_k_stride_b,
|
||||
bool split_k_offset_a_hack,
|
||||
bool split_k_offset_b_hack,
|
||||
index_t k_batch)
|
||||
{
|
||||
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx11__) || \
|
||||
defined(__gfx12__)
|
||||
@@ -182,7 +187,12 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
c_block_cluster_adaptor);
|
||||
c_block_cluster_adaptor,
|
||||
split_k_stride_a,
|
||||
split_k_stride_b,
|
||||
split_k_offset_a_hack,
|
||||
split_k_offset_b_hack,
|
||||
k_batch);
|
||||
}
|
||||
#else
|
||||
ignore = p_a_grid;
|
||||
@@ -195,6 +205,11 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
ignore = b_element_op;
|
||||
ignore = c_element_op;
|
||||
ignore = c_block_cluster_adaptor;
|
||||
ignore = split_k_stride_a;
|
||||
ignore = split_k_stride_b;
|
||||
ignore = split_k_offset_a_hack;
|
||||
ignore = split_k_offset_b_hack;
|
||||
ignore = k_batch;
|
||||
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
|
||||
}
|
||||
|
||||
@@ -662,7 +677,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
|
||||
const long_index_t split_k_stride_a,
|
||||
const long_index_t split_k_stride_b,
|
||||
bool split_k_offset_a_hack,
|
||||
bool split_k_offset_b_hack)
|
||||
bool split_k_offset_b_hack,
|
||||
index_t k_batch)
|
||||
{
|
||||
const auto K0 = a_b_k0_m_k1_grid_desc.GetLength(I1);
|
||||
|
||||
@@ -677,10 +693,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
|
||||
const long_index_t split_k_offset_b =
|
||||
split_k_offset_b_hack ? k_batch_id * split_k_stride_b : 0;
|
||||
|
||||
const long_index_t a_space_size_divisor =
|
||||
split_k_offset_a_hack ? a_b_k0_m_k1_grid_desc.GetLength(I0) : 1;
|
||||
const long_index_t b_space_size_divisor =
|
||||
split_k_offset_b_hack ? a_b_k0_m_k1_grid_desc.GetLength(I0) : 1;
|
||||
const long_index_t a_space_size_divisor = split_k_offset_a_hack ? k_batch : 1;
|
||||
const long_index_t b_space_size_divisor = split_k_offset_b_hack ? k_batch : 1;
|
||||
|
||||
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_a_grid + split_k_offset_a,
|
||||
|
||||
@@ -150,7 +150,9 @@ struct TransformConvBwdWeightToGemm
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<index_t, NDimSpatial>& input_right_pads,
|
||||
const index_t batch_k)
|
||||
const index_t batch_k,
|
||||
const bool split_k_offset_a_hack = false,
|
||||
const bool split_k_offset_b_hack = false)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
@@ -173,7 +175,9 @@ struct TransformConvBwdWeightToGemm
|
||||
const index_t GemmK0 =
|
||||
math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) *
|
||||
K0PerBlock;
|
||||
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number;
|
||||
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number;
|
||||
const index_t KBatchIndexA = split_k_offset_a_hack ? 1 : GemmKBatch;
|
||||
const index_t KBatchIndexB = split_k_offset_b_hack ? 1 : GemmKBatch;
|
||||
|
||||
if constexpr(ConvBackwardWeightSpecialization ==
|
||||
device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0)
|
||||
@@ -191,7 +195,7 @@ struct TransformConvBwdWeightToGemm
|
||||
|
||||
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
out_gemmkpad_gemmm_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(KBatchIndexA, GemmK0, GemmK1Number)),
|
||||
make_right_pad_transform(GemmM, PadGemmM)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
|
||||
@@ -209,7 +213,7 @@ struct TransformConvBwdWeightToGemm
|
||||
|
||||
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmkpad_gemmn_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(KBatchIndexB, GemmK0, GemmK1Number)),
|
||||
make_right_pad_transform(GemmN, PadGemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
|
||||
@@ -247,7 +251,7 @@ struct TransformConvBwdWeightToGemm
|
||||
|
||||
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
out_gemmkpad_gemmm_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(KBatchIndexA, GemmK0, GemmK1Number)),
|
||||
make_right_pad_transform(GemmM, PadGemmM)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
|
||||
@@ -286,7 +290,7 @@ struct TransformConvBwdWeightToGemm
|
||||
|
||||
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmkpad_gemmn_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(KBatchIndexB, GemmK0, GemmK1Number)),
|
||||
make_right_pad_transform(GemmN, PadGemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
|
||||
@@ -324,7 +328,9 @@ struct TransformConvBwdWeightToGemm
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<index_t, NDimSpatial>& input_right_pads,
|
||||
const index_t batch_k)
|
||||
const index_t batch_k,
|
||||
const bool split_k_offset_a_hack = false,
|
||||
const bool split_k_offset_b_hack = false)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
@@ -360,7 +366,9 @@ struct TransformConvBwdWeightToGemm
|
||||
const index_t GemmK0 =
|
||||
math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) *
|
||||
K0PerBlock;
|
||||
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number;
|
||||
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number;
|
||||
const index_t KBatchIndexA = split_k_offset_a_hack ? 1 : GemmKBatch;
|
||||
const index_t KBatchIndexB = split_k_offset_b_hack ? 1 : GemmKBatch;
|
||||
|
||||
const auto out_grid_desc = make_out_grid_desc<NDim>(N, Ho, Wo, K, output_strides);
|
||||
const auto in_grid_desc = make_in_grid_desc<NDim>(N, Hi, Wi, C, input_strides);
|
||||
@@ -379,7 +387,7 @@ struct TransformConvBwdWeightToGemm
|
||||
|
||||
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
out_gemmkpad_gemmm_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(KBatchIndexA, GemmK0, GemmK1Number)),
|
||||
make_right_pad_transform(GemmM, PadGemmM)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
|
||||
@@ -394,7 +402,7 @@ struct TransformConvBwdWeightToGemm
|
||||
|
||||
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmkpad_gemmn_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(KBatchIndexB, GemmK0, GemmK1Number)),
|
||||
make_right_pad_transform(GemmN, PadGemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
|
||||
@@ -423,7 +431,7 @@ struct TransformConvBwdWeightToGemm
|
||||
|
||||
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
out_gemmkpad_gemmm_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(KBatchIndexA, GemmK0, GemmK1Number)),
|
||||
make_right_pad_transform(GemmM, PadGemmM)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
|
||||
@@ -464,7 +472,7 @@ struct TransformConvBwdWeightToGemm
|
||||
|
||||
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmkpad_gemmn_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(KBatchIndexB, GemmK0, GemmK1Number)),
|
||||
make_right_pad_transform(GemmN, PadGemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
|
||||
@@ -498,7 +506,9 @@ struct TransformConvBwdWeightToGemm
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<index_t, NDimSpatial>& input_right_pads,
|
||||
const index_t batch_k)
|
||||
const index_t batch_k,
|
||||
const bool split_k_offset_a_hack = false,
|
||||
const bool split_k_offset_b_hack = false)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
@@ -541,7 +551,9 @@ struct TransformConvBwdWeightToGemm
|
||||
const index_t GemmK0 =
|
||||
math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) *
|
||||
K0PerBlock;
|
||||
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number;
|
||||
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number;
|
||||
const index_t KBatchIndexA = split_k_offset_a_hack ? 1 : GemmKBatch;
|
||||
const index_t KBatchIndexB = split_k_offset_b_hack ? 1 : GemmKBatch;
|
||||
|
||||
const auto out_grid_desc = make_out_grid_desc<NDim>(N, Do, Ho, Wo, K, output_strides);
|
||||
const auto in_grid_desc = make_in_grid_desc<NDim>(N, Di, Hi, Wi, C, input_strides);
|
||||
@@ -560,7 +572,7 @@ struct TransformConvBwdWeightToGemm
|
||||
|
||||
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
out_gemmkpad_gemmm_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(KBatchIndexA, GemmK0, GemmK1Number)),
|
||||
make_right_pad_transform(GemmM, PadGemmM)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
|
||||
@@ -575,7 +587,7 @@ struct TransformConvBwdWeightToGemm
|
||||
|
||||
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmkpad_gemmn_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(KBatchIndexB, GemmK0, GemmK1Number)),
|
||||
make_right_pad_transform(GemmN, PadGemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
|
||||
@@ -604,7 +616,7 @@ struct TransformConvBwdWeightToGemm
|
||||
|
||||
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
out_gemmkpad_gemmm_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(KBatchIndexA, GemmK0, GemmK1Number)),
|
||||
make_right_pad_transform(GemmM, PadGemmM)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
|
||||
@@ -654,7 +666,7 @@ struct TransformConvBwdWeightToGemm
|
||||
|
||||
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmkpad_gemmn_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(KBatchIndexB, GemmK0, GemmK1Number)),
|
||||
make_right_pad_transform(GemmN, PadGemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
|
||||
|
||||
@@ -324,7 +324,9 @@ struct TransformConvBwdWeightToGemmV2
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<index_t, NDimSpatial>& input_right_pads,
|
||||
const index_t batch_k)
|
||||
const index_t batch_k,
|
||||
const bool split_k_offset_a_hack = false,
|
||||
const bool split_k_offset_b_hack = false)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
@@ -353,7 +355,9 @@ struct TransformConvBwdWeightToGemmV2
|
||||
const index_t GemmK0 =
|
||||
math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) *
|
||||
K0PerBlock;
|
||||
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number;
|
||||
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number;
|
||||
const index_t KBatchIndexA = split_k_offset_a_hack ? 1 : GemmKBatch;
|
||||
const index_t KBatchIndexB = split_k_offset_b_hack ? 1 : GemmKBatch;
|
||||
|
||||
const auto out_grid_desc = make_out_grid_desc<NDim>(N, Wo, K, output_strides);
|
||||
const auto in_grid_desc = make_in_grid_desc<NDim>(N, Wi, C, input_strides);
|
||||
@@ -373,7 +377,7 @@ struct TransformConvBwdWeightToGemmV2
|
||||
|
||||
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
out_gemmkpad_gemmm_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(KBatchIndexA * GemmK0, GemmK1Number)),
|
||||
make_right_pad_transform(GemmM, PadGemmM)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
@@ -389,7 +393,7 @@ struct TransformConvBwdWeightToGemmV2
|
||||
|
||||
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmkpad_gemmn_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(KBatchIndexB * GemmK0, GemmK1Number)),
|
||||
make_right_pad_transform(GemmN, PadGemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
@@ -419,7 +423,7 @@ struct TransformConvBwdWeightToGemmV2
|
||||
|
||||
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
out_gemmkpad_gemmm_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(KBatchIndexA * GemmK0, GemmK1Number)),
|
||||
make_right_pad_transform(GemmM, PadGemmM)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
@@ -460,7 +464,7 @@ struct TransformConvBwdWeightToGemmV2
|
||||
|
||||
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmkpad_gemmn_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(KBatchIndexB * GemmK0, GemmK1Number)),
|
||||
make_right_pad_transform(GemmN, PadGemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
@@ -495,7 +499,9 @@ struct TransformConvBwdWeightToGemmV2
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<index_t, NDimSpatial>& input_right_pads,
|
||||
const index_t batch_k)
|
||||
const index_t batch_k,
|
||||
const bool split_k_offset_a_hack = false,
|
||||
const bool split_k_offset_b_hack = false)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
@@ -531,7 +537,9 @@ struct TransformConvBwdWeightToGemmV2
|
||||
const index_t GemmK0 =
|
||||
math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) *
|
||||
K0PerBlock;
|
||||
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number;
|
||||
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number;
|
||||
const index_t KBatchIndexA = split_k_offset_a_hack ? 1 : GemmKBatch;
|
||||
const index_t KBatchIndexB = split_k_offset_b_hack ? 1 : GemmKBatch;
|
||||
|
||||
const auto out_grid_desc = make_out_grid_desc<NDim>(N, Ho, Wo, K, output_strides);
|
||||
const auto in_grid_desc = make_in_grid_desc<NDim>(N, Hi, Wi, C, input_strides);
|
||||
@@ -551,7 +559,7 @@ struct TransformConvBwdWeightToGemmV2
|
||||
|
||||
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
out_gemmkpad_gemmm_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(KBatchIndexA * GemmK0, GemmK1Number)),
|
||||
make_right_pad_transform(GemmM, PadGemmM)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
@@ -567,7 +575,7 @@ struct TransformConvBwdWeightToGemmV2
|
||||
|
||||
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmkpad_gemmn_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(KBatchIndexB * GemmK0, GemmK1Number)),
|
||||
make_right_pad_transform(GemmN, PadGemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
@@ -597,7 +605,7 @@ struct TransformConvBwdWeightToGemmV2
|
||||
|
||||
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
out_gemmkpad_gemmm_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(KBatchIndexA * GemmK0, GemmK1Number)),
|
||||
make_right_pad_transform(GemmM, PadGemmM)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
@@ -647,7 +655,7 @@ struct TransformConvBwdWeightToGemmV2
|
||||
|
||||
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmkpad_gemmn_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(KBatchIndexB * GemmK0, GemmK1Number)),
|
||||
make_right_pad_transform(GemmN, PadGemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
@@ -681,7 +689,9 @@ struct TransformConvBwdWeightToGemmV2
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<index_t, NDimSpatial>& input_right_pads,
|
||||
const index_t batch_k)
|
||||
const index_t batch_k,
|
||||
const bool split_k_offset_a_hack = false,
|
||||
const bool split_k_offset_b_hack = false)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
@@ -724,7 +734,9 @@ struct TransformConvBwdWeightToGemmV2
|
||||
const index_t GemmK0 =
|
||||
math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) *
|
||||
K0PerBlock;
|
||||
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number;
|
||||
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number;
|
||||
const index_t KBatchIndexA = split_k_offset_a_hack ? 1 : GemmKBatch;
|
||||
const index_t KBatchIndexB = split_k_offset_b_hack ? 1 : GemmKBatch;
|
||||
|
||||
const auto out_grid_desc = make_out_grid_desc<NDim>(N, Do, Ho, Wo, K, output_strides);
|
||||
const auto in_grid_desc = make_in_grid_desc<NDim>(N, Di, Hi, Wi, C, input_strides);
|
||||
@@ -744,7 +756,7 @@ struct TransformConvBwdWeightToGemmV2
|
||||
|
||||
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
out_gemmkpad_gemmm_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(KBatchIndexA * GemmK0, GemmK1Number)),
|
||||
make_right_pad_transform(GemmM, PadGemmM)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
@@ -760,7 +772,7 @@ struct TransformConvBwdWeightToGemmV2
|
||||
|
||||
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmkpad_gemmn_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(KBatchIndexB * GemmK0, GemmK1Number)),
|
||||
make_right_pad_transform(GemmN, PadGemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
@@ -790,7 +802,7 @@ struct TransformConvBwdWeightToGemmV2
|
||||
|
||||
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
out_gemmkpad_gemmm_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(KBatchIndexA * GemmK0, GemmK1Number)),
|
||||
make_right_pad_transform(GemmM, PadGemmM)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
@@ -855,7 +867,7 @@ struct TransformConvBwdWeightToGemmV2
|
||||
|
||||
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmkpad_gemmn_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(KBatchIndexB * GemmK0, GemmK1Number)),
|
||||
make_right_pad_transform(GemmN, PadGemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
Reference in New Issue
Block a user