mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-20 12:59:49 +00:00
[CK][CK Tile] Improve access for merged groups and remove modulo from xor (#5334)
## Motivation [CK][CK Tile] Improve access for merged groups and remove modulo from xor ## Technical Details - add template parameter to xor if modulo is needed. We don't need modulo for merged groups - use access by m for merged groups for a tensor - ## Test Plan test_grouped_convnd_fwd_tile ## Test Result passed locally ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
@@ -1298,7 +1298,7 @@ CK_TILE_HOST_DEVICE static void print(const modulo<Modulus, UpLength>& m)
|
||||
}
|
||||
|
||||
// 2D XOR, NOTE: "xor" is a keyword
|
||||
template <typename LowLengths>
|
||||
template <typename LowLengths, bool ApplyModulo = true>
|
||||
struct xor_t : public base_transform<2, 2>
|
||||
{
|
||||
static constexpr auto type_enum = coord_transform_enum::xor_t;
|
||||
@@ -1330,8 +1330,15 @@ struct xor_t : public base_transform<2, 2>
|
||||
|
||||
idx_low(number<0>{}) = idx_up[number<0>{}];
|
||||
|
||||
idx_low(number<1>{}) =
|
||||
idx_up[number<1>{}] ^ (idx_up[number<0>{}] % up_lengths_[number<1>{}]);
|
||||
if constexpr(ApplyModulo)
|
||||
{
|
||||
idx_low(number<1>{}) =
|
||||
idx_up[number<1>{}] ^ (idx_up[number<0>{}] % up_lengths_[number<1>{}]);
|
||||
}
|
||||
else
|
||||
{
|
||||
idx_low(number<1>{}) = idx_up[number<1>{}] ^ (idx_up[number<0>{}]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
|
||||
@@ -1382,8 +1389,8 @@ struct xor_t : public base_transform<2, 2>
|
||||
}
|
||||
};
|
||||
|
||||
template <typename LowLengths>
|
||||
CK_TILE_HOST_DEVICE static void print(const xor_t<LowLengths>& x)
|
||||
template <typename LowLengths, bool ApplyModulo = true>
|
||||
CK_TILE_HOST_DEVICE static void print(const xor_t<LowLengths, ApplyModulo>& x)
|
||||
{
|
||||
printf("xor_t{");
|
||||
printf("up_lengths_: ");
|
||||
@@ -1737,10 +1744,10 @@ CK_TILE_HOST_DEVICE constexpr auto make_modulo_transform(const Modulus& modulus,
|
||||
return modulo<Modulus, UpLength>{modulus, up_length};
|
||||
}
|
||||
|
||||
template <typename LowLengths>
|
||||
template <typename LowLengths, bool ApplyModulo = true>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_xor_transform(const LowLengths& low_lengths)
|
||||
{
|
||||
return xor_t<LowLengths>{low_lengths};
|
||||
return xor_t<LowLengths, ApplyModulo>{low_lengths};
|
||||
}
|
||||
|
||||
template <typename LowLength, typename OffsetLength>
|
||||
|
||||
Reference in New Issue
Block a user