add g; fixed strides (#355)

[ROCm/composable_kernel commit: 35e49f2de6]
This commit is contained in:
zjing14
2022-08-12 15:22:39 -05:00
committed by GitHub
parent 2bcb8c6fbc
commit 56de9aaa27
4 changed files with 275 additions and 246 deletions

View File

@@ -500,11 +500,8 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle
std::array<long_index_t, NumDTensor> ds_offset;
static_for<0, NumDTensor, 1>{}([&](auto i) {
if constexpr(NumDimG > 0)
ds_offset[i] =
ds_grid_desc_g_m_n_[i].CalculateOffset(make_multi_index(g_idx, 0, 0));
else
ds_offset[i] = 0;
ds_offset[i] =
ds_grid_desc_g_m_n_[i].CalculateOffset(make_multi_index(g_idx, 0, 0));
});
return ds_offset;
@@ -512,10 +509,7 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle
__host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const
{
if constexpr(NumDimG > 0)
return e_grid_desc_g_m_n_.CalculateOffset(make_multi_index(g_idx, 0, 0));
else
return 0;
return e_grid_desc_g_m_n_.CalculateOffset(make_multi_index(g_idx, 0, 0));
}
private:
@@ -634,6 +628,8 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle
compute_ptr_offset_of_batch_{
a_batch_stride_, b_batch_stride_, ds_grid_desc_g_m_n_, e_grid_desc_g_m_n_}
{
static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0, "");
// populate pointer, batch stride, desc for Ds
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;