mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-20 12:59:49 +00:00
add g; fixed strides (#355)
[ROCm/composable_kernel commit: 35e49f2de6]
This commit is contained in:
@@ -500,11 +500,8 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle
|
||||
std::array<long_index_t, NumDTensor> ds_offset;
|
||||
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
if constexpr(NumDimG > 0)
|
||||
ds_offset[i] =
|
||||
ds_grid_desc_g_m_n_[i].CalculateOffset(make_multi_index(g_idx, 0, 0));
|
||||
else
|
||||
ds_offset[i] = 0;
|
||||
ds_offset[i] =
|
||||
ds_grid_desc_g_m_n_[i].CalculateOffset(make_multi_index(g_idx, 0, 0));
|
||||
});
|
||||
|
||||
return ds_offset;
|
||||
@@ -512,10 +509,7 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const
|
||||
{
|
||||
if constexpr(NumDimG > 0)
|
||||
return e_grid_desc_g_m_n_.CalculateOffset(make_multi_index(g_idx, 0, 0));
|
||||
else
|
||||
return 0;
|
||||
return e_grid_desc_g_m_n_.CalculateOffset(make_multi_index(g_idx, 0, 0));
|
||||
}
|
||||
|
||||
private:
|
||||
@@ -634,6 +628,8 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle
|
||||
compute_ptr_offset_of_batch_{
|
||||
a_batch_stride_, b_batch_stride_, ds_grid_desc_g_m_n_, e_grid_desc_g_m_n_}
|
||||
{
|
||||
static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0, "");
|
||||
|
||||
// populate pointer, batch stride, desc for Ds
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
|
||||
|
||||
Reference in New Issue
Block a user