mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 05:31:24 +00:00
fixed G offset calc for long_index (#428)
This commit is contained in:
@@ -506,12 +506,12 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(batch_stride_A_);
|
||||
return static_cast<long_index_t>(g_idx) * batch_stride_A_;
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(batch_stride_B_);
|
||||
return static_cast<long_index_t>(g_idx) * batch_stride_B_;
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const
|
||||
@@ -519,8 +519,8 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle
|
||||
std::array<long_index_t, NumDTensor> ds_offset;
|
||||
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
ds_offset[i] =
|
||||
ds_grid_desc_g_m_n_[i].CalculateOffset(make_multi_index(g_idx, 0, 0));
|
||||
ds_offset[i] = static_cast<long_index_t>(g_idx) *
|
||||
ds_grid_desc_g_m_n_[i].CalculateOffset(make_multi_index(1, 0, 0));
|
||||
});
|
||||
|
||||
return ds_offset;
|
||||
@@ -528,7 +528,8 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const
|
||||
{
|
||||
return e_grid_desc_g_m_n_.CalculateOffset(make_multi_index(g_idx, 0, 0));
|
||||
return static_cast<long_index_t>(g_idx) *
|
||||
e_grid_desc_g_m_n_.CalculateOffset(make_multi_index(1, 0, 0));
|
||||
}
|
||||
|
||||
private:
|
||||
|
||||
Reference in New Issue
Block a user