fixed G offset calc for long_index (#428)

This commit is contained in:
zjing14
2022-09-21 10:15:43 -05:00
committed by GitHub
parent 567f70f552
commit 01876afafe

View File

@@ -506,12 +506,12 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle
__host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(batch_stride_A_);
return static_cast<long_index_t>(g_idx) * batch_stride_A_;
}
__host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(batch_stride_B_);
return static_cast<long_index_t>(g_idx) * batch_stride_B_;
}
__host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const
@@ -519,8 +519,8 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle
std::array<long_index_t, NumDTensor> ds_offset;
static_for<0, NumDTensor, 1>{}([&](auto i) {
ds_offset[i] =
ds_grid_desc_g_m_n_[i].CalculateOffset(make_multi_index(g_idx, 0, 0));
ds_offset[i] = static_cast<long_index_t>(g_idx) *
ds_grid_desc_g_m_n_[i].CalculateOffset(make_multi_index(1, 0, 0));
});
return ds_offset;
@@ -528,7 +528,8 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle
__host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const
{
return e_grid_desc_g_m_n_.CalculateOffset(make_multi_index(g_idx, 0, 0));
return static_cast<long_index_t>(g_idx) *
e_grid_desc_g_m_n_.CalculateOffset(make_multi_index(1, 0, 0));
}
private: